Commit d5bb2294 authored by gwenn's avatar gwenn

Upgrade CreateScalarFunction signature to declare deterministic functions (SQLite 3.8.3)

parent 7dba26ee
...@@ -5,18 +5,19 @@ ...@@ -5,18 +5,19 @@
package sqlite_test package sqlite_test
import ( import (
"github.com/bmizerany/assert"
. "github.com/gwenn/gosqlite"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/bmizerany/assert"
. "github.com/gwenn/gosqlite"
) )
func TestInterrupt(t *testing.T) { func TestInterrupt(t *testing.T) {
db := open(t) db := open(t)
defer checkClose(db, t) defer checkClose(db, t)
db.CreateScalarFunction("interrupt", 0, nil, func(ctx *ScalarContext, nArg int) { db.CreateScalarFunction("interrupt", 0, false, nil, func(ctx *ScalarContext, nArg int) {
db.Interrupt() db.Interrupt()
ctx.ResultText("ok") ctx.ResultText("ok")
}, nil) }, nil)
......
...@@ -329,14 +329,14 @@ type StepFunction func(ctx *AggregateContext, nArg int) ...@@ -329,14 +329,14 @@ type StepFunction func(ctx *AggregateContext, nArg int)
// FinalFunction is the expected signature of final function implemented in Go // FinalFunction is the expected signature of final function implemented in Go
type FinalFunction func(ctx *AggregateContext) type FinalFunction func(ctx *AggregateContext)
// DestroyFunctionData is the expected signature of function used to finalize user data. // DestroyDataFunction is the expected signature of function used to finalize user data.
type DestroyFunctionData func(pApp interface{}) type DestroyDataFunction func(pApp interface{})
type sqliteFunction struct { type sqliteFunction struct {
scalar ScalarFunction scalar ScalarFunction
step StepFunction step StepFunction
final FinalFunction final FinalFunction
d DestroyFunctionData d DestroyDataFunction
pApp interface{} pApp interface{}
scalarCtxs map[*ScalarContext]bool scalarCtxs map[*ScalarContext]bool
aggrCtxs map[*AggregateContext]bool aggrCtxs map[*AggregateContext]bool
...@@ -417,18 +417,24 @@ func goXDestroy(pApp unsafe.Pointer) { ...@@ -417,18 +417,24 @@ func goXDestroy(pApp unsafe.Pointer) {
} }
} }
const sqliteDeterministic = 0x800 // C.SQLITE_DETERMINISTIC
// CreateScalarFunction creates or redefines SQL scalar functions. // CreateScalarFunction creates or redefines SQL scalar functions.
// TODO Make possible to specify the preferred encoding // TODO Make possible to specify the preferred encoding
// (See http://sqlite.org/c3ref/create_function.html) // (See http://sqlite.org/c3ref/create_function.html)
func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interface{}, func (c *Conn) CreateScalarFunction(functionName string, nArg int, deterministic bool, pApp interface{},
f ScalarFunction, d DestroyFunctionData) error { f ScalarFunction, d DestroyDataFunction) error {
var eTextRep C.int = C.SQLITE_UTF8
if deterministic {
eTextRep = eTextRep | sqliteDeterministic
}
fname := C.CString(functionName) fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname)) defer C.free(unsafe.Pointer(fname))
if f == nil { if f == nil {
if len(c.udfs) > 0 { if len(c.udfs) > 0 {
delete(c.udfs, functionName) delete(c.udfs, functionName)
} }
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil), return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), eTextRep, nil, nil, nil, nil, nil),
fmt.Sprintf("<Conn.CreateScalarFunction(%q)", functionName)) fmt.Sprintf("<Conn.CreateScalarFunction(%q)", functionName))
} }
// To make sure it is not gced, keep a reference in the connection. // To make sure it is not gced, keep a reference in the connection.
...@@ -437,7 +443,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac ...@@ -437,7 +443,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
c.udfs = make(map[string]*sqliteFunction) c.udfs = make(map[string]*sqliteFunction)
} }
c.udfs[functionName] = udf // FIXME same function name with different args is not supported c.udfs[functionName] = udf // FIXME same function name with different args is not supported
return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)), return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), eTextRep, unsafe.Pointer(udf)),
fmt.Sprintf("Conn.CreateScalarFunction(%q)", functionName)) fmt.Sprintf("Conn.CreateScalarFunction(%q)", functionName))
} }
...@@ -445,7 +451,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac ...@@ -445,7 +451,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
// TODO Make possible to specify the preferred encoding // TODO Make possible to specify the preferred encoding
// (See http://sqlite.org/c3ref/create_function.html) // (See http://sqlite.org/c3ref/create_function.html)
func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{}, func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{},
step StepFunction, final FinalFunction, d DestroyFunctionData) error { step StepFunction, final FinalFunction, d DestroyDataFunction) error {
fname := C.CString(functionName) fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname)) defer C.free(unsafe.Pointer(fname))
if step == nil { if step == nil {
......
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
package sqlite_test package sqlite_test
import ( import (
"github.com/bmizerany/assert"
. "github.com/gwenn/gosqlite"
"math/rand" "math/rand"
"os" "os"
"regexp" "regexp"
"testing" "testing"
"github.com/bmizerany/assert"
. "github.com/gwenn/gosqlite"
) )
func half(ctx *ScalarContext, nArg int) { func half(ctx *ScalarContext, nArg int) {
...@@ -25,13 +26,13 @@ func half(ctx *ScalarContext, nArg int) { ...@@ -25,13 +26,13 @@ func half(ctx *ScalarContext, nArg int) {
func TestScalarFunction(t *testing.T) { func TestScalarFunction(t *testing.T) {
db := open(t) db := open(t)
defer checkClose(db, t) defer checkClose(db, t)
err := db.CreateScalarFunction("half", 1, nil, half, nil) err := db.CreateScalarFunction("half", 1, true, nil, half, nil)
checkNoError(t, err, "couldn't create function: %s") checkNoError(t, err, "couldn't create function: %s")
var d float64 var d float64
err = db.OneValue("SELECT half(6)", &d) err = db.OneValue("SELECT half(6)", &d)
checkNoError(t, err, "couldn't retrieve result: %s") checkNoError(t, err, "couldn't retrieve result: %s")
assert.Equal(t, 3.0, d) assert.Equal(t, 3.0, d)
err = db.CreateScalarFunction("half", 1, nil, nil, nil) err = db.CreateScalarFunction("half", 1, true, nil, nil, nil)
checkNoError(t, err, "couldn't destroy function: %s") checkNoError(t, err, "couldn't destroy function: %s")
} }
...@@ -72,7 +73,7 @@ func reDestroy(ad interface{}) { ...@@ -72,7 +73,7 @@ func reDestroy(ad interface{}) {
func TestRegexpFunction(t *testing.T) { func TestRegexpFunction(t *testing.T) {
db := open(t) db := open(t)
defer checkClose(db, t) defer checkClose(db, t)
err := db.CreateScalarFunction("regexp", 2, nil, re, reDestroy) err := db.CreateScalarFunction("regexp", 2, true, nil, re, reDestroy)
checkNoError(t, err, "couldn't create function: %s") checkNoError(t, err, "couldn't create function: %s")
s, err := db.Prepare("SELECT regexp('l.s[aeiouy]', name) from (SELECT 'lisa' AS name UNION ALL SELECT 'bart')") s, err := db.Prepare("SELECT regexp('l.s[aeiouy]', name) from (SELECT 'lisa' AS name UNION ALL SELECT 'bart')")
checkNoError(t, err, "couldn't prepare statement: %s") checkNoError(t, err, "couldn't prepare statement: %s")
...@@ -106,13 +107,13 @@ func user(ctx *ScalarContext, nArg int) { ...@@ -106,13 +107,13 @@ func user(ctx *ScalarContext, nArg int) {
func TestUserFunction(t *testing.T) { func TestUserFunction(t *testing.T) {
db := open(t) db := open(t)
defer checkClose(db, t) defer checkClose(db, t)
err := db.CreateScalarFunction("user", 0, nil, user, nil) err := db.CreateScalarFunction("user", 0, false, nil, user, nil)
checkNoError(t, err, "couldn't create function: %s") checkNoError(t, err, "couldn't create function: %s")
var name string var name string
err = db.OneValue("SELECT user()", &name) err = db.OneValue("SELECT user()", &name)
checkNoError(t, err, "couldn't retrieve result: %s") checkNoError(t, err, "couldn't retrieve result: %s")
assert.Tf(t, len(name) > 0, "unexpected user name: %q", name) assert.Tf(t, len(name) > 0, "unexpected user name: %q", name)
err = db.CreateScalarFunction("user", 1, nil, nil, nil) err = db.CreateScalarFunction("user", 1, false, nil, nil, nil)
checkNoError(t, err, "couldn't destroy function: %s") checkNoError(t, err, "couldn't destroy function: %s")
} }
...@@ -186,7 +187,7 @@ func BenchmarkHalf(b *testing.B) { ...@@ -186,7 +187,7 @@ func BenchmarkHalf(b *testing.B) {
db, _ := Open(":memory:") db, _ := Open(":memory:")
defer db.Close() defer db.Close()
randomFill(db, 1) randomFill(db, 1)
db.CreateScalarFunction("half", 1, nil, half, nil) db.CreateScalarFunction("half", 1, true, nil, half, nil)
cs, _ := db.Prepare("SELECT count(1) FROM test WHERE half(rank) > 20") cs, _ := db.Prepare("SELECT count(1) FROM test WHERE half(rank) > 20")
defer cs.Finalize() defer cs.Finalize()
...@@ -202,7 +203,7 @@ func BenchmarkRegexp(b *testing.B) { ...@@ -202,7 +203,7 @@ func BenchmarkRegexp(b *testing.B) {
db, _ := Open(":memory:") db, _ := Open(":memory:")
defer db.Close() defer db.Close()
randomFill(db, 1) randomFill(db, 1)
db.CreateScalarFunction("regexp", 2, nil, re, reDestroy) db.CreateScalarFunction("regexp", 2, true, nil, re, reDestroy)
cs, _ := db.Prepare("SELECT count(1) FROM test WHERE name regexp '(?i)\\blisa\\b'") cs, _ := db.Prepare("SELECT count(1) FROM test WHERE name regexp '(?i)\\blisa\\b'")
defer cs.Finalize() defer cs.Finalize()
......
...@@ -299,6 +299,8 @@ const ( ...@@ -299,6 +299,8 @@ const (
StmtStatusFullScanStep StmtStatus = C.SQLITE_STMTSTATUS_FULLSCAN_STEP StmtStatusFullScanStep StmtStatus = C.SQLITE_STMTSTATUS_FULLSCAN_STEP
StmtStatusSort StmtStatus = C.SQLITE_STMTSTATUS_SORT StmtStatusSort StmtStatus = C.SQLITE_STMTSTATUS_SORT
StmtStatusAutoIndex StmtStatus = C.SQLITE_STMTSTATUS_AUTOINDEX StmtStatusAutoIndex StmtStatus = C.SQLITE_STMTSTATUS_AUTOINDEX
// StmtStatusVmStep StmtStatus = C.SQLITE_STMTSTATUS_VM_STEP
) )
// Status returns the value of a status counter for a prepared statement. // Status returns the value of a status counter for a prepared statement.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment