BindStructMap

This commit is contained in:
Michał Matczuk
2017-08-01 13:29:52 +02:00
parent 906f9433fe
commit 711e3369d6
4 changed files with 71 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
// +build integration // +build all integration
package gocqlx_test package gocqlx_test

View File

@@ -4,6 +4,7 @@ package gocqlx_test
import ( import (
"testing" "testing"
"time"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/scylladb/gocqlx" "github.com/scylladb/gocqlx"
@@ -134,6 +135,17 @@ func TestExample(t *testing.T) {
mustExec(q.Query) mustExec(q.Query)
} }
// Insert with TTL
{
q := Query(qb.Insert("gocqlx_test.person").Columns("first_name", "last_name", "email").TTL().ToCql())
if err := q.BindStructMap(p, map[string]interface{}{
"_ttl": qb.TTL(86400 * time.Second),
}); err != nil {
t.Fatal("bind:", err)
}
mustExec(q.Query)
}
// Update // Update
{ {
p.Email = append(p.Email, "patricia1.citzen@gocqlx_test.com") p.Email = append(p.Email, "patricia1.citzen@gocqlx_test.com")

View File

@@ -91,9 +91,10 @@ func Query(q *gocql.Query, names []string) Queryx {
} }
} }
// BindStruct binds query named parameters using mapper. // BindStruct binds query named parameters to values from arg using mapper. If
// value cannot be found error is reported.
func (q Queryx) BindStruct(arg interface{}) error { func (q Queryx) BindStruct(arg interface{}) error {
arglist, err := bindStructArgs(q.Names, arg, q.Mapper) arglist, err := bindStructArgs(q.Names, arg, nil, q.Mapper)
if err != nil { if err != nil {
return err return err
} }
@@ -103,22 +104,41 @@ func (q Queryx) BindStruct(arg interface{}) error {
return nil return nil
} }
func bindStructArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { // BindStructMap binds query named parameters to values from arg0 and arg1
// using a mapper. If value cannot be found in arg0 it's looked up in arg1
// before reporting an error.
func (q Queryx) BindStructMap(arg0 interface{}, arg1 map[string]interface{}) error {
arglist, err := bindStructArgs(q.Names, arg0, arg1, q.Mapper)
if err != nil {
return err
}
q.Bind(arglist...)
return nil
}
func bindStructArgs(names []string, arg0 interface{}, arg1 map[string]interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names)) arglist := make([]interface{}, 0, len(names))
// grab the indirected value of arg // grab the indirected value of arg
v := reflect.ValueOf(arg) v := reflect.ValueOf(arg0)
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { for v = reflect.ValueOf(arg0); v.Kind() == reflect.Ptr; {
v = v.Elem() v = v.Elem()
} }
fields := m.TraversalsByName(v.Type(), names) fields := m.TraversalsByName(v.Type(), names)
for i, t := range fields { for i, t := range fields {
if len(t) == 0 { if len(t) != 0 {
return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
} else {
val, ok := arg1[names[i]]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %#v and %#v", names[i], arg0, arg1)
}
arglist = append(arglist, val)
} }
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
} }
return arglist, nil return arglist, nil

View File

@@ -83,7 +83,7 @@ func TestBindStruct(t *testing.T) {
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
names := []string{"name", "age", "first", "last"} names := []string{"name", "age", "first", "last"}
args, err := bindStructArgs(names, v, DefaultMapper) args, err := bindStructArgs(names, v, nil, DefaultMapper)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -94,8 +94,34 @@ func TestBindStruct(t *testing.T) {
}) })
t.Run("error", func(t *testing.T) { t.Run("error", func(t *testing.T) {
names := []string{"name", "first", "not_found"} names := []string{"name", "age", "first", "not_found"}
_, err := bindStructArgs(names, v, DefaultMapper) _, err := bindStructArgs(names, v, nil, DefaultMapper)
if err == nil {
t.Fatal("unexpected error")
}
})
t.Run("fallback", func(t *testing.T) {
names := []string{"name", "age", "first", "not_found"}
m := map[string]interface{}{
"not_found": "last",
}
args, err := bindStructArgs(names, v, m, DefaultMapper)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(args, []interface{}{"name", 30, "first", "last"}); diff != "" {
t.Error("args mismatch", diff)
}
})
t.Run("fallback error", func(t *testing.T) {
names := []string{"name", "age", "first", "not_found", "really_not_found"}
m := map[string]interface{}{
"not_found": "last",
}
_, err := bindStructArgs(names, v, m, DefaultMapper)
if err == nil { if err == nil {
t.Fatal("unexpected error") t.Fatal("unexpected error")
} }