diff --git a/benchmark_test.go b/benchmark_test.go index f7755b1..7852971 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -1,4 +1,4 @@ -// +build integration +// +build all integration package gocqlx_test diff --git a/example_test.go b/example_test.go index 15c0f08..9aede55 100644 --- a/example_test.go +++ b/example_test.go @@ -4,6 +4,7 @@ package gocqlx_test import ( "testing" + "time" "github.com/gocql/gocql" "github.com/scylladb/gocqlx" @@ -134,6 +135,17 @@ func TestExample(t *testing.T) { mustExec(q.Query) } + // Insert with TTL + { + q := Query(qb.Insert("gocqlx_test.person").Columns("first_name", "last_name", "email").TTL().ToCql()) + if err := q.BindStructMap(p, map[string]interface{}{ + "_ttl": qb.TTL(86400 * time.Second), + }); err != nil { + t.Fatal("bind:", err) + } + mustExec(q.Query) + } + // Update { p.Email = append(p.Email, "patricia1.citzen@gocqlx_test.com") diff --git a/queryx.go b/queryx.go index d7e7da0..f26392e 100644 --- a/queryx.go +++ b/queryx.go @@ -91,9 +91,10 @@ func Query(q *gocql.Query, names []string) Queryx { } } -// BindStruct binds query named parameters using mapper. +// BindStruct binds query named parameters to values from arg using mapper. If +// value cannot be found error is reported. func (q Queryx) BindStruct(arg interface{}) error { - arglist, err := bindStructArgs(q.Names, arg, q.Mapper) + arglist, err := bindStructArgs(q.Names, arg, nil, q.Mapper) if err != nil { return err } @@ -103,22 +104,41 @@ func (q Queryx) BindStruct(arg interface{}) error { return nil } -func bindStructArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { +// BindStructMap binds query named parameters to values from arg0 and arg1 +// using a mapper. If value cannot be found in arg0 it's looked up in arg1 +// before reporting an error. +func (q Queryx) BindStructMap(arg0 interface{}, arg1 map[string]interface{}) error { + arglist, err := bindStructArgs(q.Names, arg0, arg1, q.Mapper) + if err != nil { + return err + } + + q.Bind(arglist...) + + return nil +} + +func bindStructArgs(names []string, arg0 interface{}, arg1 map[string]interface{}, m *reflectx.Mapper) ([]interface{}, error) { arglist := make([]interface{}, 0, len(names)) // grab the indirected value of arg - v := reflect.ValueOf(arg) - for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { + v := reflect.ValueOf(arg0) + for v = reflect.ValueOf(arg0); v.Kind() == reflect.Ptr; { v = v.Elem() } fields := m.TraversalsByName(v.Type(), names) for i, t := range fields { - if len(t) == 0 { - return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg) + if len(t) != 0 { + val := reflectx.FieldByIndexesReadOnly(v, t) + arglist = append(arglist, val.Interface()) + } else { + val, ok := arg1[names[i]] + if !ok { + return arglist, fmt.Errorf("could not find name %s in %#v and %#v", names[i], arg0, arg1) + } + arglist = append(arglist, val) } - val := reflectx.FieldByIndexesReadOnly(v, t) - arglist = append(arglist, val.Interface()) } return arglist, nil diff --git a/queryx_test.go b/queryx_test.go index 2c74f85..df0f212 100644 --- a/queryx_test.go +++ b/queryx_test.go @@ -83,7 +83,7 @@ func TestBindStruct(t *testing.T) { t.Run("simple", func(t *testing.T) { names := []string{"name", "age", "first", "last"} - args, err := bindStructArgs(names, v, DefaultMapper) + args, err := bindStructArgs(names, v, nil, DefaultMapper) if err != nil { t.Fatal(err) } @@ -94,8 +94,34 @@ func TestBindStruct(t *testing.T) { }) t.Run("error", func(t *testing.T) { - names := []string{"name", "first", "not_found"} - _, err := bindStructArgs(names, v, DefaultMapper) + names := []string{"name", "age", "first", "not_found"} + _, err := bindStructArgs(names, v, nil, DefaultMapper) + if err == nil { + t.Fatal("unexpected error") + } + }) + + t.Run("fallback", func(t *testing.T) { + names := []string{"name", "age", "first", "not_found"} + m := map[string]interface{}{ + "not_found": "last", + } + args, err := bindStructArgs(names, v, m, DefaultMapper) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(args, []interface{}{"name", 30, "first", "last"}); diff != "" { + t.Error("args mismatch", diff) + } + }) + + t.Run("fallback error", func(t *testing.T) { + names := []string{"name", "age", "first", "not_found", "really_not_found"} + m := map[string]interface{}{ + "not_found": "last", + } + _, err := bindStructArgs(names, v, m, DefaultMapper) if err == nil { t.Fatal("unexpected error") }