iterx: Allow forcing scanning as struct

We have a structure type that implements UnmarshalCQL method.
We use it to unmarshal a user defined type. We also want to use the same struct for scanning an entire row.

There is StructScan method available in gocqlx for this purpose when iterating over rows, but no equivalent when doing a Select.
This commit introduces the possibility when doing select/get as well.

Co-authored-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
Martin Sucha
2019-11-06 15:33:37 +01:00
committed by Michal Jan Matczuk
parent a9ce16bfc6
commit a08a66ee85
4 changed files with 260 additions and 42 deletions

View File

@@ -22,9 +22,10 @@ type Iterx struct {
*gocql.Iter
Mapper *reflectx.Mapper
unsafe bool
started bool
err error
unsafe bool
structOnly bool
started bool
err error
// Cache memory for a rows during iteration in StructScan.
fields [][]int
@@ -48,14 +49,29 @@ func (iter *Iterx) Unsafe() *Iterx {
return iter
}
// Get scans first row into a destination and closes the iterator. If the
// destination type is a struct pointer, then StructScan will be used.
// StructOnly forces the iterator to treat a single-argument struct as
// non-scannable. This is is useful if you need to scan a row into a struct
// that also implements gocql.UDTUnmarshaler or in rare cases gocql.Unmarshaler.
func (iter *Iterx) StructOnly() *Iterx {
iter.structOnly = true
return iter
}
// Get scans first row into a destination and closes the iterator.
//
// If the destination type is a struct pointer, then StructScan will be
// used.
// If the destination is some other type, then the row must only have one column
// which can scan into that type.
// This includes types that implement gocql.Unmarshaler and gocql.UDTUnmarshaler.
//
// If you'd like to treat a type that implements gocql.Unmarshaler or
// gocql.UDTUnmarshaler as an ordinary struct you should call
// StructOnly().Get(dest) instead.
//
// If no rows were selected, ErrNotFound is returned.
func (iter *Iterx) Get(dest interface{}) error {
iter.scanAny(dest, false)
iter.scanAny(dest)
iter.Close()
return iter.checkErrAndNotFound()
@@ -63,11 +79,11 @@ func (iter *Iterx) Get(dest interface{}) error {
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. t is scannable if:
// * ptr to t implements gocql.Unmarshaler
// * ptr to t implements gocql.Unmarshaler or gocql.UDTUnmarshaler
// * it is not a struct
// * it has no exported fields
func (iter *Iterx) isScannable(t reflect.Type) bool {
if reflect.PtrTo(t).Implements(_unmarshallerInterface) {
if ptr := reflect.PtrTo(t); ptr.Implements(unmarshallerInterface) || ptr.Implements(udtUnmarshallerInterface) {
return true
}
if t.Kind() != reflect.Struct {
@@ -77,7 +93,7 @@ func (iter *Iterx) isScannable(t reflect.Type) bool {
return len(iter.Mapper.TypeMap(t).Index) == 0
}
func (iter *Iterx) scanAny(dest interface{}, structOnly bool) bool {
func (iter *Iterx) scanAny(dest interface{}) bool {
value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr {
iter.err = fmt.Errorf("expected a pointer but got %T", dest)
@@ -91,13 +107,17 @@ func (iter *Iterx) scanAny(dest interface{}, structOnly bool) bool {
base := reflectx.Deref(value.Type())
scannable := iter.isScannable(base)
if structOnly && scannable {
iter.err = structOnlyError(base)
return false
if iter.structOnly && scannable {
if base.Kind() == reflect.Struct {
scannable = false
} else {
iter.err = structOnlyError(base)
return false
}
}
if scannable && len(iter.Columns()) > 1 {
iter.err = fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns()))
iter.err = fmt.Errorf("expected 1 column in result while scanning scannable type %s but got %d", base.Kind(), len(iter.Columns()))
return false
}
@@ -109,20 +129,27 @@ func (iter *Iterx) scanAny(dest interface{}, structOnly bool) bool {
}
// Select scans all rows into a destination, which must be a pointer to slice
// of any type and closes the iterator. If the destination slice type is
// a struct, then StructScan will be used on each row. If the destination is
// some other type, then each row must only have one column which can scan into
// that type.
// of any type, and closes the iterator.
//
// If the destination slice type is a struct, then StructScan will be used
// on each row.
// If the destination is some other type, then each row must only have one
// column which can scan into that type.
// This includes types that implement gocql.Unmarshaler and gocql.UDTUnmarshaler.
//
// If you'd like to treat a type that implements gocql.Unmarshaler or
// gocql.UDTUnmarshaler as an ordinary struct you should call
// StructOnly().Select(dest) instead.
//
// If no rows were selected, ErrNotFound is NOT returned.
func (iter *Iterx) Select(dest interface{}) error {
iter.scanAll(dest, false)
iter.scanAll(dest)
iter.Close()
return iter.err
}
func (iter *Iterx) scanAll(dest interface{}, structOnly bool) bool {
func (iter *Iterx) scanAll(dest interface{}) bool {
value := reflect.ValueOf(dest)
// json.Unmarshal returns errors for these
@@ -145,14 +172,18 @@ func (iter *Iterx) scanAll(dest interface{}, structOnly bool) bool {
base := reflectx.Deref(slice.Elem())
scannable := iter.isScannable(base)
if structOnly && scannable {
iter.err = structOnlyError(base)
return false
if iter.structOnly && scannable {
if base.Kind() == reflect.Struct {
scannable = false
} else {
iter.err = structOnlyError(base)
return false
}
}
// if it's a base type make sure it only has 1 column; if not return an error
if scannable && len(iter.Columns()) > 1 {
iter.err = fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns()))
iter.err = fmt.Errorf("expected 1 column in result while scanning scannable type %s but got %d", base.Kind(), len(iter.Columns()))
return false
}