Commit e12f6b1b authored by gwenn's avatar gwenn

Add NewDriver to hook connection creation.

parent 7c9d6542
......@@ -17,7 +17,7 @@ import (
)
func init() {
sql.Register("sqlite3", &impl{})
sql.Register("sqlite3", &impl{open: defaultOpen})
if os.Getenv("SQLITE_LOG") != "" {
ConfigLog(func(d interface{}, err error, msg string) {
log.Printf("%s: %s, %s\n", d, err, msg)
......@@ -28,6 +28,8 @@ func init() {
// impl is an adapter to database/sql/driver
type impl struct {
open func(name string) (*Conn, error)
configure func(*Conn) error
}
type conn struct {
c *Conn
......@@ -42,17 +44,42 @@ type rowsImpl struct {
columnNames []string // cache
}
// Open opens a new database connection.
// ":memory:" for memory db,
// "" for temp file db
func (d *impl) Open(name string) (driver.Conn, error) {
// NewDriver creates a new driver with specialized connection creation/configuration.
// NewDriver(customOpen, nil) // no post-creation hook
// NewDriver(nil, customConfigure) // default connection creation but specific configuration step
func NewDriver(open func(name string) (*Conn, error), configure func(*Conn) error) driver.Driver {
if open == nil {
open = defaultOpen
}
return &impl{open: open, configure: configure}
}
var defaultOpen = func(name string) (*Conn, error) {
// OpenNoMutex == multi-thread mode (http://sqlite.org/compile.html#threadsafe and http://sqlite.org/threadsafe.html)
c, err := Open(name, OpenUri, OpenNoMutex, OpenReadWrite, OpenCreate)
if err != nil {
return nil, err
}
c.BusyTimeout(10 * time.Second)
//c.DefaultTimeLayout = "2006-01-02 15:04:05.999999999"
c.ScanNumericalAsTime = true
return c, nil
}
// Open opens a new database connection.
// ":memory:" for memory db,
// "" for temp file db
func (d *impl) Open(name string) (driver.Conn, error) {
c, err := d.open(name)
if err != nil {
return nil, err
}
if d.configure != nil {
if err = d.configure(c); err != nil {
_ = c.Close()
return nil, err
}
}
return &conn{c}, nil
}
......
......@@ -7,6 +7,7 @@ package sqlite_test
import (
"database/sql"
"testing"
"time"
"github.com/bmizerany/assert"
"github.com/gwenn/gosqlite"
......@@ -167,8 +168,28 @@ func TestRowsWithStmtClosed(t *testing.T) {
func TestUnwrap(t *testing.T) {
db := sqlOpen(t)
defer checkSqlDbClose(db, t)
conn := sqlite.Unwrap(db)
assert.Tf(t, conn != nil, "got %#v; want *sqlite.Conn", conn)
// fmt.Printf("%#v\n", conn)
conn.TotalChanges()
}
func TestCustomRegister(t *testing.T) {
sql.Register("sqlite3ReadOnly", sqlite.NewDriver(func(name string) (*sqlite.Conn, error) {
c, err := sqlite.Open(name, sqlite.OpenUri, sqlite.OpenNoMutex, sqlite.OpenReadOnly)
if err != nil {
return nil, err
}
c.BusyTimeout(10 * time.Second)
return c, nil
}, nil))
// readlonly memory db is useless but...
db, err := sql.Open("sqlite3ReadOnly", ":memory:")
checkNoError(t, err, "Error while opening customized db: %s")
defer checkSqlDbClose(db, t)
conn := sqlite.Unwrap(db)
ro, err := conn.Readonly("main")
checkNoError(t, err, "Error while setting reverse_unordered_selects status: %s")
assert.Tf(t, ro, "readonly = %t; want %t", ro, true)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment