Replace Unsafe with Strict mechanism

Previously by default the presence of a missing field in
a udt would result in an error reported. The Unsafe
mechanism could be used to ignore these fields.

This PR changes the default behavior to ignoring missing
fields and only reporting an error if Strict mode is used.
This approach is in line with the gocql.
This commit is contained in:
sylwiaszunejko
2024-06-25 13:34:50 +02:00
committed by Sylwia Szunejko
parent 6a60650668
commit 7072863b0c
6 changed files with 69 additions and 70 deletions

View File

@@ -264,7 +264,7 @@ func TestIterxUDT(t *testing.T) {
})
t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
if err := session.Query(insertStmt, nil).Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
@@ -273,7 +273,7 @@ func TestIterxUDT(t *testing.T) {
// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
@@ -281,7 +281,7 @@ func TestIterxUDT(t *testing.T) {
t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
@@ -289,7 +289,7 @@ func TestIterxUDT(t *testing.T) {
t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
@@ -305,7 +305,7 @@ func TestIterxUDT(t *testing.T) {
t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
@@ -320,7 +320,7 @@ func TestIterxUDT(t *testing.T) {
t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
@@ -334,7 +334,7 @@ func TestIterxUDT(t *testing.T) {
})
t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
@@ -352,7 +352,7 @@ func TestIterxUDT(t *testing.T) {
})
t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
@@ -736,41 +736,41 @@ func TestIterxStructOnlyUDT(t *testing.T) {
})
}
func TestIterxUnsafe(t *testing.T) {
func TestIterxStrict(t *testing.T) {
session := gocqlxtest.CreateSession(t)
defer session.Close()
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.strict_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
t.Fatal("create table:", err)
}
if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil {
if err := session.Query(`INSERT INTO strict_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil {
t.Fatal("insert:", err)
}
type UnsafeTable struct {
type StrictTable struct {
Testtext string
}
m := UnsafeTable{
m := StrictTable{
Testtext: "test",
}
const (
stmt = `SELECT * FROM unsafe_table`
golden = "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable"
stmt = `SELECT * FROM strict_table`
golden = "missing destination name \"testtextunbound\" in gocqlx_test.StrictTable"
)
t.Run("get", func(t *testing.T) {
var v UnsafeTable
err := session.Query(stmt, nil).Get(&v)
t.Run("get strict", func(t *testing.T) {
var v StrictTable
err := session.Query(stmt, nil).Strict().Get(&v)
if err == nil || !strings.HasPrefix(err.Error(), golden) {
t.Fatalf("Get() error=%q expected %s", err, golden)
}
})
t.Run("select", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Select(&v)
t.Run("select strict", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Strict().Select(&v)
if err == nil || !strings.HasPrefix(err.Error(), golden) {
t.Fatalf("Select() error=%q expected %s", err, golden)
}
@@ -779,9 +779,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})
t.Run("get unsafe", func(t *testing.T) {
var v UnsafeTable
err := session.Query(stmt, nil).Iter().Unsafe().Get(&v)
t.Run("get", func(t *testing.T) {
var v StrictTable
err := session.Query(stmt, nil).Get(&v)
if err != nil {
t.Fatal("Get() failed:", err)
}
@@ -790,9 +790,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})
t.Run("select unsafe", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Iter().Unsafe().Select(&v)
t.Run("select", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Select(&v)
if err != nil {
t.Fatal("Select() failed:", err)
}
@@ -804,9 +804,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})
t.Run("select default unsafe", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Unsafe().Iter().Select(&v)
t.Run("select default", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Iter().Select(&v)
if err != nil {
t.Fatal("Select() failed:", err)
}