Commit 69eefa95 authored by gwenn's avatar gwenn

Add Transaction method like in tclsqlite.

parent 506077b9
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
"io" "io"
"os" "os"
"reflect" "reflect"
"strconv"
"time" "time"
"unsafe" "unsafe"
) )
...@@ -180,6 +181,7 @@ type Conn struct { ...@@ -180,6 +181,7 @@ type Conn struct {
udfs map[string]*sqliteFunction udfs map[string]*sqliteFunction
modules map[string]*sqliteModule modules map[string]*sqliteModule
timeUsed time.Time timeUsed time.Time
nTransaction uint8
} }
// Version returns the run-time library version number // Version returns the run-time library version number
...@@ -465,6 +467,45 @@ func (c *Conn) Rollback() error { ...@@ -465,6 +467,45 @@ func (c *Conn) Rollback() error {
return c.exec("ROLLBACK") return c.exec("ROLLBACK")
} }
// Transaction is used to execute a function inside an SQLite database transaction.
// The transaction is committed when the function completes (with no error),
// or it rolls back if the function fails.
// If the transaction occurs within another transaction (only one that is started using this method) a Savepoint is created.
// Two errors may be returned: the first is the one returned by the f function,
// the second is the one returned by begin/commit/rollback.
func (c *Conn) Transaction(t TransactionType, f func(c *Conn) error) (gerr error, serr error) {
if c.nTransaction == 0 {
serr = c.BeginTransaction(t)
} else {
serr = c.Savepoint(strconv.Itoa(int(c.nTransaction)))
}
if serr != nil {
return
}
c.nTransaction++
defer func() {
c.nTransaction--
if gerr != nil {
if c.nTransaction == 0 {
serr = c.Rollback()
} else {
serr = c.RollbackSavepoint(strconv.Itoa(int(c.nTransaction)))
}
} else {
if c.nTransaction == 0 {
serr = c.Commit()
} else {
serr = c.ReleaseSavepoint(strconv.Itoa(int(c.nTransaction)))
}
if serr != nil {
c.Rollback()
}
}
}()
gerr = f(c)
return
}
// Savepoint starts a new transaction with a name. // Savepoint starts a new transaction with a name.
// (See http://sqlite.org/lang_savepoint.html) // (See http://sqlite.org/lang_savepoint.html)
func (c *Conn) Savepoint(name string) error { func (c *Conn) Savepoint(name string) error {
......
...@@ -93,7 +93,7 @@ func TestCreateTable(t *testing.T) { ...@@ -93,7 +93,7 @@ func TestCreateTable(t *testing.T) {
createTable(db, t) createTable(db, t)
} }
func TestTransaction(t *testing.T) { func TestManualTransaction(t *testing.T) {
db := open(t) db := open(t)
defer checkClose(db, t) defer checkClose(db, t)
checkNoError(t, db.Begin(), "Error while beginning transaction: %s") checkNoError(t, db.Begin(), "Error while beginning transaction: %s")
...@@ -243,6 +243,22 @@ func TestExecMisuse(t *testing.T) { ...@@ -243,6 +243,22 @@ func TestExecMisuse(t *testing.T) {
assert(t, "exec misuse expected", err != nil) assert(t, "exec misuse expected", err != nil)
} }
func TestTransaction(t *testing.T) {
db := open(t)
defer checkClose(db, t)
createTable(db, t)
gerr, serr := db.Transaction(Immediate, func(_ *Conn) error {
err, nerr := db.Transaction(Immediate, func(__ *Conn) error {
return db.Exec("INSERT INTO test VALUES (?, ?, ?, ?)", 0, 273.1, 1, "test")
})
checkNoError(t, err, "Applicative error: %s")
checkNoError(t, nerr, "SQLite error: %s")
return err
})
checkNoError(t, gerr, "Applicative error: %s")
checkNoError(t, serr, "SQLite error: %s")
}
func assertEquals(t *testing.T, format string, expected, actual interface{}) { func assertEquals(t *testing.T, format string, expected, actual interface{}) {
if expected != actual { if expected != actual {
t.Errorf(format, expected, actual) t.Errorf(format, expected, actual)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment