From ab279e68ed5258d2e683a9db20a4d83780429a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Wed, 15 Apr 2020 17:06:29 +0200 Subject: [PATCH] Automated UDT support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch adds the power of GocqlX to UDTs. Now you can make a struct be UDT compatible by adding a single line. ``` type FullName struct { gocqlx.UDT FirstName string LastName string } ``` Signed-off-by: MichaƂ Matczuk --- README.md | 7 +-- doc_test.go | 27 +++++++++++ gocqlx.go | 31 ++----------- iterx.go | 123 ++++++++++++++++++++++++++++++++++++------------- iterx_test.go | 94 ++++++++++++++++++++++++++++++------- queryx.go | 7 +++ queryx_wrap.go | 7 --- udt.go | 74 +++++++++++++++++++++++++++++ 8 files changed, 284 insertions(+), 86 deletions(-) create mode 100644 doc_test.go create mode 100644 udt.go diff --git a/README.md b/README.md index a3496bf..107b279 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,11 @@ Package `gocqlx` is an idiomatic extension to `gocql` that provides usability fe ## Features -* Binding query parameters form struct or map -* Scanning results directly into struct or slice +* Binding query parameters form struct +* Scanning results into struct or slice +* Automated UDT support +* CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table)) * CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb)) -* Super simple CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table)) * Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate)) * Fast! diff --git a/doc_test.go b/doc_test.go new file mode 100644 index 0000000..7af22d4 --- /dev/null +++ b/doc_test.go @@ -0,0 +1,27 @@ +package gocqlx_test + +import ( + "github.com/scylladb/gocqlx" +) + +func ExampleUDT() { + // Just add gocqlx.UDT to a type, no need to implement marshalling functions + type FullName struct { + gocqlx.UDT + FirstName string + LastName string + } +} + +func ExampleUDT_wraper() { + type FullName struct { + FirstName string + LastName string + } + + // Create new UDT wrapper type + type FullNameUDT struct { + gocqlx.UDT + *FullName + } +} diff --git a/gocqlx.go b/gocqlx.go index 4888ba6..edcc06a 100644 --- a/gocqlx.go +++ b/gocqlx.go @@ -28,6 +28,10 @@ func structOnlyError(t reflect.Type) error { return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name()) } + if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT { + return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name()) + } + return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) } @@ -36,6 +40,7 @@ func structOnlyError(t reflect.Type) error { var ( unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem() udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem() + autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem() ) func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { @@ -46,32 +51,6 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { return t, nil } -// fieldsByName fills a values interface with fields from the passed value based -// on the traversals in int. If ptrs is true, return addresses instead of values. -// We write this instead of using FieldsByName to save allocations and map lookups -// when iterating over many rows. Empty traversals will get an interface pointer. -// Because of the necessity of requesting ptrs or values, it's considered a bit too -// specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { - v = reflect.Indirect(v) - if v.Kind() != reflect.Struct { - return errors.New("argument not a struct") - } - - for i, traversal := range traversals { - if len(traversal) == 0 { - continue - } - f := reflectx.FieldByIndexes(v, traversal) - if ptrs { - values[i] = f.Addr().Interface() - } else { - values[i] = f.Interface() - } - } - return nil -} - func missingFields(transversals [][]int) (field int, err error) { for i, t := range transversals { if len(t) == 0 { diff --git a/iterx.go b/iterx.go index 5cd885b..6633be8 100644 --- a/iterx.go +++ b/iterx.go @@ -24,10 +24,9 @@ type Iterx struct { unsafe bool structOnly bool - started bool err error - // Cache memory for a rows during iteration in StructScan. + // Cache memory for a rows during iteration in structScan. fields [][]int values []interface{} } @@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error { return iter.checkErrAndNotFound() } -// isScannable takes the reflect.Type and the actual dest value and returns -// whether or not it's Scannable. t is scannable if: -// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler -// * it is not a struct -// * it has no exported fields -func (iter *Iterx) isScannable(t reflect.Type) bool { - if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) { - return true - } - if t.Kind() != reflect.Struct { - return true - } - - return len(iter.Mapper.TypeMap(t).Index) == 0 -} - func (iter *Iterx) scanAny(dest interface{}) bool { value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { iter.err = fmt.Errorf("expected a pointer but got %T", dest) return false @@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool { } if scannable { - return iter.Scan(dest) + return iter.scan(value) } - return iter.StructScan(dest) + return iter.structScan(value) } // Select scans all rows into a destination, which must be a pointer to slice @@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool { // scan into the struct field pointers if !scannable { - ok = iter.StructScan(vp.Interface()) + ok = iter.structScan(vp) } else { - ok = iter.Scan(vp.Interface()) + ok = iter.scan(vp) } if !ok { break @@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool { return true } +// isScannable takes the reflect.Type and the actual dest value and returns +// whether or not it's Scannable. t is scannable if: +// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT +// * it is not a struct +// * it has no exported fields +func (iter *Iterx) isScannable(t reflect.Type) bool { + ptr := reflect.PtrTo(t) + switch { + case ptr.Implements(unmarshallerInterface): + return true + case ptr.Implements(udtUnmarshallerInterface): + return true + case ptr.Implements(autoUDTInterface): + return true + case t.Kind() != reflect.Struct: + return true + default: + return len(iter.Mapper.TypeMap(t).Index) == 0 + } +} + +func (iter *Iterx) scan(value reflect.Value) bool { + if value.Kind() != reflect.Ptr { + panic("value must be a pointer") + } + return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe)) +} + // StructScan is like gocql.Iter.Scan, but scans a single row into a single // struct. Use this and iterate manually when the memory load of Select() might // be prohibitive. StructScan caches the reflect work of matching up column @@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool { // safe to run StructScan on the same Iterx instance with different struct // types. func (iter *Iterx) StructScan(dest interface{}) bool { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - iter.err = errors.New("must pass a pointer, not a value, to StructScan destination") + value := reflect.ValueOf(dest) + + if value.Kind() != reflect.Ptr { + iter.err = fmt.Errorf("expected a pointer but got %T", dest) + return false + } + if value.IsNil() { + iter.err = errors.New("expected a pointer but got nil") return false } - if !iter.started { - columns := columnNames(iter.Iter.Columns()) - m := iter.Mapper + return iter.structScan(value) +} + +func (iter *Iterx) structScan(value reflect.Value) bool { + if value.Kind() != reflect.Ptr { + panic("value must be a pointer") + } + + if iter.fields == nil { + columns := columnNames(iter.Iter.Columns()) + iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) - iter.fields = m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if !iter.unsafe { if f, err := missingFields(iter.fields); err != nil { - iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest) + iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type()) return false } } iter.values = make([]interface{}, len(columns)) - iter.started = true } - err := fieldsByTraversal(v, iter.fields, iter.values, true) - if err != nil { + if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil { iter.err = err return false } + // scan into the struct field pointers and append to our results return iter.Iter.Scan(iter.values...) } +// fieldsByName fills a values interface with fields from the passed value based +// on the traversals in int. +// We write this instead of using FieldsByName to save allocations and map +// lookups when iterating over many rows. +// Empty traversals will get an interface pointer. +func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error { + value = reflect.Indirect(value) + if value.Kind() != reflect.Struct { + return fmt.Errorf("expected a struct but got %s", value.Type()) + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + continue + } + f := reflectx.FieldByIndexes(value, traversal).Addr() + values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe) + } + + return nil +} + func columnNames(ci []gocql.ColumnInfo) []string { r := make([]string, len(ci)) for i, column := range ci { @@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string { return r } +// Scan consumes the next row of the iterator and copies the columns of the +// current row into the values pointed at by dest. Use nil as a dest value +// to skip the corresponding column. Scan might send additional queries +// to the database to retrieve the next set of rows if paging was enabled. +// +// Scan returns true if the row was successfully unmarshaled or false if the +// end of the result set was reached or if an error occurred. Close should +// be called afterwards to retrieve any potential errors. +func (iter *Iterx) Scan(dest ...interface{}) bool { + return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...) +} + // Close closes the iterator and returns any errors that happened during // the query or the iteration. func (iter *Iterx) Close() error { diff --git a/iterx_test.go b/iterx_test.go index 01a1888..b429093 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { } type FullNameUDT struct { - FirstName string - LastName string + gocqlx.UDT + FullName } -func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { - f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name) - return gocql.Marshal(info, f.Interface()) -} - -func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { - f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name) - return gocql.Unmarshal(info, data, f.Addr().Interface()) +type FullNamePtrUDT struct { + gocqlx.UDT + *FullName } func TestStruct(t *testing.T) { @@ -75,7 +70,8 @@ func TestStruct(t *testing.T) { testvarint varint, testinet inet, testcustom text, - testudt gocqlx_test.FullName + testudt gocqlx_test.FullName, + testptrudt gocqlx_test.FullName )`); err != nil { t.Fatal("create table:", err) } @@ -98,6 +94,7 @@ func TestStruct(t *testing.T) { Testinet string Testcustom FullName Testudt FullNameUDT + Testptrudt FullNamePtrUDT } bigInt := new(big.Int) @@ -122,10 +119,11 @@ func TestStruct(t *testing.T) { Testvarint: bigInt, Testinet: "213.212.2.19", Testcustom: FullName{FirstName: "John", LastName: "Doe"}, - Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"}, + Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}}, + Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}}, } - if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind( m.Testuuid, m.Testtimestamp, m.Testvarchar, @@ -142,7 +140,8 @@ func TestStruct(t *testing.T) { m.Testvarint, m.Testinet, m.Testcustom, - m.Testudt).Exec(); err != nil { + m.Testudt, + m.Testptrudt).Exec(); err != nil { t.Fatal("insert:", err) } @@ -186,6 +185,28 @@ func TestStruct(t *testing.T) { t.Fatal("not equals") } }) + + t.Run("struct scan", func(t *testing.T) { + var ( + v StructTable + n int + ) + + i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter() + for i.StructScan(&v) { + n++ + } + if err := i.Close(); err != nil { + t.Fatal("struct scan failed", err) + } + + if n != 1 { + t.Fatal("struct scan unexpected number of rows", n) + } + if !reflect.DeepEqual(m, v) { + t.Fatal("not equals") + } + }) } func TestScannable(t *testing.T) { @@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) { t.Fatal("create table:", err) } - m := FullNameUDT{"John", "Doe"} + m := FullNameUDT{ + FullName: FullName{ + FirstName: "John", + LastName: "Doe", + }, + } if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil { t.Fatal("insert:", err) @@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) { t.Run("safe get", func(t *testing.T) { var v UnsafeTable i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" { + if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { t.Fatal("expected ErrNotFound", "got", err) } }) @@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) { t.Run("safe select", func(t *testing.T) { var v []UnsafeTable i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) - if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" { + if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" { t.Fatal("expected ErrNotFound", "got", err) } if cap(v) > 0 { @@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) { }) } +func TestErrorOnNil(t *testing.T) { + session := CreateSession(t) + defer session.Close() + if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil { + t.Fatal("create table:", err) + } + + const ( + stmt = "SELECT * FROM not_found_table WRONG" + golden = "expected a pointer but got " + ) + + t.Run("get", func(t *testing.T) { + err := gocqlx.Iter(session.Query(stmt)).Get(nil) + if err == nil || err.Error() != golden { + t.Fatalf("Get()=%q expected %q error", err, golden) + } + }) + t.Run("select", func(t *testing.T) { + err := gocqlx.Iter(session.Query(stmt)).Select(nil) + if err == nil || err.Error() != golden { + t.Fatalf("Select()=%q expected %q error", err, golden) + } + }) + t.Run("struct scan", func(t *testing.T) { + i := gocqlx.Iter(session.Query(stmt)) + i.StructScan(nil) + err := i.Close() + if err == nil || err.Error() != golden { + t.Fatalf("StructScan()=%q expected %q error", err, golden) + } + }) +} + func TestPaging(t *testing.T) { session := CreateSession(t) defer session.Close() diff --git a/queryx.go b/queryx.go index 297fb87..025a675 100644 --- a/queryx.go +++ b/queryx.go @@ -180,6 +180,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err return arglist, nil } +// Bind sets query arguments of query. This can also be used to rebind new query arguments +// to an existing query instance. +func (q *Queryx) Bind(v ...interface{}) *Queryx { + q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...) + return q +} + // Err returns any binding errors. func (q *Queryx) Err() error { return q.err diff --git a/queryx_wrap.go b/queryx_wrap.go index c7649f2..e63bb4a 100644 --- a/queryx_wrap.go +++ b/queryx_wrap.go @@ -117,13 +117,6 @@ func (q *Queryx) Idempotent(value bool) *Queryx { return q } -// Bind sets query arguments of query. This can also be used to rebind new query arguments -// to an existing query instance. -func (q *Queryx) Bind(v ...interface{}) *Queryx { - q.Query.Bind(v...) - return q -} - // SerialConsistency sets the consistency level for the // serial phase of conditional updates. That consistency can only be // either SERIAL or LOCAL_SERIAL and if not present, it defaults to diff --git a/udt.go b/udt.go new file mode 100644 index 0000000..e17ac06 --- /dev/null +++ b/udt.go @@ -0,0 +1,74 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +package gocqlx + +import ( + "fmt" + "reflect" + + "github.com/gocql/gocql" + "github.com/scylladb/go-reflectx" +) + +// UDT is a marker interface that needs to be embedded in a struct if you want +// to marshal or unmarshal it as a User Defined Type. +type UDT interface { + udt() +} + +var ( + _ gocql.UDTMarshaler = udt{} + _ gocql.UDTUnmarshaler = udt{} +) + +type udt struct { + value reflect.Value + field map[string]reflect.Value + unsafe bool +} + +func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { + return udt{ + value: value, + field: mapper.FieldMap(value), + unsafe: unsafe, + } +} + +func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { + value, ok := u.field[name] + if !ok { + return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) + } + + return gocql.Marshal(info, value.Interface()) +} + +func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { + value, ok := u.field[name] + if !ok && !u.unsafe { + return fmt.Errorf("missing name %q in %s", name, u.value.Type()) + } + + return gocql.Unmarshal(info, data, value.Addr().Interface()) +} + +// udtWrapValue adds UDT wrapper if needed. +func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} { + if value.Type().Implements(autoUDTInterface) { + return makeUDT(value, mapper, unsafe) + } + return value.Interface() +} + +// udtWrapSlice adds UDT wrapper if needed. +func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} { + for i := range v { + if _, ok := v[i].(UDT); ok { + v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe) + } + } + return v +}