Automated UDT support
This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.
```
type FullName struct {
gocqlx.UDT
FirstName string
LastName string
}
```
Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
committed by
Michal Jan Matczuk
parent
2569c3dd8f
commit
ab279e68ed
123
iterx.go
123
iterx.go
@@ -24,10 +24,9 @@ type Iterx struct {
|
||||
|
||||
unsafe bool
|
||||
structOnly bool
|
||||
started bool
|
||||
err error
|
||||
|
||||
// Cache memory for a rows during iteration in StructScan.
|
||||
// Cache memory for a rows during iteration in structScan.
|
||||
fields [][]int
|
||||
values []interface{}
|
||||
}
|
||||
@@ -77,24 +76,9 @@ func (iter *Iterx) Get(dest interface{}) error {
|
||||
return iter.checkErrAndNotFound()
|
||||
}
|
||||
|
||||
// isScannable takes the reflect.Type and the actual dest value and returns
|
||||
// whether or not it's Scannable. t is scannable if:
|
||||
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
|
||||
// * it is not a struct
|
||||
// * it has no exported fields
|
||||
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
||||
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
|
||||
return true
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return true
|
||||
}
|
||||
|
||||
return len(iter.Mapper.TypeMap(t).Index) == 0
|
||||
}
|
||||
|
||||
func (iter *Iterx) scanAny(dest interface{}) bool {
|
||||
value := reflect.ValueOf(dest)
|
||||
|
||||
if value.Kind() != reflect.Ptr {
|
||||
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||
return false
|
||||
@@ -122,10 +106,10 @@ func (iter *Iterx) scanAny(dest interface{}) bool {
|
||||
}
|
||||
|
||||
if scannable {
|
||||
return iter.Scan(dest)
|
||||
return iter.scan(value)
|
||||
}
|
||||
|
||||
return iter.StructScan(dest)
|
||||
return iter.structScan(value)
|
||||
}
|
||||
|
||||
// Select scans all rows into a destination, which must be a pointer to slice
|
||||
@@ -199,9 +183,9 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
|
||||
// scan into the struct field pointers
|
||||
if !scannable {
|
||||
ok = iter.StructScan(vp.Interface())
|
||||
ok = iter.structScan(vp)
|
||||
} else {
|
||||
ok = iter.Scan(vp.Interface())
|
||||
ok = iter.scan(vp)
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
@@ -228,6 +212,34 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// isScannable takes the reflect.Type and the actual dest value and returns
|
||||
// whether or not it's Scannable. t is scannable if:
|
||||
// * ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT
|
||||
// * it is not a struct
|
||||
// * it has no exported fields
|
||||
func (iter *Iterx) isScannable(t reflect.Type) bool {
|
||||
ptr := reflect.PtrTo(t)
|
||||
switch {
|
||||
case ptr.Implements(unmarshallerInterface):
|
||||
return true
|
||||
case ptr.Implements(udtUnmarshallerInterface):
|
||||
return true
|
||||
case ptr.Implements(autoUDTInterface):
|
||||
return true
|
||||
case t.Kind() != reflect.Struct:
|
||||
return true
|
||||
default:
|
||||
return len(iter.Mapper.TypeMap(t).Index) == 0
|
||||
}
|
||||
}
|
||||
|
||||
func (iter *Iterx) scan(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Ptr {
|
||||
panic("value must be a pointer")
|
||||
}
|
||||
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
|
||||
}
|
||||
|
||||
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
|
||||
// struct. Use this and iterate manually when the memory load of Select() might
|
||||
// be prohibitive. StructScan caches the reflect work of matching up column
|
||||
@@ -235,37 +247,70 @@ func (iter *Iterx) scanAll(dest interface{}) bool {
|
||||
// safe to run StructScan on the same Iterx instance with different struct
|
||||
// types.
|
||||
func (iter *Iterx) StructScan(dest interface{}) bool {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
||||
value := reflect.ValueOf(dest)
|
||||
|
||||
if value.Kind() != reflect.Ptr {
|
||||
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
|
||||
return false
|
||||
}
|
||||
if value.IsNil() {
|
||||
iter.err = errors.New("expected a pointer but got nil")
|
||||
return false
|
||||
}
|
||||
|
||||
if !iter.started {
|
||||
columns := columnNames(iter.Iter.Columns())
|
||||
m := iter.Mapper
|
||||
return iter.structScan(value)
|
||||
}
|
||||
|
||||
func (iter *Iterx) structScan(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Ptr {
|
||||
panic("value must be a pointer")
|
||||
}
|
||||
|
||||
if iter.fields == nil {
|
||||
columns := columnNames(iter.Iter.Columns())
|
||||
iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
|
||||
|
||||
iter.fields = m.TraversalsByName(v.Type(), columns)
|
||||
// if we are not unsafe and are missing fields, return an error
|
||||
if !iter.unsafe {
|
||||
if f, err := missingFields(iter.fields); err != nil {
|
||||
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest)
|
||||
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
|
||||
return false
|
||||
}
|
||||
}
|
||||
iter.values = make([]interface{}, len(columns))
|
||||
iter.started = true
|
||||
}
|
||||
|
||||
err := fieldsByTraversal(v, iter.fields, iter.values, true)
|
||||
if err != nil {
|
||||
if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil {
|
||||
iter.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
// scan into the struct field pointers and append to our results
|
||||
return iter.Iter.Scan(iter.values...)
|
||||
}
|
||||
|
||||
// fieldsByName fills a values interface with fields from the passed value based
|
||||
// on the traversals in int.
|
||||
// We write this instead of using FieldsByName to save allocations and map
|
||||
// lookups when iterating over many rows.
|
||||
// Empty traversals will get an interface pointer.
|
||||
func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error {
|
||||
value = reflect.Indirect(value)
|
||||
if value.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("expected a struct but got %s", value.Type())
|
||||
}
|
||||
|
||||
for i, traversal := range traversals {
|
||||
if len(traversal) == 0 {
|
||||
continue
|
||||
}
|
||||
f := reflectx.FieldByIndexes(value, traversal).Addr()
|
||||
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func columnNames(ci []gocql.ColumnInfo) []string {
|
||||
r := make([]string, len(ci))
|
||||
for i, column := range ci {
|
||||
@@ -274,6 +319,18 @@ func columnNames(ci []gocql.ColumnInfo) []string {
|
||||
return r
|
||||
}
|
||||
|
||||
// Scan consumes the next row of the iterator and copies the columns of the
|
||||
// current row into the values pointed at by dest. Use nil as a dest value
|
||||
// to skip the corresponding column. Scan might send additional queries
|
||||
// to the database to retrieve the next set of rows if paging was enabled.
|
||||
//
|
||||
// Scan returns true if the row was successfully unmarshaled or false if the
|
||||
// end of the result set was reached or if an error occurred. Close should
|
||||
// be called afterwards to retrieve any potential errors.
|
||||
func (iter *Iterx) Scan(dest ...interface{}) bool {
|
||||
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
|
||||
}
|
||||
|
||||
// Close closes the iterator and returns any errors that happened during
|
||||
// the query or the iteration.
|
||||
func (iter *Iterx) Close() error {
|
||||
|
||||
Reference in New Issue
Block a user