2017-07-25 11:10:19 +02:00
|
|
|
package gocqlx
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"errors"
|
2017-07-25 12:25:59 +02:00
|
|
|
"fmt"
|
|
|
|
|
"reflect"
|
2017-07-25 11:10:19 +02:00
|
|
|
"strconv"
|
2017-07-25 12:25:59 +02:00
|
|
|
|
|
|
|
|
"github.com/gocql/gocql"
|
|
|
|
|
"github.com/jmoiron/sqlx/reflectx"
|
2017-07-25 11:10:19 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// CompileNamedQuery compiles a named query into an unbound query using the
|
|
|
|
|
// '?' bindvar and a list of names.
|
2017-07-25 11:13:34 +02:00
|
|
|
func CompileNamedQuery(qs []byte) (stmt string, names []string, err error) {
|
2017-07-25 11:10:19 +02:00
|
|
|
// guess number of names
|
|
|
|
|
n := bytes.Count(qs, []byte(":"))
|
|
|
|
|
if n == 0 {
|
|
|
|
|
return "", nil, errors.New("expected a named query")
|
|
|
|
|
}
|
|
|
|
|
names = make([]string, 0, n)
|
|
|
|
|
rebound := make([]byte, 0, len(qs))
|
|
|
|
|
|
|
|
|
|
inName := false
|
|
|
|
|
last := len(qs) - 1
|
|
|
|
|
name := make([]byte, 0, 10)
|
|
|
|
|
|
|
|
|
|
for i, b := range qs {
|
|
|
|
|
// a ':' while we're in a name is an error
|
|
|
|
|
if b == ':' {
|
|
|
|
|
// if this is the second ':' in a '::' escape sequence, append a ':'
|
|
|
|
|
if inName && i > 0 && qs[i-1] == ':' {
|
|
|
|
|
rebound = append(rebound, ':')
|
|
|
|
|
inName = false
|
|
|
|
|
continue
|
|
|
|
|
} else if inName {
|
|
|
|
|
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
|
2017-07-25 11:13:34 +02:00
|
|
|
return stmt, names, err
|
2017-07-25 11:10:19 +02:00
|
|
|
}
|
|
|
|
|
inName = true
|
|
|
|
|
name = []byte{}
|
|
|
|
|
// if we're in a name, and this is an allowed character, continue
|
2017-07-26 09:10:35 +02:00
|
|
|
} else if inName && (allowedBindRune(b) || b == '_' || b == '.') && i != last {
|
2017-07-25 11:10:19 +02:00
|
|
|
// append the byte to the name if we are in a name and not on the last byte
|
|
|
|
|
name = append(name, b)
|
|
|
|
|
// if we're in a name and it's not an allowed character, the name is done
|
|
|
|
|
} else if inName {
|
|
|
|
|
inName = false
|
|
|
|
|
// if this is the final byte of the string and it is part of the name, then
|
|
|
|
|
// make sure to add it to the name
|
2017-07-26 09:10:35 +02:00
|
|
|
if i == last && allowedBindRune(b) {
|
2017-07-25 11:10:19 +02:00
|
|
|
name = append(name, b)
|
|
|
|
|
}
|
|
|
|
|
// add the string representation to the names list
|
|
|
|
|
names = append(names, string(name))
|
|
|
|
|
// add a proper bindvar for the bindType
|
|
|
|
|
rebound = append(rebound, '?')
|
|
|
|
|
// add this byte to string unless it was not part of the name
|
|
|
|
|
if i != last {
|
|
|
|
|
rebound = append(rebound, b)
|
2017-07-26 09:10:35 +02:00
|
|
|
} else if !allowedBindRune(b) {
|
2017-07-25 11:10:19 +02:00
|
|
|
rebound = append(rebound, b)
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// this is a normal byte and should just go onto the rebound query
|
|
|
|
|
rebound = append(rebound, b)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return string(rebound), names, err
|
|
|
|
|
}
|
2017-07-25 12:25:59 +02:00
|
|
|
|
2017-07-26 09:10:35 +02:00
|
|
|
func allowedBindRune(b byte) bool {
|
|
|
|
|
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-25 12:25:59 +02:00
|
|
|
// Queryx is a wrapper around gocql.Query which adds struct binding capabilities.
|
|
|
|
|
type Queryx struct {
|
|
|
|
|
*gocql.Query
|
|
|
|
|
Names []string
|
|
|
|
|
Mapper *reflectx.Mapper
|
|
|
|
|
}
|
|
|
|
|
|
2017-07-28 10:18:38 +02:00
|
|
|
// Query creates a new Queryx from gocql.Query using a default mapper.
|
|
|
|
|
func Query(q *gocql.Query, names []string) Queryx {
|
|
|
|
|
return Queryx{
|
|
|
|
|
Query: q,
|
|
|
|
|
Names: names,
|
|
|
|
|
Mapper: DefaultMapper,
|
2017-07-25 12:25:59 +02:00
|
|
|
}
|
2017-07-28 10:18:38 +02:00
|
|
|
}
|
2017-07-25 12:25:59 +02:00
|
|
|
|
2017-08-01 13:29:52 +02:00
|
|
|
// BindStruct binds query named parameters to values from arg using mapper. If
|
|
|
|
|
// value cannot be found error is reported.
|
2017-07-28 10:18:38 +02:00
|
|
|
func (q Queryx) BindStruct(arg interface{}) error {
|
2017-08-01 13:29:52 +02:00
|
|
|
arglist, err := bindStructArgs(q.Names, arg, nil, q.Mapper)
|
2017-07-25 12:25:59 +02:00
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
q.Bind(arglist...)
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2017-08-01 13:29:52 +02:00
|
|
|
// BindStructMap binds query named parameters to values from arg0 and arg1
|
|
|
|
|
// using a mapper. If value cannot be found in arg0 it's looked up in arg1
|
|
|
|
|
// before reporting an error.
|
|
|
|
|
func (q Queryx) BindStructMap(arg0 interface{}, arg1 map[string]interface{}) error {
|
|
|
|
|
arglist, err := bindStructArgs(q.Names, arg0, arg1, q.Mapper)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
q.Bind(arglist...)
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func bindStructArgs(names []string, arg0 interface{}, arg1 map[string]interface{}, m *reflectx.Mapper) ([]interface{}, error) {
|
2017-07-25 12:25:59 +02:00
|
|
|
arglist := make([]interface{}, 0, len(names))
|
|
|
|
|
|
|
|
|
|
// grab the indirected value of arg
|
2017-08-01 13:29:52 +02:00
|
|
|
v := reflect.ValueOf(arg0)
|
|
|
|
|
for v = reflect.ValueOf(arg0); v.Kind() == reflect.Ptr; {
|
2017-07-25 12:25:59 +02:00
|
|
|
v = v.Elem()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fields := m.TraversalsByName(v.Type(), names)
|
|
|
|
|
for i, t := range fields {
|
2017-08-01 13:29:52 +02:00
|
|
|
if len(t) != 0 {
|
|
|
|
|
val := reflectx.FieldByIndexesReadOnly(v, t)
|
|
|
|
|
arglist = append(arglist, val.Interface())
|
|
|
|
|
} else {
|
|
|
|
|
val, ok := arg1[names[i]]
|
|
|
|
|
if !ok {
|
|
|
|
|
return arglist, fmt.Errorf("could not find name %s in %#v and %#v", names[i], arg0, arg1)
|
|
|
|
|
}
|
|
|
|
|
arglist = append(arglist, val)
|
2017-07-25 12:25:59 +02:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return arglist, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BindMap binds query named parameters using map.
|
|
|
|
|
func (q Queryx) BindMap(arg map[string]interface{}) error {
|
|
|
|
|
arglist, err := bindMapArgs(q.Names, arg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
q.Bind(arglist...)
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
|
|
|
|
|
arglist := make([]interface{}, 0, len(names))
|
|
|
|
|
|
|
|
|
|
for _, name := range names {
|
|
|
|
|
val, ok := arg[name]
|
|
|
|
|
if !ok {
|
|
|
|
|
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
|
|
|
|
|
}
|
|
|
|
|
arglist = append(arglist, val)
|
|
|
|
|
}
|
|
|
|
|
return arglist, nil
|
|
|
|
|
}
|
2017-07-28 10:18:38 +02:00
|
|
|
|
|
|
|
|
// QueryFunc creates Queryx from qb.Builder.ToCql() output.
|
|
|
|
|
type QueryFunc func(stmt string, names []string) Queryx
|
|
|
|
|
|
|
|
|
|
// SessionQuery creates QueryFunc that's session aware.
|
|
|
|
|
func SessionQuery(session *gocql.Session) QueryFunc {
|
|
|
|
|
return func(stmt string, names []string) Queryx {
|
|
|
|
|
return Query(session.Query(stmt), names)
|
|
|
|
|
}
|
|
|
|
|
}
|