Releasing query objects in Get/Select could easily lead to double-releases, which can cause dangerous and tricky data races. Remove the query field of Iterx and its usages (all release-related). This is a breaking API change, because it removes the exported method ReleaseQuery. Update documenting examples to demonstrate the deferred query release pattern clients can use to manage query release.
263 lines
6.6 KiB
Go
263 lines
6.6 KiB
Go
// Copyright (C) 2017 ScyllaDB
|
|
// Use of this source code is governed by a ALv2-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package gocqlx
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/gocql/gocql"
|
|
"github.com/jmoiron/sqlx/reflectx"
|
|
)
|
|
|
|
// Get is a convenience function for creating iterator and calling Get.
|
|
func Get(dest interface{}, q *gocql.Query) error {
|
|
return Iter(q).Get(dest)
|
|
}
|
|
|
|
// Select is a convenience function for creating iterator and calling Select.
|
|
func Select(dest interface{}, q *gocql.Query) error {
|
|
return Iter(q).Select(dest)
|
|
}
|
|
|
|
// Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities.
|
|
type Iterx struct {
|
|
*gocql.Iter
|
|
err error
|
|
|
|
unsafe bool
|
|
Mapper *reflectx.Mapper
|
|
// these fields cache memory use for a rows during iteration w/ structScan
|
|
started bool
|
|
fields [][]int
|
|
values []interface{}
|
|
}
|
|
|
|
// Iter creates a new Iterx from gocql.Query using a default mapper.
|
|
func Iter(q *gocql.Query) *Iterx {
|
|
return &Iterx{
|
|
Iter: q.Iter(),
|
|
Mapper: DefaultMapper,
|
|
}
|
|
}
|
|
|
|
// Unsafe forces the iterator to ignore missing fields. By default when scanning
|
|
// a struct if result row has a column that cannot be mapped to any destination
|
|
// field an error is reported. With unsafe such columns are ignored.
|
|
func (iter *Iterx) Unsafe() *Iterx {
|
|
iter.unsafe = true
|
|
return iter
|
|
}
|
|
|
|
// Get scans first row into a destination and closes the iterator. If the
|
|
// destination type is a Struct, then StructScan will be used. If the
|
|
// destination is some other type, then the row must only have one column which
|
|
// can scan into that type. If no rows were selected, ErrNotFound is returned.
|
|
func (iter *Iterx) Get(dest interface{}) error {
|
|
iter.scanAny(dest, false)
|
|
iter.Close()
|
|
|
|
return iter.checkErrAndNotFound()
|
|
}
|
|
|
|
func (iter *Iterx) scanAny(dest interface{}, structOnly bool) bool {
|
|
value := reflect.ValueOf(dest)
|
|
if value.Kind() != reflect.Ptr {
|
|
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
|
return false
|
|
}
|
|
if value.IsNil() {
|
|
iter.err = errors.New("nil pointer passed to StructScan destination")
|
|
return false
|
|
}
|
|
|
|
// no results or query error
|
|
if iter.Iter.NumRows() == 0 {
|
|
return false
|
|
}
|
|
|
|
base := reflectx.Deref(value.Type())
|
|
scannable := isScannable(base)
|
|
|
|
if structOnly && scannable {
|
|
iter.err = structOnlyError(base)
|
|
return false
|
|
}
|
|
|
|
if scannable && len(iter.Columns()) > 1 {
|
|
iter.err = fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns()))
|
|
return false
|
|
}
|
|
|
|
if scannable {
|
|
return iter.Scan(dest)
|
|
}
|
|
|
|
return iter.StructScan(dest)
|
|
}
|
|
|
|
// Select scans all rows into a destination, which must be a slice of any type
|
|
// and closes the iterator. If the destination slice type is a Struct, then
|
|
// StructScan will be used on each row. If the destination is some other type,
|
|
// then each row must only have one column which can scan into that type.
|
|
func (iter *Iterx) Select(dest interface{}) error {
|
|
iter.scanAll(dest, false)
|
|
iter.Close()
|
|
|
|
return iter.err
|
|
}
|
|
|
|
func (iter *Iterx) scanAll(dest interface{}, structOnly bool) bool {
|
|
value := reflect.ValueOf(dest)
|
|
|
|
// json.Unmarshal returns errors for these
|
|
if value.Kind() != reflect.Ptr {
|
|
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
|
return false
|
|
}
|
|
if value.IsNil() {
|
|
iter.err = errors.New("nil pointer passed to StructScan destination")
|
|
return false
|
|
}
|
|
|
|
// no results or query error
|
|
if iter.Iter.NumRows() == 0 {
|
|
return false
|
|
}
|
|
|
|
slice, err := baseType(value.Type(), reflect.Slice)
|
|
if err != nil {
|
|
iter.err = err
|
|
return false
|
|
}
|
|
|
|
isPtr := slice.Elem().Kind() == reflect.Ptr
|
|
base := reflectx.Deref(slice.Elem())
|
|
scannable := isScannable(base)
|
|
|
|
if structOnly && scannable {
|
|
iter.err = structOnlyError(base)
|
|
return false
|
|
}
|
|
|
|
// if it's a base type make sure it only has 1 column; if not return an error
|
|
if scannable && len(iter.Columns()) > 1 {
|
|
iter.err = fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns()))
|
|
return false
|
|
}
|
|
|
|
var (
|
|
alloc bool
|
|
v reflect.Value
|
|
vp reflect.Value
|
|
ok bool
|
|
)
|
|
for {
|
|
// create a new struct type (which returns PtrTo) and indirect it
|
|
vp = reflect.New(base)
|
|
|
|
// scan into the struct field pointers
|
|
if !scannable {
|
|
ok = iter.StructScan(vp.Interface())
|
|
} else {
|
|
ok = iter.Scan(vp.Interface())
|
|
}
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
// allocate memory for the page data
|
|
if !alloc {
|
|
v = reflect.MakeSlice(slice, 0, iter.Iter.NumRows())
|
|
alloc = true
|
|
}
|
|
|
|
if isPtr {
|
|
v = reflect.Append(v, vp)
|
|
} else {
|
|
v = reflect.Append(v, reflect.Indirect(vp))
|
|
}
|
|
}
|
|
|
|
// update dest if allocated slice
|
|
if alloc {
|
|
reflect.Indirect(value).Set(v)
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// StructScan is like gocql.Iter.Scan, but scans a single row into a single
|
|
// Struct. Use this and iterate manually when the memory load of Select() might
|
|
// be prohibitive. StructScan caches the reflect work of matching up column
|
|
// positions to fields to avoid that overhead per scan, which means it is not
|
|
// safe to run StructScan on the same Iterx instance with different struct
|
|
// types.
|
|
func (iter *Iterx) StructScan(dest interface{}) bool {
|
|
v := reflect.ValueOf(dest)
|
|
if v.Kind() != reflect.Ptr {
|
|
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
|
|
return false
|
|
}
|
|
|
|
// no results or query error
|
|
if iter.Iter.NumRows() == 0 {
|
|
return false
|
|
}
|
|
|
|
if !iter.started {
|
|
columns := columnNames(iter.Iter.Columns())
|
|
m := iter.Mapper
|
|
|
|
iter.fields = m.TraversalsByName(v.Type(), columns)
|
|
// if we are not unsafe and are missing fields, return an error
|
|
if !iter.unsafe {
|
|
if f, err := missingFields(iter.fields); err != nil {
|
|
iter.err = fmt.Errorf("missing destination name %q in %T", columns[f], dest)
|
|
return false
|
|
}
|
|
}
|
|
iter.values = make([]interface{}, len(columns))
|
|
iter.started = true
|
|
}
|
|
|
|
err := fieldsByTraversal(v, iter.fields, iter.values, true)
|
|
if err != nil {
|
|
iter.err = err
|
|
return false
|
|
}
|
|
// scan into the struct field pointers and append to our results
|
|
return iter.Iter.Scan(iter.values...)
|
|
}
|
|
|
|
func columnNames(ci []gocql.ColumnInfo) []string {
|
|
r := make([]string, len(ci))
|
|
for i, column := range ci {
|
|
r[i] = column.Name
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Close closes the iterator and returns any errors that happened during
|
|
// the query or the iteration.
|
|
func (iter *Iterx) Close() error {
|
|
err := iter.Iter.Close()
|
|
if iter.err == nil {
|
|
iter.err = err
|
|
}
|
|
return iter.err
|
|
}
|
|
|
|
// checkErrAndNotFound handle error and NotFound in one method.
|
|
func (iter *Iterx) checkErrAndNotFound() error {
|
|
if iter.err != nil {
|
|
return iter.err
|
|
} else if iter.Iter.NumRows() == 0 {
|
|
return gocql.ErrNotFound
|
|
}
|
|
return nil
|
|
}
|