Commit 3f258c0f authored by gwenn's avatar gwenn

First draft of aggregation function.

parent ab027b97
......@@ -52,17 +52,27 @@ static void goSqlite3SetAuxdata(sqlite3_context *ctx, int N, void *ad) {
sqlite3_set_auxdata(ctx, N, ad, goXAuxDataDestroy);
}
extern void goXFunc(sqlite3_context *ctx, void *udf, void *goctx, int argc, sqlite3_value **argv);
extern void goXFuncOrStep(sqlite3_context *ctx, void *udf, void *goctx, int argc, sqlite3_value **argv);
extern void goXFinal(void *udf, void *goctx);
extern void goXDestroy(void *pApp);
static void cXFunc(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
static void cXFuncOrStep(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
void *udf = sqlite3_user_data(ctx);
void *goctx = sqlite3_get_auxdata(ctx, 0);
goXFunc(ctx, udf, goctx, argc, argv);
goXFuncOrStep(ctx, udf, goctx, argc, argv);
}
static void cXFinal(sqlite3_context *ctx) {
void *udf = sqlite3_user_data(ctx);
void *goctx = sqlite3_get_auxdata(ctx, 0);
goXFinal(udf, goctx);
}
static int goSqlite3CreateScalarFunction(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp) {
return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, cXFunc, NULL, NULL, goXDestroy);
return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, cXFuncOrStep, NULL, NULL, goXDestroy);
}
static int goSqlite3CreateAggregateFunction(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp) {
return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, NULL, cXFuncOrStep, cXFinal, goXDestroy);
}
*/
import "C"
......@@ -81,9 +91,10 @@ sqlite3 *sqlite3_context_db_handle(sqlite3_context*);
*/
type Context struct {
sc *C.sqlite3_context
argv **C.sqlite3_value
ad map[int]interface{} // Function Auxiliary Data
sc *C.sqlite3_context
argv **C.sqlite3_value
ad map[int]interface{} // Function Auxiliary Data
AggregateContext interface{} // Aggregate Function Context
}
func (c *Context) Result(r interface{}) {
......@@ -315,14 +326,14 @@ type FinalFunction func(ctx *Context)
type DestroyFunctionData func(pApp interface{})
/*
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
*/
type sqliteFunction struct {
f ScalarFunction
d DestroyFunctionData
pApp interface{}
funcOrStep ScalarFunction
final FinalFunction
d DestroyFunctionData
pApp interface{}
}
// To prevent Context from being gced
......@@ -338,8 +349,8 @@ func goXAuxDataDestroy(ad unsafe.Pointer) {
//fmt.Printf("%v\n", contexts)
}
//export goXFunc
func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
//export goXFuncOrStep
func goXFuncOrStep(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
// To avoid the creation of a Context at each call, just put it in auxdata
c := (*Context)(ctxp)
......@@ -351,7 +362,15 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
contexts[c.sc] = c
}
c.argv = (**C.sqlite3_value)(argv)
udf.f(c, argc)
udf.funcOrStep(c, argc)
c.argv = nil
}
//export goXFinal
func goXFinal(udfp, ctxp unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
c := (*Context)(ctxp)
udf.final(c)
}
//export goXDestroy
......@@ -375,7 +394,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, d, pApp}
udf := &sqliteFunction{f, nil, d, pApp}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
......@@ -383,29 +402,31 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)))
}
// Obtain Aggregate Function Context
/*
// Calls http://sqlite.org/c3ref/aggregate_context.html
func (c *Context) AggregateContext(nBytes int) interface{} {
return C.sqlite3_aggregate_context(c.sc, C.int(nBytes))
}
*/
// Create or redefine SQL functions
// TODO Make possible to specify the preferred encoding
// Calls http://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{}, f ScalarFunction, d DestroyFunctionData) error {
func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{},
step ScalarFunction, final FinalFunction, d DestroyFunctionData) error {
fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname))
if f == nil {
if step == nil {
if len(c.udfs) > 0 {
delete(c.udfs, functionName)
}
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
}
// To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, d, pApp}
udf := &sqliteFunction{step, final, d, pApp}
if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction)
}
c.udfs[functionName] = udf // FIXME same function name with different args is not supported
return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)))
return c.error(C.goSqlite3CreateAggregateFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)))
}
......@@ -80,6 +80,7 @@ func TestRegexpFunction(t *testing.T) {
if err != nil {
t.Fatalf("couldn't prepare statement: %s", err)
}
defer s.Finalize()
if b := Must(s.Next()); !b {
t.Fatalf("No result")
}
......@@ -100,15 +101,54 @@ func TestRegexpFunction(t *testing.T) {
if i != 0 {
t.Errorf("Expected %d but got %d", 0, i)
}
if err = s.Finalize(); err != nil {
t.Fatalf("couldn't finalize statement: %s", err)
}
func sumStep(ctx *Context, nArg int) {
nt := ctx.NumericType(0)
if nt == Integer || nt == Float {
var sum float64
var ok bool
if sum, ok = (ctx.AggregateContext).(float64); !ok {
sum = 0
}
sum += ctx.Double(0)
ctx.AggregateContext = sum
}
}
func sumFinal(ctx *Context) {
if sum, ok := (ctx.AggregateContext).(float64); ok {
ctx.ResultDouble(sum)
} else {
ctx.ResultNull()
}
}
/*
func TestSumFunction(t *testing.T) {
db, err := Open("")
if err != nil {
t.Fatalf("couldn't open database file: %s", err)
}
defer db.Close()
if err = db.CreateAggregateFunction("sum", 1, nil, sumStep, sumFinal, nil); err != nil {
t.Fatalf("couldn't create function: %s", err)
}
i, err := db.OneValue("select sum(i) from (select 2 as i union all select 2 as i)")
if err != nil {
t.Fatalf("couldn't execute statement: %s", err)
}
if i != 4 {
t.Errorf("Expected %d but got %d", 4, i)
}
}
*/
func randomFill(db *Conn, n int) {
db.Exec("DROP TABLE IF EXISTS test")
db.Exec("CREATE TABLE test (name TEXT, rank int)")
s, _ := db.Prepare("INSERT INTO test (name, rank) VALUES (?, ?)")
defer s.Finalize()
names := []string{"Bart", "Homer", "Lisa", "Maggie", "Marge"}
......@@ -116,7 +156,6 @@ func randomFill(db *Conn, n int) {
for i := 0; i < n; i++ {
s.Exec(names[rand.Intn(len(names))], rand.Intn(100))
}
s.Finalize()
db.Commit()
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment