// Copyright (C) 2017 ScyllaDB // Use of this source code is governed by a ALv2-style // license that can be found in the LICENSE file. package gocqlx import ( "errors" "fmt" "reflect" gocql "github.com/apache/cassandra-gocql-driver/v2" "github.com/scylladb/go-reflectx" ) // DefaultStrict disables the behavior of forcing queries and iterators to ignore // missing fields for all queries. See Strict below for more information. var DefaultStrict bool // Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities. type Iterx struct { err error *gocql.Iter Mapper *reflectx.Mapper // Cache memory for a rows during iteration in structScan. fields [][]int values []interface{} strict bool structOnly bool applied bool } // Strict forces the iterator to disable ignoring missing fields. In Strict mode // when scanning a struct if result row has a column that cannot be mapped to any // destination field an error is reported. By default such columns are ignored. func (iter *Iterx) Strict() *Iterx { iter.strict = true return iter } // StructOnly forces the iterator to treat a single-argument struct as // non-scannable. This is is useful if you need to scan a row into a struct // that also implements gocql.UDTUnmarshaler or in rare cases gocql.Unmarshaler. func (iter *Iterx) StructOnly() *Iterx { iter.structOnly = true return iter } // Get scans first row into a destination and closes the iterator. // // If the destination type is a struct pointer, then StructScan will be // used. // If the destination is some other type, then the row must only have one column // which can scan into that type. // This includes types that implement gocql.Unmarshaler and gocql.UDTUnmarshaler. // // If you'd like to treat a type that implements gocql.Unmarshaler or // gocql.UDTUnmarshaler as an ordinary struct you should call // StructOnly().Get(dest) instead. // // If no rows were selected, ErrNotFound is returned. func (iter *Iterx) Get(dest interface{}) error { iter.scanAny(dest) _ = iter.Close() return iter.checkErrAndNotFound() } func (iter *Iterx) scanAny(dest interface{}) bool { value := reflect.ValueOf(dest) if value.Kind() != reflect.Ptr { iter.err = fmt.Errorf("expected a pointer but got %T", dest) return false } if value.IsNil() { iter.err = errors.New("expected a pointer but got nil") return false } base := reflectx.Deref(value.Type()) scannable := iter.isScannable(base) if iter.structOnly && scannable { if base.Kind() == reflect.Struct { scannable = false } else { iter.err = structOnlyError(base) return false } } if scannable && len(iter.Columns()) > 1 { iter.err = fmt.Errorf("expected 1 column in result while scanning scannable type %s but got %d", base.Kind(), len(iter.Columns())) return false } if scannable { return iter.scan(value) } return iter.structScan(value) } // Select scans all rows into a destination, which must be a pointer to slice // of any type, and closes the iterator. // // If the destination slice type is a struct, then StructScan will be used // on each row. // If the destination is some other type, then each row must only have one // column which can scan into that type. // This includes types that implement gocql.Unmarshaler and gocql.UDTUnmarshaler. // // If you'd like to treat a type that implements gocql.Unmarshaler or // gocql.UDTUnmarshaler as an ordinary struct you should call // StructOnly().Select(dest) instead. // // If no rows were selected, ErrNotFound is NOT returned. func (iter *Iterx) Select(dest interface{}) error { iter.scanAll(dest) _ = iter.Close() return iter.err } func (iter *Iterx) scanAll(dest interface{}) bool { value := reflect.ValueOf(dest) // json.Unmarshal returns errors for these if value.Kind() != reflect.Ptr { iter.err = fmt.Errorf("expected a pointer but got %T", dest) return false } if value.IsNil() { iter.err = errors.New("expected a pointer but got nil") return false } slice, err := baseType(value.Type(), reflect.Slice) if err != nil { iter.err = err return false } isPtr := slice.Elem().Kind() == reflect.Ptr base := reflectx.Deref(slice.Elem()) scannable := iter.isScannable(base) if iter.structOnly && scannable { if base.Kind() == reflect.Struct { scannable = false } else { iter.err = structOnlyError(base) return false } } // if it's a base type make sure it only has 1 column; if not return an error if scannable && len(iter.Columns()) > 1 { iter.err = fmt.Errorf("expected 1 column in result while scanning scannable type %s but got %d", base.Kind(), len(iter.Columns())) return false } var ( alloc bool v reflect.Value vp reflect.Value ok bool ) for { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) // scan into the struct field pointers if !scannable { ok = iter.structScan(vp) } else { ok = iter.scan(vp) } if !ok { break } // allocate memory for the page data if !alloc { v = reflect.MakeSlice(slice, 0, iter.NumRows()) alloc = true } if isPtr { v = reflect.Append(v, vp) } else { v = reflect.Append(v, reflect.Indirect(vp)) } } // update dest if allocated slice if alloc { reflect.Indirect(value).Set(v) } return true } // isScannable takes the reflect.Type and the actual dest value and returns // whether or not it's Scannable. t is scannable if: // - ptr to t implements gocql.Unmarshaler, gocql.UDTUnmarshaler or UDT // - it is not a struct // - it has no exported fields func (iter *Iterx) isScannable(t reflect.Type) bool { ptr := reflect.PointerTo(t) switch { case ptr.Implements(unmarshallerInterface): return true case ptr.Implements(udtUnmarshallerInterface): return true case ptr.Implements(autoUDTInterface): return true case t.Kind() != reflect.Struct: return true default: return len(iter.Mapper.TypeMap(t).Index) == 0 } } func (iter *Iterx) scan(value reflect.Value) bool { if value.Kind() != reflect.Ptr { panic("value must be a pointer") } return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.strict)) } // StructScan is like gocql.Iter.Scan, but scans a single row into a single // struct. Use this and iterate manually when the memory load of Select() might // be prohibitive. StructScan caches the reflect work of matching up column // positions to fields to avoid that overhead per scan, which means it is not // safe to run StructScan on the same Iterx instance with different struct // types. func (iter *Iterx) StructScan(dest interface{}) bool { value := reflect.ValueOf(dest) if value.Kind() != reflect.Ptr { iter.err = fmt.Errorf("expected a pointer but got %T", dest) return false } if value.IsNil() { iter.err = errors.New("expected a pointer but got nil") return false } return iter.structScan(value) } const appliedColumn = "[applied]" func (iter *Iterx) structScan(value reflect.Value) bool { if value.Kind() != reflect.Ptr { panic("value must be a pointer") } if iter.fields == nil { columns := columnNames(iter.Columns()) cas := len(columns) > 0 && columns[0] == appliedColumn iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) // if we are strict and it's not CAS query and are missing fields, return an error if iter.strict && !cas { if f, err := missingFields(iter.fields); err != nil { iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type()) return false } } iter.values = make([]interface{}, len(columns)) if cas { iter.values[0] = &iter.applied } } if err := iter.fieldsByTraversal(value, iter.fields, iter.values); err != nil { iter.err = err return false } // scan into the struct field pointers and append to our results return iter.Iter.Scan(iter.values...) } // fieldsByName fills a values interface with fields from the passed value based // on the traversals in int. // We write this instead of using FieldsByName to save allocations and map // lookups when iterating over many rows. // Empty traversals will get an interface pointer. func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, values []interface{}) error { value = reflect.Indirect(value) if value.Kind() != reflect.Struct { return fmt.Errorf("expected a struct but got %s", value.Type()) } for i, traversal := range traversals { if len(traversal) == 0 { continue } f := reflectx.FieldByIndexes(value, traversal).Addr() values[i] = udtWrapValue(f, iter.Mapper, iter.strict) } return nil } func columnNames(ci []gocql.ColumnInfo) []string { r := make([]string, len(ci)) for i, column := range ci { r[i] = column.Name } return r } // Scan consumes the next row of the iterator and copies the columns of the // current row into the values pointed at by dest. Use nil as a dest value // to skip the corresponding column. Scan might send additional queries // to the database to retrieve the next set of rows if paging was enabled. // // Scan returns true if the row was successfully unmarshaled or false if the // end of the result set was reached or if an error occurred. Close should // be called afterwards to retrieve any potential errors. func (iter *Iterx) Scan(dest ...interface{}) bool { return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.strict, dest)...) } // Close closes the iterator and returns any errors that happened during // the query or the iteration. func (iter *Iterx) Close() error { err := iter.Iter.Close() if iter.err == nil { iter.err = err } return iter.err } // checkErrAndNotFound handle error and NotFound in one method. func (iter *Iterx) checkErrAndNotFound() error { if iter.err != nil { return iter.err } else if iter.NumRows() == 0 { return gocql.ErrNotFound } return nil }