Commit 65a5bb7a authored by Jeremy Hylton's avatar Jeremy Hylton

A bunch of small fixes.

Make txn_factory an attribute of the base class.

Raise an exception when prepare() returns False, rather than
automatically aborting.

Pass transaction object to Rollback() so that rollback() method can
check state of transaction.

Add IllegalStateError calls to prevent assertions from failing.
XXX Should the manager duplicate these checks?

Add suspend() and resume() to non-threaded txn manager.

Fix bug that caused threaded suspend() to fail with KeyError for
thread with no current transaction.
parent 38c7ec78
...@@ -12,6 +12,12 @@ class AbstractTransactionManager(object): ...@@ -12,6 +12,12 @@ class AbstractTransactionManager(object):
# base class to provide commit logic # base class to provide commit logic
# concrete class must provide logger attribute # concrete class must provide logger attribute
txn_factory = Transaction
# XXX the methods below use assertions, but perhaps they should
# check errors. on the other hand, the transaction instances
# do raise exceptions.
def commit(self, txn): def commit(self, txn):
# commit calls _finishCommit() or abort() # commit calls _finishCommit() or abort()
assert txn._status is Status.ACTIVE assert txn._status is Status.ACTIVE
...@@ -21,16 +27,14 @@ class AbstractTransactionManager(object): ...@@ -21,16 +27,14 @@ class AbstractTransactionManager(object):
try: try:
for r in txn._resources: for r in txn._resources:
if prepare_ok and not r.prepare(txn): if prepare_ok and not r.prepare(txn):
prepare_ok = False raise AbortError(r)
except: except:
txn._status = Status.FAILED txn._status = Status.FAILED
raise raise
txn._status = Status.PREPARED txn._status = Status.PREPARED
# XXX An error below is intolerable. What state to use? # XXX An error below is intolerable. What state to use?
if prepare_ok: # Need code to handle this case.
self._finishCommit(txn) self._finishCommit(txn)
else:
self.abort(txn)
def _finishCommit(self, txn): def _finishCommit(self, txn):
self.logger.debug("%s: commit", txn) self.logger.debug("%s: commit", txn)
...@@ -48,18 +52,18 @@ class AbstractTransactionManager(object): ...@@ -48,18 +52,18 @@ class AbstractTransactionManager(object):
txn._status = Status.ABORTED txn._status = Status.ABORTED
def savepoint(self, txn): def savepoint(self, txn):
assert txn._status == Status.ACTIVE
self.logger.debug("%s: savepoint", txn) self.logger.debug("%s: savepoint", txn)
return Rollback([r.savepoint(txn) for r in txn._resources]) return Rollback(txn, [r.savepoint(txn) for r in txn._resources])
class TransactionManager(AbstractTransactionManager): class TransactionManager(AbstractTransactionManager):
txn_factory = Transaction
__implements__ = ITransactionManager __implements__ = ITransactionManager
def __init__(self): def __init__(self):
self.logger = logging.getLogger("txn") self.logger = logging.getLogger("txn")
self._current = None self._current = None
self._suspended = Set()
def get(self): def get(self):
if self._current is None: if self._current is None:
...@@ -67,9 +71,11 @@ class TransactionManager(AbstractTransactionManager): ...@@ -67,9 +71,11 @@ class TransactionManager(AbstractTransactionManager):
return self._current return self._current
def begin(self): def begin(self):
txn = self.txn_factory(self) if self._current is not None:
self.logger.debug("%s: begin", txn) self._current.abort()
return txn self._current = self.txn_factory(self)
self.logger.debug("%s: begin", self._current)
return self._current
def commit(self, txn): def commit(self, txn):
super(TransactionManager, self).commit(txn) super(TransactionManager, self).commit(txn)
...@@ -79,16 +85,31 @@ class TransactionManager(AbstractTransactionManager): ...@@ -79,16 +85,31 @@ class TransactionManager(AbstractTransactionManager):
super(TransactionManager, self).abort(txn) super(TransactionManager, self).abort(txn)
self._current = None self._current = None
# XXX need suspend and resume def suspend(self, txn):
if self._current != txn:
raise TransactionError("Can't suspend transaction because "
"it is not active")
self._suspended.add(txn)
self._current = None
def resume(self, txn):
if self._current is not None:
raise TransactionError("Can't resume while other "
"transaction is active")
self._suspended.remove(txn)
self._current = txn
class Rollback(object): class Rollback(object):
__implements__ = IRollback __implements__ = IRollback
def __init__(self, resources): def __init__(self, txn, resources):
self._txn = txn
self._resources = resources self._resources = resources
def rollback(self): def rollback(self):
if self._txn.status() != Status.ACTIVE:
raise IllegalStateError("rollback", self._txn.status())
for r in self._resources: for r in self._resources:
r.rollback() r.rollback()
...@@ -150,7 +171,7 @@ class ThreadedTransactionManager(AbstractTransactionManager): ...@@ -150,7 +171,7 @@ class ThreadedTransactionManager(AbstractTransactionManager):
def suspend(self, txn): def suspend(self, txn):
tid = thread.get_ident() tid = thread.get_ident()
if self._pool[tid] is txn: if self._pool.get(tid) is txn:
self._suspend.add(txn) self._suspend.add(txn)
del self._pool[tid] del self._pool[tid]
else: else:
...@@ -164,5 +185,5 @@ class ThreadedTransactionManager(AbstractTransactionManager): ...@@ -164,5 +185,5 @@ class ThreadedTransactionManager(AbstractTransactionManager):
tid) tid)
if txn not in self._suspend: if txn not in self._suspend:
raise TransactionError("unknown transaction: %s" % txn) raise TransactionError("unknown transaction: %s" % txn)
del self._suspend[txn] self._suspend.remove(txn)
self._pool[tid] = txn self._pool[tid] = txn
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment