initial version

This commit is contained in:
Michał Matczuk
2017-07-20 15:55:19 +02:00
parent 13fab055f7
commit 6653e63afd
4 changed files with 647 additions and 0 deletions

15
.gitignore vendored Normal file
View File

@@ -0,0 +1,15 @@
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/

212
casssandra_test.go Normal file
View File

@@ -0,0 +1,212 @@
package gocqlx
import (
"math/big"
"reflect"
"strings"
"testing"
"time"
"github.com/gocql/gocql"
"gopkg.in/inf.v0"
)
type FullName struct {
FirstName string
LastName string
}
func (n FullName) MarshalCQL(info gocql.TypeInfo) ([]byte, error) {
return []byte(n.FirstName + " " + n.LastName), nil
}
func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
t := strings.SplitN(string(data), " ", 2)
n.FirstName, n.LastName = t[0], t[1]
return nil
}
func TestScannable(t *testing.T) {
session := createSession(t)
defer session.Close()
if err := createTable(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
m := FullName{"John", "Doe"}
if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v FullName
if err := Get(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil {
t.Fatal("get failed", err)
}
if !reflect.DeepEqual(m, v) {
t.Fatal("not equals")
}
})
t.Run("select", func(t *testing.T) {
var v []FullName
if err := Select(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil {
t.Fatal("get failed", err)
}
if len(v) != 1 {
t.Fatal("select unexpecrted number of rows", len(v))
}
if !reflect.DeepEqual(m, v[0]) {
t.Fatal("not equals")
}
})
t.Run("select ptr", func(t *testing.T) {
var v []*FullName
if err := Select(session.Query(`SELECT testfullname FROM scannable_table`), &v); err != nil {
t.Fatal("get failed", err)
}
if len(v) != 1 {
t.Fatal("select unexpecrted number of rows", len(v))
}
if !reflect.DeepEqual(&m, v[0]) {
t.Fatal("not equals")
}
})
}
func TestStruct(t *testing.T) {
session := createSession(t)
defer session.Close()
if err := createTable(session, `CREATE TABLE gocqlx_test.struct_table (
testuuid timeuuid PRIMARY KEY,
testtimestamp timestamp,
testvarchar varchar,
testbigint bigint,
testblob blob,
testbool boolean,
testfloat float,
testdouble double,
testint int,
testdecimal decimal,
testlist list<text>,
testset set<int>,
testmap map<varchar, varchar>,
testvarint varint,
testinet inet,
testcustom text
)`); err != nil {
t.Fatal("create table:", err)
}
type StructTable struct {
Testuuid gocql.UUID
Testvarchar string
Testbigint int64
Testtimestamp time.Time
Testblob []byte
Testbool bool
Testfloat float32
Testdouble float64
Testint int
Testdecimal *inf.Dec
Testlist []string
Testset []int
Testmap map[string]string
Testvarint *big.Int
Testinet string
Testcustom FullName
}
bigInt := new(big.Int)
if _, ok := bigInt.SetString("830169365738487321165427203929228", 10); !ok {
t.Fatal("failed setting bigint by string")
}
m := StructTable{
Testuuid: gocql.TimeUUID(),
Testvarchar: "Test VarChar",
Testbigint: time.Now().Unix(),
Testtimestamp: time.Now().Truncate(time.Millisecond).UTC(),
Testblob: []byte("test blob"),
Testbool: true,
Testfloat: float32(4.564),
Testdouble: float64(4.815162342),
Testint: 2343,
Testdecimal: inf.NewDec(100, 0),
Testlist: []string{"quux", "foo", "bar", "baz", "quux"},
Testset: []int{1, 2, 3, 4, 5, 6, 7, 8, 9},
Testmap: map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"},
Testvarint: bigInt,
Testinet: "213.212.2.19",
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
}
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
m.Testuuid,
m.Testtimestamp,
m.Testvarchar,
m.Testbigint,
m.Testblob,
m.Testbool,
m.Testfloat,
m.Testdouble,
m.Testint,
m.Testdecimal,
m.Testlist,
m.Testset,
m.Testmap,
m.Testvarint,
m.Testinet,
m.Testcustom).Exec(); err != nil {
t.Fatal("insert:", err)
}
t.Run("get", func(t *testing.T) {
var v StructTable
if err := Get(session.Query(`SELECT * FROM struct_table`), &v); err != nil {
t.Fatal("get failed", err)
}
if !reflect.DeepEqual(m, v) {
t.Fatal("not equals")
}
})
t.Run("select", func(t *testing.T) {
var v []StructTable
if err := Select(session.Query(`SELECT * FROM struct_table`), &v); err != nil {
t.Fatal("select failed", err)
}
if len(v) != 1 {
t.Fatal("select unexpecrted number of rows", len(v))
}
if !reflect.DeepEqual(m, v[0]) {
t.Fatal("not equals")
}
})
t.Run("select ptr", func(t *testing.T) {
var v []*StructTable
if err := Select(session.Query(`SELECT * FROM struct_table`), &v); err != nil {
t.Fatal("select failed", err)
}
if len(v) != 1 {
t.Fatal("select unexpecrted number of rows", len(v))
}
if !reflect.DeepEqual(&m, v[0]) {
t.Fatal("not equals")
}
})
}

112
common_test.go Normal file
View File

@@ -0,0 +1,112 @@
package gocqlx
import (
"flag"
"fmt"
"log"
"strings"
"sync"
"testing"
"time"
"github.com/gocql/gocql"
)
var (
flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagProto = flag.Int("proto", 0, "protcol version")
flagCQL = flag.String("cql", "3.0.0", "CQL version")
flagRF = flag.Int("rf", 1, "replication factor for test keyspace")
flagRetry = flag.Int("retries", 5, "number of times to retry queries")
flagCompressTest = flag.String("compressor", "", "compressor to use")
flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations")
clusterHosts []string
)
func init() {
flag.Parse()
clusterHosts = strings.Split(*flagCluster, ",")
log.SetFlags(log.Lshortfile | log.LstdFlags)
}
var initOnce sync.Once
func createTable(s *gocql.Session, table string) error {
if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil {
log.Printf("error creating table table=%q err=%v\n", table, err)
return err
}
return nil
}
func createCluster() *gocql.ClusterConfig {
cluster := gocql.NewCluster(clusterHosts...)
cluster.ProtoVersion = *flagProto
cluster.CQLVersion = *flagCQL
cluster.Timeout = *flagTimeout
cluster.Consistency = gocql.Quorum
cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow
if *flagRetry > 0 {
cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry}
}
switch *flagCompressTest {
case "snappy":
cluster.Compressor = &gocql.SnappyCompressor{}
case "":
default:
panic("invalid compressor: " + *flagCompressTest)
}
return cluster
}
func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) {
c := *cluster
c.Keyspace = "system"
c.Timeout = 30 * time.Second
session, err := c.CreateSession()
if err != nil {
panic(err)
}
defer session.Close()
defer tb.Log("closing keyspace session")
err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace)
if err != nil {
panic(fmt.Sprintf("unable to drop keyspace: %v", err))
}
err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : %d
}`, keyspace, *flagRF))
if err != nil {
panic(fmt.Sprintf("unable to create keyspace: %v", err))
}
}
func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocql.Session {
// Drop and re-create the keyspace once. Different tests should use their own
// individual tables, but can assume that the table does not exist before.
initOnce.Do(func() {
createKeyspace(tb, cluster, "gocqlx_test")
})
cluster.Keyspace = "gocqlx_test"
session, err := cluster.CreateSession()
if err != nil {
tb.Fatal("createSession:", err)
}
return session
}
func createSession(tb testing.TB) *gocql.Session {
cluster := createCluster()
return createSessionFromCluster(cluster, tb)
}

308
gocqlx.go Normal file
View File

@@ -0,0 +1,308 @@
package gocqlx
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/gocql/gocql"
"github.com/jmoiron/sqlx/reflectx"
)
// DefaultMapper uses `db` tag and strings.ToLower to lowercase struct field
// names. It can be set to whatever you want, but it is encouraged to be set
// before gocqlx is used as name-to-field mappings are cached after first
// use on a type.
var DefaultMapper = reflectx.NewMapperFunc("db", strings.ToLower)
// Get is a convenience function for creating iterator and calling Get on it.
func Get(q *gocql.Query, dest interface{}) error {
return Iter(q).Get(dest)
}
// Select is a convenience function for creating iterator and calling Select on it.
func Select(q *gocql.Query, dest interface{}) error {
return Iter(q).Select(dest)
}
// Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities.
type Iterx struct {
*gocql.Iter
unsafe bool
Mapper *reflectx.Mapper
// these fields cache memory use for a rows during iteration w/ structScan
started bool
fields [][]int
values []interface{}
err error
}
// Iter creates a new Iterx from gocql.Query using a default mapper.
func Iter(q *gocql.Query) *Iterx {
return &Iterx{
Iter: q.Iter(),
Mapper: DefaultMapper,
}
}
// Get scans first row into a destination and closes the iterator. If the
// destination type is a Struct, then StructScan will be used. If the
// destination is some other type, then the row must only have one column which
// can scan into that type.
func (iter *Iterx) Get(dest interface{}) error {
if err := iter.scanAny(dest, false); err != nil {
iter.err = err
}
if err := iter.Close(); err != nil {
iter.err = err
}
return iter.err
}
func (iter *Iterx) scanAny(dest interface{}, structOnly bool) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if v.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
base := reflectx.Deref(v.Type())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
if scannable && len(iter.Columns()) > 1 {
return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns()))
}
if !scannable {
iter.StructScan(dest)
} else {
iter.Scan(dest)
}
return nil
}
// Select scans all rows into a destination, which must be a slice of any type
// and closes the iterator. If the destination slice type is a Struct, then
// StructScan will be used on each row. If the destination is some other type,
// then each row must only have one column which can scan into that type.
func (iter *Iterx) Select(dest interface{}) error {
if err := iter.scanAll(dest, false); err != nil {
iter.err = err
}
if err := iter.Close(); err != nil {
iter.err = err
}
return iter.err
}
func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error {
var v, vp reflect.Value
value := reflect.ValueOf(dest)
// json.Unmarshal returns errors for these
if value.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if value.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
direct := reflect.Indirect(value)
slice, err := baseType(value.Type(), reflect.Slice)
if err != nil {
return err
}
isPtr := slice.Elem().Kind() == reflect.Ptr
base := reflectx.Deref(slice.Elem())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
// if it's a base type make sure it only has 1 column; if not return an error
if scannable && len(iter.Columns()) > 1 {
return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns()))
}
if !scannable {
for {
// create a new struct type (which returns PtrTo) and indirect it
vp = reflect.New(base)
v = reflect.Indirect(vp)
// scan into the struct field pointers and append to our results
if ok := iter.StructScan(vp.Interface()); !ok {
break
}
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, v))
}
}
} else {
for {
vp = reflect.New(base)
if ok := iter.Scan(vp.Interface()); !ok {
break
}
// append
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, reflect.Indirect(vp)))
}
}
}
return iter.Err()
}
// StructScan is like gocql.Scan, but scans a single row into a single Struct.
// Use this and iterate manually when the memory load of Select() might be
// prohibitive. StructScan caches the reflect work of matching up column
// positions to fields to avoid that overhead per scan, which means it is not
// safe to run StructScan on the same Iterx instance with different struct
// types.
func (iter *Iterx) StructScan(dest interface{}) bool {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
iter.err = errors.New("must pass a pointer, not a value, to StructScan destination")
return false
}
if !iter.started {
columns := columnNames(iter.Iter.Columns())
m := iter.Mapper
iter.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(iter.fields); err != nil && !iter.unsafe {
iter.err = fmt.Errorf("missing destination name %s in %T", columns[f], dest)
return false
}
iter.values = make([]interface{}, len(columns))
iter.started = true
}
err := fieldsByTraversal(v, iter.fields, iter.values, true)
if err != nil {
iter.err = err
return false
}
// scan into the struct field pointers and append to our results
return iter.Iter.Scan(iter.values...)
}
func columnNames(ci []gocql.ColumnInfo) []string {
r := make([]string, len(ci))
for i, column := range ci {
r[i] = column.Name
}
return r
}
// Err returns the error encountered while scanning.
func (iter *Iterx) Err() error {
return iter.err
}
// structOnlyError returns an error appropriate for type when a non-scannable
// struct is expected but something else is given
func structOnlyError(t reflect.Type) error {
isStruct := t.Kind() == reflect.Struct
isScanner := reflect.PtrTo(t).Implements(_unmarshallerInterface)
if !isStruct {
return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind())
}
if isScanner {
return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements unmarshaler", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
}
// reflect helpers
var _unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
t = reflectx.Deref(t)
if t.Kind() != expected {
return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
}
return t, nil
}
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. Something is scannable if:
// * it is not a struct
// * it implements gocql.Unmarshaler
// * it has no exported fields
func isScannable(t reflect.Type) bool {
if reflect.PtrTo(t).Implements(_unmarshallerInterface) {
return true
}
if t.Kind() != reflect.Struct {
return true
}
// it's not important that we use the right mapper for this particular object,
// we're only concerned on how many exported fields this struct has
m := DefaultMapper
if len(m.TypeMap(t).Index) == 0 {
return true
}
return false
}
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int. If ptrs is true, return addresses instead of values.
// We write this instead of using FieldsByName to save allocations and map lookups
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}