Commit ee538ffb authored by gwenn's avatar gwenn

Fix custom aggregate function support.

parent f282917c
...@@ -347,7 +347,7 @@ func goXAuxDataDestroy(ad unsafe.Pointer) { ...@@ -347,7 +347,7 @@ func goXAuxDataDestroy(ad unsafe.Pointer) {
if c != nil { if c != nil {
delete(contexts, c.sc) delete(contexts, c.sc)
} }
//fmt.Printf("%v\n", contexts) // fmt.Printf("Contexts: %v\n", contexts)
} }
//export goXFunc //export goXFunc
...@@ -370,10 +370,20 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) { ...@@ -370,10 +370,20 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
//export goXStep //export goXStep
func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) { func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
udf := (*sqliteFunction)(udfp) udf := (*sqliteFunction)(udfp)
var c *Context var cp unsafe.Pointer
c = (*Context)(C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), C.int(unsafe.Sizeof(c)))) cp = C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), C.int(unsafe.Sizeof(cp)))
if c != nil { if cp != nil {
c.sc = (*C.sqlite3_context)(scp) var c *Context
p := *(*unsafe.Pointer)(cp)
if p == nil {
c = new(Context)
c.sc = (*C.sqlite3_context)(scp)
*(*unsafe.Pointer)(cp) = unsafe.Pointer(c)
// To make sure it is not cged
contexts[c.sc] = c
} else {
c = (*Context)(p)
}
c.argv = (**C.sqlite3_value)(argv) c.argv = (**C.sqlite3_value)(argv)
udf.funcOrStep(c, argc) udf.funcOrStep(c, argc)
...@@ -384,11 +394,17 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) { ...@@ -384,11 +394,17 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
//export goXFinal //export goXFinal
func goXFinal(scp, udfp unsafe.Pointer) { func goXFinal(scp, udfp unsafe.Pointer) {
udf := (*sqliteFunction)(udfp) udf := (*sqliteFunction)(udfp)
c := (*Context)(C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), 0)) cp := C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), 0)
if c != nil { if cp != nil {
//c.sc = (*C.sqlite3_context)(scp) p := *(*unsafe.Pointer)(cp)
udf.final(c) if p != nil {
c := (*Context)(p)
delete(contexts, c.sc)
c.sc = (*C.sqlite3_context)(scp)
udf.final(c)
}
} }
// fmt.Printf("Contexts: %v\n", contexts)
} }
//export goXDestroy //export goXDestroy
......
...@@ -88,23 +88,22 @@ func TestRegexpFunction(t *testing.T) { ...@@ -88,23 +88,22 @@ func TestRegexpFunction(t *testing.T) {
} }
} }
func sumStep(ctx *Context, nArg int) { func sumStep(ctx *Context, nArg int) {
nt := ctx.NumericType(0) nt := ctx.NumericType(0)
if nt == Integer || nt == Float { if nt == Integer || nt == Float {
var sum float64 var sum int64
var ok bool var ok bool
if sum, ok = (ctx.AggregateContext).(float64); !ok { if sum, ok = (ctx.AggregateContext).(int64); !ok {
sum = 0 sum = 0
} }
sum += ctx.Double(0) sum += ctx.Int64(0)
ctx.AggregateContext = sum ctx.AggregateContext = sum
} }
} }
func sumFinal(ctx *Context) { func sumFinal(ctx *Context) {
if sum, ok := (ctx.AggregateContext).(float64); ok { if sum, ok := (ctx.AggregateContext).(int64); ok {
ctx.ResultDouble(sum) ctx.ResultInt64(sum)
} else { } else {
ctx.ResultNull() ctx.ResultNull()
} }
...@@ -116,10 +115,10 @@ func TestSumFunction(t *testing.T) { ...@@ -116,10 +115,10 @@ func TestSumFunction(t *testing.T) {
defer db.Close() defer db.Close()
err = db.CreateAggregateFunction("mysum", 1, nil, sumStep, sumFinal, nil) err = db.CreateAggregateFunction("mysum", 1, nil, sumStep, sumFinal, nil)
checkNoError(t, err, "couldn't create function: %s") checkNoError(t, err, "couldn't create function: %s")
i, err := db.OneValue("select sum(i) from (select 2 as i union all select 2)") i, err := db.OneValue("select mysum(i) from (select 2 as i union all select 2)")
checkNoError(t, err, "couldn't execute statement: %s") checkNoError(t, err, "couldn't execute statement: %s")
if i != int64(4) { if i != int64(4) {
t.Errorf("Expected %d but got %d", 4, i) t.Errorf("Expected %d but got %v", 4, i)
} }
} }
......
...@@ -404,4 +404,4 @@ func TestScanNull(t *testing.T) { ...@@ -404,4 +404,4 @@ func TestScanNull(t *testing.T) {
} else if ps != nil { } else if ps != nil {
t.Errorf("Expected nil but got %p\n", ps) t.Errorf("Expected nil but got %p\n", ps)
} }
} }
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment