Automated UDT support
This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.
```
type FullName struct {
gocqlx.UDT
FirstName string
LastName string
}
```
Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
committed by
Michal Jan Matczuk
parent
2569c3dd8f
commit
ab279e68ed
@@ -8,10 +8,11 @@ Package `gocqlx` is an idiomatic extension to `gocql` that provides usability fe
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Binding query parameters form struct or map
|
* Binding query parameters form struct
|
||||||
* Scanning results directly into struct or slice
|
* Scanning results into struct or slice
|
||||||
|
* Automated UDT support
|
||||||
|
* CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
|
||||||
* CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb))
|
* CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb))
|
||||||
* Super simple CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
|
|
||||||
* Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate))
|
* Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate))
|
||||||
* Fast!
|
* Fast!
|
||||||
|
|
||||||
|
|||||||
27
doc_test.go
Normal file
27
doc_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package gocqlx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/scylladb/gocqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleUDT() {
|
||||||
|
// Just add gocqlx.UDT to a type, no need to implement marshalling functions
|
||||||
|
type FullName struct {
|
||||||
|
gocqlx.UDT
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleUDT_wraper() {
|
||||||
|
type FullName struct {
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new UDT wrapper type
|
||||||
|
type FullNameUDT struct {
|
||||||
|
gocqlx.UDT
|
||||||
|
*FullName
|
||||||
|
}
|
||||||
|
}
|
||||||
31
gocqlx.go
31
gocqlx.go
@@ -28,6 +28,10 @@ func structOnlyError(t reflect.Type) error {
|
|||||||
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
|
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT {
|
||||||
|
return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name())
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
|
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +40,7 @@ func structOnlyError(t reflect.Type) error {
|
|||||||
var (
|
var (
|
||||||
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
|
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
|
||||||
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
|
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
|
||||||
|
autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem()
|
||||||
)
|
)
|
||||||
|
|
||||||
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
|
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
|
||||||
@@ -46,32 +51,6 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fieldsByName fills a values interface with fields from the passed value based
|
|
||||||
// on the traversals in int. If ptrs is true, return addresses instead of values.
|
|
||||||
// We write this instead of using FieldsByName to save allocations and map lookups
|
|
||||||
// when iterating over many rows. Empty traversals will get an interface pointer.
|
|
||||||
// Because of the necessity of requesting ptrs or values, it's considered a bit too
|
|
||||||
// specialized for inclusion in reflectx itself.
|
|
||||||
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
|
|
||||||
v = reflect.Indirect(v)
|
|
||||||
if v.Kind() != reflect.Struct {
|
|
||||||
return errors.New("argument not a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, traversal := range traversals {
|
|
||||||
if len(traversal) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
f := reflectx.FieldByIndexes(v, traversal)
|
|
||||||
if ptrs {
|
|
||||||
values[i] = f.Addr().Interface()
|
|
||||||
} else {
|
|
||||||
values[i] = f.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func missingFields(transversals [][]int) (field int, err error) {
|
func missingFields(transversals [][]int) (field int, err error) {
|
||||||
for i, t := range transversals {
|
for i, t := range transversals {
|
||||||
if len(t) == 0 {
|
if len(t) == 0 {
|
||||||
|
|||||||
123
iterx.go
123
iterx.go
@@ -24,10 +24,9 @@ type Iterx struct {
|
|||||||
|
|
||||||
unsafe bool
|
unsafe bool
|
||||||
structOnly bool
|
structOnly bool
|
||||||
started bool
|
|
||||||
err error
|
err error
|
||||||
|
|
||||||
// Cache memory for a rows during iteration in StructScan.
|
// Cache memory for a rows during iteration in structScan.
|
||||||
fields [][]int
|
fields [][]int
|
||||||
values []interface{}
|
values []interface{}
|
||||||
}
|
}
|
||||||
@@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error {
|
|||||||
return iter.checkErrAndNotFound()
|
return iter.checkErrAndNotFound()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isScannable takes the reflect.Type and the actual dest value and returns
|
|
||||||
// whether or not it's Scannable. t is scannable if:
|
|
||||||
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
|
|
||||||
// * it is not a struct
|
|
||||||
// * it has no exported fields
|
|
||||||
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
|
||||||
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if t.Kind() != reflect.Struct {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(iter.Mapper.TypeMap(t).Index) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iterx) scanAny(dest interface{}) bool {
|
func (iter *Iterx) scanAny(dest interface{}) bool {
|
||||||
value := reflect.ValueOf(dest)
|
value := reflect.ValueOf(dest)
|
||||||
|
|
||||||
if value.Kind() != reflect.Ptr {
|
if value.Kind() != reflect.Ptr {
|
||||||
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||||
return false
|
return false
|
||||||
@@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if scannable {
|
if scannable {
|
||||||
return iter.Scan(dest)
|
return iter.scan(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return iter.StructScan(dest)
|
return iter.structScan(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select scans all rows into a destination, which must be a pointer to slice
|
// Select scans all rows into a destination, which must be a pointer to slice
|
||||||
@@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
|||||||
|
|
||||||
// scan into the struct field pointers
|
// scan into the struct field pointers
|
||||||
if !scannable {
|
if !scannable {
|
||||||
ok = iter.StructScan(vp.Interface())
|
ok = iter.structScan(vp)
|
||||||
} else {
|
} else {
|
||||||
ok = iter.Scan(vp.Interface())
|
ok = iter.scan(vp)
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
@@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isScannable takes the reflect.Type and the actual dest value and returns
|
||||||
|
// whether or not it's Scannable. t is scannable if:
|
||||||
|
// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT
|
||||||
|
// * it is not a struct
|
||||||
|
// * it has no exported fields
|
||||||
|
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
||||||
|
ptr := reflect.PtrTo(t)
|
||||||
|
switch {
|
||||||
|
case ptr.Implements(unmarshallerInterface):
|
||||||
|
return true
|
||||||
|
case ptr.Implements(udtUnmarshallerInterface):
|
||||||
|
return true
|
||||||
|
case ptr.Implements(autoUDTInterface):
|
||||||
|
return true
|
||||||
|
case t.Kind() != reflect.Struct:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return len(iter.Mapper.TypeMap(t).Index) == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (iter *Iterx) scan(value reflect.Value) bool {
|
||||||
|
if value.Kind() != reflect.Ptr {
|
||||||
|
panic("value must be a pointer")
|
||||||
|
}
|
||||||
|
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
|
||||||
|
}
|
||||||
|
|
||||||
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
|
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
|
||||||
// struct. Use this and iterate manually when the memory load of Select() might
|
// struct. Use this and iterate manually when the memory load of Select() might
|
||||||
// be prohibitive. StructScan caches the reflect work of matching up column
|
// be prohibitive. StructScan caches the reflect work of matching up column
|
||||||
@@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
|||||||
// safe to run StructScan on the same Iterx instance with different struct
|
// safe to run StructScan on the same Iterx instance with different struct
|
||||||
// types.
|
// types.
|
||||||
func (iter *Iterx) StructScan(dest interface{}) bool {
|
func (iter *Iterx) StructScan(dest interface{}) bool {
|
||||||
v := reflect.ValueOf(dest)
|
value := reflect.ValueOf(dest)
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
if value.Kind() != reflect.Ptr {
|
||||||
|
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if value.IsNil() {
|
||||||
|
iter.err = errors.New("expected a pointer but got nil")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !iter.started {
|
return iter.structScan(value)
|
||||||
columns := columnNames(iter.Iter.Columns())
|
}
|
||||||
m := iter.Mapper
|
|
||||||
|
func (iter *Iterx) structScan(value reflect.Value) bool {
|
||||||
|
if value.Kind() != reflect.Ptr {
|
||||||
|
panic("value must be a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if iter.fields == nil {
|
||||||
|
columns := columnNames(iter.Iter.Columns())
|
||||||
|
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
|
||||||
|
|
||||||
iter.fields = m.TraversalsByName(v.Type(), columns)
|
|
||||||
// if we are not unsafe and are missing fields, return an error
|
// if we are not unsafe and are missing fields, return an error
|
||||||
if !iter.unsafe {
|
if !iter.unsafe {
|
||||||
if f, err := missingFields(iter.fields); err != nil {
|
if f, err := missingFields(iter.fields); err != nil {
|
||||||
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest)
|
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
iter.values = make([]interface{}, len(columns))
|
iter.values = make([]interface{}, len(columns))
|
||||||
iter.started = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := fieldsByTraversal(v, iter.fields, iter.values, true)
|
if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {
|
||||||
if err != nil {
|
|
||||||
iter.err = err
|
iter.err = err
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// scan into the struct field pointers and append to our results
|
// scan into the struct field pointers and append to our results
|
||||||
return iter.Iter.Scan(iter.values...)
|
return iter.Iter.Scan(iter.values...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fieldsByName fills a values interface with fields from the passed value based
|
||||||
|
// on the traversals in int.
|
||||||
|
// We write this instead of using FieldsByName to save allocations and map
|
||||||
|
// lookups when iterating over many rows.
|
||||||
|
// Empty traversals will get an interface pointer.
|
||||||
|
func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error {
|
||||||
|
value = reflect.Indirect(value)
|
||||||
|
if value.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("expected a struct but got %s", value.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, traversal := range traversals {
|
||||||
|
if len(traversal) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f := reflectx.FieldByIndexes(value, traversal).Addr()
|
||||||
|
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func columnNames(ci []gocql.ColumnInfo) []string {
|
func columnNames(ci []gocql.ColumnInfo) []string {
|
||||||
r := make([]string, len(ci))
|
r := make([]string, len(ci))
|
||||||
for i, column := range ci {
|
for i, column := range ci {
|
||||||
@@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scan consumes the next row of the iterator and copies the columns of the
|
||||||
|
// current row into the values pointed at by dest. Use nil as a dest value
|
||||||
|
// to skip the corresponding column. Scan might send additional queries
|
||||||
|
// to the database to retrieve the next set of rows if paging was enabled.
|
||||||
|
//
|
||||||
|
// Scan returns true if the row was successfully unmarshaled or false if the
|
||||||
|
// end of the result set was reached or if an error occurred. Close should
|
||||||
|
// be called afterwards to retrieve any potential errors.
|
||||||
|
func (iter *Iterx) Scan(dest ...interface{}) bool {
|
||||||
|
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the iterator and returns any errors that happened during
|
// Close closes the iterator and returns any errors that happened during
|
||||||
// the query or the iteration.
|
// the query or the iteration.
|
||||||
func (iter *Iterx) Close() error {
|
func (iter *Iterx) Close() error {
|
||||||
|
|||||||
@@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FullNameUDT struct {
|
type FullNameUDT struct {
|
||||||
FirstName string
|
gocqlx.UDT
|
||||||
LastName string
|
FullName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
|
type FullNamePtrUDT struct {
|
||||||
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
|
gocqlx.UDT
|
||||||
return gocql.Marshal(info, f.Interface())
|
*FullName
|
||||||
}
|
|
||||||
|
|
||||||
func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
|
|
||||||
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
|
|
||||||
return gocql.Unmarshal(info, data, f.Addr().Interface())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStruct(t *testing.T) {
|
func TestStruct(t *testing.T) {
|
||||||
@@ -75,7 +70,8 @@ func TestStruct(t *testing.T) {
|
|||||||
testvarint varint,
|
testvarint varint,
|
||||||
testinet inet,
|
testinet inet,
|
||||||
testcustom text,
|
testcustom text,
|
||||||
testudt gocqlx_test.FullName
|
testudt gocqlx_test.FullName,
|
||||||
|
testptrudt gocqlx_test.FullName
|
||||||
)`); err != nil {
|
)`); err != nil {
|
||||||
t.Fatal("create table:", err)
|
t.Fatal("create table:", err)
|
||||||
}
|
}
|
||||||
@@ -98,6 +94,7 @@ func TestStruct(t *testing.T) {
|
|||||||
Testinet string
|
Testinet string
|
||||||
Testcustom FullName
|
Testcustom FullName
|
||||||
Testudt FullNameUDT
|
Testudt FullNameUDT
|
||||||
|
Testptrudt FullNamePtrUDT
|
||||||
}
|
}
|
||||||
|
|
||||||
bigInt := new(big.Int)
|
bigInt := new(big.Int)
|
||||||
@@ -122,10 +119,11 @@ func TestStruct(t *testing.T) {
|
|||||||
Testvarint: bigInt,
|
Testvarint: bigInt,
|
||||||
Testinet: "213.212.2.19",
|
Testinet: "213.212.2.19",
|
||||||
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
|
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
|
||||||
Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"},
|
Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}},
|
||||||
|
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
|
||||||
m.Testuuid,
|
m.Testuuid,
|
||||||
m.Testtimestamp,
|
m.Testtimestamp,
|
||||||
m.Testvarchar,
|
m.Testvarchar,
|
||||||
@@ -142,7 +140,8 @@ func TestStruct(t *testing.T) {
|
|||||||
m.Testvarint,
|
m.Testvarint,
|
||||||
m.Testinet,
|
m.Testinet,
|
||||||
m.Testcustom,
|
m.Testcustom,
|
||||||
m.Testudt).Exec(); err != nil {
|
m.Testudt,
|
||||||
|
m.Testptrudt).Exec(); err != nil {
|
||||||
t.Fatal("insert:", err)
|
t.Fatal("insert:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,6 +185,28 @@ func TestStruct(t *testing.T) {
|
|||||||
t.Fatal("not equals")
|
t.Fatal("not equals")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("struct scan", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
v StructTable
|
||||||
|
n int
|
||||||
|
)
|
||||||
|
|
||||||
|
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
|
||||||
|
for i.StructScan(&v) {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if err := i.Close(); err != nil {
|
||||||
|
t.Fatal("struct scan failed", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatal("struct scan unexpected number of rows", n)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(m, v) {
|
||||||
|
t.Fatal("not equals")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScannable(t *testing.T) {
|
func TestScannable(t *testing.T) {
|
||||||
@@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) {
|
|||||||
t.Fatal("create table:", err)
|
t.Fatal("create table:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := FullNameUDT{"John", "Doe"}
|
m := FullNameUDT{
|
||||||
|
FullName: FullName{
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
|
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
|
||||||
t.Fatal("insert:", err)
|
t.Fatal("insert:", err)
|
||||||
@@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) {
|
|||||||
t.Run("safe get", func(t *testing.T) {
|
t.Run("safe get", func(t *testing.T) {
|
||||||
var v UnsafeTable
|
var v UnsafeTable
|
||||||
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
||||||
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
|
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
|
||||||
t.Fatal("expected ErrNotFound", "got", err)
|
t.Fatal("expected ErrNotFound", "got", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) {
|
|||||||
t.Run("safe select", func(t *testing.T) {
|
t.Run("safe select", func(t *testing.T) {
|
||||||
var v []UnsafeTable
|
var v []UnsafeTable
|
||||||
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
||||||
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
|
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
|
||||||
t.Fatal("expected ErrNotFound", "got", err)
|
t.Fatal("expected ErrNotFound", "got", err)
|
||||||
}
|
}
|
||||||
if cap(v) > 0 {
|
if cap(v) > 0 {
|
||||||
@@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestErrorOnNil(t *testing.T) {
|
||||||
|
session := CreateSession(t)
|
||||||
|
defer session.Close()
|
||||||
|
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
|
||||||
|
t.Fatal("create table:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
stmt = "SELECT * FROM not_found_table WRONG"
|
||||||
|
golden = "expected a pointer but got <nil>"
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run("get", func(t *testing.T) {
|
||||||
|
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
|
||||||
|
if err == nil || err.Error() != golden {
|
||||||
|
t.Fatalf("Get()=%q expected %q error", err, golden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("select", func(t *testing.T) {
|
||||||
|
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
|
||||||
|
if err == nil || err.Error() != golden {
|
||||||
|
t.Fatalf("Select()=%q expected %q error", err, golden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("struct scan", func(t *testing.T) {
|
||||||
|
i := gocqlx.Iter(session.Query(stmt))
|
||||||
|
i.StructScan(nil)
|
||||||
|
err := i.Close()
|
||||||
|
if err == nil || err.Error() != golden {
|
||||||
|
t.Fatalf("StructScan()=%q expected %q error", err, golden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestPaging(t *testing.T) {
|
func TestPaging(t *testing.T) {
|
||||||
session := CreateSession(t)
|
session := CreateSession(t)
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|||||||
@@ -180,6 +180,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
|
|||||||
return arglist, nil
|
return arglist, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bind sets query arguments of query. This can also be used to rebind new query arguments
|
||||||
|
// to an existing query instance.
|
||||||
|
func (q *Queryx) Bind(v ...interface{}) *Queryx {
|
||||||
|
q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
// Err returns any binding errors.
|
// Err returns any binding errors.
|
||||||
func (q *Queryx) Err() error {
|
func (q *Queryx) Err() error {
|
||||||
return q.err
|
return q.err
|
||||||
|
|||||||
@@ -117,13 +117,6 @@ func (q *Queryx) Idempotent(value bool) *Queryx {
|
|||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind sets query arguments of query. This can also be used to rebind new query arguments
|
|
||||||
// to an existing query instance.
|
|
||||||
func (q *Queryx) Bind(v ...interface{}) *Queryx {
|
|
||||||
q.Query.Bind(v...)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// SerialConsistency sets the consistency level for the
|
// SerialConsistency sets the consistency level for the
|
||||||
// serial phase of conditional updates. That consistency can only be
|
// serial phase of conditional updates. That consistency can only be
|
||||||
// either SERIAL or LOCAL_SERIAL and if not present, it defaults to
|
// either SERIAL or LOCAL_SERIAL and if not present, it defaults to
|
||||||
|
|||||||
74
udt.go
Normal file
74
udt.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
// Copyright (C) 2017 ScyllaDB
|
||||||
|
// Use of this source code is governed by a ALv2-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/scylladb/go-reflectx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDT is a marker interface that needs to be embedded in a struct if you want
|
||||||
|
// to marshal or unmarshal it as a User Defined Type.
|
||||||
|
type UDT interface {
|
||||||
|
udt()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ gocql.UDTMarshaler = udt{}
|
||||||
|
_ gocql.UDTUnmarshaler = udt{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type udt struct {
|
||||||
|
value reflect.Value
|
||||||
|
field map[string]reflect.Value
|
||||||
|
unsafe bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
|
||||||
|
return udt{
|
||||||
|
value: value,
|
||||||
|
field: mapper.FieldMap(value),
|
||||||
|
unsafe: unsafe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
|
||||||
|
value, ok := u.field[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
return gocql.Marshal(info, value.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
|
||||||
|
value, ok := u.field[name]
|
||||||
|
if !ok && !u.unsafe {
|
||||||
|
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
return gocql.Unmarshal(info, data, value.Addr().Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
// udtWrapValue adds UDT wrapper if needed.
|
||||||
|
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {
|
||||||
|
if value.Type().Implements(autoUDTInterface) {
|
||||||
|
return makeUDT(value, mapper, unsafe)
|
||||||
|
}
|
||||||
|
return value.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// udtWrapSlice adds UDT wrapper if needed.
|
||||||
|
func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} {
|
||||||
|
for i := range v {
|
||||||
|
if _, ok := v[i].(UDT); ok {
|
||||||
|
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user