Commit 2469a2f1 authored by gwenn's avatar gwenn

Make Scalar context distinct from Aggregate context.

parent 36bc9bfe
......@@ -11,7 +11,7 @@ import (
func TestInterrupt(t *testing.T) {
db := open(t)
defer db.Close()
db.CreateScalarFunction("interrupt", 0, nil, func(ctx *Context, nArg int) {
db.CreateScalarFunction("interrupt", 0, nil, func(ctx *ScalarContext, nArg int) {
db.Interrupt()
ctx.ResultText("ok")
}, nil)
......
......@@ -66,10 +66,16 @@ sqlite3 *sqlite3_context_db_handle(sqlite3_context*);
*/
type Context struct {
sc *C.sqlite3_context
argv **C.sqlite3_value
ad map[int]interface{} // Function Auxiliary Data
AggregateContext interface{} // Aggregate Function Context
sc *C.sqlite3_context
argv **C.sqlite3_value
}
type ScalarContext struct {
Context
ad map[int]interface{} // Function Auxiliary Data
}
type AggregateContext struct {
Context
Aggregate interface{}
}
func (c *Context) Result(r interface{}) {
......@@ -202,7 +208,7 @@ func (c *Context) UserData() interface{} {
// Function auxiliary data
// (See sqlite3_get_auxdata, http://sqlite.org/c3ref/get_auxdata.html)
func (c *Context) GetAuxData(n int) interface{} {
func (c *ScalarContext) GetAuxData(n int) interface{} {
if len(c.ad) == 0 {
return nil
}
......@@ -212,7 +218,7 @@ func (c *Context) GetAuxData(n int) interface{} {
// Function auxiliary data
// No destructor is needed a priori
// (See sqlite3_set_auxdata, http://sqlite.org/c3ref/get_auxdata.html)
func (c *Context) SetAuxData(n int, ad interface{}) {
func (c *ScalarContext) SetAuxData(n int, ad interface{}) {
if len(c.ad) == 0 {
c.ad = make(map[int]interface{})
}
......@@ -296,25 +302,27 @@ func (c *Context) Value(i int) (value interface{}) {
return
}
type FuncOrStep func(ctx *Context, nArg int)
type FinalFunction func(ctx *Context)
type ScalarFunction func(ctx *ScalarContext, nArg int)
type StepFunction func(ctx *AggregateContext, nArg int)
type FinalFunction func(ctx *AggregateContext)
type DestroyFunctionData func(pApp interface{})
type sqliteFunction struct {
funcOrStep FuncOrStep
final FinalFunction
d DestroyFunctionData
pApp interface{}
contexts map[*C.sqlite3_context]*Context
scalar ScalarFunction
step StepFunction
final FinalFunction
d DestroyFunctionData
pApp interface{}
contexts map[*C.sqlite3_context]*AggregateContext
}
// To prevent Context from being gced
// TODO Retry to put this in the sqliteFunction
var contexts map[*C.sqlite3_context]*Context = make(map[*C.sqlite3_context]*Context)
var contexts map[*C.sqlite3_context]*ScalarContext = make(map[*C.sqlite3_context]*ScalarContext)
//export goXAuxDataDestroy
func goXAuxDataDestroy(ad unsafe.Pointer) {
c := (*Context)(ad)
c := (*ScalarContext)(ad)
if c != nil {
delete(contexts, c.sc)
}
......@@ -325,16 +333,16 @@ func goXAuxDataDestroy(ad unsafe.Pointer) {
func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
// To avoid the creation of a Context at each call, just put it in auxdata
c := (*Context)(ctxp)
c := (*ScalarContext)(ctxp)
if c == nil {
c = new(Context)
c = new(ScalarContext)
c.sc = (*C.sqlite3_context)(scp)
C.goSqlite3SetAuxdata(c.sc, 0, unsafe.Pointer(c))
// To make sure it is not cged
contexts[c.sc] = c
}
c.argv = (**C.sqlite3_value)(argv)
udf.funcOrStep(c, argc)
udf.scalar(c, argc)
c.argv = nil
}
......@@ -344,20 +352,20 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
var cp unsafe.Pointer
cp = C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), C.int(unsafe.Sizeof(cp)))
if cp != nil {
var c *Context
var c *AggregateContext
p := *(*unsafe.Pointer)(cp)
if p == nil {
c = new(Context)
c = new(AggregateContext)
c.sc = (*C.sqlite3_context)(scp)
*(*unsafe.Pointer)(cp) = unsafe.Pointer(c)
// To make sure it is not cged
udf.contexts[c.sc] = c
} else {
c = (*Context)(p)
c = (*AggregateContext)(p)
}
c.argv = (**C.sqlite3_value)(argv)
udf.funcOrStep(c, argc)
udf.step(c, argc)
c.argv = nil
}
}
......@@ -369,13 +377,13 @@ func goXFinal(scp, udfp unsafe.Pointer) {
if cp != nil {
p := *(*unsafe.Pointer)(cp)
if p != nil {
c := (*Context)(p)
c := (*AggregateContext)(p)
delete(udf.contexts, c.sc)
c.sc = (*C.sqlite3_context)(scp)
udf.final(c)
}
}
// fmt.Printf("Contexts: %v\n", contexts)
// fmt.Printf("Contexts: %v\n", udf.contexts)
}
//export goXDestroy
......@@ -389,7 +397,7 @@ func goXDestroy(pApp unsafe.Pointer) {
// Create or redefine SQL functions
// TODO Make possible to specify the preferred encoding
// (See http://sqlite.org/c3ref/create_function.html)
func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interface{}, f FuncOrStep, d DestroyFunctionData) error {
func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interface{}, f ScalarFunction, d DestroyFunctionData) error {
fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname))
if f == nil {
......@@ -399,7 +407,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, nil, d, pApp, nil}
udf := &sqliteFunction{f, nil, nil, d, pApp, nil}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
......@@ -411,7 +419,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
// TODO Make possible to specify the preferred encoding
// (See http://sqlite.org/c3ref/create_function.html)
func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{},
step FuncOrStep, final FinalFunction, d DestroyFunctionData) error {
step StepFunction, final FinalFunction, d DestroyFunctionData) error {
fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname))
if step == nil {
......@@ -421,7 +429,7 @@ func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp inter
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{step, final, d, pApp, make(map[*C.sqlite3_context]*Context)}
udf := &sqliteFunction{nil, step, final, d, pApp, make(map[*C.sqlite3_context]*AggregateContext)}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
......
......@@ -7,7 +7,7 @@ import (
"testing"
)
func half(ctx *Context, nArg int) {
func half(ctx *ScalarContext, nArg int) {
nt := ctx.NumericType(0)
if nt == Integer || nt == Float {
ctx.ResultDouble(ctx.Double(0) / 2)
......@@ -32,7 +32,7 @@ func TestScalarFunction(t *testing.T) {
checkNoError(t, err, "couldn't destroy function: %s")
}
func re(ctx *Context, nArg int) {
func re(ctx *ScalarContext, nArg int) {
ad := ctx.GetAuxData(0)
var re *regexp.Regexp
if ad == nil {
......@@ -89,21 +89,21 @@ func TestRegexpFunction(t *testing.T) {
}
}
func sumStep(ctx *Context, nArg int) {
func sumStep(ctx *AggregateContext, nArg int) {
nt := ctx.NumericType(0)
if nt == Integer || nt == Float {
var sum int64
var ok bool
if sum, ok = (ctx.AggregateContext).(int64); !ok {
if sum, ok = (ctx.Aggregate).(int64); !ok {
sum = 0
}
sum += ctx.Int64(0)
ctx.AggregateContext = sum
ctx.Aggregate = sum
}
}
func sumFinal(ctx *Context) {
if sum, ok := (ctx.AggregateContext).(int64); ok {
func sumFinal(ctx *AggregateContext) {
if sum, ok := (ctx.Aggregate).(int64); ok {
ctx.ResultInt64(sum)
} else {
ctx.ResultNull()
......@@ -143,13 +143,13 @@ func BenchmarkLike(b *testing.B) {
b.StopTimer()
db, _ := Open("")
defer db.Close()
randomFill(db, 1000)
randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where name like 'lisa'")
b.StartTimer()
for i := 0; i < b.N; i++ {
cs, _ := db.Prepare("SELECT count(1) FROM test where name like 'lisa'")
Must(cs.Next())
cs.Finalize()
cs.Reset()
}
}
......@@ -157,14 +157,14 @@ func BenchmarkHalf(b *testing.B) {
b.StopTimer()
db, _ := Open("")
defer db.Close()
randomFill(db, 1000)
randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
db.CreateScalarFunction("half", 1, nil, half, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
Must(cs.Next())
cs.Finalize()
cs.Reset()
}
}
......@@ -172,13 +172,13 @@ func BenchmarkRegexp(b *testing.B) {
b.StopTimer()
db, _ := Open("")
defer db.Close()
randomFill(db, 1000)
randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
db.CreateScalarFunction("regexp", 2, nil, re, reDestroy)
b.StartTimer()
for i := 0; i < b.N; i++ {
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
Must(cs.Next())
cs.Finalize()
cs.Reset()
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment