From e079a7b36a57ccd1b09122d1adc72abad4545a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Wed, 23 Aug 2017 17:56:05 +0200 Subject: [PATCH] iter: return ErrNotFound for no results --- integration_test.go | 198 ++++++++++++++++++++++++++++++++------------ iterx.go | 90 ++++++++++++-------- 2 files changed, 200 insertions(+), 88 deletions(-) diff --git a/integration_test.go b/integration_test.go index 07566ae..002c18a 100644 --- a/integration_test.go +++ b/integration_test.go @@ -30,60 +30,6 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { return nil } -func TestScannable(t *testing.T) { - session := createSession(t) - defer session.Close() - if err := createTable(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil { - t.Fatal("create table:", err) - } - m := FullName{"John", "Doe"} - - if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil { - t.Fatal("insert:", err) - } - - t.Run("get", func(t *testing.T) { - var v FullName - if err := gocqlx.Get(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { - t.Fatal("get failed", err) - } - - if !reflect.DeepEqual(m, v) { - t.Fatal("not equals") - } - }) - - t.Run("select", func(t *testing.T) { - var v []FullName - if err := gocqlx.Select(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { - t.Fatal("get failed", err) - } - - if len(v) != 1 { - t.Fatal("select unexpecrted number of rows", len(v)) - } - - if !reflect.DeepEqual(m, v[0]) { - t.Fatal("not equals") - } - }) - - t.Run("select ptr", func(t *testing.T) { - var v []*FullName - if err := gocqlx.Select(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { - t.Fatal("get failed", err) - } - - if len(v) != 1 { - t.Fatal("select unexpecrted number of rows", len(v)) - } - - if !reflect.DeepEqual(&m, v[0]) { - t.Fatal("not equals") - } - }) -} - func TestStruct(t *testing.T) { session := createSession(t) defer session.Close() @@ -213,3 +159,147 @@ func TestStruct(t *testing.T) { } }) } + +func TestScannable(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocqlx_test.scannable_table (testfullname text PRIMARY KEY)`); err != nil { + t.Fatal("create table:", err) + } + m := FullName{"John", "Doe"} + + if err := session.Query(`INSERT INTO scannable_table (testfullname) values (?)`, m).Exec(); err != nil { + t.Fatal("insert:", err) + } + + t.Run("get", func(t *testing.T) { + var v FullName + if err := gocqlx.Get(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { + t.Fatal("get failed", err) + } + + if !reflect.DeepEqual(m, v) { + t.Fatal("not equals") + } + }) + + t.Run("select", func(t *testing.T) { + var v []FullName + if err := gocqlx.Select(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { + t.Fatal("get failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(m, v[0]) { + t.Fatal("not equals") + } + }) + + t.Run("select ptr", func(t *testing.T) { + var v []*FullName + if err := gocqlx.Select(&v, session.Query(`SELECT testfullname FROM scannable_table`)); err != nil { + t.Fatal("get failed", err) + } + + if len(v) != 1 { + t.Fatal("select unexpecrted number of rows", len(v)) + } + + if !reflect.DeepEqual(&m, v[0]) { + t.Fatal("not equals") + } + }) +} + +func TestUnsafe(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil { + t.Fatal("create table:", err) + } + if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, "test", "test").Exec(); err != nil { + t.Fatal("insert:", err) + } + + type UnsafeTable struct { + Testtext string + } + + t.Run("safe get", func(t *testing.T) { + var v UnsafeTable + i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) + if err := i.Get(&v); err == nil || err.Error() != "missing destination name testtextunbound in *gocqlx_test.UnsafeTable" { + t.Fatal("expected ErrNotFound", "got", err) + } + }) + + t.Run("safe select", func(t *testing.T) { + var v []UnsafeTable + i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) + if err := i.Select(&v); err == nil || err.Error() != "missing destination name testtextunbound in *gocqlx_test.UnsafeTable" { + t.Fatal("expected ErrNotFound", "got", err) + } + if cap(v) > 0 { + t.Fatal("side effect alloc") + } + }) + + t.Run("unsafe get", func(t *testing.T) { + var v UnsafeTable + i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) + if err := i.Unsafe().Get(&v); err != nil { + t.Fatal(err) + } + if v.Testtext != "test" { + t.Fatal("get failed") + } + }) + + t.Run("unsafe select", func(t *testing.T) { + var v []UnsafeTable + i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`)) + if err := i.Unsafe().Select(&v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatal("select failed") + } + if v[0].Testtext != "test" { + t.Fatal("select failed") + } + }) +} + +func TestNotFound(t *testing.T) { + session := createSession(t) + defer session.Close() + if err := createTable(session, `CREATE TABLE gocqlx_test.not_found_table (testtext text PRIMARY KEY)`); err != nil { + t.Fatal("create table:", err) + } + + type NotFoundTable struct { + Testtext string + } + + t.Run("get", func(t *testing.T) { + var v NotFoundTable + i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`)) + if err := i.Get(&v); err != gocql.ErrNotFound { + t.Fatal("expected ErrNotFound", "got", err) + } + }) + + t.Run("select", func(t *testing.T) { + var v []NotFoundTable + i := gocqlx.Iter(session.Query(`SELECT * FROM not_found_table`)) + if err := i.Select(&v); err != gocql.ErrNotFound { + t.Fatal("expected ErrNotFound", "got", err) + } + if cap(v) > 0 { + t.Fatal("side effect alloc") + } + }) +} diff --git a/iterx.go b/iterx.go index e92864e..860d845 100644 --- a/iterx.go +++ b/iterx.go @@ -53,105 +53,114 @@ func (iter *Iterx) Unsafe() *Iterx { // Get scans first row into a destination and closes the iterator. If the // destination type is a Struct, then StructScan will be used. If the // destination is some other type, then the row must only have one column which -// can scan into that type. +// can scan into that type. If no rows were selected, ErrNotFound is returned. func (iter *Iterx) Get(dest interface{}) error { if iter.query == nil { return errors.New("using released query") } - if err := iter.scanAny(dest, false); err != nil { - iter.err = err - } - + iter.scanAny(dest, false) iter.Close() iter.ReleaseQuery() return iter.err } -func (iter *Iterx) scanAny(dest interface{}, structOnly bool) error { +func (iter *Iterx) scanAny(dest interface{}, structOnly bool) bool { value := reflect.ValueOf(dest) if value.Kind() != reflect.Ptr { - return errors.New("must pass a pointer, not a value, to StructScan destination") + iter.err = errors.New("must pass a pointer, not a value, to StructScan destination") + return false } if value.IsNil() { - return errors.New("nil pointer passed to StructScan destination") + iter.err = errors.New("nil pointer passed to StructScan destination") + return false + } + if iter.Iter.NumRows() == 0 { + iter.err = gocql.ErrNotFound + return false } base := reflectx.Deref(value.Type()) scannable := isScannable(base) if structOnly && scannable { - return structOnlyError(base) + iter.err = structOnlyError(base) + return false } if scannable && len(iter.Columns()) > 1 { - return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns())) + iter.err = fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(iter.Columns())) + return false } - if !scannable { - iter.StructScan(dest) - } else { - iter.Scan(dest) + if scannable { + return iter.Scan(dest) } - return iter.err + return iter.StructScan(dest) } // Select scans all rows into a destination, which must be a slice of any type // and closes the iterator. If the destination slice type is a Struct, then // StructScan will be used on each row. If the destination is some other type, // then each row must only have one column which can scan into that type. +// If no rows were selected, ErrNotFound is returned. func (iter *Iterx) Select(dest interface{}) error { if iter.query == nil { return errors.New("using released query") } - if err := iter.scanAll(dest, false); err != nil { - iter.err = err - } - + iter.scanAll(dest, false) iter.Close() iter.ReleaseQuery() return iter.err } -func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error { +func (iter *Iterx) scanAll(dest interface{}, structOnly bool) bool { value := reflect.ValueOf(dest) // json.Unmarshal returns errors for these if value.Kind() != reflect.Ptr { - return errors.New("must pass a pointer, not a value, to StructScan destination") + iter.err = errors.New("must pass a pointer, not a value, to StructScan destination") + return false } if value.IsNil() { - return errors.New("nil pointer passed to StructScan destination") + iter.err = errors.New("nil pointer passed to StructScan destination") + return false + } + if iter.Iter.NumRows() == 0 { + iter.err = gocql.ErrNotFound + return false } slice, err := baseType(value.Type(), reflect.Slice) if err != nil { - return err + iter.err = err + return false } - // allocate memory for the page data - v := reflect.MakeSlice(slice, 0, iter.Iter.NumRows()) - isPtr := slice.Elem().Kind() == reflect.Ptr base := reflectx.Deref(slice.Elem()) scannable := isScannable(base) if structOnly && scannable { - return structOnlyError(base) + iter.err = structOnlyError(base) + return false } // if it's a base type make sure it only has 1 column; if not return an error if scannable && len(iter.Columns()) > 1 { - return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns())) + iter.err = fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(iter.Columns())) + return false } var ( - vp reflect.Value - ok bool + alloc bool + v reflect.Value + vp reflect.Value + ok bool ) for { // create a new struct type (which returns PtrTo) and indirect it @@ -167,6 +176,12 @@ func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error { break } + // allocate memory for the page data + if !alloc { + v = reflect.MakeSlice(slice, 0, iter.Iter.NumRows()) + alloc = true + } + if isPtr { v = reflect.Append(v, vp) } else { @@ -174,10 +189,12 @@ func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error { } } - // update dest - reflect.Indirect(value).Set(v) + // update dest if allocated slice + if alloc { + reflect.Indirect(value).Set(v) + } - return iter.err + return true } // StructScan is like gocql.Scan, but scans a single row into a single Struct. @@ -185,7 +202,7 @@ func (iter *Iterx) scanAll(dest interface{}, structOnly bool) error { // prohibitive. StructScan caches the reflect work of matching up column // positions to fields to avoid that overhead per scan, which means it is not // safe to run StructScan on the same Iterx instance with different struct -// types. +// types. If no rows were selected, ErrNotFound is returned. func (iter *Iterx) StructScan(dest interface{}) bool { if iter.query == nil { iter.err = errors.New("using released query") @@ -198,6 +215,11 @@ func (iter *Iterx) StructScan(dest interface{}) bool { return false } + if iter.Iter.NumRows() == 0 { + iter.err = gocql.ErrNotFound + return false + } + if !iter.started { columns := columnNames(iter.Iter.Columns()) m := iter.Mapper