Automated UDT support
This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.
```
type FullName struct {
gocqlx.UDT
FirstName string
LastName string
}
```
Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
committed by
Michal Jan Matczuk
parent
2569c3dd8f
commit
ab279e68ed
@@ -8,10 +8,11 @@ Package `gocqlx` is an idiomatic extension to `gocql` that provides usability fe
|
||||
|
||||
## Features
|
||||
|
||||
* Binding query parameters form struct or map
|
||||
* Scanning results directly into struct or slice
|
||||
* Binding query parameters form struct
|
||||
* Scanning results into struct or slice
|
||||
* Automated UDT support
|
||||
* CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
|
||||
* CQL query builder ([package qb](https://github.com/scylladb/gocqlx/blob/master/qb))
|
||||
* Super simple CRUD operations based on table model ([package table](https://github.com/scylladb/gocqlx/blob/master/table))
|
||||
* Database migrations ([package migrate](https://github.com/scylladb/gocqlx/blob/master/migrate))
|
||||
* Fast!
|
||||
|
||||
|
||||
27
doc_test.go
Normal file
27
doc_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package gocqlx_test
|
||||
|
||||
import (
|
||||
"github.com/scylladb/gocqlx"
|
||||
)
|
||||
|
||||
func ExampleUDT() {
|
||||
// Just add gocqlx.UDT to a type, no need to implement marshalling functions
|
||||
type FullName struct {
|
||||
gocqlx.UDT
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleUDT_wraper() {
|
||||
type FullName struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
|
||||
// Create new UDT wrapper type
|
||||
type FullNameUDT struct {
|
||||
gocqlx.UDT
|
||||
*FullName
|
||||
}
|
||||
}
|
||||
31
gocqlx.go
31
gocqlx.go
@@ -28,6 +28,10 @@ func structOnlyError(t reflect.Type) error {
|
||||
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
|
||||
}
|
||||
|
||||
if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT {
|
||||
return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name())
|
||||
}
|
||||
|
||||
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
|
||||
}
|
||||
|
||||
@@ -36,6 +40,7 @@ func structOnlyError(t reflect.Type) error {
|
||||
var (
|
||||
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
|
||||
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
|
||||
autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem()
|
||||
)
|
||||
|
||||
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
|
||||
@@ -46,32 +51,6 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// fieldsByName fills a values interface with fields from the passed value based
|
||||
// on the traversals in int. If ptrs is true, return addresses instead of values.
|
||||
// We write this instead of using FieldsByName to save allocations and map lookups
|
||||
// when iterating over many rows. Empty traversals will get an interface pointer.
|
||||
// Because of the necessity of requesting ptrs or values, it's considered a bit too
|
||||
// specialized for inclusion in reflectx itself.
|
||||
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Kind() != reflect.Struct {
|
||||
return errors.New("argument not a struct")
|
||||
}
|
||||
|
||||
for i, traversal := range traversals {
|
||||
if len(traversal) == 0 {
|
||||
continue
|
||||
}
|
||||
f := reflectx.FieldByIndexes(v, traversal)
|
||||
if ptrs {
|
||||
values[i] = f.Addr().Interface()
|
||||
} else {
|
||||
values[i] = f.Interface()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func missingFields(transversals [][]int) (field int, err error) {
|
||||
for i, t := range transversals {
|
||||
if len(t) == 0 {
|
||||
|
||||
123
iterx.go
123
iterx.go
@@ -24,10 +24,9 @@ type Iterx struct {
|
||||
|
||||
unsafe bool
|
||||
structOnly bool
|
||||
started bool
|
||||
err error
|
||||
|
||||
// Cache memory for a rows during iteration in StructScan.
|
||||
// Cache memory for a rows during iteration in structScan.
|
||||
fields [][]int
|
||||
values []interface{}
|
||||
}
|
||||
@@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error {
|
||||
return iter.checkErrAndNotFound()
|
||||
}
|
||||
|
||||
// isScannable takes the reflect.Type and the actual dest value and returns
|
||||
// whether or not it's Scannable. t is scannable if:
|
||||
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
|
||||
// * it is not a struct
|
||||
// * it has no exported fields
|
||||
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
||||
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
|
||||
return true
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return true
|
||||
}
|
||||
|
||||
return len(iter.Mapper.TypeMap(t).Index) == 0
|
||||
}
|
||||
|
||||
func (iter *Iterx) scanAny(dest interface{}) bool {
|
||||
value := reflect.ValueOf(dest)
|
||||
|
||||
if value.Kind() != reflect.Ptr {
|
||||
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||
return false
|
||||
@@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool {
|
||||
}
|
||||
|
||||
if scannable {
|
||||
return iter.Scan(dest)
|
||||
return iter.scan(value)
|
||||
}
|
||||
|
||||
return iter.StructScan(dest)
|
||||
return iter.structScan(value)
|
||||
}
|
||||
|
||||
// Select scans all rows into a destination, which must be a pointer to slice
|
||||
@@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
|
||||
// scan into the struct field pointers
|
||||
if !scannable {
|
||||
ok = iter.StructScan(vp.Interface())
|
||||
ok = iter.structScan(vp)
|
||||
} else {
|
||||
ok = iter.Scan(vp.Interface())
|
||||
ok = iter.scan(vp)
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
@@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// isScannable takes the reflect.Type and the actual dest value and returns
|
||||
// whether or not it's Scannable. t is scannable if:
|
||||
// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT
|
||||
// * it is not a struct
|
||||
// * it has no exported fields
|
||||
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
||||
ptr := reflect.PtrTo(t)
|
||||
switch {
|
||||
case ptr.Implements(unmarshallerInterface):
|
||||
return true
|
||||
case ptr.Implements(udtUnmarshallerInterface):
|
||||
return true
|
||||
case ptr.Implements(autoUDTInterface):
|
||||
return true
|
||||
case t.Kind() != reflect.Struct:
|
||||
return true
|
||||
default:
|
||||
return len(iter.Mapper.TypeMap(t).Index) == 0
|
||||
}
|
||||
}
|
||||
|
||||
func (iter *Iterx) scan(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Ptr {
|
||||
panic("value must be a pointer")
|
||||
}
|
||||
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
|
||||
}
|
||||
|
||||
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
|
||||
// struct. Use this and iterate manually when the memory load of Select() might
|
||||
// be prohibitive. StructScan caches the reflect work of matching up column
|
||||
@@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
// safe to run StructScan on the same Iterx instance with different struct
|
||||
// types.
|
||||
func (iter *Iterx) StructScan(dest interface{}) bool {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
||||
value := reflect.ValueOf(dest)
|
||||
|
||||
if value.Kind() != reflect.Ptr {
|
||||
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||
return false
|
||||
}
|
||||
if value.IsNil() {
|
||||
iter.err = errors.New("expected a pointer but got nil")
|
||||
return false
|
||||
}
|
||||
|
||||
if !iter.started {
|
||||
columns := columnNames(iter.Iter.Columns())
|
||||
m := iter.Mapper
|
||||
return iter.structScan(value)
|
||||
}
|
||||
|
||||
func (iter *Iterx) structScan(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Ptr {
|
||||
panic("value must be a pointer")
|
||||
}
|
||||
|
||||
if iter.fields == nil {
|
||||
columns := columnNames(iter.Iter.Columns())
|
||||
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
|
||||
|
||||
iter.fields = m.TraversalsByName(v.Type(), columns)
|
||||
// if we are not unsafe and are missing fields, return an error
|
||||
if !iter.unsafe {
|
||||
if f, err := missingFields(iter.fields); err != nil {
|
||||
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest)
|
||||
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
|
||||
return false
|
||||
}
|
||||
}
|
||||
iter.values = make([]interface{}, len(columns))
|
||||
iter.started = true
|
||||
}
|
||||
|
||||
err := fieldsByTraversal(v, iter.fields, iter.values, true)
|
||||
if err != nil {
|
||||
if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {
|
||||
iter.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
// scan into the struct field pointers and append to our results
|
||||
return iter.Iter.Scan(iter.values...)
|
||||
}
|
||||
|
||||
// fieldsByName fills a values interface with fields from the passed value based
|
||||
// on the traversals in int.
|
||||
// We write this instead of using FieldsByName to save allocations and map
|
||||
// lookups when iterating over many rows.
|
||||
// Empty traversals will get an interface pointer.
|
||||
func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error {
|
||||
value = reflect.Indirect(value)
|
||||
if value.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("expected a struct but got %s", value.Type())
|
||||
}
|
||||
|
||||
for i, traversal := range traversals {
|
||||
if len(traversal) == 0 {
|
||||
continue
|
||||
}
|
||||
f := reflectx.FieldByIndexes(value, traversal).Addr()
|
||||
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func columnNames(ci []gocql.ColumnInfo) []string {
|
||||
r := make([]string, len(ci))
|
||||
for i, column := range ci {
|
||||
@@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string {
|
||||
return r
|
||||
}
|
||||
|
||||
// Scan consumes the next row of the iterator and copies the columns of the
|
||||
// current row into the values pointed at by dest. Use nil as a dest value
|
||||
// to skip the corresponding column. Scan might send additional queries
|
||||
// to the database to retrieve the next set of rows if paging was enabled.
|
||||
//
|
||||
// Scan returns true if the row was successfully unmarshaled or false if the
|
||||
// end of the result set was reached or if an error occurred. Close should
|
||||
// be called afterwards to retrieve any potential errors.
|
||||
func (iter *Iterx) Scan(dest ...interface{}) bool {
|
||||
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
|
||||
}
|
||||
|
||||
// Close closes the iterator and returns any errors that happened during
|
||||
// the query or the iteration.
|
||||
func (iter *Iterx) Close() error {
|
||||
|
||||
@@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
|
||||
}
|
||||
|
||||
type FullNameUDT struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
gocqlx.UDT
|
||||
FullName
|
||||
}
|
||||
|
||||
func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
|
||||
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
|
||||
return gocql.Marshal(info, f.Interface())
|
||||
}
|
||||
|
||||
func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
|
||||
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
|
||||
return gocql.Unmarshal(info, data, f.Addr().Interface())
|
||||
type FullNamePtrUDT struct {
|
||||
gocqlx.UDT
|
||||
*FullName
|
||||
}
|
||||
|
||||
func TestStruct(t *testing.T) {
|
||||
@@ -75,7 +70,8 @@ func TestStruct(t *testing.T) {
|
||||
testvarint varint,
|
||||
testinet inet,
|
||||
testcustom text,
|
||||
testudt gocqlx_test.FullName
|
||||
testudt gocqlx_test.FullName,
|
||||
testptrudt gocqlx_test.FullName
|
||||
)`); err != nil {
|
||||
t.Fatal("create table:", err)
|
||||
}
|
||||
@@ -98,6 +94,7 @@ func TestStruct(t *testing.T) {
|
||||
Testinet string
|
||||
Testcustom FullName
|
||||
Testudt FullNameUDT
|
||||
Testptrudt FullNamePtrUDT
|
||||
}
|
||||
|
||||
bigInt := new(big.Int)
|
||||
@@ -122,10 +119,11 @@ func TestStruct(t *testing.T) {
|
||||
Testvarint: bigInt,
|
||||
Testinet: "213.212.2.19",
|
||||
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
|
||||
Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"},
|
||||
Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}},
|
||||
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
|
||||
}
|
||||
|
||||
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
|
||||
m.Testuuid,
|
||||
m.Testtimestamp,
|
||||
m.Testvarchar,
|
||||
@@ -142,7 +140,8 @@ func TestStruct(t *testing.T) {
|
||||
m.Testvarint,
|
||||
m.Testinet,
|
||||
m.Testcustom,
|
||||
m.Testudt).Exec(); err != nil {
|
||||
m.Testudt,
|
||||
m.Testptrudt).Exec(); err != nil {
|
||||
t.Fatal("insert:", err)
|
||||
}
|
||||
|
||||
@@ -186,6 +185,28 @@ func TestStruct(t *testing.T) {
|
||||
t.Fatal("not equals")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("struct scan", func(t *testing.T) {
|
||||
var (
|
||||
v StructTable
|
||||
n int
|
||||
)
|
||||
|
||||
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
|
||||
for i.StructScan(&v) {
|
||||
n++
|
||||
}
|
||||
if err := i.Close(); err != nil {
|
||||
t.Fatal("struct scan failed", err)
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
t.Fatal("struct scan unexpected number of rows", n)
|
||||
}
|
||||
if !reflect.DeepEqual(m, v) {
|
||||
t.Fatal("not equals")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestScannable(t *testing.T) {
|
||||
@@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) {
|
||||
t.Fatal("create table:", err)
|
||||
}
|
||||
|
||||
m := FullNameUDT{"John", "Doe"}
|
||||
m := FullNameUDT{
|
||||
FullName: FullName{
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
},
|
||||
}
|
||||
|
||||
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
|
||||
t.Fatal("insert:", err)
|
||||
@@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) {
|
||||
t.Run("safe get", func(t *testing.T) {
|
||||
var v UnsafeTable
|
||||
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
||||
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
|
||||
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
|
||||
t.Fatal("expected ErrNotFound", "got", err)
|
||||
}
|
||||
})
|
||||
@@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) {
|
||||
t.Run("safe select", func(t *testing.T) {
|
||||
var v []UnsafeTable
|
||||
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
|
||||
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
|
||||
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
|
||||
t.Fatal("expected ErrNotFound", "got", err)
|
||||
}
|
||||
if cap(v) > 0 {
|
||||
@@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorOnNil(t *testing.T) {
|
||||
session := CreateSession(t)
|
||||
defer session.Close()
|
||||
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
|
||||
t.Fatal("create table:", err)
|
||||
}
|
||||
|
||||
const (
|
||||
stmt = "SELECT * FROM not_found_table WRONG"
|
||||
golden = "expected a pointer but got <nil>"
|
||||
)
|
||||
|
||||
t.Run("get", func(t *testing.T) {
|
||||
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
|
||||
if err == nil || err.Error() != golden {
|
||||
t.Fatalf("Get()=%q expected %q error", err, golden)
|
||||
}
|
||||
})
|
||||
t.Run("select", func(t *testing.T) {
|
||||
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
|
||||
if err == nil || err.Error() != golden {
|
||||
t.Fatalf("Select()=%q expected %q error", err, golden)
|
||||
}
|
||||
})
|
||||
t.Run("struct scan", func(t *testing.T) {
|
||||
i := gocqlx.Iter(session.Query(stmt))
|
||||
i.StructScan(nil)
|
||||
err := i.Close()
|
||||
if err == nil || err.Error() != golden {
|
||||
t.Fatalf("StructScan()=%q expected %q error", err, golden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPaging(t *testing.T) {
|
||||
session := CreateSession(t)
|
||||
defer session.Close()
|
||||
|
||||
@@ -180,6 +180,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
|
||||
return arglist, nil
|
||||
}
|
||||
|
||||
// Bind sets query arguments of query. This can also be used to rebind new query arguments
|
||||
// to an existing query instance.
|
||||
func (q *Queryx) Bind(v ...interface{}) *Queryx {
|
||||
q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...)
|
||||
return q
|
||||
}
|
||||
|
||||
// Err returns any binding errors.
|
||||
func (q *Queryx) Err() error {
|
||||
return q.err
|
||||
|
||||
@@ -117,13 +117,6 @@ func (q *Queryx) Idempotent(value bool) *Queryx {
|
||||
return q
|
||||
}
|
||||
|
||||
// Bind sets query arguments of query. This can also be used to rebind new query arguments
|
||||
// to an existing query instance.
|
||||
func (q *Queryx) Bind(v ...interface{}) *Queryx {
|
||||
q.Query.Bind(v...)
|
||||
return q
|
||||
}
|
||||
|
||||
// SerialConsistency sets the consistency level for the
|
||||
// serial phase of conditional updates. That consistency can only be
|
||||
// either SERIAL or LOCAL_SERIAL and if not present, it defaults to
|
||||
|
||||
74
udt.go
Normal file
74
udt.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright (C) 2017 ScyllaDB
|
||||
// Use of this source code is governed by a ALv2-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/scylladb/go-reflectx"
|
||||
)
|
||||
|
||||
// UDT is a marker interface that needs to be embedded in a struct if you want
|
||||
// to marshal or unmarshal it as a User Defined Type.
|
||||
type UDT interface {
|
||||
udt()
|
||||
}
|
||||
|
||||
var (
|
||||
_ gocql.UDTMarshaler = udt{}
|
||||
_ gocql.UDTUnmarshaler = udt{}
|
||||
)
|
||||
|
||||
type udt struct {
|
||||
value reflect.Value
|
||||
field map[string]reflect.Value
|
||||
unsafe bool
|
||||
}
|
||||
|
||||
func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
|
||||
return udt{
|
||||
value: value,
|
||||
field: mapper.FieldMap(value),
|
||||
unsafe: unsafe,
|
||||
}
|
||||
}
|
||||
|
||||
func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
|
||||
value, ok := u.field[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
|
||||
}
|
||||
|
||||
return gocql.Marshal(info, value.Interface())
|
||||
}
|
||||
|
||||
func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
|
||||
value, ok := u.field[name]
|
||||
if !ok && !u.unsafe {
|
||||
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
|
||||
}
|
||||
|
||||
return gocql.Unmarshal(info, data, value.Addr().Interface())
|
||||
}
|
||||
|
||||
// udtWrapValue adds UDT wrapper if needed.
|
||||
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {
|
||||
if value.Type().Implements(autoUDTInterface) {
|
||||
return makeUDT(value, mapper, unsafe)
|
||||
}
|
||||
return value.Interface()
|
||||
}
|
||||
|
||||
// udtWrapSlice adds UDT wrapper if needed.
|
||||
func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} {
|
||||
for i := range v {
|
||||
if _, ok := v[i].(UDT); ok {
|
||||
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe)
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
Reference in New Issue
Block a user