Automated UDT support

This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.

```
type FullName struct {
	gocqlx.UDT
	FirstName string
	LastName  string
}
```

Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
Michał Matczuk
2020-04-15 17:06:29 +02:00
committed by Michal Jan Matczuk
parent 2569c3dd8f
commit ab279e68ed
8 changed files with 284 additions and 86 deletions

View File

@@ -8,10 +8,11 @@ Package `gocqlx` is an idiomatic extension to `gocql` that provides usability fe
## Features
* Binding query parameters form struct or map
* Scanning results directly into struct or slice
* Binding query parameters form struct
* Scanning results into struct or slice
* Automated UDT support
* CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
* CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb))
* Super simple CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
* Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate))
* Fast!

27
doc_test.go Normal file
View File

@@ -0,0 +1,27 @@
package gocqlx_test
import (
"github.com/scylladb/gocqlx"
)
func ExampleUDT() {
// Just add gocqlx.UDT to a type, no need to implement marshalling functions
type FullName struct {
gocqlx.UDT
FirstName string
LastName string
}
}
func ExampleUDT_wraper() {
type FullName struct {
FirstName string
LastName string
}
// Create new UDT wrapper type
type FullNameUDT struct {
gocqlx.UDT
*FullName
}
}

View File

@@ -28,6 +28,10 @@ func structOnlyError(t reflect.Type) error {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
}
if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
}
@@ -36,6 +40,7 @@ func structOnlyError(t reflect.Type) error {
var (
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem()
)
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
@@ -46,32 +51,6 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
return t, nil
}
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int. If ptrs is true, return addresses instead of values.
// We write this instead of using FieldsByName to save allocations and map lookups
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {

123
iterx.go
View File

@@ -24,10 +24,9 @@ type Iterx struct {
unsafe bool
structOnly bool
started bool
err error
// Cache memory for a rows during iteration in StructScan.
// Cache memory for a rows during iteration in structScan.
fields [][]int
values []interface{}
}
@@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error {
return iter.checkErrAndNotFound()
}
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. t is scannable if:
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
// * it is not a struct
// * it has no exported fields
func (iter *Iterx) isScannable(t reflect.Type) bool {
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
return true
}
if t.Kind() != reflect.Struct {
return true
}
return len(iter.Mapper.TypeMap(t).Index) == 0
}
func (iter *Iterx) scanAny(dest interface{}) bool {
value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr {
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
return false
@@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool {
}
if scannable {
return iter.Scan(dest)
return iter.scan(value)
}
return iter.StructScan(dest)
return iter.structScan(value)
}
// Select scans all rows into a destination, which must be a pointer to slice
@@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
// scan into the struct field pointers
if !scannable {
ok = iter.StructScan(vp.Interface())
ok = iter.structScan(vp)
} else {
ok = iter.Scan(vp.Interface())
ok = iter.scan(vp)
}
if !ok {
break
@@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
return true
}
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. t is scannable if:
// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT
// * it is not a struct
// * it has no exported fields
func (iter *Iterx) isScannable(t reflect.Type) bool {
ptr := reflect.PtrTo(t)
switch {
case ptr.Implements(unmarshallerInterface):
return true
case ptr.Implements(udtUnmarshallerInterface):
return true
case ptr.Implements(autoUDTInterface):
return true
case t.Kind() != reflect.Struct:
return true
default:
return len(iter.Mapper.TypeMap(t).Index) == 0
}
}
func (iter *Iterx) scan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr {
panic("value must be a pointer")
}
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
}
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
// struct. Use this and iterate manually when the memory load of Select() might
// be prohibitive. StructScan caches the reflect work of matching up column
@@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
// safe to run StructScan on the same Iterx instance with different struct
// types.
func (iter *Iterx) StructScan(dest interface{}) bool {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr {
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
return false
}
if value.IsNil() {
iter.err = errors.New("expected a pointer but got nil")
return false
}
if !iter.started {
columns := columnNames(iter.Iter.Columns())
m := iter.Mapper
return iter.structScan(value)
}
func (iter *Iterx) structScan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr {
panic("value must be a pointer")
}
if iter.fields == nil {
columns := columnNames(iter.Iter.Columns())
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
iter.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if !iter.unsafe {
if f, err := missingFields(iter.fields); err != nil {
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest)
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
return false
}
}
iter.values = make([]interface{}, len(columns))
iter.started = true
}
err := fieldsByTraversal(v, iter.fields, iter.values, true)
if err != nil {
if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {
iter.err = err
return false
}
// scan into the struct field pointers and append to our results
return iter.Iter.Scan(iter.values...)
}
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int.
// We write this instead of using FieldsByName to save allocations and map
// lookups when iterating over many rows.
// Empty traversals will get an interface pointer.
func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error {
value = reflect.Indirect(value)
if value.Kind() != reflect.Struct {
return fmt.Errorf("expected a struct but got %s", value.Type())
}
for i, traversal := range traversals {
if len(traversal) == 0 {
continue
}
f := reflectx.FieldByIndexes(value, traversal).Addr()
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
}
return nil
}
func columnNames(ci []gocql.ColumnInfo) []string {
r := make([]string, len(ci))
for i, column := range ci {
@@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string {
return r
}
// Scan consumes the next row of the iterator and copies the columns of the
// current row into the values pointed at by dest. Use nil as a dest value
// to skip the corresponding column. Scan might send additional queries
// to the database to retrieve the next set of rows if paging was enabled.
//
// Scan returns true if the row was successfully unmarshaled or false if the
// end of the result set was reached or if an error occurred. Close should
// be called afterwards to retrieve any potential errors.
func (iter *Iterx) Scan(dest ...interface{}) bool {
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
}
// Close closes the iterator and returns any errors that happened during
// the query or the iteration.
func (iter *Iterx) Close() error {

View File

@@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
}
type FullNameUDT struct {
FirstName string
LastName string
gocqlx.UDT
FullName
}
func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
return gocql.Marshal(info, f.Interface())
}
func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
return gocql.Unmarshal(info, data, f.Addr().Interface())
type FullNamePtrUDT struct {
gocqlx.UDT
*FullName
}
func TestStruct(t *testing.T) {
@@ -75,7 +70,8 @@ func TestStruct(t *testing.T) {
testvarint varint,
testinet inet,
testcustom text,
testudt gocqlx_test.FullName
testudt gocqlx_test.FullName,
testptrudt gocqlx_test.FullName
)`); err != nil {
t.Fatal("create table:", err)
}
@@ -98,6 +94,7 @@ func TestStruct(t *testing.T) {
Testinet string
Testcustom FullName
Testudt FullNameUDT
Testptrudt FullNamePtrUDT
}
bigInt := new(big.Int)
@@ -122,10 +119,11 @@ func TestStruct(t *testing.T) {
Testvarint: bigInt,
Testinet: "213.212.2.19",
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"},
Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}},
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
}
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
m.Testuuid,
m.Testtimestamp,
m.Testvarchar,
@@ -142,7 +140,8 @@ func TestStruct(t *testing.T) {
m.Testvarint,
m.Testinet,
m.Testcustom,
m.Testudt).Exec(); err != nil {
m.Testudt,
m.Testptrudt).Exec(); err != nil {
t.Fatal("insert:", err)
}
@@ -186,6 +185,28 @@ func TestStruct(t *testing.T) {
t.Fatal("not equals")
}
})
t.Run("struct scan", func(t *testing.T) {
var (
v StructTable
n int
)
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
for i.StructScan(&v) {
n++
}
if err := i.Close(); err != nil {
t.Fatal("struct scan failed", err)
}
if n != 1 {
t.Fatal("struct scan unexpected number of rows", n)
}
if !reflect.DeepEqual(m, v) {
t.Fatal("not equals")
}
})
}
func TestScannable(t *testing.T) {
@@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) {
t.Fatal("create table:", err)
}
m := FullNameUDT{"John", "Doe"}
m := FullNameUDT{
FullName: FullName{
FirstName: "John",
LastName: "Doe",
},
}
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
t.Fatal("insert:", err)
@@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe get", func(t *testing.T) {
var v UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
})
@@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe select", func(t *testing.T) {
var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
if cap(v) > 0 {
@@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) {
})
}
func TestErrorOnNil(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
const (
stmt = "SELECT * FROM not_found_table WRONG"
golden = "expected a pointer but got <nil>"
)
t.Run("get", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Get()=%q expected %q error", err, golden)
}
})
t.Run("select", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Select()=%q expected %q error", err, golden)
}
})
t.Run("struct scan", func(t *testing.T) {
i := gocqlx.Iter(session.Query(stmt))
i.StructScan(nil)
err := i.Close()
if err == nil || err.Error() != golden {
t.Fatalf("StructScan()=%q expected %q error", err, golden)
}
})
}
func TestPaging(t *testing.T) {
session := CreateSession(t)
defer session.Close()

View File

@@ -180,6 +180,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
return arglist, nil
}
// Bind sets query arguments of query. This can also be used to rebind new query arguments
// to an existing query instance.
func (q *Queryx) Bind(v ...interface{}) *Queryx {
q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...)
return q
}
// Err returns any binding errors.
func (q *Queryx) Err() error {
return q.err

View File

@@ -117,13 +117,6 @@ func (q *Queryx) Idempotent(value bool) *Queryx {
return q
}
// Bind sets query arguments of query. This can also be used to rebind new query arguments
// to an existing query instance.
func (q *Queryx) Bind(v ...interface{}) *Queryx {
q.Query.Bind(v...)
return q
}
// SerialConsistency sets the consistency level for the
// serial phase of conditional updates. That consistency can only be
// either SERIAL or LOCAL_SERIAL and if not present, it defaults to

74
udt.go Normal file
View File

@@ -0,0 +1,74 @@
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
package gocqlx
import (
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/go-reflectx"
)
// UDT is a marker interface that needs to be embedded in a struct if you want
// to marshal or unmarshal it as a User Defined Type.
type UDT interface {
udt()
}
var (
_ gocql.UDTMarshaler = udt{}
_ gocql.UDTUnmarshaler = udt{}
)
type udt struct {
value reflect.Value
field map[string]reflect.Value
unsafe bool
}
func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
return udt{
value: value,
field: mapper.FieldMap(value),
unsafe: unsafe,
}
}
func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
value, ok := u.field[name]
if !ok {
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
return gocql.Marshal(info, value.Interface())
}
func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
value, ok := u.field[name]
if !ok && !u.unsafe {
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}
return gocql.Unmarshal(info, data, value.Addr().Interface())
}
// udtWrapValue adds UDT wrapper if needed.
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {
if value.Type().Implements(autoUDTInterface) {
return makeUDT(value, mapper, unsafe)
}
return value.Interface()
}
// udtWrapSlice adds UDT wrapper if needed.
func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} {
for i := range v {
if _, ok := v[i].(UDT); ok {
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe)
}
}
return v
}