Commit ee538ffb authored by gwenn's avatar gwenn

Fix custom aggregate function support.

parent f282917c
......@@ -347,7 +347,7 @@ func goXAuxDataDestroy(ad unsafe.Pointer) {
if c != nil {
delete(contexts, c.sc)
}
//fmt.Printf("%v\n", contexts)
// fmt.Printf("Contexts: %v\n", contexts)
}
//export goXFunc
......@@ -370,10 +370,20 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
//export goXStep
func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
var c *Context
c = (*Context)(C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), C.int(unsafe.Sizeof(c))))
if c != nil {
c.sc = (*C.sqlite3_context)(scp)
var cp unsafe.Pointer
cp = C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), C.int(unsafe.Sizeof(cp)))
if cp != nil {
var c *Context
p := *(*unsafe.Pointer)(cp)
if p == nil {
c = new(Context)
c.sc = (*C.sqlite3_context)(scp)
*(*unsafe.Pointer)(cp) = unsafe.Pointer(c)
// To make sure it is not cged
contexts[c.sc] = c
} else {
c = (*Context)(p)
}
c.argv = (**C.sqlite3_value)(argv)
udf.funcOrStep(c, argc)
......@@ -384,11 +394,17 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
//export goXFinal
func goXFinal(scp, udfp unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
c := (*Context)(C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), 0))
if c != nil {
//c.sc = (*C.sqlite3_context)(scp)
udf.final(c)
cp := C.sqlite3_aggregate_context((*C.sqlite3_context)(scp), 0)
if cp != nil {
p := *(*unsafe.Pointer)(cp)
if p != nil {
c := (*Context)(p)
delete(contexts, c.sc)
c.sc = (*C.sqlite3_context)(scp)
udf.final(c)
}
}
// fmt.Printf("Contexts: %v\n", contexts)
}
//export goXDestroy
......
......@@ -88,23 +88,22 @@ func TestRegexpFunction(t *testing.T) {
}
}
func sumStep(ctx *Context, nArg int) {
nt := ctx.NumericType(0)
if nt == Integer || nt == Float {
var sum float64
var sum int64
var ok bool
if sum, ok = (ctx.AggregateContext).(float64); !ok {
if sum, ok = (ctx.AggregateContext).(int64); !ok {
sum = 0
}
sum += ctx.Double(0)
sum += ctx.Int64(0)
ctx.AggregateContext = sum
}
}
func sumFinal(ctx *Context) {
if sum, ok := (ctx.AggregateContext).(float64); ok {
ctx.ResultDouble(sum)
if sum, ok := (ctx.AggregateContext).(int64); ok {
ctx.ResultInt64(sum)
} else {
ctx.ResultNull()
}
......@@ -116,10 +115,10 @@ func TestSumFunction(t *testing.T) {
defer db.Close()
err = db.CreateAggregateFunction("mysum", 1, nil, sumStep, sumFinal, nil)
checkNoError(t, err, "couldn't create function: %s")
i, err := db.OneValue("select sum(i) from (select 2 as i union all select 2)")
i, err := db.OneValue("select mysum(i) from (select 2 as i union all select 2)")
checkNoError(t, err, "couldn't execute statement: %s")
if i != int64(4) {
t.Errorf("Expected %d but got %d", 4, i)
t.Errorf("Expected %d but got %v", 4, i)
}
}
......
......@@ -404,4 +404,4 @@ func TestScanNull(t *testing.T) {
} else if ps != nil {
t.Errorf("Expected nil but got %p\n", ps)
}
}
\ No newline at end of file
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment