Commit c1ec610f authored by Bryton Lacquement's avatar Bryton Lacquement 🚪

Refactor the sqlite3-related code

parent 186b2cb6
from lib2to3.tests.test_fixers import FixerTestCase as lib2to3FixerTestCase from lib2to3.tests.test_fixers import FixerTestCase as lib2to3FixerTestCase
import sqlite3 import sqlite3
from my2to3.trace import database, get_data, tracing_functions from my2to3.trace import conn, get_data, tracing_functions
class FixerTestCase(lib2to3FixerTestCase): class FixerTestCase(lib2to3FixerTestCase):
...@@ -10,16 +10,9 @@ class FixerTestCase(lib2to3FixerTestCase): ...@@ -10,16 +10,9 @@ class FixerTestCase(lib2to3FixerTestCase):
super(FixerTestCase, self).setUp(fix_list, fixer_pkg, options) super(FixerTestCase, self).setUp(fix_list, fixer_pkg, options)
# Clear the database # Clear the database
# XXX: a better way probably exists. with conn:
# TODO: for table in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
# - refactor conn.execute("DELETE FROM %s" % table)
# - optimize
conn = sqlite3.connect(database)
c = conn.cursor()
for table in c.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
c.execute("DELETE FROM %s" % table)
conn.commit()
conn.close()
def assertDataEqual(self, table, data): def assertDataEqual(self, table, data):
self.assertEqual(get_data(table), data) self.assertEqual(get_data(table), data)
......
...@@ -4,6 +4,7 @@ from lib2to3.pgen2 import tokenize ...@@ -4,6 +4,7 @@ from lib2to3.pgen2 import tokenize
from lib2to3.refactor import get_fixers_from_package, RefactoringTool from lib2to3.refactor import get_fixers_from_package, RefactoringTool
database = "traces.db" database = "traces.db"
conn = sqlite3.connect(database)
tracing_functions = [] tracing_functions = []
...@@ -13,28 +14,22 @@ def register_tracing_function(f): ...@@ -13,28 +14,22 @@ def register_tracing_function(f):
def create_table(table, *columns): def create_table(table, *columns):
# TODO: with conn:
# - refactor
# - optimize
conn = sqlite3.connect(database)
c = conn.cursor()
v = ', '.join(columns) v = ', '.join(columns)
c.execute("CREATE TABLE IF NOT EXISTS %s (%s, UNIQUE (%s))" % (table, v, v)) conn.execute(
conn.commit() "CREATE TABLE IF NOT EXISTS %s (%s, UNIQUE (%s))" % (table, v, v)
conn.close() )
def insert_unique(*values): def insert_unique(*values):
conn = sqlite3.connect(database) with conn:
c = conn.cursor()
try: try:
c.execute('INSERT INTO %s VALUES (%s)' % ( conn.execute(
table, ', '.join(['?'] * len(values))), values) 'INSERT INTO %s VALUES (%s)' % (table, ', '.join('?' * len(values))),
values
)
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
if not str(e).startswith('UNIQUE constraint failed:'): if not str(e).startswith('UNIQUE constraint failed:'):
raise raise
else:
conn.commit()
conn.close()
return insert_unique return insert_unique
...@@ -45,21 +40,14 @@ def get_fixers(): ...@@ -45,21 +40,14 @@ def get_fixers():
def get_data(table, columns_to_select='*', conditions={}): def get_data(table, columns_to_select='*', conditions={}):
# TODO:
# - refactor
# - optimize
conn = sqlite3.connect(database)
c = conn.cursor()
query = "SELECT %s FROM %s" % ( query = "SELECT %s FROM %s" % (
', '.join(columns_to_select), ', '.join(columns_to_select),
table, table,
) )
if conditions: if conditions:
query += " WHERE " + ' AND '.join(k + " = " + v for k, v in conditions.items()) query += " WHERE " + ' AND '.join(k + " = :" + k for k in conditions.keys())
try: with conn:
return c.execute(query).fetchall() return conn.execute(query, conditions).fetchall()
finally:
conn.close()
def apply_fixers(string, name): def apply_fixers(string, name):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment