diff --git a/iterx_test.go b/iterx_test.go index d8652a0..96de405 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -9,6 +9,7 @@ package gocqlx_test import ( "math/big" + "reflect" "strings" "testing" "time" @@ -48,6 +49,328 @@ type FullNamePtrUDT struct { *FullName } +func diff(t *testing.T, expected, got interface{}) { + t.Helper() + + if d := cmp.Diff(expected, got, diffOpts); d != "" { + t.Errorf("got %+v expected %+v, diff: %s", got, expected, d) + } +} + +var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) + +func TestIterxUDT(t *testing.T) { + session := gocqlxtest.CreateSession(t) + t.Cleanup(func() { + session.Close() + }) + + if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil { + t.Fatal("create type:", err) + } + + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table ( + testuuid timeuuid PRIMARY KEY, + testudt gocqlx_test.UDTTest_Full + )`); err != nil { + t.Fatal("create table:", err) + } + + type Full struct { + First string + Second string + } + + type Part struct { + First string + } + + type Extra struct { + First string + Second string + Third string + } + + type FullUDT struct { + gocqlx.UDT + Full + } + + type PartUDT struct { + gocqlx.UDT + Part + } + + type ExtraUDT struct { + gocqlx.UDT + Extra + } + + type FullUDTPtr struct { + gocqlx.UDT + *Full + } + + type PartUDTPtr struct { + gocqlx.UDT + *Part + } + + type ExtraUDTPtr struct { + gocqlx.UDT + *Extra + } + + full := FullUDT{ + Full: Full{ + First: "John", + Second: "Doe", + }, + } + + makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} { + b := reflect.New(reflect.StructOf([]reflect.StructField{ + { + Name: "TestUUID", + Type: reflect.TypeOf(gocql.UUID{}), + }, + { + Name: "TestUDT", + Type: reflect.TypeOf(insert), + }, + })).Interface() + reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid)) + reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert)) + return b + } + + tcases := []struct { + name string + insert interface{} + expected interface{} + expectedOnDB FullUDT + }{ + { + name: "exact-match", + insert: full, + expectedOnDB: full, + expected: full, + }, + { + name: "exact-match-ptr", + insert: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + expectedOnDB: full, + expected: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + }, + { + name: "extra-field", + insert: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "extra-field-ptr", + insert: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "absent-field", + insert: PartUDT{ + Part: Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDT{ + Part: Part{ + First: "John", + }, + }, + }, + { + name: "absent-field-ptr", + insert: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + }, + } + + const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)` + const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?` + + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + testuuid := gocql.TimeUUID() + + if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) { + t.Fatalf("insert and expectedOnDB must have the same type") + } + + t.Cleanup(func() { + session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck + }) + + t.Run("insert-bind", func(t *testing.T) { + if err := session.Query(insertStmt, nil).Unsafe().Bind( + testuuid, + tc.insert, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := FullUDT{} + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expectedOnDB, v) + }) + + t.Run("scan", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("get", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("delete", func(t *testing.T) { + if err := session.Query(deleteStmt, nil).Bind( + testuuid, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + }) + + t.Run("insert-bind-struct", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("insert-bind-struct-map", func(t *testing.T) { + t.Run("empty-map", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(b, nil).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("empty-struct", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(struct{}{}, map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + + t.Run("insert-bind-map", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindMap(map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + } +} + func TestIterxStruct(t *testing.T) { session := gocqlxtest.CreateSession(t) defer session.Close() @@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) { t.Fatal("insert:", err) } - diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) - const stmt = `SELECT * FROM struct_table` t.Run("get", func(t *testing.T) { diff --git a/queryx.go b/queryx.go index 331512c..9b67f9b 100644 --- a/queryx.go +++ b/queryx.go @@ -215,6 +215,13 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx { return q } +// Scan executes the query, copies the columns of the first selected +// row into the values pointed at by dest and discards the rest. If no rows +// were selected, ErrNotFound is returned. +func (q *Queryx) Scan(v ...interface{}) error { + return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...) +} + // Err returns any binding errors. func (q *Queryx) Err() error { return q.err diff --git a/udt.go b/udt.go index 63551f7..2785151 100644 --- a/udt.go +++ b/udt.go @@ -39,26 +39,24 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { value, ok := u.field[name] - - var data []byte - var err error if ok { - data, err = gocql.Marshal(info, value.Interface()) - if err != nil { - return nil, err - } + return gocql.Marshal(info, value.Interface()) } - - return data, err + if u.unsafe { + return nil, nil + } + return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) } func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { value, ok := u.field[name] - if !ok && !u.unsafe { - return fmt.Errorf("missing name %q in %s", name, u.value.Type()) + if ok { + return gocql.Unmarshal(info, data, value.Addr().Interface()) } - - return gocql.Unmarshal(info, data, value.Addr().Interface()) + if u.unsafe { + return nil + } + return fmt.Errorf("missing name %q in %s", name, u.value.Type()) } // udtWrapValue adds UDT wrapper if needed.