Commit 65a5bb7a authored by Jeremy Hylton's avatar Jeremy Hylton

A bunch of small fixes.

Make txn_factory an attribute of the base class.

Raise an exception when prepare() returns False, rather than
automatically aborting.

Pass transaction object to Rollback() so that rollback() method can
check state of transaction.

Add IllegalStateError calls to prevent assertions from failing.
XXX Should the manager duplicate these checks?

Add suspend() and resume() to non-threaded txn manager.

Fix bug that caused threaded suspend() to fail with KeyError for
thread with no current transaction.
parent 38c7ec78
......@@ -12,6 +12,12 @@ class AbstractTransactionManager(object):
# base class to provide commit logic
# concrete class must provide logger attribute
txn_factory = Transaction
# XXX the methods below use assertions, but perhaps they should
# check errors. on the other hand, the transaction instances
# do raise exceptions.
def commit(self, txn):
# commit calls _finishCommit() or abort()
assert txn._status is Status.ACTIVE
......@@ -21,16 +27,14 @@ class AbstractTransactionManager(object):
try:
for r in txn._resources:
if prepare_ok and not r.prepare(txn):
prepare_ok = False
raise AbortError(r)
except:
txn._status = Status.FAILED
raise
txn._status = Status.PREPARED
# XXX An error below is intolerable. What state to use?
if prepare_ok:
# Need code to handle this case.
self._finishCommit(txn)
else:
self.abort(txn)
def _finishCommit(self, txn):
self.logger.debug("%s: commit", txn)
......@@ -48,18 +52,18 @@ class AbstractTransactionManager(object):
txn._status = Status.ABORTED
def savepoint(self, txn):
assert txn._status == Status.ACTIVE
self.logger.debug("%s: savepoint", txn)
return Rollback([r.savepoint(txn) for r in txn._resources])
return Rollback(txn, [r.savepoint(txn) for r in txn._resources])
class TransactionManager(AbstractTransactionManager):
txn_factory = Transaction
__implements__ = ITransactionManager
def __init__(self):
self.logger = logging.getLogger("txn")
self._current = None
self._suspended = Set()
def get(self):
if self._current is None:
......@@ -67,9 +71,11 @@ class TransactionManager(AbstractTransactionManager):
return self._current
def begin(self):
txn = self.txn_factory(self)
self.logger.debug("%s: begin", txn)
return txn
if self._current is not None:
self._current.abort()
self._current = self.txn_factory(self)
self.logger.debug("%s: begin", self._current)
return self._current
def commit(self, txn):
super(TransactionManager, self).commit(txn)
......@@ -79,16 +85,31 @@ class TransactionManager(AbstractTransactionManager):
super(TransactionManager, self).abort(txn)
self._current = None
# XXX need suspend and resume
def suspend(self, txn):
if self._current != txn:
raise TransactionError("Can't suspend transaction because "
"it is not active")
self._suspended.add(txn)
self._current = None
def resume(self, txn):
if self._current is not None:
raise TransactionError("Can't resume while other "
"transaction is active")
self._suspended.remove(txn)
self._current = txn
class Rollback(object):
__implements__ = IRollback
def __init__(self, resources):
def __init__(self, txn, resources):
self._txn = txn
self._resources = resources
def rollback(self):
if self._txn.status() != Status.ACTIVE:
raise IllegalStateError("rollback", self._txn.status())
for r in self._resources:
r.rollback()
......@@ -150,7 +171,7 @@ class ThreadedTransactionManager(AbstractTransactionManager):
def suspend(self, txn):
tid = thread.get_ident()
if self._pool[tid] is txn:
if self._pool.get(tid) is txn:
self._suspend.add(txn)
del self._pool[tid]
else:
......@@ -164,5 +185,5 @@ class ThreadedTransactionManager(AbstractTransactionManager):
tid)
if txn not in self._suspend:
raise TransactionError("unknown transaction: %s" % txn)
del self._suspend[txn]
self._suspend.remove(txn)
self._pool[tid] = txn
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment