Automated UDT support

This patch adds the power of GocqlX to UDTs.
Now you can make a struct be UDT compatible by adding a single line.

```
type FullName struct {
	gocqlx.UDT
	FirstName string
	LastName  string
}
```

Signed-off-by: Michał Matczuk <michal@scylladb.com>
This commit is contained in:
Michał Matczuk
2020-04-15 17:06:29 +02:00
committed by Michal Jan Matczuk
parent 2569c3dd8f
commit ab279e68ed
8 changed files with 284 additions and 86 deletions

View File

@@ -36,18 +36,13 @@ func (n *FullName) UnmarshalCQL(info gocql.TypeInfo, data []byte) error {
}
type FullNameUDT struct {
FirstName string
LastName string
gocqlx.UDT
FullName
}
func (n FullNameUDT) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
return gocql.Marshal(info, f.Interface())
}
func (n *FullNameUDT) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
f := gocqlx.DefaultMapper.FieldByName(reflect.ValueOf(n), name)
return gocql.Unmarshal(info, data, f.Addr().Interface())
type FullNamePtrUDT struct {
gocqlx.UDT
*FullName
}
func TestStruct(t *testing.T) {
@@ -75,7 +70,8 @@ func TestStruct(t *testing.T) {
testvarint varint,
testinet inet,
testcustom text,
testudt gocqlx_test.FullName
testudt gocqlx_test.FullName,
testptrudt gocqlx_test.FullName
)`); err != nil {
t.Fatal("create table:", err)
}
@@ -98,6 +94,7 @@ func TestStruct(t *testing.T) {
Testinet string
Testcustom FullName
Testudt FullNameUDT
Testptrudt FullNamePtrUDT
}
bigInt := new(big.Int)
@@ -122,10 +119,11 @@ func TestStruct(t *testing.T) {
Testvarint: bigInt,
Testinet: "213.212.2.19",
Testcustom: FullName{FirstName: "John", LastName: "Doe"},
Testudt: FullNameUDT{FirstName: "John", LastName: "Doe"},
Testudt: FullNameUDT{FullName: FullName{FirstName: "John", LastName: "Doe"}},
Testptrudt: FullNamePtrUDT{FullName: &FullName{FirstName: "John", LastName: "Doe"}},
}
if err := session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
if err := gocqlx.Query(session.Query(`INSERT INTO struct_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat,testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet, testcustom, testudt, testptrudt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), nil).Bind(
m.Testuuid,
m.Testtimestamp,
m.Testvarchar,
@@ -142,7 +140,8 @@ func TestStruct(t *testing.T) {
m.Testvarint,
m.Testinet,
m.Testcustom,
m.Testudt).Exec(); err != nil {
m.Testudt,
m.Testptrudt).Exec(); err != nil {
t.Fatal("insert:", err)
}
@@ -186,6 +185,28 @@ func TestStruct(t *testing.T) {
t.Fatal("not equals")
}
})
t.Run("struct scan", func(t *testing.T) {
var (
v StructTable
n int
)
i := gocqlx.Query(session.Query(`SELECT * FROM struct_table`), nil).Iter()
for i.StructScan(&v) {
n++
}
if err := i.Close(); err != nil {
t.Fatal("struct scan failed", err)
}
if n != 1 {
t.Fatal("struct scan unexpected number of rows", n)
}
if !reflect.DeepEqual(m, v) {
t.Fatal("not equals")
}
})
}
func TestScannable(t *testing.T) {
@@ -320,7 +341,12 @@ func TestStructOnlyUDT(t *testing.T) {
t.Fatal("create table:", err)
}
m := FullNameUDT{"John", "Doe"}
m := FullNameUDT{
FullName: FullName{
FirstName: "John",
LastName: "Doe",
},
}
if err := session.Query(`INSERT INTO struct_only_udt_table (first_name, last_name) values (?, ?)`, m.FirstName, m.LastName).Exec(); err != nil {
t.Fatal("insert:", err)
@@ -401,7 +427,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe get", func(t *testing.T) {
var v UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
if err := i.Get(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
})
@@ -409,7 +435,7 @@ func TestUnsafe(t *testing.T) {
t.Run("safe select", func(t *testing.T) {
var v []UnsafeTable
i := gocqlx.Iter(session.Query(`SELECT * FROM unsafe_table`))
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in *gocqlx_test.UnsafeTable" {
if err := i.Select(&v); err == nil || err.Error() != "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable" {
t.Fatal("expected ErrNotFound", "got", err)
}
if cap(v) > 0 {
@@ -510,6 +536,40 @@ func TestNotFound(t *testing.T) {
})
}
func TestErrorOnNil(t *testing.T) {
session := CreateSession(t)
defer session.Close()
if err := ExecStmt(session, `CREATE TABLE gocqlx_test.nil_table (testtext text PRIMARY KEY)`); err != nil {
t.Fatal("create table:", err)
}
const (
stmt = "SELECT * FROM not_found_table WRONG"
golden = "expected a pointer but got <nil>"
)
t.Run("get", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Get(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Get()=%q expected %q error", err, golden)
}
})
t.Run("select", func(t *testing.T) {
err := gocqlx.Iter(session.Query(stmt)).Select(nil)
if err == nil || err.Error() != golden {
t.Fatalf("Select()=%q expected %q error", err, golden)
}
})
t.Run("struct scan", func(t *testing.T) {
i := gocqlx.Iter(session.Query(stmt))
i.StructScan(nil)
err := i.Close()
if err == nil || err.Error() != golden {
t.Fatalf("StructScan()=%q expected %q error", err, golden)
}
})
}
func TestPaging(t *testing.T) {
session := CreateSession(t)
defer session.Close()