Commit 49b47c41 authored by gwenn's avatar gwenn

Bind empty string and zero time to null by default.

parent 60079c93
...@@ -312,6 +312,12 @@ func (s *Stmt) Bind(args ...interface{}) error { ...@@ -312,6 +312,12 @@ func (s *Stmt) Bind(args ...interface{}) error {
return nil return nil
} }
// NullIfEmpty transforms empty string to null when true (true by default)
var NullIfEmptyString bool = true
// NullIfZeroTime transforms zero time (time.Time.IsZero) to null when true (true by default)
var NullIfZeroTime bool = true
// BindByIndex binds value to the specified host parameter of the prepared statement. // BindByIndex binds value to the specified host parameter of the prepared statement.
// The leftmost SQL parameter has an index of 1. // The leftmost SQL parameter has an index of 1.
func (s *Stmt) BindByIndex(index int, value interface{}) error { func (s *Stmt) BindByIndex(index int, value interface{}) error {
...@@ -321,8 +327,12 @@ func (s *Stmt) BindByIndex(index int, value interface{}) error { ...@@ -321,8 +327,12 @@ func (s *Stmt) BindByIndex(index int, value interface{}) error {
case nil: case nil:
rv = C.sqlite3_bind_null(s.stmt, i) rv = C.sqlite3_bind_null(s.stmt, i)
case string: case string:
cs, l := cstring(value) if NullIfEmptyString && len(value) == 0 {
rv = C.my_bind_text(s.stmt, i, cs, l) rv = C.sqlite3_bind_null(s.stmt, i)
} else {
cs, l := cstring(value)
rv = C.my_bind_text(s.stmt, i, cs, l)
}
case int: case int:
rv = C.sqlite3_bind_int(s.stmt, i, C.int(value)) rv = C.sqlite3_bind_int(s.stmt, i, C.int(value))
case int64: case int64:
...@@ -342,7 +352,11 @@ func (s *Stmt) BindByIndex(index int, value interface{}) error { ...@@ -342,7 +352,11 @@ func (s *Stmt) BindByIndex(index int, value interface{}) error {
} }
rv = C.my_bind_blob(s.stmt, i, unsafe.Pointer(p), C.int(len(value))) rv = C.my_bind_blob(s.stmt, i, unsafe.Pointer(p), C.int(len(value)))
case time.Time: case time.Time:
rv = C.sqlite3_bind_int64(s.stmt, i, C.sqlite3_int64(value.Unix())) if NullIfZeroTime && value.IsZero() {
rv = C.sqlite3_bind_null(s.stmt, i)
} else {
rv = C.sqlite3_bind_int64(s.stmt, i, C.sqlite3_int64(value.Unix()))
}
case ZeroBlobLength: case ZeroBlobLength:
rv = C.sqlite3_bind_zeroblob(s.stmt, i, C.int(value)) rv = C.sqlite3_bind_zeroblob(s.stmt, i, C.int(value))
case driver.Valuer: case driver.Valuer:
...@@ -568,8 +582,9 @@ func (s *Stmt) ScanByName(name string, value interface{}) (bool, error) { ...@@ -568,8 +582,9 @@ func (s *Stmt) ScanByName(name string, value interface{}) (bool, error) {
// (*)*bool // (*)*bool
// (*)*float64 // (*)*float64
// (*)*[]byte // (*)*[]byte
// *time.Time
// sql.Scanner
// *interface{} // *interface{}
// func(interface{}) (bool, error)
// //
// Returns true when column is null. // Returns true when column is null.
// Calls sqlite3_column_(blob|double|int|int64|text) depending on arg type. // Calls sqlite3_column_(blob|double|int|int64|text) depending on arg type.
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
. "github.com/gwenn/gosqlite" . "github.com/gwenn/gosqlite"
"reflect" "reflect"
"testing" "testing"
"time"
) )
func checkFinalize(s *Stmt, t *testing.T) { func checkFinalize(s *Stmt, t *testing.T) {
...@@ -399,3 +400,56 @@ func TestScanBytes(t *testing.T) { ...@@ -399,3 +400,56 @@ func TestScanBytes(t *testing.T) {
blob, _ := s.ScanBlob(0) blob, _ := s.ScanBlob(0)
assertEquals(t, "expected %v but got %v", "test", string(blob)) assertEquals(t, "expected %v but got %v", "test", string(blob))
} }
func TestBindEmptyZero(t *testing.T) {
db := open(t)
defer checkClose(db, t)
var zero time.Time
s, err := db.Prepare("SELECT ?, ?", "", zero)
checkNoError(t, err, "prepare error: %s")
defer checkFinalize(s, t)
if !Must(s.Next()) {
t.Fatal("no result")
}
var ps *string
var zt time.Time
err = s.Scan(&ps, &zt)
checkNoError(t, err, "scan error: %s")
assert(t, "Null pointers expected", ps == nil && zt.IsZero())
_, null := s.ScanValue(0, false)
assert(t, "Null string expected", null)
_, null = s.ScanValue(1, false)
assert(t, "Null time expected", null)
}
func TestBindEmptyZeroNotTransformedToNull(t *testing.T) {
db := open(t)
defer checkClose(db, t)
NullIfEmptyString = false
NullIfZeroTime = false
defer func() {
NullIfEmptyString = true
NullIfZeroTime = true
}()
var zero time.Time
s, err := db.Prepare("SELECT ?, ?", "", zero)
checkNoError(t, err, "prepare error: %s")
defer checkFinalize(s, t)
if !Must(s.Next()) {
t.Fatal("no result")
}
var st string
var zt time.Time
err = s.Scan(&st, &zt)
checkNoError(t, err, "scan error: %s")
assert(t, "Null pointers expected", len(st) == 0 && zt.IsZero())
_, null := s.ScanValue(0, false)
assert(t, "Empty string expected", !null)
_, null = s.ScanValue(1, false)
assert(t, "Zero time expected", !null)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment