Commit b7e88bab authored by Jeremy Hylton's avatar Jeremy Hylton

Merge MVCC branch to the HEAD.

parent 3e29b5b6
This diff is collapsed.
This diff is collapsed.
try:
from Interface import Base
except ImportError:
class Base:
# a dummy interface for use when Zope's is unavailable
pass
class ICache(Base):
"""ZEO client cache.
__init__(storage, size, client, var)
All arguments optional.
storage -- name of storage
size -- max size of cache in bytes
client -- a string; if specified, cache is persistent.
var -- var directory to store cache files in
"""
def open():
"""Returns a sequence of object info tuples.
An object info tuple is a pair containing an object id and a
pair of serialnos, a non-version serialno and a version serialno:
oid, (serial, ver_serial)
This method builds an index of the cache and returns a
sequence used for cache validation.
"""
def close():
"""Closes the cache."""
def verify(func):
"""Call func on every object in cache.
func is called with three arguments
func(oid, serial, ver_serial)
"""
def invalidate(oid, version):
"""Remove object from cache."""
def load(oid, version):
"""Load object from cache.
Return None if object not in cache.
Return data, serialno if object is in cache.
"""
def store(oid, p, s, version, pv, sv):
"""Store a new object in the cache."""
def update(oid, serial, version, data):
"""Update an object already in the cache.
XXX This method is called to update objects that were modified by
a transaction. It's likely that it is already in the cache,
and it may be possible for the implementation to operate more
efficiently.
"""
def modifiedInVersion(oid):
"""Return the version an object is modified in.
'' signifies the trunk.
Returns None if the object is not in the cache.
"""
def checkSize(size):
"""Check if adding size bytes would exceed cache limit.
This method is often called just before store or update. The
size is a hint about the amount of data that is about to be
stored. The cache may want to evict some data to make space.
"""
...@@ -13,6 +13,18 @@ ...@@ -13,6 +13,18 @@
############################################################################## ##############################################################################
"""RPC stubs for interface exported by StorageServer.""" """RPC stubs for interface exported by StorageServer."""
##
# ZEO storage server.
# <p>
# Remote method calls can be synchronous or asynchronous. If the call
# is synchronous, the client thread blocks until the call returns. A
# single client can only have one synchronous request outstanding. If
# several threads share a single client, threads other than the caller
# will block only if the attempt to make another synchronous call.
# An asynchronous call does not cause the client thread to block. An
# exception raised by an asynchronous method is logged on the server,
# but is not returned to the client.
class StorageServer: class StorageServer:
"""An RPC stub class for the interface exported by ClientStorage. """An RPC stub class for the interface exported by ClientStorage.
...@@ -43,47 +55,174 @@ class StorageServer: ...@@ -43,47 +55,174 @@ class StorageServer:
def extensionMethod(self, name): def extensionMethod(self, name):
return ExtensionMethodWrapper(self.rpc, name).call return ExtensionMethodWrapper(self.rpc, name).call
##
# Register current connection with a storage and a mode.
# In effect, it is like an open call.
# @param storage_name a string naming the storage. This argument
# is primarily for backwards compatibility with servers
# that supported multiple storages.
# @param read_only boolean
# @exception ValueError unknown storage_name or already registered
# @exception ReadOnlyError storage is read-only and a read-write
# connectio was requested
def register(self, storage_name, read_only): def register(self, storage_name, read_only):
self.rpc.call('register', storage_name, read_only) self.rpc.call('register', storage_name, read_only)
##
# Return dictionary of meta-data about the storage.
# @defreturn dict
def get_info(self): def get_info(self):
return self.rpc.call('get_info') return self.rpc.call('get_info')
##
# Check whether the server requires authentication. Returns
# the name of the protocol.
# @defreturn string
def getAuthProtocol(self): def getAuthProtocol(self):
return self.rpc.call('getAuthProtocol') return self.rpc.call('getAuthProtocol')
##
# Return id of the last committed transaction
# @defreturn string
def lastTransaction(self): def lastTransaction(self):
# Not in protocol version 2.0.0; see __init__() # Not in protocol version 2.0.0; see __init__()
return self.rpc.call('lastTransaction') return self.rpc.call('lastTransaction')
##
# Return invalidations for all transactions after tid.
# @param tid transaction id
# @defreturn 2-tuple, (tid, list)
# @return tuple containing the last committed transaction
# and a list of oids that were invalidated. Returns
# None and an empty list if the server does not have
# the list of oids available.
def getInvalidations(self, tid): def getInvalidations(self, tid):
# Not in protocol version 2.0.0; see __init__() # Not in protocol version 2.0.0; see __init__()
return self.rpc.call('getInvalidations', tid) return self.rpc.call('getInvalidations', tid)
##
# Check whether serial numbers s and sv are current for oid.
# If one or both of the serial numbers are not current, the
# server will make an asynchronous invalidateVerify() call.
# @param oid object id
# @param s serial number on non-version data
# @param sv serial number of version data or None
# @defreturn async
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
self.rpc.callAsync('zeoVerify', oid, s, sv) self.rpc.callAsync('zeoVerify', oid, s, sv)
##
# Check whether current serial number is valid for oid and version.
# If the serial number is not current, the server will make an
# asynchronous invalidateVerify() call.
# @param oid object id
# @param version name of version for oid
# @param serial client's current serial number
# @defreturn async
def verify(self, oid, version, serial):
self.rpc.callAsync('verify', oid, version, serial)
##
# Signal to the server that cache verification is done.
# @defreturn async
def endZeoVerify(self): def endZeoVerify(self):
self.rpc.callAsync('endZeoVerify') self.rpc.callAsync('endZeoVerify')
##
# Generate a new set of oids.
# @param n number of new oids to return
# @defreturn list
# @return list of oids
def new_oids(self, n=None): def new_oids(self, n=None):
if n is None: if n is None:
return self.rpc.call('new_oids') return self.rpc.call('new_oids')
else: else:
return self.rpc.call('new_oids', n) return self.rpc.call('new_oids', n)
##
# Pack the storage.
# @param t pack time
# @param wait optional, boolean. If true, the call will not
# return until the pack is complete.
def pack(self, t, wait=None): def pack(self, t, wait=None):
if wait is None: if wait is None:
self.rpc.call('pack', t) self.rpc.call('pack', t)
else: else:
self.rpc.call('pack', t, wait) self.rpc.call('pack', t, wait)
##
# Return current data for oid. Version data is returned if
# present.
# @param oid object id
# @defreturn 5-tuple
# @return 5-tuple, current non-version data, serial number,
# version name, version data, version data serial number
# @exception KeyError if oid is not found
def zeoLoad(self, oid): def zeoLoad(self, oid):
return self.rpc.call('zeoLoad', oid) return self.rpc.call('zeoLoad', oid)
##
# Return current data for oid along with tid if transaction that
# wrote the date.
# @param oid object id
# @param version string, name of version
# @defreturn 4-tuple
# @return data, serial number, transaction id, version,
# where version is the name of the version the data came
# from or "" for non-version data
# @exception KeyError if oid is not found
def loadEx(self, oid, version):
return self.rpc.call("loadEx", oid, version)
##
# Return non-current data along with transaction ids that identify
# the lifetime of the specific revision.
# @param oid object id
# @param tid a transaction id that provides an upper bound on
# the lifetime of the revision. That is, loadBefore
# returns the revision that was current before tid committed.
# @defreturn 4-tuple
# @return data, serial numbr, start transaction id, end transaction id
def loadBefore(self, oid, tid):
return self.rpc.call("loadBefore", oid, tid)
##
# Storage new revision of oid.
# @param oid object id
# @param serial serial number that this transaction read
# @param data new data record for oid
# @param version name of version or ""
# @param id id of current transaction
# @defreturn async
def storea(self, oid, serial, data, version, id): def storea(self, oid, serial, data, version, id):
self.rpc.callAsync('storea', oid, serial, data, version, id) self.rpc.callAsync('storea', oid, serial, data, version, id)
##
# Start two-phase commit for a transaction
# @param id id used by client to identify current transaction. The
# only purpose of this argument is to distinguish among multiple
# threads using a single ClientStorage.
# @param user name of user committing transaction (can be "")
# @param description string containing transaction metadata (can be "")
# @param ext dictionary of extended metadata (?)
# @param tid optional explicit tid to pass to underlying storage
# @param status optional status character, e.g "p" for pack
# @defreturn async
def tpc_begin(self, id, user, descr, ext, tid, status): def tpc_begin(self, id, user, descr, ext, tid, status):
return self.rpc.call('tpc_begin', id, user, descr, ext, tid, status) return self.rpc.call('tpc_begin', id, user, descr, ext, tid, status)
......
...@@ -235,6 +235,14 @@ class ZEOStorage: ...@@ -235,6 +235,14 @@ class ZEOStorage:
def getExtensionMethods(self): def getExtensionMethods(self):
return self._extensions return self._extensions
def loadEx(self, oid, version):
self.stats.loads += 1
return self.storage.loadEx(oid, version)
def loadBefore(self, oid, tid):
self.stats.loads += 1
return self.storage.loadBefore(oid, tid)
def zeoLoad(self, oid): def zeoLoad(self, oid):
self.stats.loads += 1 self.stats.loads += 1
v = self.storage.modifiedInVersion(oid) v = self.storage.modifiedInVersion(oid)
...@@ -260,12 +268,26 @@ class ZEOStorage: ...@@ -260,12 +268,26 @@ class ZEOStorage:
% (len(invlist), u64(invtid))) % (len(invlist), u64(invtid)))
return invtid, invlist return invtid, invlist
def verify(self, oid, version, tid):
try:
t = self.storage.getTid(oid)
except KeyError:
self.client.invalidateVerify((oid, ""))
else:
if tid != t:
# This will invalidate non-version data when the
# client only has invalid version data. Since this is
# an uncommon case, we avoid the cost of checking
# whether the serial number matches the current
# non-version data.
self.client.invalidateVerify((oid, version))
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
if not self.verifying: if not self.verifying:
self.verifying = 1 self.verifying = 1
self.stats.verifying_clients += 1 self.stats.verifying_clients += 1
try: try:
os = self.storage.getSerial(oid) os = self.storage.getTid(oid)
except KeyError: except KeyError:
self.client.invalidateVerify((oid, '')) self.client.invalidateVerify((oid, ''))
# XXX It's not clear what we should do now. The KeyError # XXX It's not clear what we should do now. The KeyError
...@@ -344,7 +366,7 @@ class ZEOStorage: ...@@ -344,7 +366,7 @@ class ZEOStorage:
def undoLog(self, first, last): def undoLog(self, first, last):
return run_in_thread(self.storage.undoLog, first, last) return run_in_thread(self.storage.undoLog, first, last)
def tpc_begin(self, id, user, description, ext, tid, status): def tpc_begin(self, id, user, description, ext, tid=None, status=" "):
if self.read_only: if self.read_only:
raise ReadOnlyError() raise ReadOnlyError()
if self.transaction is not None: if self.transaction is not None:
...@@ -521,25 +543,25 @@ class ZEOStorage: ...@@ -521,25 +543,25 @@ class ZEOStorage:
return self.storage.tpc_vote(self.transaction) return self.storage.tpc_vote(self.transaction)
def _abortVersion(self, src): def _abortVersion(self, src):
oids = self.storage.abortVersion(src, self.transaction) tid, oids = self.storage.abortVersion(src, self.transaction)
inv = [(oid, src) for oid in oids] inv = [(oid, src) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
def _commitVersion(self, src, dest): def _commitVersion(self, src, dest):
oids = self.storage.commitVersion(src, dest, self.transaction) tid, oids = self.storage.commitVersion(src, dest, self.transaction)
inv = [(oid, dest) for oid in oids] inv = [(oid, dest) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
if dest: if dest:
inv = [(oid, src) for oid in oids] inv = [(oid, src) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
def _transactionalUndo(self, trans_id): def _transactionalUndo(self, trans_id):
oids = self.storage.transactionalUndo(trans_id, self.transaction) tid, oids = self.storage.transactionalUndo(trans_id, self.transaction)
inv = [(oid, None) for oid in oids] inv = [(oid, None) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
# When a delayed transaction is restarted, the dance is # When a delayed transaction is restarted, the dance is
# complicated. The restart occurs when one ZEOStorage instance # complicated. The restart occurs when one ZEOStorage instance
...@@ -854,6 +876,9 @@ class StorageServer: ...@@ -854,6 +876,9 @@ class StorageServer:
log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid))) log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid)))
return None, [] return None, []
# XXX this is wrong! must check against tid or we invalidate
# too much.
oids = {} oids = {}
for tid, L in self.invq: for tid, L in self.invq:
for key in L: for key in L:
......
This diff is collapsed.
...@@ -128,15 +128,21 @@ def main(): ...@@ -128,15 +128,21 @@ def main():
# Read file, gathering statistics, and printing each record if verbose # Read file, gathering statistics, and printing each record if verbose
rt0 = time.time() rt0 = time.time()
# bycode -- map code to count of occurrences
bycode = {} bycode = {}
# records -- number of records
records = 0 records = 0
# version -- number of records with versions
versions = 0 versions = 0
t0 = te = None t0 = te = None
# datarecords -- number of records with dlen set
datarecords = 0 datarecords = 0
datasize = 0L datasize = 0L
file0 = file1 = 0 # oids -- maps oid to number of times it was loaded
oids = {} oids = {}
# bysize -- maps data size to number of loads
bysize = {} bysize = {}
# bysize -- maps data size to number of writes
bysizew = {} bysizew = {}
total_loads = 0 total_loads = 0
byinterval = {} byinterval = {}
...@@ -157,12 +163,12 @@ def main(): ...@@ -157,12 +163,12 @@ def main():
if not quiet: if not quiet:
print "Skipping 8 bytes at offset", offset-8 print "Skipping 8 bytes at offset", offset-8
continue continue
r = f_read(10) r = f_read(18)
if len(r) < 10: if len(r) < 10:
break break
offset += 10 offset += 10
records += 1 records += 1
oidlen, serial = struct_unpack(">H8s", r) oidlen, start_tid, end_tid = struct_unpack(">H8s8s", r)
oid = f_read(oidlen) oid = f_read(oidlen)
if len(oid) != oidlen: if len(oid) != oidlen:
break break
...@@ -187,11 +193,6 @@ def main(): ...@@ -187,11 +193,6 @@ def main():
if code & 0x80: if code & 0x80:
version = 'V' version = 'V'
versions += 1 versions += 1
current = code & 1
if current:
file1 += 1
else:
file0 += 1
code = code & 0x7e code = code & 0x7e
bycode[code] = bycode.get(code, 0) + 1 bycode[code] = bycode.get(code, 0) + 1
byinterval[code] = byinterval.get(code, 0) + 1 byinterval[code] = byinterval.get(code, 0) + 1
...@@ -199,22 +200,23 @@ def main(): ...@@ -199,22 +200,23 @@ def main():
if code & 0x70 == 0x20: # All loads if code & 0x70 == 0x20: # All loads
bysize[dlen] = d = bysize.get(dlen) or {} bysize[dlen] = d = bysize.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
elif code == 0x3A: # Update elif code & 0x70 == 0x50: # All stores
bysizew[dlen] = d = bysizew.get(dlen) or {} bysizew[dlen] = d = bysizew.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
if verbose: if verbose:
print "%s %d %02x %s %016x %1s %s" % ( print "%s %d %02x %s %016x %016x %1s %s" % (
time.ctime(ts)[4:-5], time.ctime(ts)[4:-5],
current, current,
code, code,
oid_repr(oid), oid_repr(oid),
U64(serial), U64(start_tid),
U64(end_tid),
version, version,
dlen and str(dlen) or "") dlen and str(dlen) or "")
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
oids[oid] = oids.get(oid, 0) + 1 oids[oid] = oids.get(oid, 0) + 1
total_loads += 1 total_loads += 1
if code in (0x00, 0x70): if code == 0x00:
if not quiet: if not quiet:
dumpbyinterval(byinterval, h0, he) dumpbyinterval(byinterval, h0, he)
byinterval = {} byinterval = {}
...@@ -222,10 +224,7 @@ def main(): ...@@ -222,10 +224,7 @@ def main():
h0 = he = ts h0 = he = ts
if not quiet: if not quiet:
print time.ctime(ts)[4:-5], print time.ctime(ts)[4:-5],
if code == 0x00:
print '='*20, "Restart", '='*20 print '='*20, "Restart", '='*20
else:
print '-'*20, "Flip->%d" % current, '-'*20
except KeyboardInterrupt: except KeyboardInterrupt:
print "\nInterrupted. Stats so far:\n" print "\nInterrupted. Stats so far:\n"
...@@ -248,8 +247,6 @@ def main(): ...@@ -248,8 +247,6 @@ def main():
print "First time: %s" % time.ctime(t0) print "First time: %s" % time.ctime(t0)
print "Last time: %s" % time.ctime(te) print "Last time: %s" % time.ctime(te)
print "Duration: %s seconds" % addcommas(te-t0) print "Duration: %s seconds" % addcommas(te-t0)
print "File stats: %s in file 0; %s in file 1" % (
addcommas(file0), addcommas(file1))
print "Data recs: %s (%.1f%%), average size %.1f KB" % ( print "Data recs: %s (%.1f%%), average size %.1f KB" % (
addcommas(datarecords), addcommas(datarecords),
100.0 * datarecords / records, 100.0 * datarecords / records,
...@@ -314,7 +311,7 @@ def dumpbyinterval(byinterval, h0, he): ...@@ -314,7 +311,7 @@ def dumpbyinterval(byinterval, h0, he):
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
n = byinterval[code] n = byinterval[code]
loads += n loads += n
if code in (0x2A, 0x2C, 0x2E): if code in (0x22, 0x26):
hits += n hits += n
if not loads: if not loads:
return return
...@@ -333,7 +330,7 @@ def hitrate(bycode): ...@@ -333,7 +330,7 @@ def hitrate(bycode):
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
n = bycode[code] n = bycode[code]
loads += n loads += n
if code in (0x2A, 0x2C, 0x2E): if code in (0x22, 0x26):
hits += n hits += n
if loads: if loads:
return 100.0 * hits / loads return 100.0 * hits / loads
...@@ -376,31 +373,18 @@ explain = { ...@@ -376,31 +373,18 @@ explain = {
0x00: "_setup_trace (initialization)", 0x00: "_setup_trace (initialization)",
0x10: "invalidate (miss)", 0x10: "invalidate (miss)",
0x1A: "invalidate (hit, version, writing 'n')", 0x1A: "invalidate (hit, version)",
0x1C: "invalidate (hit, writing 'i')", 0x1C: "invalidate (hit, saving non-current)",
0x20: "load (miss)", 0x20: "load (miss)",
0x22: "load (miss, version, status 'n')", 0x22: "load (hit)",
0x24: "load (miss, deleting index entry)", 0x24: "load (non-current, miss)",
0x26: "load (miss, no non-version data)", 0x26: "load (non-current, hit)",
0x28: "load (miss, version mismatch, no non-version data)",
0x2A: "load (hit, returning non-version data)",
0x2C: "load (hit, version mismatch, returning non-version data)",
0x2E: "load (hit, returning version data)",
0x3A: "update",
0x40: "modifiedInVersion (miss)",
0x4A: "modifiedInVersion (hit, return None, status 'n')",
0x4C: "modifiedInVersion (hit, return '')",
0x4E: "modifiedInVersion (hit, return version)",
0x5A: "store (non-version data present)",
0x5C: "store (only version data present)",
0x6A: "_copytocurrent", 0x50: "store (version)",
0x52: "store (current, non-version)",
0x54: "store (non-current)",
0x70: "checkSize (cache flip)",
} }
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -42,7 +42,7 @@ class TransUndoStorageWithCache: ...@@ -42,7 +42,7 @@ class TransUndoStorageWithCache:
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
# Make sure this doesn't load invalid data into the cache # Make sure this doesn't load invalid data into the cache
self._storage.load(oid, '') self._storage.load(oid, '')
......
...@@ -71,7 +71,7 @@ class WorkerThread(TestThread): ...@@ -71,7 +71,7 @@ class WorkerThread(TestThread):
# self.storage.tpc_vote(self.trans) # self.storage.tpc_vote(self.trans)
rpc = self.storage._server.rpc rpc = self.storage._server.rpc
msgid = rpc._deferred_call('vote', self.storage._serial) msgid = rpc._deferred_call('vote', id(self.trans))
self.ready.set() self.ready.set()
rpc._deferred_wait(msgid) rpc._deferred_wait(msgid)
self.storage._check_serials() self.storage._check_serials()
...@@ -103,6 +103,51 @@ class CommitLockTests: ...@@ -103,6 +103,51 @@ class CommitLockTests:
self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', txn) self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', txn)
return oid, txn return oid, txn
def _begin_threads(self):
# Start a second transaction on a different connection without
# blocking the test thread. Returns only after each thread has
# set it's ready event.
self._storages = []
self._threads = []
for i in range(self.NUM_CLIENTS):
storage = self._duplicate_client()
txn = Transaction()
tid = self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
t.start()
t.ready.wait()
# Close on the connections abnormally to test server response
if i == 0:
storage.close()
else:
self._storages.append((storage, txn))
def _finish_threads(self):
for t in self._threads:
t.cleanup()
def _duplicate_client(self):
"Open another ClientStorage to the same server."
# XXX argh it's hard to find the actual address
# The rpc mgr addr attribute is a list. Each element in the
# list is a socket domain (AF_INET, AF_UNIX, etc.) and an
# address.
addr = self._storage._addr
new = ZEO.ClientStorage.ClientStorage(addr, wait=1)
new.registerDB(DummyDB(), None)
return new
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
return `t`
class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self): def checkCommitLockVoteFinish(self):
oid, txn = self._start_txn() oid, txn = self._start_txn()
self._storage.tpc_vote(txn) self._storage.tpc_vote(txn)
...@@ -141,15 +186,16 @@ class CommitLockTests: ...@@ -141,15 +186,16 @@ class CommitLockTests:
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
class CommitLockUndoTests(CommitLockTests):
def _get_trans_id(self): def _get_trans_id(self):
self._dostore() self._dostore()
L = self._storage.undoInfo() L = self._storage.undoInfo()
return L[0]['id'] return L[0]['id']
def _begin_undo(self, trans_id): def _begin_undo(self, trans_id, txn):
rpc = self._storage._server.rpc rpc = self._storage._server.rpc
return rpc._deferred_call('transactionalUndo', trans_id, return rpc._deferred_call('transactionalUndo', trans_id, id(txn))
self._storage._serial)
def _finish_undo(self, msgid): def _finish_undo(self, msgid):
return self._storage._server.rpc._deferred_wait(msgid) return self._storage._server.rpc._deferred_wait(msgid)
...@@ -157,7 +203,7 @@ class CommitLockTests: ...@@ -157,7 +203,7 @@ class CommitLockTests:
def checkCommitLockUndoFinish(self): def checkCommitLockUndoFinish(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -174,7 +220,7 @@ class CommitLockTests: ...@@ -174,7 +220,7 @@ class CommitLockTests:
def checkCommitLockUndoAbort(self): def checkCommitLockUndoAbort(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -190,7 +236,7 @@ class CommitLockTests: ...@@ -190,7 +236,7 @@ class CommitLockTests:
def checkCommitLockUndoClose(self): def checkCommitLockUndoClose(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -201,46 +247,3 @@ class CommitLockTests: ...@@ -201,46 +247,3 @@ class CommitLockTests:
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
def _begin_threads(self):
# Start a second transaction on a different connection without
# blocking the test thread. Returns only after each thread has
# set it's ready event.
self._storages = []
self._threads = []
for i in range(self.NUM_CLIENTS):
storage = self._duplicate_client()
txn = Transaction()
tid = self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
t.start()
t.ready.wait()
# Close on the connections abnormally to test server response
if i == 0:
storage.close()
else:
self._storages.append((storage, txn))
def _finish_threads(self):
for t in self._threads:
t.cleanup()
def _duplicate_client(self):
"Open another ClientStorage to the same server."
# XXX argh it's hard to find the actual address
# The rpc mgr addr attribute is a list. Each element in the
# list is a socket domain (AF_INET, AF_UNIX, etc.) and an
# address.
addr = self._storage._addr
new = ZEO.ClientStorage.ClientStorage(addr, wait=1)
new.registerDB(DummyDB(), None)
return new
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
return `t`
...@@ -109,7 +109,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -109,7 +109,7 @@ class CommonSetupTearDown(StorageTestBase):
os.waitpid(pid, 0) os.waitpid(pid, 0)
for c in self.caches: for c in self.caches:
for i in 0, 1: for i in 0, 1:
path = "c1-%s-%d.zec" % (c, i) path = "%s-%s.zec" % (c, "1")
# On Windows before 2.3, we don't have a way to wait for # On Windows before 2.3, we don't have a way to wait for
# the spawned server(s) to close, and they inherited # the spawned server(s) to close, and they inherited
# file descriptors for our open files. So long as those # file descriptors for our open files. So long as those
...@@ -584,6 +584,9 @@ class InvqTests(CommonSetupTearDown): ...@@ -584,6 +584,9 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid) revid = self._dostore(oid)
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
# sync() is needed to prevent invalidation for oid from arriving
# in the middle of the load() call.
perstorage.sync()
perstorage.load(oid, '') perstorage.load(oid, '')
perstorage.close() perstorage.close()
...@@ -853,7 +856,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -853,7 +856,7 @@ class TimeoutTests(CommonSetupTearDown):
unless = self.failUnless unless = self.failUnless
self._storage = storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
# Assert that the zeo cache is empty # Assert that the zeo cache is empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Create the object # Create the object
oid = storage.new_oid() oid = storage.new_oid()
obj = MinPO(7) obj = MinPO(7)
...@@ -872,7 +875,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -872,7 +875,7 @@ class TimeoutTests(CommonSetupTearDown):
# We expect finish to fail # We expect finish to fail
raises(ClientDisconnected, storage.tpc_finish, t) raises(ClientDisconnected, storage.tpc_finish, t)
# The cache should still be empty # The cache should still be empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Load should fail since the object should not be in either the cache # Load should fail since the object should not be in either the cache
# or the server. # or the server.
raises(KeyError, storage.load, oid, '') raises(KeyError, storage.load, oid, '')
...@@ -883,7 +886,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -883,7 +886,7 @@ class TimeoutTests(CommonSetupTearDown):
unless = self.failUnless unless = self.failUnless
self._storage = storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
# Assert that the zeo cache is empty # Assert that the zeo cache is empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Create the object # Create the object
oid = storage.new_oid() oid = storage.new_oid()
obj = MinPO(7) obj = MinPO(7)
......
...@@ -39,7 +39,23 @@ from ZODB.POSException \ ...@@ -39,7 +39,23 @@ from ZODB.POSException \
# thought they added (i.e., the keys for which get_transaction().commit() # thought they added (i.e., the keys for which get_transaction().commit()
# did not raise any exception). # did not raise any exception).
class StressThread(TestThread): class FailableThread(TestThread):
# mixin class
# subclass must provide
# - self.stop attribute (an event)
# - self._testrun() method
def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
class StressThread(FailableThread):
# Append integers startnum, startnum + step, startnum + 2*step, ... # Append integers startnum, startnum + step, startnum + 2*step, ...
# to 'tree' until Event stop is set. If sleep is given, sleep # to 'tree' until Event stop is set. If sleep is given, sleep
...@@ -57,7 +73,7 @@ class StressThread(TestThread): ...@@ -57,7 +73,7 @@ class StressThread(TestThread):
self.added_keys = [] self.added_keys = []
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def _testrun(self):
cn = self.db.open() cn = self.db.open()
while not self.stop.isSet(): while not self.stop.isSet():
try: try:
...@@ -87,7 +103,7 @@ class StressThread(TestThread): ...@@ -87,7 +103,7 @@ class StressThread(TestThread):
key += self.step key += self.step
cn.close() cn.close()
class LargeUpdatesThread(TestThread): class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify # A thread that performs a lot of updates. It attempts to modify
# more than 25 objects so that it can test code that runs vote # more than 25 objects so that it can test code that runs vote
...@@ -106,6 +122,15 @@ class LargeUpdatesThread(TestThread): ...@@ -106,6 +122,15 @@ class LargeUpdatesThread(TestThread):
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
def _testrun(self):
cn = self.db.open() cn = self.db.open()
while not self.stop.isSet(): while not self.stop.isSet():
try: try:
...@@ -162,7 +187,7 @@ class LargeUpdatesThread(TestThread): ...@@ -162,7 +187,7 @@ class LargeUpdatesThread(TestThread):
self.added_keys = keys_added.keys() self.added_keys = keys_added.keys()
cn.close() cn.close()
class VersionStressThread(TestThread): class VersionStressThread(FailableThread):
def __init__(self, testcase, db, stop, threadnum, commitdict, startnum, def __init__(self, testcase, db, stop, threadnum, commitdict, startnum,
step=2, sleep=None): step=2, sleep=None):
...@@ -177,6 +202,15 @@ class VersionStressThread(TestThread): ...@@ -177,6 +202,15 @@ class VersionStressThread(TestThread):
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
def _testrun(self):
commit = 0 commit = 0
key = self.startnum key = self.startnum
while not self.stop.isSet(): while not self.stop.isSet():
...@@ -302,7 +336,10 @@ class InvalidationTests: ...@@ -302,7 +336,10 @@ class InvalidationTests:
delay = self.MINTIME delay = self.MINTIME
start = time.time() start = time.time()
while time.time() - start <= self.MAXTIME: while time.time() - start <= self.MAXTIME:
time.sleep(delay) stop.wait(delay)
if stop.isSet():
# Some thread failed. Stop right now.
break
delay = 2.0 delay = 2.0
if len(commitdict) >= len(threads): if len(commitdict) >= len(threads):
break break
...@@ -406,6 +443,7 @@ class InvalidationTests: ...@@ -406,6 +443,7 @@ class InvalidationTests:
t1 = VersionStressThread(self, db1, stop, 1, cd, 1, 3) t1 = VersionStressThread(self, db1, stop, 1, cd, 1, 3)
t2 = VersionStressThread(self, db2, stop, 2, cd, 2, 3, 0.01) t2 = VersionStressThread(self, db2, stop, 2, cd, 2, 3, 0.01)
t3 = VersionStressThread(self, db2, stop, 3, cd, 3, 3, 0.01) t3 = VersionStressThread(self, db2, stop, 3, cd, 3, 3, 0.01)
## t1 = VersionStressThread(self, db2, stop, 3, cd, 1, 3, 0.01)
self.go(stop, cd, t1, t2, t3) self.go(stop, cd, t1, t2, t3)
cn.sync() cn.sync()
......
...@@ -28,7 +28,6 @@ from ZEO.StorageServer import StorageServer ...@@ -28,7 +28,6 @@ from ZEO.StorageServer import StorageServer
from ZEO.tests.ConnectionTests import CommonSetupTearDown from ZEO.tests.ConnectionTests import CommonSetupTearDown
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
from ZODB.tests.StorageTestBase import removefs
class AuthTest(CommonSetupTearDown): class AuthTest(CommonSetupTearDown):
__super_getServerConfig = CommonSetupTearDown.getServerConfig __super_getServerConfig = CommonSetupTearDown.getServerConfig
......
...@@ -101,20 +101,12 @@ class GenericTests( ...@@ -101,20 +101,12 @@ class GenericTests(
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase,
# ZODB test mixin classes (in the same order as imported) # ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage, BasicStorage.BasicStorage,
VersionStorage.VersionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
TransactionalUndoVersionStorage.TransactionalUndoVersionStorage,
PackableStorage.PackableStorage, PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage, Synchronization.SynchronizedStorage,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
RevisionStorage.RevisionStorage,
MTStorage.MTStorage, MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage, ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported) # ZEO test mixin classes (in the same order as imported)
Cache.StorageWithCache, CommitLockTests.CommitLockVoteTests,
Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockTests,
ThreadTests.ThreadTests, ThreadTests.ThreadTests,
# Locally defined (see above) # Locally defined (see above)
MiscZEOTests MiscZEOTests
...@@ -167,8 +159,22 @@ class GenericTests( ...@@ -167,8 +159,22 @@ class GenericTests(
key = '%s:%s' % (self._storage._storage, self._storage._server_addr) key = '%s:%s' % (self._storage._storage, self._storage._server_addr)
self.assertEqual(self._storage.sortKey(), key) self.assertEqual(self._storage.sortKey(), key)
class FullGenericTests(
GenericTests,
Cache.StorageWithCache,
Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockUndoTests,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage,
RevisionStorage.RevisionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
TransactionalUndoVersionStorage.TransactionalUndoVersionStorage,
VersionStorage.VersionStorage,
):
"""Extend GenericTests with tests that MappingStorage can't pass."""
class FileStorageTests(GenericTests): class FileStorageTests(FullGenericTests):
"""Test ZEO backed by a FileStorage.""" """Test ZEO backed by a FileStorage."""
level = 2 level = 2
...@@ -180,7 +186,7 @@ class FileStorageTests(GenericTests): ...@@ -180,7 +186,7 @@ class FileStorageTests(GenericTests):
</filestorage> </filestorage>
""" % filename """ % filename
class BDBTests(FileStorageTests): class BDBTests(FullGenericTests):
"""ZEO backed by a Berkeley full storage.""" """ZEO backed by a Berkeley full storage."""
level = 2 level = 2
...@@ -192,67 +198,14 @@ class BDBTests(FileStorageTests): ...@@ -192,67 +198,14 @@ class BDBTests(FileStorageTests):
</fullstorage> </fullstorage>
""" % self._envdir """ % self._envdir
class MappingStorageTests(FileStorageTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
def getConfig(self): def getConfig(self):
return """<mappingstorage 1/>""" return """<mappingstorage 1/>"""
# Tests which MappingStorage can't possibly pass, because it doesn't # XXX There are still a bunch of tests that fail. Are there
# support versions or undo. # still test classes in GenericTests that shouldn't be there?
def checkVersions(self): pass
def checkVersionedStoreAndLoad(self): pass
def checkVersionedLoadErrors(self): pass
def checkVersionLock(self): pass
def checkVersionEmpty(self): pass
def checkUndoUnresolvable(self): pass
def checkUndoInvalidation(self): pass
def checkUndoInVersion(self): pass
def checkUndoCreationBranch2(self): pass
def checkUndoCreationBranch1(self): pass
def checkUndoConflictResolution(self): pass
def checkUndoCommitVersion(self): pass
def checkUndoAbortVersion(self): pass
def checkPackUndoLog(self): pass
def checkUndoLogMetadata(self): pass
def checkTwoObjectUndoAtOnce(self): pass
def checkTwoObjectUndoAgain(self): pass
def checkTwoObjectUndo(self): pass
def checkTransactionalUndoAfterPackWithObjectUnlinkFromRoot(self): pass
def checkTransactionalUndoAfterPack(self): pass
def checkSimpleTransactionalUndo(self): pass
def checkReadMethods(self): pass
def checkPackAfterUndoDeletion(self): pass
def checkPackAfterUndoManyTimes(self): pass
def checkPackVersions(self): pass
def checkPackUnlinkedFromRoot(self): pass
def checkPackOnlyOneObject(self): pass
def checkPackJustOldRevisions(self): pass
def checkPackEmptyStorage(self): pass
def checkPackAllRevisions(self): pass
def checkPackVersionsInPast(self): pass
def checkPackVersionReachable(self): pass
def checkNotUndoable(self): pass
def checkNewSerialOnCommitVersionToVersion(self): pass
def checkModifyAfterAbortVersion(self): pass
def checkLoadSerial(self): pass
def checkCreateObjectInVersionWithAbort(self): pass
def checkCommitVersionSerialno(self): pass
def checkCommitVersionInvalidation(self): pass
def checkCommitToOtherVersion(self): pass
def checkCommitToNonVersion(self): pass
def checkCommitLockUndoFinish(self): pass
def checkCommitLockUndoClose(self): pass
def checkCommitLockUndoAbort(self): pass
def checkCommitEmptyVersionInvalidation(self): pass
def checkCreationUndoneGetSerial(self): pass
def checkAbortVersionSerialno(self): pass
def checkAbortVersionInvalidation(self): pass
def checkAbortVersionErrors(self): pass
def checkAbortVersion(self): pass
def checkAbortOneVersionCommitTheOther(self): pass
def checkResolve(self): pass
def check4ExtStorageThread(self): pass
test_classes = [FileStorageTests, MappingStorageTests] test_classes = [FileStorageTests, MappingStorageTests]
......
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Basic unit tests for a multi-version client cache."""
import os
import tempfile
import unittest
import ZEO.cache
from ZODB.utils import p64
n1 = p64(1)
n2 = p64(2)
n3 = p64(3)
n4 = p64(4)
n5 = p64(5)
class CacheTests(unittest.TestCase):
def setUp(self):
self.cache = ZEO.cache.ClientCache()
self.cache.open()
def tearDown(self):
if self.cache.path:
os.remove(self.cache.path)
def testLastTid(self):
self.assertEqual(self.cache.getLastTid(), None)
self.cache.setLastTid(n2)
self.assertEqual(self.cache.getLastTid(), n2)
self.cache.invalidate(None, "", n1)
self.assertEqual(self.cache.getLastTid(), n2)
self.cache.invalidate(None, "", n3)
self.assertEqual(self.cache.getLastTid(), n3)
self.assertRaises(ValueError, self.cache.setLastTid, n2)
def testLoad(self):
data1 = "data for n1"
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.load(n1, "version"), None)
self.cache.store(n1, "", n3, None, data1)
self.assertEqual(self.cache.load(n1, ""), (data1, n3, ""))
# The cache doesn't know whether version exists, because it
# only has non-version data.
self.assertEqual(self.cache.load(n1, "version"), None)
self.assertEqual(self.cache.modifiedInVersion(n1), None)
def testInvalidate(self):
data1 = "data for n1"
self.cache.store(n1, "", n3, None, data1)
self.cache.invalidate(n1, "", n4)
self.cache.invalidate(n2, "", n2)
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.loadBefore(n1, n4),
(data1, n3, n4))
def testVersion(self):
data1 = "data for n1"
data1v = "data for n1 in version"
self.cache.store(n1, "version", n3, None, data1v)
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.load(n1, "version"),
(data1v, n3, "version"))
self.assertEqual(self.cache.load(n1, "random"), None)
self.assertEqual(self.cache.modifiedInVersion(n1), "version")
self.cache.invalidate(n1, "version", n4)
self.assertEqual(self.cache.load(n1, "version"), None)
def testNonCurrent(self):
data1 = "data for n1"
data2 = "data for n2"
self.cache.store(n1, "", n4, None, data1)
self.cache.store(n1, "", n2, n3, data2)
# can't say anything about state before n2
self.assertEqual(self.cache.loadBefore(n1, n2), None)
# n3 is the upper bound of non-current record n2
self.assertEqual(self.cache.loadBefore(n1, n3), (data2, n2, n3))
# no data for between n2 and n3
self.assertEqual(self.cache.loadBefore(n1, n4), None)
self.cache.invalidate(n1, "", n5)
self.assertEqual(self.cache.loadBefore(n1, n5), (data1, n4, n5))
self.assertEqual(self.cache.loadBefore(n2, n4), None)
def testException(self):
self.assertRaises(ValueError,
self.cache.store,
n1, "version", n2, n3, "data")
self.cache.store(n1, "", n2, None, "data")
self.assertRaises(ValueError,
self.cache.store,
n1, "", n3, None, "data")
def testEviction(self):
# Manually override the current maxsize
maxsize = self.cache.size = self.cache.fc.maxsize = 3395 # 1245
self.cache.fc = ZEO.cache.FileCache(3395, None, self.cache)
# Trivial test of eviction code. Doesn't test non-current
# eviction.
data = ["z" * i for i in range(100)]
for i in range(50):
n = p64(i)
self.cache.store(n, "", n, None, data[i])
self.assertEquals(len(self.cache), i + 1)
self.assert_(self.cache.fc.currentsize < maxsize)
# The cache now uses 1225 bytes. The next insert
# should delete some objects.
n = p64(50)
self.cache.store(n, "", n, None, data[51])
self.assert_(len(self.cache) < 51)
self.assert_(self.cache.fc.currentsize <= maxsize)
# XXX Need to make sure eviction of non-current data
# and of version data are handled correctly.
def testSerialization(self):
self.cache.store(n1, "", n2, None, "data for n1")
self.cache.store(n2, "version", n2, None, "version data for n2")
self.cache.store(n3, "", n3, n4, "non-current data for n3")
self.cache.store(n3, "", n4, n5, "more non-current data for n3")
path = tempfile.mktemp()
# Copy data from self.cache into path, reaching into the cache
# guts to make the copy.
dst = open(path, "wb+")
src = self.cache.fc.f
src.seek(0)
dst.write(src.read(self.cache.fc.maxsize))
dst.close()
copy = ZEO.cache.ClientCache(path)
copy.open()
# Verify that internals of both objects are the same.
# Could also test that external API produces the same results.
eq = self.assertEqual
eq(copy.tid, self.cache.tid)
eq(len(copy), len(self.cache))
eq(copy.version, self.cache.version)
eq(copy.current, self.cache.current)
eq(copy.noncurrent, self.cache.noncurrent)
def test_suite():
return unittest.makeSuite(CacheTests)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
############################################################################## ##############################################################################
"""Handy standard storage machinery """Handy standard storage machinery
$Id: BaseStorage.py,v 1.38 2003/12/23 14:37:13 jeremy Exp $ $Id: BaseStorage.py,v 1.39 2003/12/24 16:02:00 jeremy Exp $
""" """
import cPickle import cPickle
import threading import threading
...@@ -32,7 +32,6 @@ from ZODB.utils import z64 ...@@ -32,7 +32,6 @@ from ZODB.utils import z64
class BaseStorage(UndoLogCompatible): class BaseStorage(UndoLogCompatible):
_transaction=None # Transaction that is being committed _transaction=None # Transaction that is being committed
_serial=z64 # Transaction serial number
_tstatus=' ' # Transaction status, used for copying data _tstatus=' ' # Transaction status, used for copying data
_is_read_only = 0 _is_read_only = 0
...@@ -51,7 +50,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -51,7 +50,7 @@ class BaseStorage(UndoLogCompatible):
t=time.time() t=time.time()
t=self._ts=apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,))) t=self._ts=apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,)))
self._serial=`t` self._tid = `t`
if base is None: if base is None:
self._oid='\0\0\0\0\0\0\0\0' self._oid='\0\0\0\0\0\0\0\0'
else: else:
...@@ -60,16 +59,19 @@ class BaseStorage(UndoLogCompatible): ...@@ -60,16 +59,19 @@ class BaseStorage(UndoLogCompatible):
def abortVersion(self, src, transaction): def abortVersion(self, src, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
return [] return self._tid, []
def commitVersion(self, src, dest, transaction): def commitVersion(self, src, dest, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
return [] return self._tid, []
def close(self): def close(self):
pass pass
def cleanup(self):
pass
def sortKey(self): def sortKey(self):
"""Return a string that can be used to sort storage instances. """Return a string that can be used to sort storage instances.
...@@ -85,7 +87,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -85,7 +87,7 @@ class BaseStorage(UndoLogCompatible):
def getSize(self): def getSize(self):
return len(self)*300 # WAG! return len(self)*300 # WAG!
def history(self, oid, version, length=1): def history(self, oid, version, length=1, filter=None):
pass pass
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
...@@ -167,13 +169,13 @@ class BaseStorage(UndoLogCompatible): ...@@ -167,13 +169,13 @@ class BaseStorage(UndoLogCompatible):
now = time.time() now = time.time()
t = TimeStamp(*(time.gmtime(now)[:5] + (now % 60,))) t = TimeStamp(*(time.gmtime(now)[:5] + (now % 60,)))
self._ts = t = t.laterThan(self._ts) self._ts = t = t.laterThan(self._ts)
self._serial = `t` self._tid = `t`
else: else:
self._ts = TimeStamp(tid) self._ts = TimeStamp(tid)
self._serial = tid self._tid = tid
self._tstatus = status self._tstatus = status
self._begin(self._serial, user, desc, ext) self._begin(self._tid, user, desc, ext)
finally: finally:
self._lock_release() self._lock_release()
...@@ -203,10 +205,11 @@ class BaseStorage(UndoLogCompatible): ...@@ -203,10 +205,11 @@ class BaseStorage(UndoLogCompatible):
return return
try: try:
if f is not None: if f is not None:
f() f(self._tid)
u, d, e = self._ude u, d, e = self._ude
self._finish(self._serial, u, d, e) self._finish(self._tid, u, d, e)
self._clear_temp() self._clear_temp()
return self._tid
finally: finally:
self._ude = None self._ude = None
self._transaction = None self._transaction = None
...@@ -250,6 +253,48 @@ class BaseStorage(UndoLogCompatible): ...@@ -250,6 +253,48 @@ class BaseStorage(UndoLogCompatible):
raise POSException.Unsupported, ( raise POSException.Unsupported, (
"Retrieval of historical revisions is not supported") "Retrieval of historical revisions is not supported")
def loadBefore(self, oid, tid):
"""Return most recent revision of oid before tid committed."""
# XXX Is it okay for loadBefore() to return current data?
# There doesn't seem to be a good reason to forbid it, even
# though the typical use of this method will never find
# current data. But maybe we should call it loadByTid()?
n = 2
start_time = None
end_time = None
while start_time is None:
# The history() approach is a hack, because the dict
# returned by history() doesn't contain a tid. It
# contains a serialno, which is often the same, but isn't
# required to be. We'll pretend it is for now.
# A second problem is that history() doesn't say anything
# about whether the transaction status. If it falls before
# the pack time, we can't honor the MVCC request.
# Note: history() returns the most recent record first.
# XXX The filter argument to history() only appears to be
# supported by FileStorage. Perhaps it shouldn't be used.
L = self.history(oid, "", n, lambda d: not d["version"])
if not L:
return
for d in L:
if d["serial"] < tid:
start_time = d["serial"]
break
else:
end_time = d["serial"]
if len(L) < n:
break
n *= 2
if start_time is None:
return None
data = self.loadSerial(oid, start_time)
return data, start_time, end_time
def getExtensionMethods(self): def getExtensionMethods(self):
"""getExtensionMethods """getExtensionMethods
...@@ -314,7 +359,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -314,7 +359,7 @@ class BaseStorage(UndoLogCompatible):
oid=r.oid oid=r.oid
if verbose: print oid_repr(oid), r.version, len(r.data) if verbose: print oid_repr(oid), r.version, len(r.data)
if restoring: if restoring:
self.restore(oid, r.serial, r.data, r.version, self.restore(oid, r.tid, r.data, r.version,
r.data_txn, transaction) r.data_txn, transaction)
else: else:
pre=preget(oid, None) pre=preget(oid, None)
......
This diff is collapsed.
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
############################################################################## ##############################################################################
"""Database objects """Database objects
$Id: DB.py,v 1.57 2003/11/28 16:44:49 jim Exp $""" $Id: DB.py,v 1.58 2003/12/24 16:02:00 jeremy Exp $"""
__version__='$Revision: 1.57 $'[11:-2] __version__='$Revision: 1.58 $'[11:-2]
import cPickle, cStringIO, sys, POSException, UndoLogCompatible import cPickle, cStringIO, sys, POSException, UndoLogCompatible
from Connection import Connection from Connection import Connection
...@@ -74,7 +74,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -74,7 +74,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
self._version_cache_size=version_cache_size self._version_cache_size=version_cache_size
self._version_cache_deactivate_after = version_cache_deactivate_after self._version_cache_deactivate_after = version_cache_deactivate_after
self._miv_cache={} self._miv_cache = {}
# Setup storage # Setup storage
self._storage=storage self._storage=storage
...@@ -300,8 +300,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -300,8 +300,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
def importFile(self, file): def importFile(self, file):
raise NotImplementedError raise NotImplementedError
def invalidate(self, oids, connection=None, version='', def invalidate(self, tid, oids, connection=None, version=''):
rc=sys.getrefcount):
"""Invalidate references to a given oid. """Invalidate references to a given oid.
This is used to indicate that one of the connections has committed a This is used to indicate that one of the connections has committed a
...@@ -323,21 +322,21 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -323,21 +322,21 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
for cc in allocated: for cc in allocated:
if (cc is not connection and if (cc is not connection and
(not version or cc._version==version)): (not version or cc._version==version)):
if rc(cc) <= 3: if sys.getrefcount(cc) <= 3:
cc.close() cc.close()
cc.invalidate(oids) cc.invalidate(tid, oids)
temps=self._temps if self._temps:
if temps:
t=[] t=[]
for cc in temps: for cc in self._temps:
if rc(cc) > 3: if sys.getrefcount(cc) > 3:
if (cc is not connection and if (cc is not connection and
(not version or cc._version==version)): (not version or cc._version == version)):
cc.invalidate(oids) cc.invalidate(tid, oids)
t.append(cc) t.append(cc)
else: cc.close() else:
self._temps=t cc.close()
self._temps = t
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
h=hash(oid)%131 h=hash(oid)%131
...@@ -353,7 +352,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -353,7 +352,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
return len(self._storage) return len(self._storage)
def open(self, version='', transaction=None, temporary=0, force=None, def open(self, version='', transaction=None, temporary=0, force=None,
waitflag=1): waitflag=1, mvcc=True):
"""Return a object space (AKA connection) to work in """Return a object space (AKA connection) to work in
The optional version argument can be used to specify that a The optional version argument can be used to specify that a
...@@ -371,25 +370,25 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -371,25 +370,25 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
try: try:
if transaction is not None: if transaction is not None:
connections=transaction._connections connections = transaction._connections
if connections: if connections:
if connections.has_key(version) and not temporary: if connections.has_key(version) and not temporary:
return connections[version] return connections[version]
else: else:
transaction._connections=connections={} transaction._connections = connections = {}
transaction=transaction._connections transaction = transaction._connections
if temporary: if temporary:
# This is a temporary connection. # This is a temporary connection.
# We won't bother with the pools. This will be # We won't bother with the pools. This will be
# a one-use connection. # a one-use connection.
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._version_cache_size,
cache_size=self._version_cache_size) mvcc=mvcc)
c._setDB(self) c._setDB(self)
self._temps.append(c) self._temps.append(c)
if transaction is not None: transaction[id(c)]=c if transaction is not None:
transaction[id(c)] = c
return c return c
...@@ -430,18 +429,18 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -430,18 +429,18 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
if not pool: if not pool:
c=None c = None
if version: if version:
if self._version_pool_size > len(allocated) or force: if self._version_pool_size > len(allocated) or force:
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._version_cache_size,
cache_size=self._version_cache_size) mvcc=mvcc)
allocated.append(c) allocated.append(c)
pool.append(c) pool.append(c)
elif self._pool_size > len(allocated) or force: elif self._pool_size > len(allocated) or force:
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._cache_size,
cache_size=self._cache_size) mvcc=mvcc)
allocated.append(c) allocated.append(c)
pool.append(c) pool.append(c)
...@@ -456,7 +455,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -456,7 +455,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
pool_lock.release() pool_lock.release()
else: return else: return
elif len(pool)==1: elif len(pool) == 1:
# Taking last one, lock the pool # Taking last one, lock the pool
# Note that another thread might grab the lock # Note that another thread might grab the lock
# before us, so we might actually block, however, # before us, so we might actually block, however,
...@@ -470,14 +469,15 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -470,14 +469,15 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
# but it could be higher due to a race condition. # but it could be higher due to a race condition.
pool_lock.release() pool_lock.release()
c=pool[-1] c = pool[-1]
del pool[-1] del pool[-1]
c._setDB(self) c._setDB(self)
for pool, allocated in pooll: for pool, allocated in pooll:
for cc in pool: for cc in pool:
cc._incrgc() cc._incrgc()
if transaction is not None: transaction[version]=c if transaction is not None:
transaction[version] = c
return c return c
finally: self._r() finally: self._r()
...@@ -588,7 +588,8 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -588,7 +588,8 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
d = {} d = {}
for oid in storage.undo(id): for oid in storage.undo(id):
d[oid] = 1 d[oid] = 1
self.invalidate(d) # XXX I think we need to remove old undo to use mvcc
self.invalidate(None, d)
def versionEmpty(self, version): def versionEmpty(self, version):
return self._storage.versionEmpty(version) return self._storage.versionEmpty(version)
...@@ -616,13 +617,13 @@ class CommitVersion: ...@@ -616,13 +617,13 @@ class CommitVersion:
def commit(self, reallyme, t): def commit(self, reallyme, t):
dest=self._dest dest=self._dest
oids = self._db._storage.commitVersion(self._version, dest, t) tid, oids = self._db._storage.commitVersion(self._version, dest, t)
oids = list2dict(oids) oids = list2dict(oids)
self._db.invalidate(oids, version=dest) self._db.invalidate(tid, oids, version=dest)
if dest: if dest:
# the code above just invalidated the dest version. # the code above just invalidated the dest version.
# now we need to invalidate the source! # now we need to invalidate the source!
self._db.invalidate(oids, version=self._version) self._db.invalidate(tid, oids, version=self._version)
class AbortVersion(CommitVersion): class AbortVersion(CommitVersion):
"""An object that will see to version abortion """An object that will see to version abortion
...@@ -631,9 +632,9 @@ class AbortVersion(CommitVersion): ...@@ -631,9 +632,9 @@ class AbortVersion(CommitVersion):
""" """
def commit(self, reallyme, t): def commit(self, reallyme, t):
version=self._version version = self._version
oids = self._db._storage.abortVersion(version, t) tid, oids = self._db._storage.abortVersion(version, t)
self._db.invalidate(list2dict(oids), version=version) self._db.invalidate(tid, list2dict(oids), version=version)
class TransactionalUndo(CommitVersion): class TransactionalUndo(CommitVersion):
...@@ -647,5 +648,5 @@ class TransactionalUndo(CommitVersion): ...@@ -647,5 +648,5 @@ class TransactionalUndo(CommitVersion):
# similarity of rhythm that I think it's justified. # similarity of rhythm that I think it's justified.
def commit(self, reallyme, t): def commit(self, reallyme, t):
oids = self._db._storage.transactionalUndo(self._version, t) tid, oids = self._db._storage.transactionalUndo(self._version, t)
self._db.invalidate(list2dict(oids)) self._db.invalidate(tid, list2dict(oids))
This diff is collapsed.
This diff is collapsed.
# this is a package
from ZODB.FileStorage.FileStorage \
import FileStorage, RecordIterator, FileIterator, packed_version
This diff is collapsed.
from ZODB.FileStorage import FileIterator
from ZODB.FileStorage.format \
import TRANS_HDR, TRANS_HDR_LEN, DATA_HDR, DATA_HDR_LEN
from ZODB.TimeStamp import TimeStamp
from ZODB.utils import u64
from ZODB.tests.StorageTestBase import zodb_unpickle
from cPickle import Unpickler
from cStringIO import StringIO
import md5
import struct
import types
def get_pickle_metadata(data):
# ZODB's data records contain two pickles. The first is the class
# of the object, the second is the object.
if data.startswith('(c'):
# Don't actually unpickle a class, because it will attempt to
# load the class. Just break open the pickle and get the
# module and class from it.
modname, classname, rest = data.split('\n', 2)
modname = modname[2:]
return modname, classname
f = StringIO(data)
u = Unpickler(f)
try:
class_info = u.load()
except Exception, err:
print "Error", err
return '', ''
if isinstance(class_info, types.TupleType):
if isinstance(class_info[0], types.TupleType):
modname, classname = class_info[0]
else:
modname, classname = class_info
else:
# XXX not sure what to do here
modname = repr(class_info)
classname = ''
return modname, classname
def fsdump(path, file=None, with_offset=1):
i = 0
iter = FileIterator(path)
for trans in iter:
if with_offset:
print >> file, "Trans #%05d tid=%016x time=%s offset=%d" % \
(i, u64(trans.tid), str(TimeStamp(trans.tid)), trans._pos)
else:
print >> file, "Trans #%05d tid=%016x time=%s" % \
(i, u64(trans.tid), str(TimeStamp(trans.tid)))
print >> file, "\tstatus=%s user=%s description=%s" % \
(`trans.status`, trans.user, trans.description)
j = 0
for rec in trans:
if rec.data is None:
fullclass = "undo or abort of object creation"
else:
modname, classname = get_pickle_metadata(rec.data)
dig = md5.new(rec.data).hexdigest()
fullclass = "%s.%s" % (modname, classname)
# special case for testing purposes
if fullclass == "ZODB.tests.MinPO.MinPO":
obj = zodb_unpickle(rec.data)
fullclass = "%s %s" % (fullclass, obj.value)
if rec.version:
version = "version=%s " % rec.version
else:
version = ''
if rec.data_txn:
# XXX It would be nice to print the transaction number
# (i) but it would be too expensive to keep track of.
bp = "bp=%016x" % u64(rec.data_txn)
else:
bp = ""
print >> file, " data #%05d oid=%016x %sclass=%s %s" % \
(j, u64(rec.oid), version, fullclass, bp)
j += 1
print >> file
i += 1
iter.close()
def fmt(p64):
# Return a nicely formatted string for a packaged 64-bit value
return "%016x" % u64(p64)
class Dumper:
"""A very verbose dumper for debuggin FileStorage problems."""
# XXX Should revise this class to use FileStorageFormatter.
def __init__(self, path, dest=None):
self.file = open(path, "rb")
self.dest = dest
def dump(self):
fid = self.file.read(4)
print >> self.dest, "*" * 60
print >> self.dest, "file identifier: %r" % fid
while self.dump_txn():
pass
def dump_txn(self):
pos = self.file.tell()
h = self.file.read(TRANS_HDR_LEN)
if not h:
return False
tid, tlen, status, ul, dl, el = struct.unpack(TRANS_HDR, h)
end = pos + tlen
print >> self.dest, "=" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "end pos: %d" % end
print >> self.dest, "transaction id: %s" % fmt(tid)
print >> self.dest, "trec len: %d" % tlen
print >> self.dest, "status: %r" % status
user = descr = extra = ""
if ul:
user = self.file.read(ul)
if dl:
descr = self.file.read(dl)
if el:
extra = self.file.read(el)
print >> self.dest, "user: %r" % user
print >> self.dest, "description: %r" % descr
print >> self.dest, "len(extra): %d" % el
while self.file.tell() < end:
self.dump_data(pos)
stlen = self.file.read(8)
print >> self.dest, "redundant trec len: %d" % u64(stlen)
return 1
def dump_data(self, tloc):
pos = self.file.tell()
h = self.file.read(DATA_HDR_LEN)
assert len(h) == DATA_HDR_LEN
oid, revid, prev, tloc, vlen, dlen = struct.unpack(DATA_HDR, h)
print >> self.dest, "-" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "oid: %s" % fmt(oid)
print >> self.dest, "revid: %s" % fmt(revid)
print >> self.dest, "previous record offset: %d" % prev
print >> self.dest, "transaction offset: %d" % tloc
if vlen:
pnv = self.file.read(8)
sprevdata = self.file.read(8)
version = self.file.read(vlen)
print >> self.dest, "version: %r" % version
print >> self.dest, "non-version data offset: %d" % u64(pnv)
print >> self.dest, \
"previous version data offset: %d" % u64(sprevdata)
print >> self.dest, "len(data): %d" % dlen
self.file.read(dlen)
if not dlen:
sbp = self.file.read(8)
print >> self.dest, "backpointer: %d" % u64(sbp)
This diff is collapsed.
...@@ -21,7 +21,7 @@ It is meant to illustrate the simplest possible storage. ...@@ -21,7 +21,7 @@ It is meant to illustrate the simplest possible storage.
The Mapping storage uses a single data structure to map object ids to data. The Mapping storage uses a single data structure to map object ids to data.
""" """
__version__='$Revision: 1.10 $'[11:-2] __version__='$Revision: 1.11 $'[11:-2]
from ZODB import utils from ZODB import utils
from ZODB import BaseStorage from ZODB import BaseStorage
...@@ -58,6 +58,16 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -58,6 +58,16 @@ class MappingStorage(BaseStorage.BaseStorage):
finally: finally:
self._lock_release() self._lock_release()
def loadEx(self, oid, version):
self._lock_acquire()
try:
# Since this storage doesn't support versions, tid and
# serial will always be the same.
p = self._index[oid]
return p[8:], p[:8], "" # pickle, serial, tid
finally:
self._lock_release()
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
...@@ -75,11 +85,10 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -75,11 +85,10 @@ class MappingStorage(BaseStorage.BaseStorage):
serials=(oserial, serial), serials=(oserial, serial),
data=data) data=data)
serial = self._serial self._tindex.append((oid, self._tid + data))
self._tindex.append((oid, serial+data))
finally: finally:
self._lock_release() self._lock_release()
return serial return self._tid
def _clear_temp(self): def _clear_temp(self):
self._tindex = [] self._tindex = []
...@@ -87,7 +96,7 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -87,7 +96,7 @@ class MappingStorage(BaseStorage.BaseStorage):
def _finish(self, tid, user, desc, ext): def _finish(self, tid, user, desc, ext):
for oid, p in self._tindex: for oid, p in self._tindex:
self._index[oid] = p self._index[oid] = p
self._ltid = self._serial self._ltid = self._tid
def lastTransaction(self): def lastTransaction(self):
return self._ltid return self._ltid
...@@ -95,6 +104,8 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -95,6 +104,8 @@ class MappingStorage(BaseStorage.BaseStorage):
def pack(self, t, referencesf): def pack(self, t, referencesf):
self._lock_acquire() self._lock_acquire()
try: try:
if not self._index:
return
# Build an index of *only* those objects reachable from the root. # Build an index of *only* those objects reachable from the root.
rootl = ['\0\0\0\0\0\0\0\0'] rootl = ['\0\0\0\0\0\0\0\0']
pindex = {} pindex = {}
......
...@@ -85,7 +85,6 @@ class BasicStorage: ...@@ -85,7 +85,6 @@ class BasicStorage:
eq(value, MinPO(11)) eq(value, MinPO(11))
eq(revid, newrevid) eq(revid, newrevid)
## def checkNonVersionStore(self, oid=None, revid=None, version=None):
def checkNonVersionStore(self): def checkNonVersionStore(self):
revid = ZERO revid = ZERO
newrevid = self._dostore(revid=None) newrevid = self._dostore(revid=None)
......
...@@ -20,7 +20,7 @@ import tempfile ...@@ -20,7 +20,7 @@ import tempfile
import unittest import unittest
import ZODB, ZODB.FileStorage import ZODB, ZODB.FileStorage
from StorageTestBase import StorageTestBase, removefs from StorageTestBase import StorageTestBase
class FileStorageCorruptTests(StorageTestBase): class FileStorageCorruptTests(StorageTestBase):
...@@ -30,7 +30,7 @@ class FileStorageCorruptTests(StorageTestBase): ...@@ -30,7 +30,7 @@ class FileStorageCorruptTests(StorageTestBase):
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
removefs(self.path) self._storage.cleanup()
def _do_stores(self): def _do_stores(self):
oids = [] oids = []
......
...@@ -36,40 +36,40 @@ class HistoryStorage: ...@@ -36,40 +36,40 @@ class HistoryStorage:
h = self._storage.history(oid, size=1) h = self._storage.history(oid, size=1)
eq(len(h), 1) eq(len(h), 1)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
# Try to get 2 historical revisions # Try to get 2 historical revisions
h = self._storage.history(oid, size=2) h = self._storage.history(oid, size=2)
eq(len(h), 2) eq(len(h), 2)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
# Try to get all 3 historical revisions # Try to get all 3 historical revisions
h = self._storage.history(oid, size=3) h = self._storage.history(oid, size=3)
eq(len(h), 3) eq(len(h), 3)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[2] d = h[2]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
# There should be no more than 3 revisions # There should be no more than 3 revisions
h = self._storage.history(oid, size=4) h = self._storage.history(oid, size=4)
eq(len(h), 3) eq(len(h), 3)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[2] d = h[2]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkVersionHistory(self): def checkVersionHistory(self):
...@@ -94,22 +94,22 @@ class HistoryStorage: ...@@ -94,22 +94,22 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 6) eq(len(h), 6)
d = h[0] d = h[0]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[1] d = h[1]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[4] d = h[4]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkHistoryAfterVersionCommit(self): def checkHistoryAfterVersionCommit(self):
...@@ -151,25 +151,25 @@ class HistoryStorage: ...@@ -151,25 +151,25 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 7) eq(len(h), 7)
d = h[0] d = h[0]
eq(d['serial'], revid7) eq(d['tid'], revid7)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[4] d = h[4]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[6] d = h[6]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkHistoryAfterVersionAbort(self): def checkHistoryAfterVersionAbort(self):
...@@ -211,23 +211,23 @@ class HistoryStorage: ...@@ -211,23 +211,23 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 7) eq(len(h), 7)
d = h[0] d = h[0]
eq(d['serial'], revid7) eq(d['tid'], revid7)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[4] d = h[4]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[6] d = h[6]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
...@@ -33,7 +33,7 @@ class IteratorCompare: ...@@ -33,7 +33,7 @@ class IteratorCompare:
eq(reciter.tid, revid) eq(reciter.tid, revid)
for rec in reciter: for rec in reciter:
eq(rec.oid, oid) eq(rec.oid, oid)
eq(rec.serial, revid) eq(rec.tid, revid)
eq(rec.version, '') eq(rec.version, '')
eq(zodb_unpickle(rec.data), MinPO(val)) eq(zodb_unpickle(rec.data), MinPO(val))
val = val + 1 val = val + 1
...@@ -147,6 +147,20 @@ class IteratorStorage(IteratorCompare): ...@@ -147,6 +147,20 @@ class IteratorStorage(IteratorCompare):
finally: finally:
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
def checkLoadEx(self):
oid = self._storage.new_oid()
self._dostore(oid, data=42)
data, tid, ver = self._storage.loadEx(oid, "")
self.assertEqual(zodb_unpickle(data), MinPO(42))
match = False
for txn in self._storage.iterator():
for rec in txn:
if rec.oid == oid and rec.tid == tid:
self.assertEqual(txn.tid, tid)
match = True
if not match:
self.fail("Could not find transaction with matching id")
class ExtendedIteratorStorage(IteratorCompare): class ExtendedIteratorStorage(IteratorCompare):
...@@ -202,7 +216,7 @@ class IteratorDeepCompare: ...@@ -202,7 +216,7 @@ class IteratorDeepCompare:
eq(txn1._extension, txn2._extension) eq(txn1._extension, txn2._extension)
for rec1, rec2 in zip(txn1, txn2): for rec1, rec2 in zip(txn1, txn2):
eq(rec1.oid, rec2.oid) eq(rec1.oid, rec2.oid)
eq(rec1.serial, rec2.serial) eq(rec1.tid, rec2.tid)
eq(rec1.version, rec2.version) eq(rec1.version, rec2.version)
eq(rec1.data, rec2.data) eq(rec1.data, rec2.data)
# Make sure there are no more records left in rec1 and rec2, # Make sure there are no more records left in rec1 and rec2,
......
...@@ -154,9 +154,12 @@ class StorageClientThread(TestThread): ...@@ -154,9 +154,12 @@ class StorageClientThread(TestThread):
class ExtStorageClientThread(StorageClientThread): class ExtStorageClientThread(StorageClientThread):
def runtest(self): def runtest(self):
# pick some other storage ops to execute # pick some other storage ops to execute, depending in part
ops = [getattr(self, meth) for meth in dir(ExtStorageClientThread) # on the features provided by the storage.
if meth.startswith('do_')] names = ["do_load", "do_modifiedInVersion"]
if self.storage.supportsUndo():
names += ["do_loadSerial", "do_undoLog", "do_iterator"]
ops = [getattr(self, meth) for meth in names]
assert ops, "Didn't find an storage ops in %s" % self.storage assert ops, "Didn't find an storage ops in %s" % self.storage
# do a store to guarantee there's at least one oid in self.oids # do a store to guarantee there's at least one oid in self.oids
self.dostore(0) self.dostore(0)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -35,6 +35,10 @@ class DemoStorageTests(StorageTestBase.StorageTestBase, ...@@ -35,6 +35,10 @@ class DemoStorageTests(StorageTestBase.StorageTestBase,
# have this limit, so we inhibit this test here. # have this limit, so we inhibit this test here.
pass pass
def checkAbortVersionNonCurrent(self):
# XXX Need to implement a real loadBefore for DemoStorage?
pass
def test_suite(): def test_suite():
suite = unittest.makeSuite(DemoStorageTests, 'check') suite = unittest.makeSuite(DemoStorageTests, 'check')
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment