Automated UDT support

This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.

```
type FullName struct {
	gocqlx.UDT
	FirstName string
	LastName  string
}
```

Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
Michał Matczuk
2020-04-15 17:06:29 +02:00
committed by Michal Jan Matczuk
parent 2569c3dd8f
commit ab279e68ed
8 changed files with 284 additions and 86 deletions

View File

@@ -8,10 +8,11 @@ Package `gocqlx` is an idiomatic extension to `gocql` that provides usability fe
## Features ## Features
* Binding query parameters form struct or map * Binding query parameters form struct
* Scanning results directly into struct or slice * Scanning results into struct or slice
* Automated UDT support
* CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
* CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb)) * CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb))
* Super simple CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
* Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate)) * Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate))
* Fast! * Fast!

27
doc_test.go Normal file
View File

@@ -0,0 +1,27 @@
package gocqlx_test
import (
"github.com/scylladb/gocqlx"
)
func ExampleUDT() {
// Just add gocqlx.UDT to a type, no need to implement marshalling functions
type FullName struct {
gocqlx.UDT
FirstName string
LastName string
}
}
func ExampleUDT_wraper() {
type FullName struct {
FirstName string
LastName string
}
// Create new UDT wrapper type
type FullNameUDT struct {
gocqlx.UDT
*FullName
}
}

View File

@@ -28,6 +28,10 @@ func structOnlyError(t reflect.Type) error {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name()) return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
} }
if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
} }
@@ -36,6 +40,7 @@ func structOnlyError(t reflect.Type) error {
var ( var (
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem() unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem() udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem()
) )
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
@@ -46,32 +51,6 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
return t, nil return t, nil
} }
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int. If ptrs is true, return addresses instead of values.
// We write this instead of using FieldsByName to save allocations and map lookups
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
func missingFields(transversals [][]int) (field int, err error) { func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals { for i, t := range transversals {
if len(t) == 0 { if len(t) == 0 {

123
iterx.go
View File

@@ -24,10 +24,9 @@ type Iterx struct {
unsafe bool unsafe bool
structOnly bool structOnly bool
started bool
err error err error
// Cache memory for a rows during iteration in StructScan. // Cache memory for a rows during iteration in structScan.
fields [][]int fields [][]int
values []interface{} values []interface{}
} }
@@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error {
return iter.checkErrAndNotFound() return iter.checkErrAndNotFound()
} }
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. t is scannable if:
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
// * it is not a struct
// * it has no exported fields
func (iter *Iterx) isScannable(t reflect.Type) bool {
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
return true
}
if t.Kind() != reflect.Struct {
return true
}
return len(iter.Mapper.TypeMap(t).Index) == 0
}
func (iter *Iterx) scanAny(dest interface{}) bool { func (iter *Iterx) scanAny(dest interface{}) bool {
value := reflect.ValueOf(dest) value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr { if value.Kind() != reflect.Ptr {
iter.err = fmt.Errorf("expected a pointer but got %T", dest) iter.err = fmt.Errorf("expected a pointer but got %T", dest)
return false return false
@@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool {
} }
if scannable { if scannable {
return iter.Scan(dest) return iter.scan(value)
} }
return iter.StructScan(dest) return iter.structScan(value)
} }
// Select scans all rows into a destination, which must be a pointer to slice // Select scans all rows into a destination, which must be a pointer to slice
@@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
// scan into the struct field pointers // scan into the struct field pointers
if !scannable { if !scannable {
ok = iter.StructScan(vp.Interface()) ok = iter.structScan(vp)
} else { } else {
ok = iter.Scan(vp.Interface()) ok = iter.scan(vp)
} }
if !ok { if !ok {
break break
@@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
return true return true
} }
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. t is scannable if:
// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT
// * it is not a struct
// * it has no exported fields
func (iter *Iterx) isScannable(t reflect.Type) bool {
ptr := reflect.PtrTo(t)
switch {
case ptr.Implements(unmarshallerInterface):
return true
case ptr.Implements(udtUnmarshallerInterface):
return true
case ptr.Implements(autoUDTInterface):
return true
case t.Kind() != reflect.Struct:
return true
default:
return len(iter.Mapper.TypeMap(t).Index) == 0
}
}
func (iter *Iterx) scan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr {
panic("value must be a pointer")
}
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
}
// StructScan is like gocql.Iter.Scan, but scans a single row into a single // StructScan is like gocql.Iter.Scan, but scans a single row into a single
// struct. Use this and iterate manually when the memory load of Select() might // struct. Use this and iterate manually when the memory load of Select() might
// be prohibitive. StructScan caches the reflect work of matching up column // be prohibitive. StructScan caches the reflect work of matching up column
@@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
// safe to run StructScan on the same Iterx instance with different struct // safe to run StructScan on the same Iterx instance with different struct
// types. // types.
func (iter *Iterx) StructScan(dest interface{}) bool { func (iter *Iterx) StructScan(dest interface{}) bool {
v := reflect.ValueOf(dest) value := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination") if value.Kind() != reflect.Ptr {
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
return false
}
if value.IsNil() {
iter.err = errors.New("expected a pointer but got nil")
return false return false
} }
if !iter.started { return iter.structScan(value)
columns := columnNames(iter.Iter.Columns()) }
m := iter.Mapper
func (iter *Iterx) structScan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr {
panic("value must be a pointer")
}
if iter.fields == nil {
columns := columnNames(iter.Iter.Columns())
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
iter.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error // if we are not unsafe and are missing fields, return an error
if !iter.unsafe { if !iter.unsafe {
if f, err := missingFields(iter.fields); err != nil { if f, err := missingFields(iter.fields); err != nil {
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest) iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
return false return false
} }
} }
iter.values = make([]interface{}, len(columns)) iter.values = make([]interface{}, len(columns))
iter.started = true
} }
err := fieldsByTraversal(v, iter.fields, iter.values, true) if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {
if err != nil {
iter.err = err iter.err = err
return false return false
} }
// scan into the struct field pointers and append to our results // scan into the struct field pointers and append to our results
return iter.Iter.Scan(iter.values...) return iter.Iter.Scan(iter.values...)
} }
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int.
// We write this instead of using FieldsByName to save allocations and map
// lookups when iterating over many rows.
// Empty traversals will get an interface pointer.
func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error {
value = reflect.Indirect(value)
if value.Kind() != reflect.Struct {
return fmt.Errorf("expected a struct but got %s", value.Type())
}
for i, traversal := range traversals {
if len(traversal) == 0 {
continue
}
f := reflectx.FieldByIndexes(value, traversal).Addr()
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
}
return nil
}
func columnNames(ci []gocql.ColumnInfo) []string { func columnNames(ci []gocql.ColumnInfo) []string {
r := make([]string, len(ci)) r := make([]string, len(ci))
for i, column := range ci { for i, column := range ci {
@@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string {
return r return r
} }
// Scan consumes the next row of the iterator and copies the columns of the
// current row into the values pointed at by dest. Use nil as a dest value
// to skip the corresponding column. Scan might send additional queries
// to the database to retrieve the next set of rows if paging was enabled.
//
// Scan returns true if the row was successfully unmarshaled or false if the
// end of the result set was reached or if an error occurred. Close should
// be called afterwards to retrieve any potential errors.
func (iter *Iterx) Scan(dest ...interface{}) bool {
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
}
// Close closes the iterator and returns any errors that happened during // Close closes the iterator and returns any errors that happened during
// the query or the iteration. // the query or the iteration.
func (iter *Iterx) Close() error { func (iter *Iterx) Close() error {

View File

@@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
} }
type FullNameUDT struct { type FullNameUDT struct {
FirstName string gocqlx.UDT
LastName string FullName
} }
func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { type FullNamePtrUDT struct {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name) gocqlx.UDT
return gocql.Marshal(info, f.Interface()) *FullName
}
func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
return gocql.Unmarshal(info, data, f.Addr().Interface())
} }
func TestStruct(t *testing.T) { func TestStruct(t *testing.T) {
@@ -75,7 +70,8 @@ func TestStruct(t *testing.T) {
testvarint varint, testvarint varint,
testinet inet, testinet inet,
testcustom text, testcustom text,
testudt gocqlx_test.FullName testudt gocqlx_test.FullName,
testptrudt gocqlx_test.FullName
)`); err != nil { )`); err != nil {
t.Fatal("create table:", err) t.Fatal("create table:", err)
} }
@@ -98,6 +94,7 @@ func TestStruct(t *testing.T) {
Testinet string Testinet string
Testcustom FullName Testcustom FullName
Testudt FullNameUDT Testudt FullNameUDT
Testptrudt FullNamePtrUDT
} }
bigInt := new(big.Int) bigInt := new(big.Int)
@@ -122,10 +119,11 @@ func TestStruct(t *testing.T) {
Testvarint: bigInt, Testvarint: bigInt,
Testinet: "213.212.2.19", Testinet: "213.212.2.19",
Testcustom: FullName{FirstName: "John", LastName: "Doe"}, Testcustom: FullName{FirstName: "John", LastName: "Doe"},
Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"}, Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}},
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
} }
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
m.Testuuid, m.Testuuid,
m.Testtimestamp, m.Testtimestamp,
m.Testvarchar, m.Testvarchar,
@@ -142,7 +140,8 @@ func TestStruct(t *testing.T) {
m.Testvarint, m.Testvarint,
m.Testinet, m.Testinet,
m.Testcustom, m.Testcustom,
m.Testudt).Exec(); err != nil { m.Testudt,
m.Testptrudt).Exec(); err != nil {
t.Fatal("insert:", err) t.Fatal("insert:", err)
} }
@@ -186,6 +185,28 @@ func TestStruct(t *testing.T) {
t.Fatal("not equals") t.Fatal("not equals")
} }
}) })
t.Run("struct scan", func(t *testing.T) {
var (
v StructTable
n int
)
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
for i.StructScan(&v) {
n++
}
if err := i.Close(); err != nil {
t.Fatal("struct scan failed", err)
}
if n != 1 {
t.Fatal("struct scan unexpected number of rows", n)
}
if !reflect.DeepEqual(m, v) {
t.Fatal("not equals")
}
})
} }
func TestScannable(t *testing.T) { func TestScannable(t *testing.T) {
@@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) {
t.Fatal("create table:", err) t.Fatal("create table:", err)
} }
m := FullNameUDT{"John", "Doe"} m := FullNameUDT{
FullName: FullName{
FirstName: "John",
LastName: "Doe",
},
}
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil { if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
t.Fatal("insert:", err) t.Fatal("insert:", err)
@@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe get", func(t *testing.T) { t.Run("safe get", func(t *testing.T) {
var v UnsafeTable var v UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" { if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err) t.Fatal("expected ErrNotFound", "got", err)
} }
}) })
@@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe select", func(t *testing.T) { t.Run("safe select", func(t *testing.T) {
var v []UnsafeTable var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" { if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err) t.Fatal("expected ErrNotFound", "got", err)
} }
if cap(v) > 0 { if cap(v) > 0 {
@@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) {
}) })
} }
func TestErrorOnNil(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
const (
stmt = "SELECT * FROM not_found_table WRONG"
golden = "expected a pointer but got <nil>"
)
t.Run("get", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Get()=%q expected %q error", err, golden)
}
})
t.Run("select", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Select()=%q expected %q error", err, golden)
}
})
t.Run("struct scan", func(t *testing.T) {
i := gocqlx.Iter(session.Query(stmt))
i.StructScan(nil)
err := i.Close()
if err == nil || err.Error() != golden {
t.Fatalf("StructScan()=%q expected %q error", err, golden)
}
})
}
func TestPaging(t *testing.T) { func TestPaging(t *testing.T) {
session := CreateSession(t) session := CreateSession(t)
defer session.Close() defer session.Close()

View File

@@ -180,6 +180,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
return arglist, nil return arglist, nil
} }
// Bind sets query arguments of query. This can also be used to rebind new query arguments
// to an existing query instance.
func (q *Queryx) Bind(v ...interface{}) *Queryx {
q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...)
return q
}
// Err returns any binding errors. // Err returns any binding errors.
func (q *Queryx) Err() error { func (q *Queryx) Err() error {
return q.err return q.err

View File

@@ -117,13 +117,6 @@ func (q *Queryx) Idempotent(value bool) *Queryx {
return q return q
} }
// Bind sets query arguments of query. This can also be used to rebind new query arguments
// to an existing query instance.
func (q *Queryx) Bind(v ...interface{}) *Queryx {
q.Query.Bind(v...)
return q
}
// SerialConsistency sets the consistency level for the // SerialConsistency sets the consistency level for the
// serial phase of conditional updates. That consistency can only be // serial phase of conditional updates. That consistency can only be
// either SERIAL or LOCAL_SERIAL and if not present, it defaults to // either SERIAL or LOCAL_SERIAL and if not present, it defaults to

74
udt.go Normal file
View File

@@ -0,0 +1,74 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
package gocqlx
import (
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/go-reflectx"
)
// UDT is a marker interface that needs to be embedded in a struct if you want
// to marshal or unmarshal it as a User Defined Type.
type UDT interface {
udt()
}
var (
_ gocql.UDTMarshaler = udt{}
_ gocql.UDTUnmarshaler = udt{}
)
type udt struct {
value reflect.Value
field map[string]reflect.Value
unsafe bool
}
func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
return udt{
value: value,
field: mapper.FieldMap(value),
unsafe: unsafe,
}
}
func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
value, ok := u.field[name]
if !ok {
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
return gocql.Marshal(info, value.Interface())
}
func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
value, ok := u.field[name]
if !ok && !u.unsafe {
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
return gocql.Unmarshal(info, data, value.Addr().Interface())
}
// udtWrapValue adds UDT wrapper if needed.
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {
if value.Type().Implements(autoUDTInterface) {
return makeUDT(value, mapper, unsafe)
}
return value.Interface()
}
// udtWrapSlice adds UDT wrapper if needed.
func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} {
for i := range v {
if _, ok := v[i].(UDT); ok {
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe)
}
}
return v
}