Commit 9c0804e7 authored by gwenn's avatar gwenn

Improve User function support.

parent 2469a2f1
......@@ -71,7 +71,8 @@ type Context struct {
}
type ScalarContext struct {
Context
ad map[int]interface{} // Function Auxiliary Data
ad map[int]interface{} // Function Auxiliary Data
udf *sqliteFunction
}
type AggregateContext struct {
Context
......@@ -308,25 +309,22 @@ type FinalFunction func(ctx *AggregateContext)
type DestroyFunctionData func(pApp interface{})
type sqliteFunction struct {
scalar ScalarFunction
step StepFunction
final FinalFunction
d DestroyFunctionData
pApp interface{}
contexts map[*C.sqlite3_context]*AggregateContext
scalar ScalarFunction
step StepFunction
final FinalFunction
d DestroyFunctionData
pApp interface{}
scalarCtxs map[*ScalarContext]bool
aggrCtxs map[*AggregateContext]bool
}
// To prevent Context from being gced
// TODO Retry to put this in the sqliteFunction
var contexts map[*C.sqlite3_context]*ScalarContext = make(map[*C.sqlite3_context]*ScalarContext)
//export goXAuxDataDestroy
func goXAuxDataDestroy(ad unsafe.Pointer) {
c := (*ScalarContext)(ad)
if c != nil {
delete(contexts, c.sc)
delete(c.udf.scalarCtxs, c)
}
// fmt.Printf("Contexts: %v\n", contexts)
// fmt.Printf("Contexts: %v\n", c.udf.scalarCtxs)
}
//export goXFunc
......@@ -337,9 +335,10 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
if c == nil {
c = new(ScalarContext)
c.sc = (*C.sqlite3_context)(scp)
c.udf = udf
C.goSqlite3SetAuxdata(c.sc, 0, unsafe.Pointer(c))
// To make sure it is not cged
contexts[c.sc] = c
udf.scalarCtxs[c] = true
}
c.argv = (**C.sqlite3_value)(argv)
udf.scalar(c, argc)
......@@ -359,7 +358,7 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
c.sc = (*C.sqlite3_context)(scp)
*(*unsafe.Pointer)(cp) = unsafe.Pointer(c)
// To make sure it is not cged
udf.contexts[c.sc] = c
udf.aggrCtxs[c] = true
} else {
c = (*AggregateContext)(p)
}
......@@ -378,12 +377,12 @@ func goXFinal(scp, udfp unsafe.Pointer) {
p := *(*unsafe.Pointer)(cp)
if p != nil {
c := (*AggregateContext)(p)
delete(udf.contexts, c.sc)
delete(udf.aggrCtxs, c)
c.sc = (*C.sqlite3_context)(scp)
udf.final(c)
}
}
// fmt.Printf("Contexts: %v\n", udf.contexts)
// fmt.Printf("Contexts: %v\n", udf.aggrCtxts)
}
//export goXDestroy
......@@ -407,7 +406,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, nil, nil, d, pApp, nil}
udf := &sqliteFunction{f, nil, nil, d, pApp, make(map[*ScalarContext]bool), nil}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
......@@ -429,7 +428,7 @@ func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp inter
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{nil, step, final, d, pApp, make(map[*C.sqlite3_context]*AggregateContext)}
udf := &sqliteFunction{nil, step, final, d, pApp, nil, make(map[*AggregateContext]bool)}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
......
......@@ -32,10 +32,13 @@ func TestScalarFunction(t *testing.T) {
checkNoError(t, err, "couldn't destroy function: %s")
}
var reused bool
func re(ctx *ScalarContext, nArg int) {
ad := ctx.GetAuxData(0)
var re *regexp.Regexp
if ad == nil {
reused = false
//println("Compile")
var err error
re, err = regexp.Compile(ctx.Text(0))
......@@ -45,6 +48,7 @@ func re(ctx *ScalarContext, nArg int) {
}
ctx.SetAuxData(0, re)
} else {
reused = true
//println("Reuse")
var ok bool
if re, ok = ad.(*regexp.Regexp); !ok {
......@@ -71,6 +75,7 @@ func TestRegexpFunction(t *testing.T) {
s, err := db.Prepare("select regexp('l.s[aeiouy]', name) from (select 'lisa' as name union all select 'bart')")
checkNoError(t, err, "couldn't prepare statement: %s")
defer s.Finalize()
if b := Must(s.Next()); !b {
t.Fatalf("No result")
}
......@@ -79,6 +84,10 @@ func TestRegexpFunction(t *testing.T) {
if i != 1 {
t.Errorf("Expected %d but got %d", 1, i)
}
if reused {
t.Errorf("unexpected reused state")
}
if b := Must(s.Next()); !b {
t.Fatalf("No result")
}
......@@ -87,6 +96,9 @@ func TestRegexpFunction(t *testing.T) {
if i != 0 {
t.Errorf("Expected %d but got %d", 0, i)
}
if !reused {
t.Errorf("unexpected reused state")
}
}
func sumStep(ctx *AggregateContext, nArg int) {
......@@ -158,8 +170,8 @@ func BenchmarkHalf(b *testing.B) {
db, _ := Open("")
defer db.Close()
randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
db.CreateScalarFunction("half", 1, nil, half, nil)
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
b.StartTimer()
for i := 0; i < b.N; i++ {
......@@ -173,8 +185,8 @@ func BenchmarkRegexp(b *testing.B) {
db, _ := Open("")
defer db.Close()
randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
db.CreateScalarFunction("regexp", 2, nil, re, reDestroy)
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
b.StartTimer()
for i := 0; i < b.N; i++ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment