From 6653e63afd7215786271829a5a95945b99113784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Thu, 20 Jul 2017 15:55:19 +0200 Subject: [PATCH] initial version --- .gitignore | 15 +++ casssandra_test.go | 212 +++++++++++++++++++++++++++++++ common_test.go | 112 +++++++++++++++++ gocqlx.go | 308 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 647 insertions(+) create mode 100644 .gitignore create mode 100644 casssandra_test.go create mode 100644 common_test.go create mode 100644 gocqlx.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ada6c14 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + diff --git a/casssandra_test.go b/casssandra_test.go new file mode 100644 index 0000000..101ae6c --- /dev/null +++ b/casssandra_test.go @@ -0,0 +1,212 @@ +package gocqlx + +import ( + "math/big" + "reflect" + "strings" + "testing" + "time" + + "github.com/gocql/gocql" + + "gopkg.in/inf.v0" +) + +type FullName struct { + FirstName string + LastName string +} + +func (n FullName) MarshalCQL(info gocql.TypeInfo) ([]byte, error) { + return []byte(n.FirstName + " " + n.LastName), nil +} + +func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { + t := strings.SplitN(string(data), " ", 2) + n.FirstName, n.LastName = t[0], t[1] + return nil +} + +func TestScannable(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil { + t.Fatal("create table:", err) + } + m := FullName{"John", "Doe"} + + if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil { + t.Fatal("insert:", err) + } + + t.Run("get", func(t *testing.T) { + var v FullName + if err := Get(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil { + t.Fatal("get failed", err) + } + + if !reflect.DeepEqual(m, v) { + t.Fatal("not equals") + } + }) + + t.Run("select", func(t *testing.T) { + var v []FullName + if err := Select(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil { + t.Fatal("get failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(m, v[0]) { + t.Fatal("not equals") + } + }) + + t.Run("select ptr", func(t *testing.T) { + var v []*FullName + if err := Select(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil { + t.Fatal("get failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(&m, v[0]) { + t.Fatal("not equals") + } + }) +} + +func TestStruct(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocqlx_test.struct_table ( + testuuid timeuuid PRIMARY KEY, + testtimestamp timestamp, + testvarchar varchar, + testbigint bigint, + testblob blob, + testbool boolean, + testfloat float, + testdouble double, + testint int, + testdecimal decimal, + testlist list, + testset set, + testmap map, + testvarint varint, + testinet inet, + testcustom text + + )`); err != nil { + t.Fatal("create table:", err) + } + + type StructTable struct { + Testuuid gocql.UUID + Testvarchar string + Testbigint int64 + Testtimestamp time.Time + Testblob []byte + Testbool bool + Testfloat float32 + Testdouble float64 + Testint int + Testdecimal *inf.Dec + Testlist []string + Testset []int + Testmap map[string]string + Testvarint *big.Int + Testinet string + Testcustom FullName + } + + bigInt := new(big.Int) + if _, ok := bigInt.SetString("830169365738487321165427203929228", 10); !ok { + t.Fatal("failed setting bigint by string") + } + + m := StructTable{ + Testuuid: gocql.TimeUUID(), + Testvarchar: "Test VarChar", + Testbigint: time.Now().Unix(), + Testtimestamp: time.Now().Truncate(time.Millisecond).UTC(), + Testblob: []byte("test blob"), + Testbool: true, + Testfloat: float32(4.564), + Testdouble: float64(4.815162342), + Testint: 2343, + Testdecimal: inf.NewDec(100, 0), + Testlist: []string{"quux", "foo", "bar", "baz", "quux"}, + Testset: []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + Testmap: map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"}, + Testvarint: bigInt, + Testinet: "213.212.2.19", + Testcustom: FullName{FirstName: "John", LastName: "Doe"}, + } + + if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + m.Testuuid, + m.Testtimestamp, + m.Testvarchar, + m.Testbigint, + m.Testblob, + m.Testbool, + m.Testfloat, + m.Testdouble, + m.Testint, + m.Testdecimal, + m.Testlist, + m.Testset, + m.Testmap, + m.Testvarint, + m.Testinet, + m.Testcustom).Exec(); err != nil { + t.Fatal("insert:", err) + } + + t.Run("get", func(t *testing.T) { + var v StructTable + if err := Get(session.Query(`SELECT * FROM struct_table`), &v); err != nil { + t.Fatal("get failed", err) + } + + if !reflect.DeepEqual(m, v) { + t.Fatal("not equals") + } + }) + + t.Run("select", func(t *testing.T) { + var v []StructTable + if err := Select(session.Query(`SELECT * FROM struct_table`), &v); err != nil { + t.Fatal("select failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(m, v[0]) { + t.Fatal("not equals") + } + }) + + t.Run("select ptr", func(t *testing.T) { + var v []*StructTable + if err := Select(session.Query(`SELECT * FROM struct_table`), &v); err != nil { + t.Fatal("select failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(&m, v[0]) { + t.Fatal("not equals") + } + }) +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..15a79a0 --- /dev/null +++ b/common_test.go @@ -0,0 +1,112 @@ +package gocqlx + +import ( + "flag" + "fmt" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/gocql/gocql" +) + +var ( + flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + flagProto = flag.Int("proto", 0, "protcol version") + flagCQL = flag.String("cql", "3.0.0", "CQL version") + flagRF = flag.Int("rf", 1, "replication factor for test keyspace") + flagRetry = flag.Int("retries", 5, "number of times to retry queries") + flagCompressTest = flag.String("compressor", "", "compressor to use") + flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") + + clusterHosts []string +) + +func init() { + flag.Parse() + clusterHosts = strings.Split(*flagCluster, ",") + log.SetFlags(log.Lshortfile | log.LstdFlags) +} + +var initOnce sync.Once + +func createTable(s *gocql.Session, table string) error { + if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil { + log.Printf("error creating table table=%q err=%v\n", table, err) + return err + } + + return nil +} + +func createCluster() *gocql.ClusterConfig { + cluster := gocql.NewCluster(clusterHosts...) + cluster.ProtoVersion = *flagProto + cluster.CQLVersion = *flagCQL + cluster.Timeout = *flagTimeout + cluster.Consistency = gocql.Quorum + cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow + if *flagRetry > 0 { + cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry} + } + + switch *flagCompressTest { + case "snappy": + cluster.Compressor = &gocql.SnappyCompressor{} + case "": + default: + panic("invalid compressor: " + *flagCompressTest) + } + + return cluster +} + +func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) { + c := *cluster + c.Keyspace = "system" + c.Timeout = 30 * time.Second + session, err := c.CreateSession() + if err != nil { + panic(err) + } + defer session.Close() + defer tb.Log("closing keyspace session") + + err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) + if err != nil { + panic(fmt.Sprintf("unable to drop keyspace: %v", err)) + } + + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : %d + }`, keyspace, *flagRF)) + + if err != nil { + panic(fmt.Sprintf("unable to create keyspace: %v", err)) + } +} + +func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocql.Session { + // Drop and re-create the keyspace once. Different tests should use their own + // individual tables, but can assume that the table does not exist before. + initOnce.Do(func() { + createKeyspace(tb, cluster, "gocqlx_test") + }) + + cluster.Keyspace = "gocqlx_test" + session, err := cluster.CreateSession() + if err != nil { + tb.Fatal("createSession:", err) + } + + return session +} + +func createSession(tb testing.TB) *gocql.Session { + cluster := createCluster() + return createSessionFromCluster(cluster, tb) +} diff --git a/gocqlx.go b/gocqlx.go new file mode 100644 index 0000000..b3bb9f1 --- /dev/null +++ b/gocqlx.go @@ -0,0 +1,308 @@ +package gocqlx + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/gocql/gocql" + "github.com/jmoiron/sqlx/reflectx" +) + +// DefaultMapper uses `db` tag and strings.ToLower to lowercase struct field +// names. It can be set to whatever you want, but it is encouraged to be set +// before gocqlx is used as name-to-field mappings are cached after first +// use on a type. +var DefaultMapper = reflectx.NewMapperFunc("db", strings.ToLower) + +// Get is a convenience function for creating iterator and calling Get on it. +func Get(q *gocql.Query, dest interface{}) error { + return Iter(q).Get(dest) +} + +// Select is a convenience function for creating iterator and calling Select on it. +func Select(q *gocql.Query, dest interface{}) error { + return Iter(q).Select(dest) +} + +// Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities. +type Iterx struct { + *gocql.Iter + unsafe bool + Mapper *reflectx.Mapper + // these fields cache memory use for a rows during iteration w/ structScan + started bool + fields [][]int + values []interface{} + err error +} + +// Iter creates a new Iterx from gocql.Query using a default mapper. +func Iter(q *gocql.Query) *Iterx { + return &Iterx{ + Iter: q.Iter(), + Mapper: DefaultMapper, + } +} + +// Get scans first row into a destination and closes the iterator. If the +// destination type is a Struct, then StructScan will be used. If the +// destination is some other type, then the row must only have one column which +// can scan into that type. +func (iter *Iterx) Get(dest interface{}) error { + if err := iter.scanAny(dest, false); err != nil { + iter.err = err + } + + if err := iter.Close(); err != nil { + iter.err = err + } + + return iter.err +} + +func (iter *Iterx) scanAny(dest interface{}, structOnly bool) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if v.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + + base := reflectx.Deref(v.Type()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + if scannable && len(iter.Columns()) > 1 { + return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns())) + } + + if !scannable { + iter.StructScan(dest) + } else { + iter.Scan(dest) + } + + return nil +} + +// Select scans all rows into a destination, which must be a slice of any type +// and closes the iterator. If the destination slice type is a Struct, then +// StructScan will be used on each row. If the destination is some other type, +// then each row must only have one column which can scan into that type. +func (iter *Iterx) Select(dest interface{}) error { + if err := iter.scanAll(dest, false); err != nil { + iter.err = err + } + + if err := iter.Close(); err != nil { + iter.err = err + } + + return iter.err +} + +func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error { + var v, vp reflect.Value + + value := reflect.ValueOf(dest) + + // json.Unmarshal returns errors for these + if value.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + if value.IsNil() { + return errors.New("nil pointer passed to StructScan destination") + } + direct := reflect.Indirect(value) + + slice, err := baseType(value.Type(), reflect.Slice) + if err != nil { + return err + } + + isPtr := slice.Elem().Kind() == reflect.Ptr + base := reflectx.Deref(slice.Elem()) + scannable := isScannable(base) + + if structOnly && scannable { + return structOnlyError(base) + } + + // if it's a base type make sure it only has 1 column; if not return an error + if scannable && len(iter.Columns()) > 1 { + return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns())) + } + + if !scannable { + for { + // create a new struct type (which returns PtrTo) and indirect it + vp = reflect.New(base) + v = reflect.Indirect(vp) + // scan into the struct field pointers and append to our results + if ok := iter.StructScan(vp.Interface()); !ok { + break + } + + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, v)) + } + } + } else { + for { + vp = reflect.New(base) + if ok := iter.Scan(vp.Interface()); !ok { + break + } + + // append + if isPtr { + direct.Set(reflect.Append(direct, vp)) + } else { + direct.Set(reflect.Append(direct, reflect.Indirect(vp))) + } + } + } + + return iter.Err() +} + +// StructScan is like gocql.Scan, but scans a single row into a single Struct. +// Use this and iterate manually when the memory load of Select() might be +// prohibitive. StructScan caches the reflect work of matching up column +// positions to fields to avoid that overhead per scan, which means it is not +// safe to run StructScan on the same Iterx instance with different struct +// types. +func (iter *Iterx) StructScan(dest interface{}) bool { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + iter.err = errors.New("must pass a pointer, not a value, to StructScan destination") + return false + } + + if !iter.started { + columns := columnNames(iter.Iter.Columns()) + m := iter.Mapper + + iter.fields = m.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(iter.fields); err != nil && !iter.unsafe { + iter.err = fmt.Errorf("missing destination name %s in %T", columns[f], dest) + return false + } + iter.values = make([]interface{}, len(columns)) + iter.started = true + } + + err := fieldsByTraversal(v, iter.fields, iter.values, true) + if err != nil { + iter.err = err + return false + } + // scan into the struct field pointers and append to our results + return iter.Iter.Scan(iter.values...) +} + +func columnNames(ci []gocql.ColumnInfo) []string { + r := make([]string, len(ci)) + for i, column := range ci { + r[i] = column.Name + } + return r +} + +// Err returns the error encountered while scanning. +func (iter *Iterx) Err() error { + return iter.err +} + +// structOnlyError returns an error appropriate for type when a non-scannable +// struct is expected but something else is given +func structOnlyError(t reflect.Type) error { + isStruct := t.Kind() == reflect.Struct + isScanner := reflect.PtrTo(t).Implements(_unmarshallerInterface) + if !isStruct { + return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) + } + if isScanner { + return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements unmarshaler", t.Name()) + } + return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) +} + +// reflect helpers + +var _unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem() + +func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { + t = reflectx.Deref(t) + if t.Kind() != expected { + return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) + } + return t, nil +} + +// isScannable takes the reflect.Type and the actual dest value and returns +// whether or not it's Scannable. Something is scannable if: +// * it is not a struct +// * it implements gocql.Unmarshaler +// * it has no exported fields +func isScannable(t reflect.Type) bool { + if reflect.PtrTo(t).Implements(_unmarshallerInterface) { + return true + } + if t.Kind() != reflect.Struct { + return true + } + + // it's not important that we use the right mapper for this particular object, + // we're only concerned on how many exported fields this struct has + m := DefaultMapper + if len(m.TypeMap(t).Index) == 0 { + return true + } + return false +} + +// fieldsByName fills a values interface with fields from the passed value based +// on the traversals in int. If ptrs is true, return addresses instead of values. +// We write this instead of using FieldsByName to save allocations and map lookups +// when iterating over many rows. Empty traversals will get an interface pointer. +// Because of the necessity of requesting ptrs or values, it's considered a bit too +// specialized for inclusion in reflectx itself. +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return errors.New("argument not a struct") + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + values[i] = new(interface{}) + continue + } + f := reflectx.FieldByIndexes(v, traversal) + if ptrs { + values[i] = f.Addr().Interface() + } else { + values[i] = f.Interface() + } + } + return nil +} + +func missingFields(transversals [][]int) (field int, err error) { + for i, t := range transversals { + if len(t) == 0 { + return i, errors.New("missing field") + } + } + return 0, nil +}