Marshal/Unmarshal missing UDT fields as null instead of failing in unsafe mode

We can't return an error in case a field is added to the UDT,
otherwise existing code would break by simply altering the UDT in the
database. For extra fields at the end of the UDT put nulls to be in
line with gocql, but also python-driver and java-driver.

In gocql it was fixed in d2ed1bb74f
This commit is contained in:
Dmitry Kropachev
2024-06-14 11:28:30 -04:00
committed by Sylwia Szunejko
parent c6f942afc7
commit 6a60650668
3 changed files with 341 additions and 15 deletions

View File

@@ -9,6 +9,7 @@ package gocqlx_test
import (
"math/big"
"reflect"
"strings"
"testing"
"time"
@@ -48,6 +49,328 @@ type FullNamePtrUDT struct {
*FullName
}
func diff(t *testing.T, expected, got interface{}) {
t.Helper()
if d := cmp.Diff(expected, got, diffOpts); d != "" {
t.Errorf("got %+v expected %+v, diff: %s", got, expected, d)
}
}
var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})
func TestIterxUDT(t *testing.T) {
session := gocqlxtest.CreateSession(t)
t.Cleanup(func() {
session.Close()
})
if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil {
t.Fatal("create type:", err)
}
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table (
testuuid timeuuid PRIMARY KEY,
testudt gocqlx_test.UDTTest_Full
)`); err != nil {
t.Fatal("create table:", err)
}
type Full struct {
First string
Second string
}
type Part struct {
First string
}
type Extra struct {
First string
Second string
Third string
}
type FullUDT struct {
gocqlx.UDT
Full
}
type PartUDT struct {
gocqlx.UDT
Part
}
type ExtraUDT struct {
gocqlx.UDT
Extra
}
type FullUDTPtr struct {
gocqlx.UDT
*Full
}
type PartUDTPtr struct {
gocqlx.UDT
*Part
}
type ExtraUDTPtr struct {
gocqlx.UDT
*Extra
}
full := FullUDT{
Full: Full{
First: "John",
Second: "Doe",
},
}
makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} {
b := reflect.New(reflect.StructOf([]reflect.StructField{
{
Name: "TestUUID",
Type: reflect.TypeOf(gocql.UUID{}),
},
{
Name: "TestUDT",
Type: reflect.TypeOf(insert),
},
})).Interface()
reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid))
reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert))
return b
}
tcases := []struct {
name string
insert interface{}
expected interface{}
expectedOnDB FullUDT
}{
{
name: "exact-match",
insert: full,
expectedOnDB: full,
expected: full,
},
{
name: "exact-match-ptr",
insert: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
expectedOnDB: full,
expected: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
},
{
name: "extra-field",
insert: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "extra-field-ptr",
insert: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "absent-field",
insert: PartUDT{
Part: Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDT{
Part: Part{
First: "John",
},
},
},
{
name: "absent-field-ptr",
insert: PartUDTPtr{
Part: &Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDTPtr{
Part: &Part{
First: "John",
},
},
},
}
const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)`
const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?`
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
testuuid := gocql.TimeUUID()
if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) {
t.Fatalf("insert and expectedOnDB must have the same type")
}
t.Cleanup(func() {
session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck
})
t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
})
t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})
t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})
t.Run("delete", func(t *testing.T) {
if err := session.Query(deleteStmt, nil).Bind(
testuuid,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
})
t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})
t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})
}
}
func TestIterxStruct(t *testing.T) {
session := gocqlxtest.CreateSession(t)
defer session.Close()
@@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) {
t.Fatal("insert:", err)
}
diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})
const stmt = `SELECT * FROM struct_table`
t.Run("get", func(t *testing.T) {

View File

@@ -215,6 +215,13 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx {
return q
}
// Scan executes the query, copies the columns of the first selected
// row into the values pointed at by dest and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Queryx) Scan(v ...interface{}) error {
return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...)
}
// Err returns any binding errors.
func (q *Queryx) Err() error {
return q.err

22
udt.go
View File

@@ -39,27 +39,25 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
value, ok := u.field[name]
var data []byte
var err error
if ok {
data, err = gocql.Marshal(info, value.Interface())
if err != nil {
return nil, err
return gocql.Marshal(info, value.Interface())
}
if u.unsafe {
return nil, nil
}
return data, err
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
value, ok := u.field[name]
if !ok && !u.unsafe {
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
if ok {
return gocql.Unmarshal(info, data, value.Addr().Interface())
}
if u.unsafe {
return nil
}
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
// udtWrapValue adds UDT wrapper if needed.
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {