Commit da5312e1 authored by Jim Fulton's avatar Jim Fulton

Refactored storage server to support multiple client threads.

Changed ZEO undo protocol. (Undo is disabled with older clients.)
Now use one-way undoa.  Undone oids are now returned by (tpc_)vote for
ZEO. Undo no-longer gets commit lock.
parent a3c0cf67
...@@ -14,6 +14,10 @@ New Features ...@@ -14,6 +14,10 @@ New Features
database's undo method multiple times in the same transaction now database's undo method multiple times in the same transaction now
raises an exception. raises an exception.
- The ZEO protocol for undo has changed. The only user-visible
consequence of this is that when ZODB 3.10 ZEO servers won't support
undo for older clients.
- The storage API (IStorage) has been tightened. Now, storages should - The storage API (IStorage) has been tightened. Now, storages should
raise a StorageTransactionError when invalid transactions are passed raise a StorageTransactionError when invalid transactions are passed
to tpc_begin, tpc_vote, or tpc_finish. to tpc_begin, tpc_vote, or tpc_finish.
......
...@@ -1198,14 +1198,19 @@ class ClientStorage(object): ...@@ -1198,14 +1198,19 @@ class ClientStorage(object):
if self._cache is None: if self._cache is None:
return return
for oid, data in self._tbuf: for oid, _ in self._seriald.iteritems():
self._cache.invalidate(oid, tid, False) self._cache.invalidate(oid, tid, False)
for oid, data in self._tbuf:
# If data is None, we just invalidate. # If data is None, we just invalidate.
if data is not None: if data is not None:
s = self._seriald[oid] s = self._seriald[oid]
if s != ResolvedSerial: if s != ResolvedSerial:
assert s == tid, (s, tid) assert s == tid, (s, tid)
self._cache.store(oid, s, None, data) self._cache.store(oid, s, None, data)
else:
# object deletion
self._cache.invalidate(oid, tid, False)
if self.fshelper is not None: if self.fshelper is not None:
blobs = self._tbuf.blobs blobs = self._tbuf.blobs
...@@ -1241,10 +1246,7 @@ class ClientStorage(object): ...@@ -1241,10 +1246,7 @@ class ClientStorage(object):
""" """
self._check_trans(txn) self._check_trans(txn)
tid, oids = self._server.undo(trans_id, id(txn)) self._server.undoa(trans_id, id(txn))
for oid in oids:
self._tbuf.invalidate(oid)
return tid, oids
def undoInfo(self, first=0, last=-20, specification=None): def undoInfo(self, first=0, last=-20, specification=None):
"""Storage API: return undo information.""" """Storage API: return undo information."""
......
...@@ -272,8 +272,8 @@ class StorageServer: ...@@ -272,8 +272,8 @@ class StorageServer:
def new_oid(self): def new_oid(self):
return self.rpc.call('new_oid') return self.rpc.call('new_oid')
def undo(self, trans_id, trans): def undoa(self, trans_id, trans):
return self.rpc.call('undo', trans_id, trans) self.rpc.callAsync('undoa', trans_id, trans)
def undoLog(self, first, last): def undoLog(self, first, last):
return self.rpc.call('undoLog', first, last) return self.rpc.call('undoLog', first, last)
......
This diff is collapsed.
...@@ -37,14 +37,11 @@ class TransUndoStorageWithCache: ...@@ -37,14 +37,11 @@ class TransUndoStorageWithCache:
# Now start an undo transaction # Now start an undo transaction
t = Transaction() t = Transaction()
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
# Make sure this doesn't load invalid data into the cache # Make sure this doesn't load invalid data into the cache
self._storage.load(oid, '') self._storage.load(oid, '')
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
assert len(oids) == 1 assert len(oids) == 1
......
...@@ -181,64 +181,3 @@ class CommitLockVoteTests(CommitLockTests): ...@@ -181,64 +181,3 @@ class CommitLockVoteTests(CommitLockTests):
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
class CommitLockUndoTests(CommitLockTests):
def _get_trans_id(self):
self._dostore()
L = self._storage.undoInfo()
return L[0]['id']
def _begin_undo(self, trans_id, txn):
rpc = self._storage._server.rpc
return rpc._deferred_call('undo', trans_id, id(txn))
def _finish_undo(self, msgid):
return self._storage._server.rpc._deferred_wait(msgid)
def checkCommitLockUndoFinish(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.tpc_finish(txn)
self._storage.load(oid, '')
self._finish_threads()
self._dostore()
self._cleanup()
def checkCommitLockUndoAbort(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.tpc_abort(txn)
self._finish_threads()
self._dostore()
self._cleanup()
def checkCommitLockUndoClose(self):
trans_id = self._get_trans_id()
oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id, txn)
self._begin_threads()
self._finish_undo(msgid)
self._storage.tpc_vote(txn)
self._storage.close()
self._finish_threads()
self._cleanup()
...@@ -318,9 +318,9 @@ class InvalidationTests: ...@@ -318,9 +318,9 @@ class InvalidationTests:
# tearDown then immediately, but if other threads are still # tearDown then immediately, but if other threads are still
# running that can lead to a cascade of spurious exceptions. # running that can lead to a cascade of spurious exceptions.
for t in threads: for t in threads:
t.join(10) t.join(30)
for t in threads: for t in threads:
t.cleanup() t.cleanup(10)
def checkConcurrentUpdates2Storages_emulated(self): def checkConcurrentUpdates2Storages_emulated(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
...@@ -378,6 +378,34 @@ class InvalidationTests: ...@@ -378,6 +378,34 @@ class InvalidationTests:
db1.close() db1.close()
db2.close() db2.close()
def checkConcurrentUpdates19Storages(self):
n = 19
dbs = [DB(self.openClientStorage()) for i in range(n)]
self._storage = dbs[0].storage
stop = threading.Event()
cn = dbs[0].open()
tree = cn.root()["tree"] = OOBTree()
transaction.commit()
cn.close()
# Run threads that update the BTree
cd = {}
threads = [self.StressThread(dbs[i], stop, i, cd, i, n)
for i in range(n)]
self.go(stop, cd, *threads)
while len(set(db.lastTransaction() for db in dbs)) > 1:
_ = [db._storage.sync() for db in dbs]
cn = dbs[0].open()
tree = cn.root()["tree"]
self._check_tree(cn, tree)
self._check_threads(tree, *threads)
cn.close()
_ = [db.close() for db in dbs]
def checkConcurrentUpdates1Storage(self): def checkConcurrentUpdates1Storage(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
......
...@@ -58,3 +58,7 @@ class Connection: ...@@ -58,3 +58,7 @@ class Connection:
print self.name, 'callAsync', meth, repr(args) print self.name, 'callAsync', meth, repr(args)
callAsyncNoPoll = callAsync callAsyncNoPoll = callAsync
def call_from_thread(self, *args):
if args:
args[0](*args[1:])
...@@ -25,7 +25,6 @@ from ZODB.tests import StorageTestBase, BasicStorage, \ ...@@ -25,7 +25,6 @@ from ZODB.tests import StorageTestBase, BasicStorage, \
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle from ZODB.tests.StorageTestBase import zodb_unpickle
import asyncore
import doctest import doctest
import logging import logging
import os import os
...@@ -244,7 +243,6 @@ class GenericTests( ...@@ -244,7 +243,6 @@ class GenericTests(
class FullGenericTests( class FullGenericTests(
GenericTests, GenericTests,
Cache.TransUndoStorageWithCache, Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockUndoTests,
ConflictResolution.ConflictResolvingStorage, ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage, ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage, PackableStorage.PackableUndoStorage,
...@@ -727,6 +725,10 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests): ...@@ -727,6 +725,10 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
blob_cache_dir = 'blobs' blob_cache_dir = 'blobs'
shared_blob_dir = True shared_blob_dir = True
class FauxConn:
addr = 'x'
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol
class StorageServerClientWrapper: class StorageServerClientWrapper:
def __init__(self): def __init__(self):
...@@ -743,8 +745,8 @@ class StorageServerWrapper: ...@@ -743,8 +745,8 @@ class StorageServerWrapper:
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
self.storage_id = storage_id self.storage_id = storage_id
self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only) self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only)
self.server.notifyConnected(FauxConn())
self.server.register(storage_id, False) self.server.register(storage_id, False)
self.server._thunk = lambda : None
self.server.client = StorageServerClientWrapper() self.server.client = StorageServerClientWrapper()
def sortKey(self): def sortKey(self):
...@@ -766,8 +768,7 @@ class StorageServerWrapper: ...@@ -766,8 +768,7 @@ class StorageServerWrapper:
self.server.tpc_begin(id(transaction), '', '', {}, None, ' ') self.server.tpc_begin(id(transaction), '', '', {}, None, ' ')
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
self.server._restart() assert self.server.vote(id(transaction)) is None
self.server.vote(id(transaction))
result = self.server.client.serials[:] result = self.server.client.serials[:]
del self.server.client.serials[:] del self.server.client.serials[:]
return result return result
...@@ -775,8 +776,11 @@ class StorageServerWrapper: ...@@ -775,8 +776,11 @@ class StorageServerWrapper:
def store(self, oid, serial, data, version_ignored, transaction): def store(self, oid, serial, data, version_ignored, transaction):
self.server.storea(oid, serial, data, id(transaction)) self.server.storea(oid, serial, data, id(transaction))
def send_reply(self, *args): # Masquerade as conn
pass
def tpc_finish(self, transaction, func = lambda: None): def tpc_finish(self, transaction, func = lambda: None):
self.server.tpc_finish(id(transaction)) self.server.tpc_finish(id(transaction)).set_sender(0, self)
def multiple_storages_invalidation_queue_is_not_insane(): def multiple_storages_invalidation_queue_is_not_insane():
...@@ -849,6 +853,7 @@ Now we'll open a storage server on the data, simulating a restart: ...@@ -849,6 +853,7 @@ Now we'll open a storage server on the data, simulating a restart:
>>> fs = FileStorage('t.fs') >>> fs = FileStorage('t.fs')
>>> sv = StorageServer(('', get_port()), dict(fs=fs)) >>> sv = StorageServer(('', get_port()), dict(fs=fs))
>>> s = ZEOStorage(sv, sv.read_only) >>> s = ZEOStorage(sv, sv.read_only)
>>> s.notifyConnected(FauxConn())
>>> s.register('fs', False) >>> s.register('fs', False)
If we ask for the last transaction, we should get the last transaction If we ask for the last transaction, we should get the last transaction
...@@ -941,7 +946,7 @@ def tpc_finish_error(): ...@@ -941,7 +946,7 @@ def tpc_finish_error():
... def close(self): ... def close(self):
... print 'connection closed' ... print 'connection closed'
... trigger = property(lambda self: self) ... trigger = property(lambda self: self)
... pull_trigger = lambda self, func: func() ... pull_trigger = lambda self, func, *args: func(*args)
>>> class ConnectionManager: >>> class ConnectionManager:
... def __init__(self, addr, client, tmin, tmax): ... def __init__(self, addr, client, tmin, tmax):
...@@ -1251,6 +1256,8 @@ Invalidations could cause errors when closing client storages, ...@@ -1251,6 +1256,8 @@ Invalidations could cause errors when closing client storages,
>>> thread.join(1) >>> thread.join(1)
""" """
if sys.version_info >= (2, 6): if sys.version_info >= (2, 6):
import multiprocessing import multiprocessing
...@@ -1259,28 +1266,32 @@ if sys.version_info >= (2, 6): ...@@ -1259,28 +1266,32 @@ if sys.version_info >= (2, 6):
q.put((name, conn.root.x)) q.put((name, conn.root.x))
conn.close() conn.close()
def work_with_multiprocessing(): class MultiprocessingTests(unittest.TestCase):
"""Client storage should work with multi-processing.
def test_work_with_multiprocessing(self):
>>> import StringIO "Client storage should work with multi-processing."
>>> sys.stdin = StringIO.StringIO()
>>> addr, _ = start_server() self.globs = {}
>>> conn = ZEO.connection(addr) forker.setUp(self)
>>> conn.root.x = 1 addr, adminaddr = self.globs['start_server']()
>>> transaction.commit() conn = ZEO.connection(addr)
>>> q = multiprocessing.Queue() conn.root.x = 1
>>> processes = [multiprocessing.Process( transaction.commit()
... target=work_with_multiprocessing_process, q = multiprocessing.Queue()
... args=(i, addr, q)) processes = [multiprocessing.Process(
... for i in range(3)] target=work_with_multiprocessing_process,
>>> _ = [p.start() for p in processes] args=(i, addr, q))
>>> sorted(q.get(timeout=60) for p in processes) for i in range(3)]
[(0, 1), (1, 1), (2, 1)] _ = [p.start() for p in processes]
self.assertEqual(sorted(q.get(timeout=300) for p in processes),
>>> _ = [p.join(30) for p in processes] [(0, 1), (1, 1), (2, 1)])
>>> conn.close()
""" _ = [p.join(30) for p in processes]
conn.close()
zope.testing.setupstack.tearDown(self)
else:
class MultiprocessingTests(unittest.TestCase):
pass
slow_test_classes = [ slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests, BlobAdaptedFileStorageTests, BlobWritableCacheTests,
...@@ -1353,6 +1364,7 @@ def test_suite(): ...@@ -1353,6 +1364,7 @@ def test_suite():
# unit test layer # unit test layer
zeo = unittest.TestSuite() zeo = unittest.TestSuite()
zeo.addTest(unittest.makeSuite(ZODB.tests.util.AAAA_Test_Runner_Hack)) zeo.addTest(unittest.makeSuite(ZODB.tests.util.AAAA_Test_Runner_Hack))
zeo.addTest(unittest.makeSuite(MultiprocessingTests))
zeo.addTest(doctest.DocTestSuite( zeo.addTest(doctest.DocTestSuite(
setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown)) setUp=forker.setUp, tearDown=zope.testing.setupstack.tearDown))
zeo.addTest(doctest.DocTestSuite(ZEO.tests.IterationTests, zeo.addTest(doctest.DocTestSuite(ZEO.tests.IterationTests,
......
...@@ -93,9 +93,9 @@ client will be restarted. It will get a conflict error, that is ...@@ -93,9 +93,9 @@ client will be restarted. It will get a conflict error, that is
handled correctly: handled correctly:
>>> zs1.tpc_abort('0') # doctest: +ELLIPSIS >>> zs1.tpc_abort('0') # doctest: +ELLIPSIS
(511/test-addr) ('1') unlock: transactions waiting: 0
2 callAsync serialnos ... 2 callAsync serialnos ...
reply 1 None reply 1 None
(511/test-addr) Blocked transaction restarted.
>>> fs.tpc_transaction() is not None >>> fs.tpc_transaction() is not None
True True
......
...@@ -55,6 +55,16 @@ class Delay: ...@@ -55,6 +55,16 @@ class Delay:
log("Error raised in delayed method", logging.ERROR, exc_info=True) log("Error raised in delayed method", logging.ERROR, exc_info=True)
self.conn.return_error(self.msgid, *exc_info[:2]) self.conn.return_error(self.msgid, *exc_info[:2])
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, conn):
reply, callback = self.args
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -218,18 +228,25 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -218,18 +228,25 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# restorea, iterator_start, iterator_next, # restorea, iterator_start, iterator_next,
# iterator_record_start, iterator_record_next, # iterator_record_start, iterator_record_next,
# iterator_gc # iterator_gc
#
# Z310 -- named after the ZODB release 3.10
# New server methods:
# undoa
# Doesn't support undo for older clients.
# Undone oid info returned by vote.
# Protocol variables: # Protocol variables:
# Our preferred protocol. # Our preferred protocol.
current_protocol = "Z309" current_protocol = "Z310"
# If we're a client, an exhaustive list of the server protocols we # If we're a client, an exhaustive list of the server protocols we
# can accept. # can accept.
servers_we_can_talk_to = ["Z308", current_protocol] servers_we_can_talk_to = ["Z308", "Z309", current_protocol]
# If we're a server, an exhaustive list of the client protocols we # If we're a server, an exhaustive list of the client protocols we
# can accept. # can accept.
clients_we_can_talk_to = ["Z200", "Z201", "Z303", "Z308", current_protocol] clients_we_can_talk_to = [
"Z200", "Z201", "Z303", "Z308", "Z309", current_protocol]
# This is pretty excruciating. Details: # This is pretty excruciating. Details:
# #
......
...@@ -666,32 +666,11 @@ class Connection(ExportImport, object): ...@@ -666,32 +666,11 @@ class Connection(ExportImport, object):
self._cache.update_object_size_estimation(oid, len(p)) self._cache.update_object_size_estimation(oid, len(p))
obj._p_estimated_size = len(p) obj._p_estimated_size = len(p)
self._handle_serial(s, oid) self._handle_serial(oid, s)
def _handle_serial(self, store_return, oid=None, change=1): def _handle_serial(self, oid, serial, change=True):
"""Handle the returns from store() and tpc_vote() calls.""" if not serial:
# These calls can return different types depending on whether
# ZEO is used. ZEO uses asynchronous returns that may be
# returned in batches by the ClientStorage. ZEO1 can also
# return an exception object and expect that the Connection
# will raise the exception.
# When conflict resolution occurs, the object state held by
# the connection does not match what is written to the
# database. Invalidate the object here to guarantee that
# the new state is read the next time the object is used.
if not store_return:
return return
if isinstance(store_return, str):
assert oid is not None
self._handle_one_serial(oid, store_return, change)
else:
for oid, serial in store_return:
self._handle_one_serial(oid, serial, change)
def _handle_one_serial(self, oid, serial, change):
if not isinstance(serial, str): if not isinstance(serial, str):
raise serial raise serial
obj = self._cache.get(oid, None) obj = self._cache.get(oid, None)
...@@ -757,7 +736,9 @@ class Connection(ExportImport, object): ...@@ -757,7 +736,9 @@ class Connection(ExportImport, object):
except AttributeError: except AttributeError:
return return
s = vote(transaction) s = vote(transaction)
self._handle_serial(s) if s:
for oid, serial in s:
self._handle_serial(oid, serial)
def tpc_finish(self, transaction): def tpc_finish(self, transaction):
"""Indicate confirmation that the transaction is done.""" """Indicate confirmation that the transaction is done."""
...@@ -1171,7 +1152,7 @@ class Connection(ExportImport, object): ...@@ -1171,7 +1152,7 @@ class Connection(ExportImport, object):
s = self._storage.store(oid, serial, data, s = self._storage.store(oid, serial, data,
'', transaction) '', transaction)
self._handle_serial(s, oid, change=False) self._handle_serial(oid, s, change=False)
src.close() src.close()
def _abort_savepoint(self): def _abort_savepoint(self):
......
...@@ -158,6 +158,7 @@ class ConflictResolvingTransUndoStorage: ...@@ -158,6 +158,7 @@ class ConflictResolvingTransUndoStorage:
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
self._storage.undo(tid, t) self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
def checkUndoUnresolvable(self): def checkUndoUnresolvable(self):
...@@ -177,7 +178,5 @@ class ConflictResolvingTransUndoStorage: ...@@ -177,7 +178,5 @@ class ConflictResolvingTransUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(UndoError, self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
...@@ -122,7 +122,7 @@ class RevisionStorage: ...@@ -122,7 +122,7 @@ class RevisionStorage:
tid = info[0]["id"] tid = info[0]["id"]
# Always undo the most recent txn, so the value will # Always undo the most recent txn, so the value will
# alternate between 3 and 4. # alternate between 3 and 4.
self._undo(tid, [oid], note="undo %d" % i) self._undo(tid, note="undo %d" % i)
revs.append(self._storage.load(oid, "")) revs.append(self._storage.load(oid, ""))
prev_tid = None prev_tid = None
......
...@@ -209,10 +209,12 @@ class StorageTestBase(ZODB.tests.util.TestCase): ...@@ -209,10 +209,12 @@ class StorageTestBase(ZODB.tests.util.TestCase):
t = transaction.Transaction() t = transaction.Transaction()
t.note(note or "undo") t.note(note or "undo")
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
tid, oids = self._storage.undo(tid, t) undo_result = self._storage.undo(tid, t)
self._storage.tpc_vote(t) vote_result = self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
if expected_oids is not None: if expected_oids is not None:
oids = undo_result and undo_result[1] or []
oids.extend(oid for (oid, _) in vote_result or ())
self.assertEqual(len(oids), len(expected_oids), repr(oids)) self.assertEqual(len(oids), len(expected_oids), repr(oids))
for oid in expected_oids: for oid in expected_oids:
self.assert_(oid in oids) self.assert_(oid in oids)
......
...@@ -101,12 +101,20 @@ class TransactionalUndoStorage: ...@@ -101,12 +101,20 @@ class TransactionalUndoStorage:
for rec in txn: for rec in txn:
pass pass
def _begin_undos_vote(self, t, *tids):
self._storage.tpc_begin(t)
oids = []
for tid in tids:
undo_result = self._storage.undo(tid, t)
if undo_result:
oids.extend(undo_result[1])
oids.extend(oid for (oid, _) in self._storage.tpc_vote(t) or ())
return oids
def undo(self, tid, note): def undo(self, tid, note):
t = Transaction() t = Transaction()
t.note(note) t.note(note)
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
return oids return oids
...@@ -152,9 +160,7 @@ class TransactionalUndoStorage: ...@@ -152,9 +160,7 @@ class TransactionalUndoStorage:
tid = info[0]['id'] tid = info[0]['id']
t = Transaction() t = Transaction()
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) self._begin_undos_vote(t, tid)
self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
# Check that calling getTid on an uncreated object raises a KeyError # Check that calling getTid on an uncreated object raises a KeyError
# The current version of FileStorage fails this test # The current version of FileStorage fails this test
...@@ -281,14 +287,10 @@ class TransactionalUndoStorage: ...@@ -281,14 +287,10 @@ class TransactionalUndoStorage:
tid = info[0]['id'] tid = info[0]['id']
tid1 = info[1]['id'] tid1 = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid, tid1)
tid, oids = self._storage.undo(tid, t)
tid, oids1 = self._storage.undo(tid1, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
# We get the finalization stuff called an extra time: # We get the finalization stuff called an extra time:
eq(len(oids), 2) eq(len(oids), 4)
eq(len(oids1), 2)
unless(oid1 in oids) unless(oid1 in oids)
unless(oid2 in oids) unless(oid2 in oids)
data, revid1 = self._storage.load(oid1, '') data, revid1 = self._storage.load(oid1, '')
...@@ -355,9 +357,7 @@ class TransactionalUndoStorage: ...@@ -355,9 +357,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
self.failUnless(oid1 in oids) self.failUnless(oid1 in oids)
...@@ -368,7 +368,6 @@ class TransactionalUndoStorage: ...@@ -368,7 +368,6 @@ class TransactionalUndoStorage:
eq(zodb_unpickle(data), MinPO(54)) eq(zodb_unpickle(data), MinPO(54))
self._iterate() self._iterate()
def checkNotUndoable(self): def checkNotUndoable(self):
eq = self.assertEqual eq = self.assertEqual
# Set things up so we've got a transaction that can't be undone # Set things up so we've got a transaction that can't be undone
...@@ -380,10 +379,7 @@ class TransactionalUndoStorage: ...@@ -380,10 +379,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(POSException.UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(POSException.UndoError,
self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
# Now have more fun: object1 and object2 are in the same transaction, # Now have more fun: object1 and object2 are in the same transaction,
# which we'll try to undo to, but one of them has since modified in # which we'll try to undo to, but one of them has since modified in
...@@ -419,10 +415,7 @@ class TransactionalUndoStorage: ...@@ -419,10 +415,7 @@ class TransactionalUndoStorage:
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self.assertRaises(POSException.UndoError, self._begin_undos_vote, t, tid)
self.assertRaises(POSException.UndoError,
self._storage.undo,
tid, t)
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
self._iterate() self._iterate()
...@@ -439,7 +432,7 @@ class TransactionalUndoStorage: ...@@ -439,7 +432,7 @@ class TransactionalUndoStorage:
# So, basically, this makes sure that undo info doesn't depend # So, basically, this makes sure that undo info doesn't depend
# on file positions. We change the file positions in an undo # on file positions. We change the file positions in an undo
# record by packing. # record by packing.
# Add a few object revisions # Add a few object revisions
oid = '\0'*8 oid = '\0'*8
revid0 = self._dostore(oid, data=MinPO(50)) revid0 = self._dostore(oid, data=MinPO(50))
...@@ -462,9 +455,7 @@ class TransactionalUndoStorage: ...@@ -462,9 +455,7 @@ class TransactionalUndoStorage:
self.assertEqual(len(info2), 2) self.assertEqual(len(info2), 2)
# And now attempt to undo the last transaction # And now attempt to undo the last transaction
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) oids = self._begin_undos_vote(t, tid)
tid, oids = self._storage.undo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
self.assertEqual(len(oids), 1) self.assertEqual(len(oids), 1)
self.assertEqual(oids[0], oid) self.assertEqual(oids[0], oid)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment