Commit 7c426872 authored by gwenn's avatar gwenn

Renamed ScanColumn to ScanByIndex.

Renamed NamedScanColumn to ScanByName.
Refactored BindParamaterIndex and BindParamaterName to return error on invalid param.
Added NamedBind.
parent fc80494a
...@@ -11,7 +11,7 @@ Stmt#Bind uses native sqlite3_bind_x methods and failed if unsupported type. ...@@ -11,7 +11,7 @@ Stmt#Bind uses native sqlite3_bind_x methods and failed if unsupported type.
Stmt#Next returns a (bool, os.Error) couple like Reader#Read. Stmt#Next returns a (bool, os.Error) couple like Reader#Read.
Stmt#Scan uses native sqlite3_column_x methods. Stmt#Scan uses native sqlite3_column_x methods.
Stmt#NamedScan is added. It's compliant with [go-dbi](https://github.com/thomaslee/go-dbi/) API but I think its signature should be improved/modified. Stmt#NamedScan is added. It's compliant with [go-dbi](https://github.com/thomaslee/go-dbi/) API but I think its signature should be improved/modified.
Stmt#ScanColumn/NamedScanColumn are added to test NULL value. Stmt#ScanByIndex/ScanByName are added to test NULL value.
Currently, the weak point of the binding is the *Scan methods: Currently, the weak point of the binding is the *Scan methods:
The original implementation is using this strategy: The original implementation is using this strategy:
......
...@@ -361,6 +361,7 @@ type Stmt struct { ...@@ -361,6 +361,7 @@ type Stmt struct {
stmt *C.sqlite3_stmt stmt *C.sqlite3_stmt
tail string tail string
cols map[string]int // cached columns index by name cols map[string]int // cached columns index by name
params map[string]int // cached parameter index by name
// Enable NULL value check in Scan methods // Enable NULL value check in Scan methods
CheckNull bool CheckNull bool
// Enable type check in Scan methods // Enable type check in Scan methods
...@@ -433,16 +434,58 @@ func (s *Stmt) BindParameterCount() int { ...@@ -433,16 +434,58 @@ func (s *Stmt) BindParameterCount() int {
} }
// Calls http://sqlite.org/c3ref/bind_parameter_index.html // Calls http://sqlite.org/c3ref/bind_parameter_index.html
func (s *Stmt) BindParameterIndex(name string) int { func (s *Stmt) BindParameterIndex(name string) (int, error) {
if s.params == nil {
count := s.BindParameterCount()
s.params = make(map[string]int, count)
}
index, ok := s.params[name]
if ok {
return index, nil
}
cname := C.CString(name) cname := C.CString(name)
defer C.free(unsafe.Pointer(cname)) defer C.free(unsafe.Pointer(cname))
return int(C.sqlite3_bind_parameter_index(s.stmt, cname)) index = int(C.sqlite3_bind_parameter_index(s.stmt, cname))
if index == 0 {
return -1, errors.New("invalid parameter name: " + name)
}
s.params[name] = index
return index, nil
} }
// The first host parameter has an index of 1, not 0. // The first host parameter has an index of 1, not 0.
// Calls http://sqlite.org/c3ref/bind_parameter_name.html // Calls http://sqlite.org/c3ref/bind_parameter_name.html
func (s *Stmt) BindParameterName(i int) string { func (s *Stmt) BindParameterName(i int) (string, error) {
return C.GoString(C.sqlite3_bind_parameter_name(s.stmt, C.int(i))) name := C.sqlite3_bind_parameter_name(s.stmt, C.int(i))
if name == nil {
return "", errors.New(fmt.Sprintf("invalid parameter index: %d", i))
}
return C.GoString(name), nil
}
func (s *Stmt) NamedBind(args ...interface{}) error {
err := s.Reset() // TODO sqlite3_clear_bindings?
if err != nil {
return err
}
if len(args)%2 != 0 {
return errors.New("Expected an even number of arguments")
}
for i := 0; i < len(args); i += 2 {
name, ok := args[i].(string)
if !ok {
return errors.New("non-string param name")
}
index, err := s.BindParameterIndex(name) // How to look up only once for one statement ?
if err != nil {
return err
}
err = s.BindByIndex(index, args[i+1])
if err != nil {
return err
}
}
return nil
} }
// Calls sqlite3_bind_parameter_count and sqlite3_bind_(blob|double|int|int64|null|text) depending on args type. // Calls sqlite3_bind_parameter_count and sqlite3_bind_(blob|double|int|int64|null|text) depending on args type.
...@@ -459,42 +502,51 @@ func (s *Stmt) Bind(args ...interface{}) error { ...@@ -459,42 +502,51 @@ func (s *Stmt) Bind(args ...interface{}) error {
} }
for i, v := range args { for i, v := range args {
err = s.BindByIndex(i+1, v)
if err != nil {
return err
}
}
return nil
}
// The leftmost SQL parameter has an index of 1.
func (s *Stmt) BindByIndex(index int, value interface{}) error {
i := C.int(index)
var rv C.int var rv C.int
index := C.int(i + 1) switch value := value.(type) {
switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.stmt, index) rv = C.sqlite3_bind_null(s.stmt, i)
case string: case string:
cstr := C.CString(v) cstr := C.CString(value)
rv = C.my_bind_text(s.stmt, index, cstr, C.int(len(v))) rv = C.my_bind_text(s.stmt, i, cstr, C.int(len(value)))
C.free(unsafe.Pointer(cstr)) C.free(unsafe.Pointer(cstr))
case int: case int:
rv = C.sqlite3_bind_int(s.stmt, index, C.int(v)) rv = C.sqlite3_bind_int(s.stmt, i, C.int(value))
case int64: case int64:
rv = C.sqlite3_bind_int64(s.stmt, index, C.sqlite3_int64(v)) rv = C.sqlite3_bind_int64(s.stmt, i, C.sqlite3_int64(value))
case byte: case byte:
rv = C.sqlite3_bind_int(s.stmt, index, C.int(v)) rv = C.sqlite3_bind_int(s.stmt, i, C.int(value))
case bool: case bool:
rv = C.sqlite3_bind_int(s.stmt, index, btocint(v)) rv = C.sqlite3_bind_int(s.stmt, i, btocint(value))
case float32: case float32:
rv = C.sqlite3_bind_double(s.stmt, index, C.double(v)) rv = C.sqlite3_bind_double(s.stmt, i, C.double(value))
case float64: case float64:
rv = C.sqlite3_bind_double(s.stmt, index, C.double(v)) rv = C.sqlite3_bind_double(s.stmt, i, C.double(value))
case []byte: case []byte:
var p *byte var p *byte
if len(v) > 0 { if len(value) > 0 {
p = &v[0] p = &value[0]
} }
rv = C.my_bind_blob(s.stmt, index, unsafe.Pointer(p), C.int(len(v))) rv = C.my_bind_blob(s.stmt, i, unsafe.Pointer(p), C.int(len(value)))
case ZeroBlobLength: case ZeroBlobLength:
rv = C.sqlite3_bind_zeroblob(s.stmt, index, C.int(v)) rv = C.sqlite3_bind_zeroblob(s.stmt, i, C.int(value))
default: default:
return errors.New("unsupported type in Bind: " + reflect.TypeOf(v).String()) return errors.New("unsupported type in Bind: " + reflect.TypeOf(value).String())
} }
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return s.c.error(rv) return s.c.error(rv)
} }
}
return nil return nil
} }
...@@ -621,14 +673,14 @@ func (s *Stmt) NamedScan(args ...interface{}) error { ...@@ -621,14 +673,14 @@ func (s *Stmt) NamedScan(args ...interface{}) error {
for i := 0; i < len(args); i += 2 { for i := 0; i < len(args); i += 2 {
name, ok := args[i].(string) name, ok := args[i].(string)
if !ok { if !ok {
return errors.New("non-string field name field") return errors.New("non-string field name")
} }
index, err := s.ColumnIndex(name) // How to look up only once for one statement ? index, err := s.ColumnIndex(name) // How to look up only once for one statement ?
if err != nil { if err != nil {
return err return err
} }
ptr := args[i+1] ptr := args[i+1]
_, err = s.ScanColumn(index, ptr) _, err = s.ScanByIndex(index, ptr)
if err != nil { if err != nil {
return err return err
} }
...@@ -658,7 +710,7 @@ func (s *Stmt) Scan(args ...interface{}) error { ...@@ -658,7 +710,7 @@ func (s *Stmt) Scan(args ...interface{}) error {
} }
for i, v := range args { for i, v := range args {
_, err := s.ScanColumn(i, v) _, err := s.ScanByIndex(i, v)
if err != nil { if err != nil {
return err return err
} }
...@@ -689,23 +741,31 @@ func (s *Stmt) ColumnIndex(name string) (int, error) { ...@@ -689,23 +741,31 @@ func (s *Stmt) ColumnIndex(name string) (int, error) {
return 0, errors.New("invalid column name: " + name) return 0, errors.New("invalid column name: " + name)
} }
// Set nullable to false to skip NULL type test.
// Returns true when column is null and Stmt.CheckNull is activated. // Returns true when column is null and Stmt.CheckNull is activated.
// Calls sqlite3_column_count, sqlite3_column_name and sqlite3_column_(blob|double|int|int64|text) depending on arg type. // Calls sqlite3_column_count, sqlite3_column_name and sqlite3_column_(blob|double|int|int64|text) depending on arg type.
// http://sqlite.org/c3ref/column_blob.html // http://sqlite.org/c3ref/column_blob.html
func (s *Stmt) NamedScanColumn(name string, value interface{}) (bool, error) { func (s *Stmt) ScanByName(name string, value interface{}) (bool, error) {
index, err := s.ColumnIndex(name) index, err := s.ColumnIndex(name)
if err != nil { if err != nil {
return false, err return false, err
} }
return s.ScanColumn(index, value) return s.ScanByIndex(index, value)
} }
// The leftmost column/index is number 0. // The leftmost column/index is number 0.
//
// The value must be of one of the following types:
// *string
// *int, *int64, *byte,
// *bool
// *float64
// *[]byte
// *interface{}
//
// Returns true when column is null and Stmt.CheckNull is activated. // Returns true when column is null and Stmt.CheckNull is activated.
// Calls sqlite3_column_(blob|double|int|int64|text) depending on arg type. // Calls sqlite3_column_(blob|double|int|int64|text) depending on arg type.
// http://sqlite.org/c3ref/column_blob.html // http://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ScanColumn(index int, value interface{}) (bool, error) { func (s *Stmt) ScanByIndex(index int, value interface{}) (bool, error) {
var isNull bool var isNull bool
var err error var err error
switch value := value.(type) { switch value := value.(type) {
...@@ -734,6 +794,14 @@ func (s *Stmt) ScanColumn(index int, value interface{}) (bool, error) { ...@@ -734,6 +794,14 @@ func (s *Stmt) ScanColumn(index int, value interface{}) (bool, error) {
} }
// The leftmost column/index is number 0. // The leftmost column/index is number 0.
//
// The returned value will be of one of the following types:
// nil
// string
// int64
// float64
// []byte
//
// Calls sqlite3_column_(blob|double|int|int64|text) depending on columns type. // Calls sqlite3_column_(blob|double|int|int64|text) depending on columns type.
// http://sqlite.org/c3ref/column_blob.html // http://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ScanValue(index int) (value interface{}) { func (s *Stmt) ScanValue(index int) (value interface{}) {
......
...@@ -155,13 +155,13 @@ func TestInsertWithStatement(t *testing.T) { ...@@ -155,13 +155,13 @@ func TestInsertWithStatement(t *testing.T) {
if paramCount != 3 { if paramCount != 3 {
t.Errorf("bind parameter count error: %d <> 3", paramCount) t.Errorf("bind parameter count error: %d <> 3", paramCount)
} }
firstParamName := s.BindParameterName(1) firstParamName, berr := s.BindParameterName(1)
if firstParamName != ":f" { if firstParamName != ":f" {
t.Errorf("bind parameter name error: %s <> ':f'", firstParamName) t.Errorf("bind parameter name error: %s <> ':f' (%s)", firstParamName, berr)
} }
lastParamIndex := s.BindParameterIndex(":s") lastParamIndex, berr := s.BindParameterIndex(":s")
if lastParamIndex != 3 { if lastParamIndex != 3 {
t.Errorf("bind parameter name error: %d <> 3", lastParamIndex) t.Errorf("bind parameter name error: %d <> 3 (%s)", lastParamIndex, berr)
} }
db.Begin() db.Begin()
...@@ -315,19 +315,19 @@ func TestScanColumn(t *testing.T) { ...@@ -315,19 +315,19 @@ func TestScanColumn(t *testing.T) {
t.Fatal("no result") t.Fatal("no result")
} }
var i1, i2, i3 int var i1, i2, i3 int
null := Must(s.ScanColumn(0, &i1 /*, true*/ )) null := Must(s.ScanByIndex(0, &i1 /*, true*/ ))
if null { if null {
t.Errorf("Expected not null value") t.Errorf("Expected not null value")
} else if i1 != 1 { } else if i1 != 1 {
t.Errorf("Expected 1 <> %d\n", i1) t.Errorf("Expected 1 <> %d\n", i1)
} }
null = Must(s.ScanColumn(1, &i2 /*, true*/ )) null = Must(s.ScanByIndex(1, &i2 /*, true*/ ))
if !null { if !null {
t.Errorf("Expected null value") t.Errorf("Expected null value")
} else if i2 != 0 { } else if i2 != 0 {
t.Errorf("Expected 0 <> %d\n", i2) t.Errorf("Expected 0 <> %d\n", i2)
} }
null = Must(s.ScanColumn(2, &i3 /*, true*/ )) null = Must(s.ScanByIndex(2, &i3 /*, true*/ ))
if null { if null {
t.Errorf("Expected not null value") t.Errorf("Expected not null value")
} else if i3 != 0 { } else if i3 != 0 {
...@@ -348,19 +348,19 @@ func TestNamedScanColumn(t *testing.T) { ...@@ -348,19 +348,19 @@ func TestNamedScanColumn(t *testing.T) {
t.Fatal("no result") t.Fatal("no result")
} }
var i1, i2, i3 int var i1, i2, i3 int
null := Must(s.NamedScanColumn("i1", &i1 /*, true*/ )) null := Must(s.ScanByName("i1", &i1 /*, true*/ ))
if null { if null {
t.Errorf("Expected not null value") t.Errorf("Expected not null value")
} else if i1 != 1 { } else if i1 != 1 {
t.Errorf("Expected 1 <> %d\n", i1) t.Errorf("Expected 1 <> %d\n", i1)
} }
null = Must(s.NamedScanColumn("i2", &i2 /*, true*/ )) null = Must(s.ScanByName("i2", &i2 /*, true*/ ))
if !null { if !null {
t.Errorf("Expected null value") t.Errorf("Expected null value")
} else if i2 != 0 { } else if i2 != 0 {
t.Errorf("Expected 0 <> %d\n", i2) t.Errorf("Expected 0 <> %d\n", i2)
} }
null = Must(s.NamedScanColumn("i3", &i3 /*, true*/ )) null = Must(s.ScanByName("i3", &i3 /*, true*/ ))
if null { if null {
t.Errorf("Expected not null value") t.Errorf("Expected not null value")
} else if i3 != 0 { } else if i3 != 0 {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment