Move from gocql/gocql to scylladb/gocql

This commit is contained in:
Dmitry Kropachev
2024-06-16 08:40:32 -04:00
committed by Sylwia Szunejko
parent ab80d70106
commit 207ba8723e
6 changed files with 41 additions and 42 deletions

View File

@@ -12,6 +12,7 @@ import (
"os"
"path"
"regexp"
"sort"
"strings"
"github.com/gocql/gocql"
@@ -77,7 +78,6 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) {
New("keyspace.tmpl").
Funcs(template.FuncMap{"camelize": camelize}).
Funcs(template.FuncMap{"mapScyllaToGoType": mapScyllaToGoType}).
Funcs(template.FuncMap{"typeToString": typeToString}).
Parse(keyspaceTmpl)
if err != nil {
log.Fatalln("unable to parse models template:", err)
@@ -99,25 +99,28 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) {
}
orphanedTypes := make(map[string]struct{})
for userTypeName := range md.UserTypes {
for userTypeName := range md.Types {
if !usedInTables(userTypeName, md.Tables) {
orphanedTypes[userTypeName] = struct{}{}
}
}
for typeName := range orphanedTypes {
delete(md.UserTypes, typeName)
delete(md.Types, typeName)
}
imports := make([]string, 0)
for _, t := range md.Tables {
// Ensure ordered columns are sorted alphabetically
sort.Strings(t.OrderedColumns)
for _, c := range t.Columns {
if (c.Validator == "timestamp" || c.Validator == "date" || c.Validator == "duration" || c.Validator == "time") && !existsInSlice(imports, "time") {
if (c.Type == "timestamp" || c.Type == "date" || c.Type == "duration" || c.Type == "time") && !existsInSlice(imports, "time") {
imports = append(imports, "time")
}
if c.Validator == "decimal" && !existsInSlice(imports, "gopkg.in/inf.v0") {
if c.Type == "decimal" && !existsInSlice(imports, "gopkg.in/inf.v0") {
imports = append(imports, "gopkg.in/inf.v0")
}
if c.Validator == "duration" && !existsInSlice(imports, "github.com/gocql/gocql") {
if c.Type == "duration" && !existsInSlice(imports, "github.com/gocql/gocql") {
imports = append(imports, "github.com/gocql/gocql")
}
}
@@ -127,7 +130,7 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) {
data := map[string]interface{}{
"PackageName": *flagPkgname,
"Tables": md.Tables,
"UserTypes": md.UserTypes,
"UserTypes": md.Types,
"Imports": imports,
}
@@ -173,10 +176,10 @@ var userTypes = regexp.MustCompile(`(?:<|\s)(\w+)[>,]`) // match all types conta
func usedInTables(typeName string, tables map[string]*gocql.TableMetadata) bool {
for _, table := range tables {
for _, column := range table.Columns {
if typeName == column.Validator {
if typeName == column.Type {
return true
}
matches := userTypes.FindAllStringSubmatch(column.Validator, -1)
matches := userTypes.FindAllStringSubmatch(column.Type, -1)
for _, s := range matches {
if s[1] == typeName {
return true