diff --git a/example_test.go b/example_test.go index 33f3108..22cd792 100644 --- a/example_test.go +++ b/example_test.go @@ -345,8 +345,8 @@ func basicReadScyllaVersion(t *testing.T, session gocqlx.Session) { } // This examples shows how to bind data from a map using "BindMap" function, -// override field name mapping using the "db" tags, and use "Unsafe" function -// to handle situations where driver returns more coluns that we are ready to +// override field name mapping using the "db" tags, with the default mechanism of +// handling situations where driver returns more coluns that we are ready to // consume. func datatypesBlob(t *testing.T, session gocqlx.Session) { t.Helper() @@ -384,9 +384,8 @@ func datatypesBlob(t *testing.T, session gocqlx.Session) { }{} q := qb.Select("examples.blobs").Where(qb.EqLit("k", "1")).Query(session) - // Unsafe is used here to override validation error that check if all - // requested columns are consumed `failed: missing destination name "k" in struct` error - if err := q.Iter().Unsafe().Get(row); err != nil { + // By default missing UDT fields are treated as null instead of failing + if err := q.Iter().Get(row); err != nil { t.Fatal("Get() failed:", err) } diff --git a/iterx.go b/iterx.go index 9ebfedf..8d4b093 100644 --- a/iterx.go +++ b/iterx.go @@ -13,9 +13,9 @@ import ( "github.com/scylladb/go-reflectx" ) -// DefaultUnsafe enables the behavior of forcing queries and iterators to ignore -// missing fields for all queries. See Unsafe below for more information. -var DefaultUnsafe bool +// DefaultStrict disables the behavior of forcing queries and iterators to ignore +// missing fields for all queries. See Strict below for more information. +var DefaultStrict bool // Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities. type Iterx struct { @@ -26,16 +26,16 @@ type Iterx struct { // Cache memory for a rows during iteration in structScan. fields [][]int values []interface{} - unsafe bool + strict bool structOnly bool applied bool } -// Unsafe forces the iterator to ignore missing fields. By default when scanning -// a struct if result row has a column that cannot be mapped to any destination -// field an error is reported. With unsafe such columns are ignored. -func (iter *Iterx) Unsafe() *Iterx { - iter.unsafe = true +// Strict forces the iterator to disable ignoring missing fields. In Strict mode +// when scanning a struct if result row has a column that cannot be mapped to any +// destination field an error is reported. By default such columns are ignored. +func (iter *Iterx) Strict() *Iterx { + iter.strict = true return iter } @@ -228,7 +228,7 @@ func (iter *Iterx) scan(value reflect.Value) bool { if value.Kind() != reflect.Ptr { panic("value must be a pointer") } - return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe)) + return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.strict)) } // StructScan is like gocql.Iter.Scan, but scans a single row into a single @@ -264,8 +264,8 @@ func (iter *Iterx) structScan(value reflect.Value) bool { cas := len(columns) > 0 && columns[0] == appliedColumn iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns) - // if we are not unsafe and it's not CAS query and are missing fields, return an error - if !iter.unsafe && !cas { + // if we are strict and it's not CAS query and are missing fields, return an error + if iter.strict && !cas { if f, err := missingFields(iter.fields); err != nil { iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type()) return false @@ -302,7 +302,7 @@ func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, va continue } f := reflectx.FieldByIndexes(value, traversal).Addr() - values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe) + values[i] = udtWrapValue(f, iter.Mapper, iter.strict) } return nil @@ -325,7 +325,7 @@ func columnNames(ci []gocql.ColumnInfo) []string { // end of the result set was reached or if an error occurred. Close should // be called afterwards to retrieve any potential errors. func (iter *Iterx) Scan(dest ...interface{}) bool { - return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...) + return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.strict, dest)...) } // Close closes the iterator and returns any errors that happened during diff --git a/iterx_test.go b/iterx_test.go index 96de405..3d217bb 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -264,7 +264,7 @@ func TestIterxUDT(t *testing.T) { }) t.Run("insert-bind", func(t *testing.T) { - if err := session.Query(insertStmt, nil).Unsafe().Bind( + if err := session.Query(insertStmt, nil).Bind( testuuid, tc.insert, ).ExecRelease(); err != nil { @@ -273,7 +273,7 @@ func TestIterxUDT(t *testing.T) { // Make sure the UDT was inserted correctly v := FullUDT{} - if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil { + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(&v); err != nil { t.Fatal(err.Error()) } diff(t, tc.expectedOnDB, v) @@ -281,7 +281,7 @@ func TestIterxUDT(t *testing.T) { t.Run("scan", func(t *testing.T) { v := reflect.New(reflect.TypeOf(tc.expected)).Interface() - if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil { + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Scan(v); err != nil { t.Fatal(err.Error()) } diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) @@ -289,7 +289,7 @@ func TestIterxUDT(t *testing.T) { t.Run("get", func(t *testing.T) { v := reflect.New(reflect.TypeOf(tc.expected)).Interface() - if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil { + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { t.Fatal(err.Error()) } diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) @@ -305,7 +305,7 @@ func TestIterxUDT(t *testing.T) { t.Run("insert-bind-struct", func(t *testing.T) { b := makeStruct(testuuid, tc.insert) - if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).BindStruct(b).ExecRelease(); err != nil { t.Fatal(err.Error()) } @@ -320,7 +320,7 @@ func TestIterxUDT(t *testing.T) { t.Run("insert-bind-struct-map", func(t *testing.T) { t.Run("empty-map", func(t *testing.T) { b := makeStruct(testuuid, tc.insert) - if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}). BindStructMap(b, nil).ExecRelease(); err != nil { t.Fatal(err.Error()) } @@ -334,7 +334,7 @@ func TestIterxUDT(t *testing.T) { }) t.Run("empty-struct", func(t *testing.T) { - if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}). BindStructMap(struct{}{}, map[string]interface{}{ "test_uuid": testuuid, "test_udt": tc.insert, @@ -352,7 +352,7 @@ func TestIterxUDT(t *testing.T) { }) t.Run("insert-bind-map", func(t *testing.T) { - if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}). BindMap(map[string]interface{}{ "test_uuid": testuuid, "test_udt": tc.insert, @@ -736,41 +736,41 @@ func TestIterxStructOnlyUDT(t *testing.T) { }) } -func TestIterxUnsafe(t *testing.T) { +func TestIterxStrict(t *testing.T) { session := gocqlxtest.CreateSession(t) defer session.Close() - if err := session.ExecStmt(`CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil { + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.strict_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil { t.Fatal("create table:", err) } - if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil { + if err := session.Query(`INSERT INTO strict_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil { t.Fatal("insert:", err) } - type UnsafeTable struct { + type StrictTable struct { Testtext string } - m := UnsafeTable{ + m := StrictTable{ Testtext: "test", } const ( - stmt = `SELECT * FROM unsafe_table` - golden = "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" + stmt = `SELECT * FROM strict_table` + golden = "missing destination name \"testtextunbound\" in gocqlx_test.StrictTable" ) - t.Run("get", func(t *testing.T) { - var v UnsafeTable - err := session.Query(stmt, nil).Get(&v) + t.Run("get strict", func(t *testing.T) { + var v StrictTable + err := session.Query(stmt, nil).Strict().Get(&v) if err == nil || !strings.HasPrefix(err.Error(), golden) { t.Fatalf("Get() error=%q expected %s", err, golden) } }) - t.Run("select", func(t *testing.T) { - var v []UnsafeTable - err := session.Query(stmt, nil).Select(&v) + t.Run("select strict", func(t *testing.T) { + var v []StrictTable + err := session.Query(stmt, nil).Strict().Select(&v) if err == nil || !strings.HasPrefix(err.Error(), golden) { t.Fatalf("Select() error=%q expected %s", err, golden) } @@ -779,9 +779,9 @@ func TestIterxUnsafe(t *testing.T) { } }) - t.Run("get unsafe", func(t *testing.T) { - var v UnsafeTable - err := session.Query(stmt, nil).Iter().Unsafe().Get(&v) + t.Run("get", func(t *testing.T) { + var v StrictTable + err := session.Query(stmt, nil).Get(&v) if err != nil { t.Fatal("Get() failed:", err) } @@ -790,9 +790,9 @@ func TestIterxUnsafe(t *testing.T) { } }) - t.Run("select unsafe", func(t *testing.T) { - var v []UnsafeTable - err := session.Query(stmt, nil).Iter().Unsafe().Select(&v) + t.Run("select", func(t *testing.T) { + var v []StrictTable + err := session.Query(stmt, nil).Select(&v) if err != nil { t.Fatal("Select() failed:", err) } @@ -804,9 +804,9 @@ func TestIterxUnsafe(t *testing.T) { } }) - t.Run("select default unsafe", func(t *testing.T) { - var v []UnsafeTable - err := session.Query(stmt, nil).Unsafe().Iter().Select(&v) + t.Run("select default", func(t *testing.T) { + var v []StrictTable + err := session.Query(stmt, nil).Iter().Select(&v) if err != nil { t.Fatal("Select() failed:", err) } diff --git a/queryx.go b/queryx.go index 9b67f9b..04b72cf 100644 --- a/queryx.go +++ b/queryx.go @@ -95,7 +95,7 @@ type Queryx struct { Mapper *reflectx.Mapper *gocql.Query Names []string - unsafe bool + strict bool } // Query creates a new Queryx from gocql.Query using a default mapper. @@ -107,7 +107,7 @@ func Query(q *gocql.Query, names []string) *Queryx { Names: names, Mapper: DefaultMapper, tr: DefaultBindTransformer, - unsafe: DefaultUnsafe, + strict: DefaultStrict, } } @@ -211,7 +211,7 @@ func (q *Queryx) bindMapArgs(arg map[string]interface{}) ([]interface{}, error) // Bind sets query arguments of query. This can also be used to rebind new query arguments // to an existing query instance. func (q *Queryx) Bind(v ...interface{}) *Queryx { - q.Query.Bind(udtWrapSlice(q.Mapper, q.unsafe, v)...) + q.Query.Bind(udtWrapSlice(q.Mapper, q.strict, v)...) return q } @@ -219,7 +219,7 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx { // row into the values pointed at by dest and discards the rest. If no rows // were selected, ErrNotFound is returned. func (q *Queryx) Scan(v ...interface{}) error { - return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...) + return q.Query.Scan(udtWrapSlice(q.Mapper, q.strict, v)...) } // Err returns any binding errors. @@ -351,14 +351,14 @@ func (q *Queryx) Iter() *Iterx { return &Iterx{ Iter: q.Query.Iter(), Mapper: q.Mapper, - unsafe: q.unsafe, + strict: q.strict, } } -// Unsafe forces the query and iterators to ignore missing fields. By default when scanning -// a struct if result row has a column that cannot be mapped to any destination -// field an error is reported. With unsafe such columns are ignored. -func (q *Queryx) Unsafe() *Queryx { - q.unsafe = true +// Strict forces the query and iterators to report an error if there are missing fields. +// By default when scanning a struct if result row has a column that cannot be mapped to +// any destination it is ignored. With strict error is reported. +func (q *Queryx) Strict() *Queryx { + q.strict = true return q } diff --git a/session.go b/session.go index cb82c3c..ea92478 100644 --- a/session.go +++ b/session.go @@ -51,7 +51,7 @@ func (s Session) ContextQuery(ctx context.Context, stmt string, names []string) Names: names, Mapper: s.Mapper, tr: DefaultBindTransformer, - unsafe: DefaultUnsafe, + strict: DefaultStrict, } } @@ -66,7 +66,7 @@ func (s Session) Query(stmt string, names []string) *Queryx { Names: names, Mapper: s.Mapper, tr: DefaultBindTransformer, - unsafe: DefaultUnsafe, + strict: DefaultStrict, } } diff --git a/udt.go b/udt.go index 2785151..cc11f4e 100644 --- a/udt.go +++ b/udt.go @@ -26,14 +26,14 @@ var ( type udt struct { field map[string]reflect.Value value reflect.Value - unsafe bool + strict bool } -func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { +func makeUDT(value reflect.Value, mapper *reflectx.Mapper, strict bool) udt { return udt{ value: value, field: mapper.FieldMap(value), - unsafe: unsafe, + strict: strict, } } @@ -42,7 +42,7 @@ func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { if ok { return gocql.Marshal(info, value.Interface()) } - if u.unsafe { + if !u.strict { return nil, nil } return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) @@ -53,25 +53,25 @@ func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { if ok { return gocql.Unmarshal(info, data, value.Addr().Interface()) } - if u.unsafe { + if !u.strict { return nil } return fmt.Errorf("missing name %q in %s", name, u.value.Type()) } // udtWrapValue adds UDT wrapper if needed. -func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} { +func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, strict bool) interface{} { if value.Type().Implements(autoUDTInterface) { - return makeUDT(value, mapper, unsafe) + return makeUDT(value, mapper, strict) } return value.Interface() } // udtWrapSlice adds UDT wrapper if needed. -func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} { +func udtWrapSlice(mapper *reflectx.Mapper, strict bool, v []interface{}) []interface{} { for i := range v { if _, ok := v[i].(UDT); ok { - v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe) + v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, strict) } } return v