Commit c316b563 authored by Jeremy Hylton's avatar Jeremy Hylton

Merge ZODB3-fast-restart-branch to the trunk

parent a45a0786
...@@ -34,8 +34,16 @@ temporary directory as determined by the tempfile module. ...@@ -34,8 +34,16 @@ temporary directory as determined by the tempfile module.
The ClientStorage overrides the client name default to the value of The ClientStorage overrides the client name default to the value of
the environment variable ZEO_CLIENT, if it exists. the environment variable ZEO_CLIENT, if it exists.
Each cache file has a 4-byte magic number followed by a sequence of Each cache file has a 12-byte header followed by a sequence of
records of the form: records. The header format is as follows:
offset in header: name -- description
0: magic -- 4-byte magic number, identifying this as a ZEO cache file
4: lasttid -- 8-byte last transaction id
Each record has the following form:
offset in record: name -- description offset in record: name -- description
...@@ -111,7 +119,8 @@ from ZODB.utils import U64 ...@@ -111,7 +119,8 @@ from ZODB.utils import U64
import zLOG import zLOG
from ZEO.ICache import ICache from ZEO.ICache import ICache
magic='ZEC0' magic = 'ZEC1'
headersize = 12
class ClientCache: class ClientCache:
...@@ -126,6 +135,8 @@ class ClientCache: ...@@ -126,6 +135,8 @@ class ClientCache:
self._storage = storage self._storage = storage
self._limit = size / 2 self._limit = size / 2
self._client = client
self._ltid = None # For getLastTid()
# Allocate locks: # Allocate locks:
L = allocate_lock() L = allocate_lock()
...@@ -154,9 +165,9 @@ class ClientCache: ...@@ -154,9 +165,9 @@ class ClientCache:
fi = open(p[i],'r+b') fi = open(p[i],'r+b')
if fi.read(4) == magic: # Minimal sanity if fi.read(4) == magic: # Minimal sanity
fi.seek(0, 2) fi.seek(0, 2)
if fi.tell() > 30: if fi.tell() > headersize:
# First serial is at offset 19 + 4 for magic # Read serial at offset 19 of first record
fi.seek(23) fi.seek(headersize + 19)
s[i] = fi.read(8) s[i] = fi.read(8)
# If we found a non-zero serial, then use the file # If we found a non-zero serial, then use the file
if s[i] != '\0\0\0\0\0\0\0\0': if s[i] != '\0\0\0\0\0\0\0\0':
...@@ -172,14 +183,14 @@ class ClientCache: ...@@ -172,14 +183,14 @@ class ClientCache:
if f[0] is None: if f[0] is None:
# We started, open the first cache file # We started, open the first cache file
f[0] = open(p[0], 'w+b') f[0] = open(p[0], 'w+b')
f[0].write(magic) f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0 current = 0
f[1] = None f[1] = None
else: else:
self._f = f = [tempfile.TemporaryFile(suffix='.zec'), None] self._f = f = [tempfile.TemporaryFile(suffix='.zec'), None]
# self._p file name 'None' signifies an unnamed temp file. # self._p file name 'None' signifies an unnamed temp file.
self._p = p = [None, None] self._p = p = [None, None]
f[0].write(magic) f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0 current = 0
self.log("%s: storage=%r, size=%r; file[%r]=%r" % self.log("%s: storage=%r, size=%r; file[%r]=%r" %
...@@ -219,6 +230,57 @@ class ClientCache: ...@@ -219,6 +230,57 @@ class ClientCache:
except OSError: except OSError:
pass pass
def getLastTid(self):
"""Get the last transaction id stored by setLastTid().
If the cache is persistent, it is read from the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
return self._ltid
else:
self._acquire()
try:
return self._getLastTid()
finally:
self._release()
def _getLastTid(self):
f = self._f[self._current]
f.seek(4)
tid = f.read(8)
if len(tid) < 8 or tid == '\0\0\0\0\0\0\0\0':
return None
else:
return tid
def setLastTid(self, tid):
"""Store the last transaction id.
If the cache is persistent, it is written to the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
if tid == '\0\0\0\0\0\0\0\0':
tid = None
self._ltid = tid
else:
self._acquire()
try:
self._setLastTid(tid)
finally:
self._release()
def _setLastTid(self, tid):
if tid is None:
tid = '\0\0\0\0\0\0\0\0'
else:
tid = str(tid)
assert len(tid) == 8
f = self._f[self._current]
f.seek(4)
f.write(tid)
def verify(self, verifyFunc): def verify(self, verifyFunc):
"""Call the verifyFunc on every object in the cache. """Call the verifyFunc on every object in the cache.
...@@ -477,6 +539,7 @@ class ClientCache: ...@@ -477,6 +539,7 @@ class ClientCache:
self._acquire() self._acquire()
try: try:
if self._pos + size > self._limit: if self._pos + size > self._limit:
ltid = self._getLastTid()
current = not self._current current = not self._current
self._current = current self._current = current
self._trace(0x70) self._trace(0x70)
...@@ -500,8 +563,12 @@ class ClientCache: ...@@ -500,8 +563,12 @@ class ClientCache:
else: else:
# Temporary cache file: # Temporary cache file:
self._f[current] = tempfile.TemporaryFile(suffix='.zec') self._f[current] = tempfile.TemporaryFile(suffix='.zec')
self._f[current].write(magic) header = magic
self._pos = 4 if ltid:
header += ltid
self._f[current].write(header +
'\0' * (headersize - len(header)))
self._pos = headersize
finally: finally:
self._release() self._release()
...@@ -593,7 +660,7 @@ class ClientCache: ...@@ -593,7 +660,7 @@ class ClientCache:
f = self._f[fileindex] f = self._f[fileindex]
seek = f.seek seek = f.seek
read = f.read read = f.read
pos = 4 pos = headersize
count = 0 count = 0
while 1: while 1:
...@@ -652,7 +719,6 @@ class ClientCache: ...@@ -652,7 +719,6 @@ class ClientCache:
del serial[oid] del serial[oid]
del index[oid] del index[oid]
pos = pos + tlen pos = pos + tlen
count += 1 count += 1
......
...@@ -22,7 +22,6 @@ ClientDisconnected -- exception raised by ClientStorage ...@@ -22,7 +22,6 @@ ClientDisconnected -- exception raised by ClientStorage
""" """
# XXX TO DO # XXX TO DO
# get rid of beginVerify, set up _tfile in verify_cache
# set self._storage = stub later, in endVerify # set self._storage = stub later, in endVerify
# if wait is given, wait until verify is complete # if wait is given, wait until verify is complete
...@@ -60,6 +59,9 @@ class UnrecognizedResult(ClientStorageError): ...@@ -60,6 +59,9 @@ class UnrecognizedResult(ClientStorageError):
class ClientDisconnected(ClientStorageError, Disconnected): class ClientDisconnected(ClientStorageError, Disconnected):
"""The database storage is disconnected from the storage.""" """The database storage is disconnected from the storage."""
def tid2time(tid):
return str(TimeStamp(tid))
def get_timestamp(prev_ts=None): def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance. """Internal helper to return a unique TimeStamp instance.
...@@ -208,6 +210,8 @@ class ClientStorage: ...@@ -208,6 +210,8 @@ class ClientStorage:
self._connection = None self._connection = None
# _server_addr is used by sortKey() # _server_addr is used by sortKey()
self._server_addr = None self._server_addr = None
self._tfile = None
self._pickler = None
self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client', self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client',
'supportsUndo':0, 'supportsVersions': 0, 'supportsUndo':0, 'supportsVersions': 0,
...@@ -337,12 +341,14 @@ class ClientStorage: ...@@ -337,12 +341,14 @@ class ClientStorage:
This is called by ConnectionManager after it has decided which This is called by ConnectionManager after it has decided which
connection should be used. connection should be used.
""" """
# XXX would like to report whether we get a read-only connection
if self._connection is not None: if self._connection is not None:
log2(INFO, "Reconnected to storage") reconnect = 1
else: else:
log2(INFO, "Connected to storage") reconnect = 0
self.set_server_addr(conn.get_addr()) self.set_server_addr(conn.get_addr())
stub = self.StorageServerStubClass(conn) stub = self.StorageServerStubClass(conn)
stub = self.StorageServerStubClass(conn)
self._oids = [] self._oids = []
self._info.update(stub.get_info()) self._info.update(stub.get_info())
self.verify_cache(stub) self.verify_cache(stub)
...@@ -353,6 +359,11 @@ class ClientStorage: ...@@ -353,6 +359,11 @@ class ClientStorage:
self._connection = conn self._connection = conn
self._server = stub self._server = stub
if reconnect:
log2(INFO, "Reconnected to storage: %s" % self._server_addr)
else:
log2(INFO, "Connected to storage: %s" % self._server_addr)
def set_server_addr(self, addr): def set_server_addr(self, addr):
# Normalize server address and convert to string # Normalize server address and convert to string
if isinstance(addr, types.StringType): if isinstance(addr, types.StringType):
...@@ -381,12 +392,42 @@ class ClientStorage: ...@@ -381,12 +392,42 @@ class ClientStorage:
return self._server_addr return self._server_addr
def verify_cache(self, server): def verify_cache(self, server):
"""Internal routine called to verify the cache.""" """Internal routine called to verify the cache.
# XXX beginZeoVerify ends up calling back to beginVerify() below.
# That whole exchange is rather unnecessary. The return value (indicating which path we took) is used by
server.beginZeoVerify() the test suite.
"""
last_inval_tid = self._cache.getLastTid()
if last_inval_tid is not None:
ltid = server.lastTransaction()
if ltid == last_inval_tid:
log2(INFO, "No verification necessary "
"(last_inval_tid up-to-date)")
self._cache.open()
return "no verification"
# log some hints about last transaction
log2(INFO, "last inval tid: %r %s"
% (last_inval_tid, tid2time(last_inval_tid)))
log2(INFO, "last transaction: %r %s" %
(ltid, ltid and tid2time(ltid)))
pair = server.getInvalidations(last_inval_tid)
if pair is not None:
log2(INFO, "Recovering %d invalidations" % len(pair[1]))
self._cache.open()
self.invalidateTransaction(*pair)
return "quick verification"
log2(INFO, "Verifying cache")
# setup tempfile to hold zeoVerify results
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
self._cache.verify(server.zeoVerify) self._cache.verify(server.zeoVerify)
server.endZeoVerify() server.endZeoVerify()
return "full verification"
### Is there a race condition between notifyConnected and ### Is there a race condition between notifyConnected and
### notifyDisconnected? In Particular, what if we get ### notifyDisconnected? In Particular, what if we get
...@@ -402,7 +443,8 @@ class ClientStorage: ...@@ -402,7 +443,8 @@ class ClientStorage:
This is called by ConnectionManager when the connection is This is called by ConnectionManager when the connection is
closed or when certain problems with the connection occur. closed or when certain problems with the connection occur.
""" """
log2(PROBLEM, "Disconnected from storage") log2(PROBLEM, "Disconnected from storage: %s"
% repr(self._server_addr))
self._connection = None self._connection = None
self._server = disconnected_stub self._server = disconnected_stub
...@@ -644,6 +686,7 @@ class ClientStorage: ...@@ -644,6 +686,7 @@ class ClientStorage:
self._serial = id self._serial = id
self._seriald.clear() self._seriald.clear()
del self._serials[:] del self._serials[:]
self._tbuf.clear()
def end_transaction(self): def end_transaction(self):
"""Internal helper to end a transaction.""" """Internal helper to end a transaction."""
...@@ -678,12 +721,13 @@ class ClientStorage: ...@@ -678,12 +721,13 @@ class ClientStorage:
if f is not None: if f is not None:
f() f()
self._server.tpc_finish(self._serial) tid = self._server.tpc_finish(self._serial)
r = self._check_serials() r = self._check_serials()
assert r is None or len(r) == 0, "unhandled serialnos: %s" % r assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
self._update_cache() self._update_cache()
self._cache.setLastTid(tid)
finally: finally:
self.end_transaction() self.end_transaction()
...@@ -779,12 +823,6 @@ class ClientStorage: ...@@ -779,12 +823,6 @@ class ClientStorage:
"""Server callback to update the info dictionary.""" """Server callback to update the info dictionary."""
self._info.update(dict) self._info.update(dict)
def beginVerify(self):
"""Server callback to signal start of cache validation."""
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
def invalidateVerify(self, args): def invalidateVerify(self, args):
"""Server callback to invalidate an (oid, version) pair. """Server callback to invalidate an (oid, version) pair.
...@@ -802,6 +840,7 @@ class ClientStorage: ...@@ -802,6 +840,7 @@ class ClientStorage:
if self._pickler is None: if self._pickler is None:
return return
self._pickler.dump((0,0)) self._pickler.dump((0,0))
self._pickler = None
self._tfile.seek(0) self._tfile.seek(0)
unpick = cPickle.Unpickler(self._tfile) unpick = cPickle.Unpickler(self._tfile)
f = self._tfile f = self._tfile
...@@ -815,29 +854,26 @@ class ClientStorage: ...@@ -815,29 +854,26 @@ class ClientStorage:
self._db.invalidate(oid, version=version) self._db.invalidate(oid, version=version)
f.close() f.close()
def invalidateTrans(self, args): def invalidateTransaction(self, tid, args):
"""Server callback to invalidate a list of (oid, version) pairs. """Invalidate objects modified by tid."""
self._cache.setLastTid(tid)
This is called as the result of a transaction. if self._pickler is not None:
""" self.log("Transactional invalidation during cache verification",
level=zLOG.BLATHER)
for t in args:
self.self._pickler.dump(t)
return
db = self._db
for oid, version in args: for oid, version in args:
self._cache.invalidate(oid, version=version) self._cache.invalidate(oid, version=version)
try: if db is not None:
self._db.invalidate(oid, version=version) db.invalidate(oid, version=version)
except AttributeError, msg:
log2(PROBLEM, # The following are for compatibility with protocol version 2.0.0
"Invalidate(%s, %s) failed for _db: %s" % (repr(oid),
repr(version), def invalidateTrans(self, args):
msg)) return self.invalidateTransaction(None, args)
# Unfortunately, the ZEO 2 wire protocol uses different names for
# several of the callback methods invoked by the StorageServer.
# We can't change the wire protocol at this point because that
# would require synchronized updates of clients and servers and we
# don't want that. So here we alias the old names to their new
# implementations.
begin = beginVerify
invalidate = invalidateVerify invalidate = invalidateVerify
end = endVerify end = endVerify
Invalidate = invalidateTrans Invalidate = invalidateTrans
......
...@@ -44,16 +44,16 @@ class ClientStorage: ...@@ -44,16 +44,16 @@ class ClientStorage:
self.rpc = rpc self.rpc = rpc
def beginVerify(self): def beginVerify(self):
self.rpc.callAsync('begin') self.rpc.callAsync('beginVerify')
def invalidateVerify(self, args): def invalidateVerify(self, args):
self.rpc.callAsync('invalidate', args) self.rpc.callAsync('invalidateVerify', args)
def endVerify(self): def endVerify(self):
self.rpc.callAsync('end') self.rpc.callAsync('endVerify')
def invalidateTrans(self, args): def invalidateTransaction(self, tid, args):
self.rpc.callAsync('Invalidate', args) self.rpc.callAsync('invalidateTransaction', tid, args)
def serialnos(self, arg): def serialnos(self, arg):
self.rpc.callAsync('serialnos', arg) self.rpc.callAsync('serialnos', arg)
......
...@@ -32,6 +32,9 @@ class StorageServer: ...@@ -32,6 +32,9 @@ class StorageServer:
zrpc.connection.Connection class. zrpc.connection.Connection class.
""" """
self.rpc = rpc self.rpc = rpc
if self.rpc.peer_protocol_version == 'Z200':
self.lastTransaction = lambda: None
self.getInvalidations = lambda tid: None
def extensionMethod(self, name): def extensionMethod(self, name):
return ExtensionMethodWrapper(self.rpc, name).call return ExtensionMethodWrapper(self.rpc, name).call
...@@ -51,8 +54,13 @@ class StorageServer: ...@@ -51,8 +54,13 @@ class StorageServer:
def get_info(self): def get_info(self):
return self.rpc.call('get_info') return self.rpc.call('get_info')
def beginZeoVerify(self): def lastTransaction(self):
self.rpc.callAsync('beginZeoVerify') # Not in protocol version 2.0.0; see __init__()
return self.rpc.call('lastTransaction')
def getInvalidations(self, tid):
# Not in protocol version 2.0.0; see __init__()
return self.rpc.call('getInvalidations', tid)
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
self.rpc.callAsync('zeoVerify', oid, s, sv) self.rpc.callAsync('zeoVerify', oid, s, sv)
......
...@@ -37,6 +37,7 @@ from ZODB.POSException import StorageError, StorageTransactionError ...@@ -37,6 +37,7 @@ from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError from ZODB.POSException import TransactionError, ReadOnlyError
from ZODB.referencesf import referencesf from ZODB.referencesf import referencesf
from ZODB.Transaction import Transaction from ZODB.Transaction import Transaction
from ZODB.utils import u64
_label = "ZSS" # Default label used for logging. _label = "ZSS" # Default label used for logging.
...@@ -68,8 +69,8 @@ class StorageServer: ...@@ -68,8 +69,8 @@ class StorageServer:
ZEOStorageClass = None # patched up later ZEOStorageClass = None # patched up later
ManagedServerConnectionClass = ManagedServerConnection ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages, read_only=0): def __init__(self, addr, storages, read_only=0,
invalidation_queue_size=100):
"""StorageServer constructor. """StorageServer constructor.
This is typically invoked from the start.py script. This is typically invoked from the start.py script.
...@@ -102,13 +103,17 @@ class StorageServer: ...@@ -102,13 +103,17 @@ class StorageServer:
self.storages = storages self.storages = storages
set_label() set_label()
msg = ", ".join( msg = ", ".join(
["%s:%s" % (name, storage.isReadOnly() and "RO" or "RW") ["%s:%s:%s" % (name, storage.isReadOnly() and "RO" or "RW",
storage.getName())
for name, storage in storages.items()]) for name, storage in storages.items()])
log("%s created %s with storages: %s" % log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg)) (self.__class__.__name__, read_only and "RO" or "RW", msg))
for s in storages.values(): for s in storages.values():
s._waiting = [] s._waiting = []
self.read_only = read_only self.read_only = read_only
# A list of at most invalidation_queue_size invalidations
self.invq = []
self.invq_bound = invalidation_queue_size
self.connections = {} self.connections = {}
self.dispatcher = self.DispatcherClass(addr, self.dispatcher = self.DispatcherClass(addr,
factory=self.new_connection, factory=self.new_connection,
...@@ -141,7 +146,7 @@ class StorageServer: ...@@ -141,7 +146,7 @@ class StorageServer:
l = self.connections[storage_id] = [] l = self.connections[storage_id] = []
l.append(conn) l.append(conn)
def invalidate(self, conn, storage_id, invalidated=(), info=None): def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients. """Internal: broadcast info and invalidations to clients.
This is called from several ZEOStorage methods. This is called from several ZEOStorage methods.
...@@ -149,7 +154,7 @@ class StorageServer: ...@@ -149,7 +154,7 @@ class StorageServer:
This can do three different things: This can do three different things:
- If the invalidated argument is non-empty, it broadcasts - If the invalidated argument is non-empty, it broadcasts
invalidateTrans() messages to all clients of the given invalidateTransaction() messages to all clients of the given
storage except the current client (the conn argument). storage except the current client (the conn argument).
- If the invalidated argument is empty and the info argument - If the invalidated argument is empty and the info argument
...@@ -158,17 +163,47 @@ class StorageServer: ...@@ -158,17 +163,47 @@ class StorageServer:
client. client.
- If both the invalidated argument and the info argument are - If both the invalidated argument and the info argument are
non-empty, it broadcasts invalidateTrans() messages to all non-empty, it broadcasts invalidateTransaction() messages to all
clients except the current, and sends an info() message to clients except the current, and sends an info() message to
the current client. the current client.
""" """
if invalidated:
if len(self.invq) >= self.invq_bound:
del self.invq[0]
self.invq.append((tid, invalidated))
for p in self.connections.get(storage_id, ()): for p in self.connections.get(storage_id, ()):
if invalidated and p is not conn: if invalidated and p is not conn:
p.client.invalidateTrans(invalidated) p.client.invalidateTransaction(tid, invalidated)
elif info is not None: elif info is not None:
p.client.info(info) p.client.info(info)
def get_invalidations(self, tid):
"""Return a tid and list of all objects invalidation since tid.
The tid is the most recent transaction id committed by the server.
Returns None if it is unable to provide a complete list
of invalidations for tid. In this case, client should
do full cache verification.
"""
if not self.invq:
log("invq empty")
return None, []
earliest_tid = self.invq[0][0]
if earliest_tid > tid:
log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid)))
return None, []
oids = {}
for tid, L in self.invq:
for key in L:
oids[key] = 1
latest_tid = self.invq[-1][0]
return latest_tid, oids.keys()
def close_server(self): def close_server(self):
"""Close the dispatcher so that there are no new connections. """Close the dispatcher so that there are no new connections.
...@@ -212,10 +247,18 @@ class ZEOStorage: ...@@ -212,10 +247,18 @@ class ZEOStorage:
self.storage_id = "uninitialized" self.storage_id = "uninitialized"
self.transaction = None self.transaction = None
self.read_only = read_only self.read_only = read_only
self.log_label = _label
def notifyConnected(self, conn): def notifyConnected(self, conn):
self.connection = conn # For restart_other() below self.connection = conn # For restart_other() below
self.client = self.ClientStorageStubClass(conn) self.client = self.ClientStorageStubClass(conn)
addr = conn.addr
if isinstance(addr, type("")):
label = addr
else:
host, port = addr
label = str(host) + ":" + str(port)
self.log_label = _label + "/" + label
def notifyDisconnected(self): def notifyDisconnected(self):
# When this storage closes, we must ensure that it aborts # When this storage closes, we must ensure that it aborts
...@@ -237,7 +280,7 @@ class ZEOStorage: ...@@ -237,7 +280,7 @@ class ZEOStorage:
return "<%s %X trans=%s s_trans=%s>" % (name, id(self), tid, stid) return "<%s %X trans=%s s_trans=%s>" % (name, id(self), tid, stid)
def log(self, msg, level=zLOG.INFO, error=None): def log(self, msg, level=zLOG.INFO, error=None):
zLOG.LOG("%s:%s" % (_label, self.storage_id), level, msg, error=error) zLOG.LOG(self.log_label, level, msg, error=error)
def setup_delegation(self): def setup_delegation(self):
"""Delegate several methods to the storage""" """Delegate several methods to the storage"""
...@@ -259,6 +302,7 @@ class ZEOStorage: ...@@ -259,6 +302,7 @@ class ZEOStorage:
for name in fn().keys(): for name in fn().keys():
if not hasattr(self,name): if not hasattr(self,name):
setattr(self, name, getattr(self.storage, name)) setattr(self, name, getattr(self.storage, name))
self.lastTransaction = self.storage.lastTransaction
def check_tid(self, tid, exc=None): def check_tid(self, tid, exc=None):
if self.read_only: if self.read_only:
...@@ -286,7 +330,7 @@ class ZEOStorage: ...@@ -286,7 +330,7 @@ class ZEOStorage:
This method must be the first one called by the client. This method must be the first one called by the client.
""" """
if self.storage is not None: if self.storage is not None:
log("duplicate register() call") self.log("duplicate register() call")
raise ValueError, "duplicate register() call" raise ValueError, "duplicate register() call"
storage = self.server.storages.get(storage_id) storage = self.server.storages.get(storage_id)
if storage is None: if storage is None:
...@@ -342,8 +386,13 @@ class ZEOStorage: ...@@ -342,8 +386,13 @@ class ZEOStorage:
raise raise
return p, s, v, pv, sv return p, s, v, pv, sv
def beginZeoVerify(self): def getInvalidations(self, tid):
self.client.beginVerify() invtid, invlist = self.server.get_invalidations(tid)
if invtid is None:
return None
self.log("Return %d invalidations up to tid %s"
% (len(invlist), u64(invtid)))
return invtid, invlist
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
try: try:
...@@ -394,7 +443,8 @@ class ZEOStorage: ...@@ -394,7 +443,8 @@ class ZEOStorage:
self.storage.pack(time, referencesf) self.storage.pack(time, referencesf)
self.log("pack(time=%s) complete" % repr(time)) self.log("pack(time=%s) complete" % repr(time))
# Broadcast new size statistics # Broadcast new size statistics
self.server.invalidate(0, self.storage_id, (), self.get_size_info()) self.server.invalidate(0, self.storage_id, None,
(), self.get_size_info())
def new_oids(self, n=100): def new_oids(self, n=100):
"""Return a sequence of n new oids, where n defaults to 100""" """Return a sequence of n new oids, where n defaults to 100"""
...@@ -409,7 +459,7 @@ class ZEOStorage: ...@@ -409,7 +459,7 @@ class ZEOStorage:
raise ReadOnlyError() raise ReadOnlyError()
oids = self.storage.undo(transaction_id) oids = self.storage.undo(transaction_id)
if oids: if oids:
self.server.invalidate(self, self.storage_id, self.server.invalidate(self, self.storage_id, None,
map(lambda oid: (oid, ''), oids)) map(lambda oid: (oid, ''), oids))
return oids return oids
return () return ()
...@@ -450,12 +500,15 @@ class ZEOStorage: ...@@ -450,12 +500,15 @@ class ZEOStorage:
if not self.check_tid(id): if not self.check_tid(id):
return return
invalidated = self.strategy.tpc_finish() invalidated = self.strategy.tpc_finish()
tid = self.storage.lastTransaction()
if invalidated: if invalidated:
self.server.invalidate(self, self.storage_id, self.server.invalidate(self, self.storage_id, tid,
invalidated, self.get_size_info()) invalidated, self.get_size_info())
self.transaction = None self.transaction = None
self.strategy = None self.strategy = None
# Return the tid, for cache invalidation optimization
self.handle_waiting() self.handle_waiting()
return tid
def tpc_abort(self, id): def tpc_abort(self, id):
if not self.check_tid(id): if not self.check_tid(id):
...@@ -546,7 +599,8 @@ class ZEOStorage: ...@@ -546,7 +599,8 @@ class ZEOStorage:
old_strategy = self.strategy old_strategy = self.strategy
assert isinstance(old_strategy, DelayedCommitStrategy) assert isinstance(old_strategy, DelayedCommitStrategy)
self.strategy = ImmediateCommitStrategy(self.storage, self.strategy = ImmediateCommitStrategy(self.storage,
self.client) self.client,
self.log)
resp = old_strategy.restart(self.strategy) resp = old_strategy.restart(self.strategy)
if delay is not None: if delay is not None:
delay.reply(resp) delay.reply(resp)
...@@ -602,11 +656,12 @@ class ICommitStrategy: ...@@ -602,11 +656,12 @@ class ICommitStrategy:
class ImmediateCommitStrategy: class ImmediateCommitStrategy:
"""The storage is available so do a normal commit.""" """The storage is available so do a normal commit."""
def __init__(self, storage, client): def __init__(self, storage, client, logmethod):
self.storage = storage self.storage = storage
self.client = client self.client = client
self.invalidated = [] self.invalidated = []
self.serials = [] self.serials = []
self.log = logmethod
def tpc_begin(self, txn, tid, status): def tpc_begin(self, txn, tid, status):
self.txn = txn self.txn = txn
...@@ -628,12 +683,14 @@ class ImmediateCommitStrategy: ...@@ -628,12 +683,14 @@ class ImmediateCommitStrategy:
try: try:
newserial = self.storage.store(oid, serial, data, version, newserial = self.storage.store(oid, serial, data, version,
self.txn) self.txn)
except (SystemExit, KeyboardInterrupt):
raise
except Exception, err: except Exception, err:
if not isinstance(err, TransactionError): if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client # Unexpected errors are logged and passed to the client
exc_info = sys.exc_info() exc_info = sys.exc_info()
log("store error: %s, %s" % exc_info[:2], self.log("store error: %s, %s" % exc_info[:2],
zLOG.ERROR, error=exc_info) zLOG.ERROR, error=exc_info)
del exc_info del exc_info
# Try to pickle the exception. If it can't be pickled, # Try to pickle the exception. If it can't be pickled,
# the RPC response would fail, so use something else. # the RPC response would fail, so use something else.
...@@ -643,7 +700,7 @@ class ImmediateCommitStrategy: ...@@ -643,7 +700,7 @@ class ImmediateCommitStrategy:
pickler.dump(err, 1) pickler.dump(err, 1)
except: except:
msg = "Couldn't pickle storage exception: %s" % repr(err) msg = "Couldn't pickle storage exception: %s" % repr(err)
log(msg, zLOG.ERROR) self.log(msg, zLOG.ERROR)
err = StorageServerError(msg) err = StorageServerError(msg)
# The exception is reported back as newserial for this oid # The exception is reported back as newserial for this oid
newserial = err newserial = err
...@@ -776,6 +833,8 @@ class SlowMethodThread(threading.Thread): ...@@ -776,6 +833,8 @@ class SlowMethodThread(threading.Thread):
def run(self): def run(self):
try: try:
result = self._method(*self._args) result = self._method(*self._args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception: except Exception:
self.delay.error(sys.exc_info()) self.delay.error(sys.exc_info())
else: else:
......
...@@ -117,18 +117,12 @@ def main(): ...@@ -117,18 +117,12 @@ def main():
# Must be a misaligned record caused by a crash # Must be a misaligned record caused by a crash
##print "Skipping 8 bytes at offset", offset-8 ##print "Skipping 8 bytes at offset", offset-8
continue continue
oid = f_read(8) r = f_read(16)
if len(oid) < 8: if len(r) < 16:
break break
if heuristic and oid[:4] != '\0\0\0\0': offset += 16
f.seek(-8, 1)
continue
offset += 8
serial = f_read(8)
if len(serial) < 8:
break
offset += 8
records += 1 records += 1
oid, serial = struct_unpack(">8s8s", r)
# Decode the code # Decode the code
dlen, version, code, current = (code & 0x7fffff00, dlen, version, code, current = (code & 0x7fffff00,
code & 0x80, code & 0x80,
......
...@@ -153,24 +153,14 @@ def main(): ...@@ -153,24 +153,14 @@ def main():
if ts == 0: if ts == 0:
# Must be a misaligned record caused by a crash # Must be a misaligned record caused by a crash
if not quiet: if not quiet:
print "Skipping 8 bytes at offset", offset-8, print "Skipping 8 bytes at offset", offset-8
print repr(r)
continue continue
oid = f_read(8) r = f_read(16)
if len(oid) < 8: if len(r) < 16:
break break
if heuristic and oid[:4] != '\0\0\0\0': offset += 16
# Heuristic for severe data corruption
print "Seeking back over bad oid at offset", offset,
print repr(r)
f.seek(-8, 1)
continue
offset += 8
serial = f_read(8)
if len(serial) < 8:
break
offset += 8
records += 1 records += 1
oid, serial = struct_unpack(">8s8s", r)
if t0 is None: if t0 is None:
t0 = ts t0 = ts
thisinterval = t0 / interval thisinterval = t0 / interval
......
...@@ -20,7 +20,9 @@ import select ...@@ -20,7 +20,9 @@ import select
import socket import socket
import asyncore import asyncore
import tempfile import tempfile
import thread # XXX do we really need to catch thread.error
import threading import threading
import time
import zLOG import zLOG
...@@ -36,9 +38,18 @@ from ZODB.tests.MinPO import MinPO ...@@ -36,9 +38,18 @@ from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
from ZODB.tests.StorageTestBase import handle_all_serials, ZERO from ZODB.tests.StorageTestBase import handle_all_serials, ZERO
class TestClientStorage(ClientStorage):
def verify_cache(self, stub):
self.end_verify = threading.Event()
self.verify_result = ClientStorage.verify_cache(self, stub)
def endVerify(self):
ClientStorage.endVerify(self)
self.end_verify.set()
class DummyDB: class DummyDB:
def invalidate(self, *args, **kws): def invalidate(self, *args, **kwargs):
pass pass
...@@ -48,6 +59,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -48,6 +59,7 @@ class CommonSetupTearDown(StorageTestBase):
__super_setUp = StorageTestBase.setUp __super_setUp = StorageTestBase.setUp
__super_tearDown = StorageTestBase.tearDown __super_tearDown = StorageTestBase.tearDown
keep = 0 keep = 0
invq = None
def setUp(self): def setUp(self):
"""Test setup for connection tests. """Test setup for connection tests.
...@@ -99,17 +111,15 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -99,17 +111,15 @@ class CommonSetupTearDown(StorageTestBase):
raise NotImplementedError raise NotImplementedError
def openClientStorage(self, cache='', cache_size=200000, wait=1, def openClientStorage(self, cache='', cache_size=200000, wait=1,
read_only=0, read_only_fallback=0, read_only=0, read_only_fallback=0):
addr=None): base = TestClientStorage(self.addr,
if addr is None: client=cache,
addr = self.addr cache_size=cache_size,
storage = ClientStorage(addr, wait=wait,
client=cache, min_disconnect_poll=0.1,
cache_size=cache_size, read_only=read_only,
wait=wait, read_only_fallback=read_only_fallback)
min_disconnect_poll=0.1, storage = base
read_only=read_only,
read_only_fallback=read_only_fallback)
storage.registerDB(DummyDB(), None) storage.registerDB(DummyDB(), None)
return storage return storage
...@@ -121,7 +131,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -121,7 +131,7 @@ class CommonSetupTearDown(StorageTestBase):
path = "%s.%d" % (self.file, index) path = "%s.%d" % (self.file, index)
conf = self.getConfig(path, create, read_only) conf = self.getConfig(path, create, read_only)
zeoport, adminaddr, pid = forker.start_zeo_server( zeoport, adminaddr, pid = forker.start_zeo_server(
conf, addr, ro_svr, self.keep) conf, addr, ro_svr, self.keep, self.invq)
self._pids.append(pid) self._pids.append(pid)
self._servers.append(adminaddr) self._servers.append(adminaddr)
...@@ -420,9 +430,9 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -420,9 +430,9 @@ class ConnectionTests(CommonSetupTearDown):
for t in threads: for t in threads:
t.closeclients() t.closeclients()
class ReconnectionTests(CommonSetupTearDown): class ReconnectionTests(CommonSetupTearDown):
keep = 1 keep = 1
invq = 2
def checkReadOnlyStorage(self): def checkReadOnlyStorage(self):
# Open a read-only client to a read-only *storage*; stores fail # Open a read-only client to a read-only *storage*; stores fail
...@@ -557,6 +567,113 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -557,6 +567,113 @@ class ReconnectionTests(CommonSetupTearDown):
else: else:
self.fail("Couldn't store after starting a read-write server") self.fail("Couldn't store after starting a read-write server")
def checkNoVerificationOnServerRestart(self):
self._storage = self.openClientStorage()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
self._dostore()
self.shutdownServer()
self.pollDown()
self._storage.verify_result = None
self.startServer(create=0)
self.pollUp()
# There were no transactions committed, so no verification
# should be needed.
self.assertEqual(self._storage.verify_result, "no verification")
def checkNoVerificationOnServerRestartWith2Clients(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
self._dostore(oid, revid)
perstorage.load(oid, '')
self.shutdownServer()
self.pollDown()
self._storage.verify_result = None
perstorage.verify_result = None
self.startServer(create=0)
self.pollUp()
# There were no transactions committed, so no verification
# should be needed.
self.assertEqual(self._storage.verify_result, "no verification")
perstorage.close()
self.assertEqual(perstorage.verify_result, "no verification")
def checkQuickVerificationWith2Clients(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
revid = self._dostore(oid, revid)
perstorage.load(oid, '')
perstorage.close()
revid = self._dostore(oid, revid)
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "quick verification")
self.assertEqual(perstorage.load(oid, ''),
self._storage.load(oid, ''))
def checkVerificationWith2ClientsInvqOverflow(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
revid = self._dostore(oid, revid)
perstorage.load(oid, '')
perstorage.close()
# the test code sets invq bound to 2
for i in range(5):
revid = self._dostore(oid, revid)
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
t = time.time() + 30
while not perstorage.end_verify.isSet():
perstorage.sync()
if time.time() > t:
self.fail("timed out waiting for endVerify")
self.assertEqual(self._storage.load(oid, '')[1], revid)
self.assertEqual(perstorage.load(oid, ''),
self._storage.load(oid, ''))
perstorage.close()
class MSTThread(threading.Thread): class MSTThread(threading.Thread):
......
...@@ -51,7 +51,7 @@ def get_port(): ...@@ -51,7 +51,7 @@ def get_port():
raise RuntimeError, "Can't find port" raise RuntimeError, "Can't find port"
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0): def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None):
"""Start a ZEO server in a separate process. """Start a ZEO server in a separate process.
Returns the ZEO port, the test server port, and the pid. Returns the ZEO port, the test server port, and the pid.
...@@ -77,6 +77,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0): ...@@ -77,6 +77,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0):
args.append('-r') args.append('-r')
if keep: if keep:
args.append('-k') args.append('-k')
if invq:
args += ['-Q', str(invq)]
args.append(str(port)) args.append(str(port))
d = os.environ.copy() d = os.environ.copy()
d['PYTHONPATH'] = os.pathsep.join(sys.path) d['PYTHONPATH'] = os.pathsep.join(sys.path)
......
...@@ -261,6 +261,19 @@ class ClientCacheTests(unittest.TestCase): ...@@ -261,6 +261,19 @@ class ClientCacheTests(unittest.TestCase):
self.assert_(None is not cache._index.get(oid1) < 0) self.assert_(None is not cache._index.get(oid1) < 0)
self.assert_(None is not cache._index.get(oid2) < 0) self.assert_(None is not cache._index.get(oid2) < 0)
def testLastTid(self):
cache = self.cache
self.failUnless(cache.getLastTid() is None)
ltid = 'pqrstuvw'
cache.setLastTid(ltid)
self.assertEqual(cache.getLastTid(), ltid)
cache.checkSize(10*self.cachesize) # Force a file flip
self.assertEqual(cache.getLastTid(), ltid)
cache.setLastTid(None)
self.failUnless(cache.getLastTid() is None)
cache.checkSize(10*self.cachesize) # Force a file flip
self.failUnless(cache.getLastTid() is None)
class PersistentClientCacheTests(unittest.TestCase): class PersistentClientCacheTests(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -348,6 +361,26 @@ class PersistentClientCacheTests(unittest.TestCase): ...@@ -348,6 +361,26 @@ class PersistentClientCacheTests(unittest.TestCase):
self.fail("invalidated data resurrected, size %d, was %d" % self.fail("invalidated data resurrected, size %d, was %d" %
(len(loaded[0]), len(data))) (len(loaded[0]), len(data)))
def testPersistentLastTid(self):
cache = self.cache
self.failUnless(cache.getLastTid() is None)
ltid = 'pqrstuvw'
cache.setLastTid(ltid)
self.assertEqual(cache.getLastTid(), ltid)
oid = 'abcdefgh'
data = '1234'
serial = 'ABCDEFGH'
cache.store(oid, data, serial, '', '', '')
self.assertEqual(cache.getLastTid(), ltid)
cache.checkSize(10*self.cachesize) # Force a file flip
self.assertEqual(cache.getLastTid(), ltid)
cache = self.reopenCache()
self.assertEqual(cache.getLastTid(), ltid)
cache.setLastTid(None)
self.failUnless(cache.getLastTid() is None)
cache.checkSize(10*self.cachesize) # Force a file flip
self.failUnless(cache.getLastTid() is None)
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ClientCacheTests)) suite.addTest(unittest.makeSuite(ClientCacheTests))
......
...@@ -116,8 +116,9 @@ def main(): ...@@ -116,8 +116,9 @@ def main():
ro_svr = 0 ro_svr = 0
keep = 0 keep = 0
configfile = None configfile = None
invalidation_queue_size = 100
# Parse the arguments and let getopt.error percolate # Parse the arguments and let getopt.error percolate
opts, args = getopt.getopt(sys.argv[1:], 'rkC:') opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:')
for opt, arg in opts: for opt, arg in opts:
if opt == '-r': if opt == '-r':
ro_svr = 1 ro_svr = 1
...@@ -125,6 +126,8 @@ def main(): ...@@ -125,6 +126,8 @@ def main():
keep = 1 keep = 1
elif opt == '-C': elif opt == '-C':
configfile = arg configfile = arg
elif opt == '-Q':
invalidation_queue_size = int(arg)
# Open the config file and let ZConfig parse the data there. Then remove # Open the config file and let ZConfig parse the data there. Then remove
# the config file, otherwise we'll leave turds. # the config file, otherwise we'll leave turds.
fp = open(configfile, 'r') fp = open(configfile, 'r')
...@@ -145,7 +148,9 @@ def main(): ...@@ -145,7 +148,9 @@ def main():
sys.exit(2) sys.exit(2)
addr = ('', zeo_port) addr = ('', zeo_port)
log(label, 'creating the storage server') log(label, 'creating the storage server')
serv = ZEO.StorageServer.StorageServer(addr, {'1': storage}, ro_svr) serv = ZEO.StorageServer.StorageServer(
addr, {'1': storage}, ro_svr,
invalidation_queue_size=invalidation_queue_size)
log(label, 'entering ThreadedAsync loop') log(label, 'entering ThreadedAsync loop')
ThreadedAsync.LoopCallback.loop() ThreadedAsync.LoopCallback.loop()
......
...@@ -119,7 +119,7 @@ class ConnectionManager: ...@@ -119,7 +119,7 @@ class ConnectionManager:
# XXX need each connection started with async==0 to have a # XXX need each connection started with async==0 to have a
# callback # callback
log("CM.set_async(%s)" % repr(map)) log("CM.set_async(%s)" % repr(map), level=zLOG.DEBUG)
if not self.closed and self.trigger is None: if not self.closed and self.trigger is None:
log("CM.set_async(): first call") log("CM.set_async(): first call")
self.trigger = trigger() self.trigger = trigger()
...@@ -294,6 +294,9 @@ class ConnectThread(threading.Thread): ...@@ -294,6 +294,9 @@ class ConnectThread(threading.Thread):
if success > 0: if success > 0:
break break
time.sleep(delay) time.sleep(delay)
if self.mgr.is_connected():
log("CT: still trying to replace fallback connection",
level=zLOG.INFO)
delay = min(delay*2, self.tmax) delay = min(delay*2, self.tmax)
log("CT: exiting thread: %s" % self.getName()) log("CT: exiting thread: %s" % self.getName())
......
...@@ -21,7 +21,7 @@ import types ...@@ -21,7 +21,7 @@ import types
import ThreadedAsync import ThreadedAsync
from ZEO.zrpc import smac from ZEO.zrpc import smac
from ZEO.zrpc.error import ZRPCError, DisconnectedError from ZEO.zrpc.error import ZRPCError, DisconnectedError
from ZEO.zrpc.log import log, short_repr from ZEO.zrpc.log import short_repr, log
from ZEO.zrpc.marshal import Marshaller from ZEO.zrpc.marshal import Marshaller
from ZEO.zrpc.trigger import trigger from ZEO.zrpc.trigger import trigger
import zLOG import zLOG
...@@ -115,13 +115,32 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -115,13 +115,32 @@ class Connection(smac.SizedMessageAsyncConnection):
__super_init = smac.SizedMessageAsyncConnection.__init__ __super_init = smac.SizedMessageAsyncConnection.__init__
__super_close = smac.SizedMessageAsyncConnection.close __super_close = smac.SizedMessageAsyncConnection.close
protocol_version = "Z200" # Protocol variables:
#
# oldest_protocol_version -- the oldest protocol version we support
# protocol_version -- the newest protocol version we support; preferred
oldest_protocol_version = "Z200"
protocol_version = "Z201"
# Protocol history:
#
# Z200 -- original ZEO 2.0 protocol
#
# Z201 -- added invalidateTransaction() to client;
# renamed several client methods;
# added lastTransaction() to server
def __init__(self, sock, addr, obj=None): def __init__(self, sock, addr, obj=None):
self.obj = None self.obj = None
self.marshal = Marshaller() self.marshal = Marshaller()
self.closed = 0 self.closed = 0
self.msgid = 0 self.msgid = 0
self.peer_protocol_version = None # Set in recv_handshake()
if isinstance(addr, types.TupleType):
self.log_label = "zrpc-conn:%s:%d" % addr
else:
self.log_label = "zrpc-conn:%s" % addr
self.__super_init(sock, addr) self.__super_init(sock, addr)
# A Connection either uses asyncore directly or relies on an # A Connection either uses asyncore directly or relies on an
# asyncore mainloop running in a separate thread. If # asyncore mainloop running in a separate thread. If
...@@ -147,6 +166,9 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -147,6 +166,9 @@ class Connection(smac.SizedMessageAsyncConnection):
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__ __str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=zLOG.BLATHER, error=None):
zLOG.LOG(self.log_label, level, message, error=error)
def close(self): def close(self):
if self.closed: if self.closed:
return return
...@@ -156,7 +178,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -156,7 +178,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.__super_close() self.__super_close()
def close_trigger(self): def close_trigger(self):
# overridden by ManagedConnection # Overridden by ManagedConnection
if self.trigger is not None: if self.trigger is not None:
self.trigger.close() self.trigger.close()
...@@ -164,7 +186,9 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -164,7 +186,9 @@ class Connection(smac.SizedMessageAsyncConnection):
"""Register obj as the true object to invoke methods on""" """Register obj as the true object to invoke methods on"""
self.obj = obj self.obj = obj
def handshake(self): def handshake(self, proto=None):
# Overridden by ManagedConnection
# When a connection is created the first message sent is a # When a connection is created the first message sent is a
# 4-byte protocol version. This mechanism should allow the # 4-byte protocol version. This mechanism should allow the
# protocol to evolve over time, and let servers handle clients # protocol to evolve over time, and let servers handle clients
...@@ -174,17 +198,18 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -174,17 +198,18 @@ class Connection(smac.SizedMessageAsyncConnection):
# first message received. # first message received.
# The client sends the protocol version it is using. # The client sends the protocol version it is using.
self._message_input = self.message_input
self.message_input = self.recv_handshake self.message_input = self.recv_handshake
self.message_output(self.protocol_version) self.message_output(proto or self.protocol_version)
def recv_handshake(self, message): def recv_handshake(self, proto):
if message == self.protocol_version: # Extended by ManagedConnection
self.message_input = self._message_input del self.message_input
self.peer_protocol_version = proto
if self.oldest_protocol_version <= proto <= self.protocol_version:
self.log("received handshake %r" % proto, level=zLOG.INFO)
else: else:
log("recv_handshake: bad handshake %s" % short_repr(message), self.log("bad handshake %s" % short_repr(proto), level=zLOG.ERROR)
level=zLOG.ERROR) raise ZRPCError("bad handshake %r" % proto)
# otherwise do something else...
def message_input(self, message): def message_input(self, message):
"""Decoding an incoming message and dispatch it""" """Decoding an incoming message and dispatch it"""
...@@ -195,9 +220,9 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -195,9 +220,9 @@ class Connection(smac.SizedMessageAsyncConnection):
msgid, flags, name, args = self.marshal.decode(message) msgid, flags, name, args = self.marshal.decode(message)
if __debug__: if __debug__:
log("recv msg: %s, %s, %s, %s" % (msgid, flags, name, self.log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
short_repr(args)), short_repr(args)),
level=zLOG.TRACE) level=zLOG.TRACE)
if name == REPLY: if name == REPLY:
self.handle_reply(msgid, flags, args) self.handle_reply(msgid, flags, args)
else: else:
...@@ -205,8 +230,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -205,8 +230,8 @@ class Connection(smac.SizedMessageAsyncConnection):
def handle_reply(self, msgid, flags, args): def handle_reply(self, msgid, flags, args):
if __debug__: if __debug__:
log("recv reply: %s, %s, %s" % (msgid, flags, short_repr(args)), self.log("recv reply: %s, %s, %s"
level=zLOG.DEBUG) % (msgid, flags, short_repr(args)), level=zLOG.DEBUG)
self.replies_cond.acquire() self.replies_cond.acquire()
try: try:
self.replies[msgid] = flags, args self.replies[msgid] = flags, args
...@@ -219,7 +244,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -219,7 +244,8 @@ class Connection(smac.SizedMessageAsyncConnection):
msg = "Invalid method name: %s on %s" % (name, repr(self.obj)) msg = "Invalid method name: %s on %s" % (name, repr(self.obj))
raise ZRPCError(msg) raise ZRPCError(msg)
if __debug__: if __debug__:
log("calling %s%s" % (name, short_repr(args)), level=zLOG.BLATHER) self.log("calling %s%s" % (name, short_repr(args)),
level=zLOG.BLATHER)
meth = getattr(self.obj, name) meth = getattr(self.obj, name)
try: try:
...@@ -228,8 +254,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -228,8 +254,8 @@ class Connection(smac.SizedMessageAsyncConnection):
raise raise
except Exception, msg: except Exception, msg:
error = sys.exc_info() error = sys.exc_info()
log("%s() raised exception: %s" % (name, msg), zLOG.INFO, self.log("%s() raised exception: %s" % (name, msg), zLOG.INFO,
error=error) error=error)
error = error[:2] error = error[:2]
return self.return_error(msgid, flags, *error) return self.return_error(msgid, flags, *error)
...@@ -239,7 +265,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -239,7 +265,7 @@ class Connection(smac.SizedMessageAsyncConnection):
(name, short_repr(ret))) (name, short_repr(ret)))
else: else:
if __debug__: if __debug__:
log("%s returns %s" % (name, short_repr(ret)), zLOG.DEBUG) self.log("%s returns %s" % (name, short_repr(ret)), zLOG.DEBUG)
if isinstance(ret, Delay): if isinstance(ret, Delay):
ret.set_sender(msgid, self.send_reply, self.return_error) ret.set_sender(msgid, self.send_reply, self.return_error)
else: else:
...@@ -252,7 +278,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -252,7 +278,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.close() self.close()
def log_error(self, msg="No error message supplied"): def log_error(self, msg="No error message supplied"):
log(msg, zLOG.ERROR, error=sys.exc_info()) self.log(msg, zLOG.ERROR, error=sys.exc_info())
def check_method(self, name): def check_method(self, name):
# XXX Is this sufficient "security" for now? # XXX Is this sufficient "security" for now?
...@@ -304,8 +330,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -304,8 +330,8 @@ class Connection(smac.SizedMessageAsyncConnection):
finally: finally:
self.msgid_lock.release() self.msgid_lock.release()
if __debug__: if __debug__:
log("send msg: %d, %d, %s, ..." % (msgid, flags, method), self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
zLOG.TRACE) zLOG.TRACE)
buf = self.marshal.encode(msgid, flags, method, args) buf = self.marshal.encode(msgid, flags, method, args)
self.message_output(buf) self.message_output(buf)
return msgid return msgid
...@@ -342,7 +368,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -342,7 +368,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.thr_async = 1 self.thr_async = 1
def is_async(self): def is_async(self):
# overridden for ManagedConnection # Overridden by ManagedConnection
if self.thr_async: if self.thr_async:
return 1 return 1
else: else:
...@@ -360,8 +386,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -360,8 +386,8 @@ class Connection(smac.SizedMessageAsyncConnection):
def wait(self, msgid): def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply.""" """Invoke asyncore mainloop and wait for reply."""
if __debug__: if __debug__:
log("wait(%d), async=%d" % (msgid, self.is_async()), self.log("wait(%d), async=%d" % (msgid, self.is_async()),
level=zLOG.TRACE) level=zLOG.TRACE)
if self.is_async(): if self.is_async():
self._pull_trigger() self._pull_trigger()
...@@ -378,8 +404,8 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -378,8 +404,8 @@ class Connection(smac.SizedMessageAsyncConnection):
if reply is not None: if reply is not None:
del self.replies[msgid] del self.replies[msgid]
if __debug__: if __debug__:
log("wait(%d): reply=%s" % (msgid, short_repr(reply)), self.log("wait(%d): reply=%s" %
level=zLOG.DEBUG) (msgid, short_repr(reply)), level=zLOG.DEBUG)
return reply return reply
if self.is_async(): if self.is_async():
self.replies_cond.wait(10.0) self.replies_cond.wait(10.0)
...@@ -388,14 +414,14 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -388,14 +414,14 @@ class Connection(smac.SizedMessageAsyncConnection):
try: try:
try: try:
if __debug__: if __debug__:
log("wait(%d): asyncore.poll(%s)" % self.log("wait(%d): asyncore.poll(%s)" %
(msgid, delay), level=zLOG.TRACE) (msgid, delay), level=zLOG.TRACE)
asyncore.poll(delay, self._map) asyncore.poll(delay, self._map)
if delay < 1.0: if delay < 1.0:
delay += delay delay += delay
except select.error, err: except select.error, err:
log("Closing. asyncore.poll() raised %s." % err, self.log("Closing. asyncore.poll() raised %s."
level=zLOG.BLATHER) % err, level=zLOG.BLATHER)
self.close() self.close()
finally: finally:
self.replies_cond.acquire() self.replies_cond.acquire()
...@@ -405,7 +431,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -405,7 +431,7 @@ class Connection(smac.SizedMessageAsyncConnection):
def poll(self): def poll(self):
"""Invoke asyncore mainloop to get pending message out.""" """Invoke asyncore mainloop to get pending message out."""
if __debug__: if __debug__:
log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE) self.log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE)
if self.is_async(): if self.is_async():
self._pull_trigger() self._pull_trigger()
else: else:
...@@ -414,7 +440,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -414,7 +440,7 @@ class Connection(smac.SizedMessageAsyncConnection):
def pending(self): def pending(self):
"""Invoke mainloop until any pending messages are handled.""" """Invoke mainloop until any pending messages are handled."""
if __debug__: if __debug__:
log("pending(), async=%d" % self.is_async(), level=zLOG.TRACE) self.log("pending(), async=%d" % self.is_async(), level=zLOG.TRACE)
if self.is_async(): if self.is_async():
return return
# Inline the asyncore poll() function to know whether any input # Inline the asyncore poll() function to know whether any input
...@@ -465,6 +491,64 @@ class ManagedConnection(Connection): ...@@ -465,6 +491,64 @@ class ManagedConnection(Connection):
self.__super_init(sock, addr, obj) self.__super_init(sock, addr, obj)
self.check_mgr_async() self.check_mgr_async()
# PROTOCOL NEGOTIATION:
#
# The code implementing protocol version 2.0.0 (which is deployed
# in the field and cannot be changed) *only* talks to peers that
# send a handshake indicating protocol version 2.0.0. In that
# version, both the client and the server immediately send out
# their protocol handshake when a connection is established,
# without waiting for their peer, and disconnect when a different
# handshake is receive.
#
# The new protocol uses this to enable new clients to talk to
# 2.0.0 servers: in the new protocol, the client waits until it
# receives the server's protocol handshake before sending its own
# handshake. The client sends the lower of its own protocol
# version and the server protocol version, allowing it to talk to
# servers using later protocol versions (2.0.2 and higher) as
# well: the effective protocol used will be the lower of the
# client and server protocol.
#
# The ZEO modules ClientStorage and ServerStub have backwards
# compatibility code for dealing with the previous version of the
# protocol. The client accept the old version of some messages,
# and will not send new messages when talking to an old server.
#
# As long as the client hasn't sent its handshake, it can't send
# anything else; output messages are queued during this time.
# (Output can happen because the connection testing machinery can
# start sending requests before the handshake is received.)
#
# UPGRADING FROM ZEO 2.0.0 TO NEWER VERSIONS:
#
# Because a new client can talk to an old server, but not vice
# versa, all clients should be upgraded before upgrading any
# servers. Protocol upgrades beyond 2.0.1 will not have this
# restriction, because clients using protocol 2.0.1 or later can
# talk to both older and newer servers.
#
# No compatibility with protocol version 1 is provided.
def handshake(self):
self.message_input = self.recv_handshake
self.message_output = self.queue_output
self.output_queue = []
# The handshake is sent by recv_handshake() below
def queue_output(self, message):
self.output_queue.append(message)
def recv_handshake(self, proto):
del self.message_output
proto = min(proto, self.protocol_version)
Connection.recv_handshake(self, proto) # Raise error if wrong proto
self.message_output(proto)
queue = self.output_queue
del self.output_queue
for message in queue:
self.message_output(message)
# Defer the ThreadedAsync work to the manager. # Defer the ThreadedAsync work to the manager.
def close_trigger(self): def close_trigger(self):
......
...@@ -16,6 +16,7 @@ import asyncore ...@@ -16,6 +16,7 @@ import asyncore
import os import os
import socket import socket
import thread import thread
import errno
if os.name == 'posix': if os.name == 'posix':
......
...@@ -12,11 +12,9 @@ ...@@ -12,11 +12,9 @@
# #
############################################################################## ##############################################################################
"""Handy standard storage machinery """Handy standard storage machinery
"""
# Do this portably in the face of checking out with -kv
import string
__version__ = string.split('$Revision: 1.30 $')[-2:][0]
$Id: BaseStorage.py,v 1.31 2003/01/03 22:07:43 jeremy Exp $
"""
import cPickle import cPickle
import ThreadLock, bpthread import ThreadLock, bpthread
import time, UndoLogCompatible import time, UndoLogCompatible
...@@ -277,8 +275,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible): ...@@ -277,8 +275,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible):
restoring = 1 restoring = 1
else: else:
restoring = 0 restoring = 0
for transaction in other.iterator(): fiter = other.iterator()
for transaction in fiter:
tid=transaction.tid tid=transaction.tid
if _ts is None: if _ts is None:
_ts=TimeStamp(tid) _ts=TimeStamp(tid)
...@@ -313,6 +311,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible): ...@@ -313,6 +311,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible):
self.tpc_vote(transaction) self.tpc_vote(transaction)
self.tpc_finish(transaction) self.tpc_finish(transaction)
fiter.close()
class TransactionRecord: class TransactionRecord:
"""Abstract base class for iterator protocol""" """Abstract base class for iterator protocol"""
......
...@@ -115,7 +115,7 @@ ...@@ -115,7 +115,7 @@
# may have a back pointer to a version record or to a non-version # may have a back pointer to a version record or to a non-version
# record. # record.
# #
__version__='$Revision: 1.123 $'[11:-2] __version__='$Revision: 1.124 $'[11:-2]
import base64 import base64
from cPickle import Pickler, Unpickler, loads from cPickle import Pickler, Unpickler, loads
...@@ -124,7 +124,7 @@ import os ...@@ -124,7 +124,7 @@ import os
import struct import struct
import sys import sys
import time import time
from types import StringType from types import StringType, DictType
from struct import pack, unpack from struct import pack, unpack
try: try:
...@@ -137,7 +137,12 @@ from ZODB.POSException import UndoError, POSKeyError, MultipleUndoErrors ...@@ -137,7 +137,12 @@ from ZODB.POSException import UndoError, POSKeyError, MultipleUndoErrors
from ZODB.TimeStamp import TimeStamp from ZODB.TimeStamp import TimeStamp
from ZODB.lock_file import lock_file from ZODB.lock_file import lock_file
from ZODB.utils import p64, u64, cp, z64 from ZODB.utils import p64, u64, cp, z64
from ZODB.fsIndex import fsIndex
try:
from ZODB.fsIndex import fsIndex
except ImportError:
def fsIndex():
return {}
from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC
...@@ -203,6 +208,8 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -203,6 +208,8 @@ class FileStorage(BaseStorage.BaseStorage,
# default pack time is 0 # default pack time is 0
_packt = z64 _packt = z64
_records_before_save = 10000
def __init__(self, file_name, create=0, read_only=0, stop=None, def __init__(self, file_name, create=0, read_only=0, stop=None,
quota=None): quota=None):
...@@ -270,7 +277,9 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -270,7 +277,9 @@ class FileStorage(BaseStorage.BaseStorage,
r = self._restore_index() r = self._restore_index()
if r is not None: if r is not None:
self._used_index = 1 # Marker for testing
index, vindex, start, maxoid, ltid = r index, vindex, start, maxoid, ltid = r
self._initIndex(index, vindex, tindex, tvindex) self._initIndex(index, vindex, tindex, tvindex)
self._pos, self._oid, tid = read_index( self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop, self._file, file_name, index, vindex, tindex, stop,
...@@ -278,10 +287,15 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -278,10 +287,15 @@ class FileStorage(BaseStorage.BaseStorage,
read_only=read_only, read_only=read_only,
) )
else: else:
self._used_index = 0 # Marker for testing
self._pos, self._oid, tid = read_index( self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop, self._file, file_name, index, vindex, tindex, stop,
read_only=read_only, read_only=read_only,
) )
self._save_index()
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid self._ltid = tid
# self._pos should always point just past the last # self._pos should always point just past the last
...@@ -314,6 +328,7 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -314,6 +328,7 @@ class FileStorage(BaseStorage.BaseStorage,
# hook to use something other than builtin dict # hook to use something other than builtin dict
return fsIndex(), {}, {}, {} return fsIndex(), {}, {}, {}
_saved = 0
def _save_index(self): def _save_index(self):
"""Write the database index to a file to support quick startup.""" """Write the database index to a file to support quick startup."""
...@@ -329,6 +344,7 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -329,6 +344,7 @@ class FileStorage(BaseStorage.BaseStorage,
p.dump(info) p.dump(info)
f.flush() f.flush()
f.close() f.close()
try: try:
try: try:
os.remove(index_name) os.remove(index_name)
...@@ -337,6 +353,8 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -337,6 +353,8 @@ class FileStorage(BaseStorage.BaseStorage,
os.rename(tmp_name, index_name) os.rename(tmp_name, index_name)
except: pass except: pass
self._saved += 1
def _clear_index(self): def _clear_index(self):
index_name = self.__name__ + '.index' index_name = self.__name__ + '.index'
if os.path.exists(index_name): if os.path.exists(index_name):
...@@ -354,58 +372,77 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -354,58 +372,77 @@ class FileStorage(BaseStorage.BaseStorage,
object positions cause zero to be returned. object positions cause zero to be returned.
""" """
if pos < 100: return 0 if pos < 100:
file=self._file return 0 # insane
seek=file.seek file = self._file
read=file.read seek = file.seek
read = file.read
seek(0,2) seek(0,2)
if file.tell() < pos: return 0 if file.tell() < pos:
ltid=None return 0 # insane
ltid = None
while 1: max_checked = 5
checked = 0
while checked < max_checked:
seek(pos-8) seek(pos-8)
rstl=read(8) rstl = read(8)
tl=u64(rstl) tl = u64(rstl)
pos=pos-tl-8 pos = pos-tl-8
if pos < 4: return 0 if pos < 4:
return 0 # insane
seek(pos) seek(pos)
s = read(TRANS_HDR_LEN) s = read(TRANS_HDR_LEN)
tid, stl, status, ul, dl, el = unpack(TRANS_HDR, s) tid, stl, status, ul, dl, el = unpack(TRANS_HDR, s)
if not ltid: ltid=tid if not ltid:
if stl != rstl: return 0 # inconsistent lengths ltid = tid
if status == 'u': continue # undone trans, search back if stl != rstl:
if status not in ' p': return 0 return 0 # inconsistent lengths
if tl < (TRANS_HDR_LEN + ul + dl + el): return 0 if status == 'u':
tend=pos+tl continue # undone trans, search back
opos=pos+(TRANS_HDR_LEN + ul + dl + el) if status not in ' p':
if opos==tend: continue # empty trans return 0 # insane
if tl < (TRANS_HDR_LEN + ul + dl + el):
while opos < tend: return 0 # insane
tend = pos+tl
opos = pos+(TRANS_HDR_LEN + ul + dl + el)
if opos == tend:
continue # empty trans
while opos < tend and checked < max_checked:
# Read the data records for this transaction # Read the data records for this transaction
seek(opos) seek(opos)
h=read(DATA_HDR_LEN) h = read(DATA_HDR_LEN)
oid,serial,sprev,stloc,vlen,splen = unpack(DATA_HDR, h) oid, serial, sprev, stloc, vlen, splen = unpack(DATA_HDR, h)
tloc=u64(stloc) tloc = u64(stloc)
plen=u64(splen) plen = u64(splen)
dlen=DATA_HDR_LEN+(plen or 8) dlen = DATA_HDR_LEN+(plen or 8)
if vlen: dlen=dlen+(16+vlen) if vlen:
dlen = dlen+(16+vlen)
if opos+dlen > tend or tloc != pos:
return 0 # insane
if opos+dlen > tend or tloc != pos: return 0 if index.get(oid, 0) != opos:
return 0 # insane
if index.get(oid, 0) != opos: return 0 checked += 1
opos=opos+dlen opos = opos+dlen
return ltid return ltid
def _restore_index(self): def _restore_index(self):
"""Load database index to support quick startup.""" """Load database index to support quick startup."""
try: file_name=self.__name__
f = open("%s.index" % self.__name__, 'rb') index_name=file_name+'.index'
except:
return None try: f=open(index_name,'rb')
p = Unpickler(f) except: return None
p=Unpickler(f)
try: try:
info=p.load() info=p.load()
...@@ -422,6 +459,23 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -422,6 +459,23 @@ class FileStorage(BaseStorage.BaseStorage,
return None return None
pos = long(pos) pos = long(pos)
if isinstance(index, DictType) and not self._is_read_only:
# Convert to fsIndex
newindex = fsIndex()
if type(newindex) is not type(index):
# And we have fsIndex
newindex.update(index)
# Now save the index
f = open(index_name, 'wb')
p = Pickler(f, 1)
info['index'] = newindex
p.dump(info)
f.close()
# Now call this method again to get the new data
return self._restore_index()
tid = self._sane(index, pos) tid = self._sane(index, pos)
if not tid: if not tid:
return None return None
...@@ -955,6 +1009,9 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -955,6 +1009,9 @@ class FileStorage(BaseStorage.BaseStorage,
finally: finally:
self._lock_release() self._lock_release()
# Keep track of the number of records that we've written
_records_written = 0
def _finish(self, tid, u, d, e): def _finish(self, tid, u, d, e):
nextpos=self._nextpos nextpos=self._nextpos
if nextpos: if nextpos:
...@@ -967,10 +1024,20 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -967,10 +1024,20 @@ class FileStorage(BaseStorage.BaseStorage,
if fsync is not None: fsync(file.fileno()) if fsync is not None: fsync(file.fileno())
self._pos=nextpos self._pos = nextpos
self._index.update(self._tindex) self._index.update(self._tindex)
self._vindex.update(self._tvindex) self._vindex.update(self._tvindex)
# Update the number of records that we've written
# +1 for the transaction record
self._records_written += len(self._tindex) + 1
if self._records_written >= self._records_before_save:
self._save_index()
self._records_written = 0
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid self._ltid = tid
def _abort(self): def _abort(self):
...@@ -1210,7 +1277,8 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -1210,7 +1277,8 @@ class FileStorage(BaseStorage.BaseStorage,
while pos < tend: while pos < tend:
self._file.seek(pos) self._file.seek(pos)
h = self._file.read(DATA_HDR_LEN) h = self._file.read(DATA_HDR_LEN)
oid, serial, sprev, stloc, vlen, splen = struct.unpack(DATA_HDR, h) oid, serial, sprev, stloc, vlen, splen = \
struct.unpack(DATA_HDR, h)
if failed(oid): if failed(oid):
del failures[oid] # second chance! del failures[oid] # second chance!
plen = u64(splen) plen = u64(splen)
...@@ -1966,6 +2034,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8, ...@@ -1966,6 +2034,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
id in the data. The transaction id is the tid of the last id in the data. The transaction id is the tid of the last
transaction. transaction.
""" """
read = file.read read = file.read
seek = file.seek seek = file.seek
seek(0, 2) seek(0, 2)
...@@ -2001,7 +2070,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8, ...@@ -2001,7 +2070,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
if tid <= ltid: if tid <= ltid:
warn("%s time-stamp reduction at %s", name, pos) warn("%s time-stamp reduction at %s", name, pos)
ltid=tid ltid = tid
tl=u64(stl) tl=u64(stl)
...@@ -2074,7 +2143,12 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8, ...@@ -2074,7 +2143,12 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
if vlen: if vlen:
dlen=dlen+(16+vlen) dlen=dlen+(16+vlen)
read(16) read(16)
pv=u64(read(8))
version=read(vlen) version=read(vlen)
# Jim says: "It's just not worth the bother."
#if vndexpos(version, 0) != pv:
# panic("%s incorrect previous version pointer at %s",
# name, pos)
vindex[version]=pos vindex[version]=pos
if pos+dlen > tend or tloc != tpos: if pos+dlen > tend or tloc != tpos:
...@@ -2223,8 +2297,11 @@ class FileIterator(Iterator): ...@@ -2223,8 +2297,11 @@ class FileIterator(Iterator):
self._stop = stop self._stop = stop
def __len__(self): def __len__(self):
# This is a lie. It's here only for Python 2.1 support for # Define a bogus __len__() to make the iterator work
# list()-ifying these objects. # with code like builtin list() and tuple() in Python 2.1.
# There's a lot of C code that expects a sequence to have
# an __len__() but can cope with any sort of mistake in its
# implementation. So just return 0.
return 0 return 0
def close(self): def close(self):
...@@ -2362,7 +2439,6 @@ class FileIterator(Iterator): ...@@ -2362,7 +2439,6 @@ class FileIterator(Iterator):
class RecordIterator(Iterator, BaseStorage.TransactionRecord): class RecordIterator(Iterator, BaseStorage.TransactionRecord):
"""Iterate over the transactions in a FileStorage file.""" """Iterate over the transactions in a FileStorage file."""
def __init__(self, tid, status, user, desc, ext, pos, tend, file, tpos): def __init__(self, tid, status, user, desc, ext, pos, tend, file, tpos):
self.tid = tid self.tid = tid
self.status = status self.status = status
......
...@@ -38,6 +38,7 @@ class IteratorCompare: ...@@ -38,6 +38,7 @@ class IteratorCompare:
eq(zodb_unpickle(rec.data), MinPO(val)) eq(zodb_unpickle(rec.data), MinPO(val))
val = val + 1 val = val + 1
eq(val, val0 + len(revids)) eq(val, val0 + len(revids))
txniter.close()
class IteratorStorage(IteratorCompare): class IteratorStorage(IteratorCompare):
...@@ -191,3 +192,5 @@ class IteratorDeepCompare: ...@@ -191,3 +192,5 @@ class IteratorDeepCompare:
# they were the same length # they were the same length
self.assertRaises(IndexError, iter1.next) self.assertRaises(IndexError, iter1.next)
self.assertRaises(IndexError, iter2.next) self.assertRaises(IndexError, iter2.next)
iter1.close()
iter2.close()
...@@ -72,6 +72,101 @@ class FileStorageTests( ...@@ -72,6 +72,101 @@ class FileStorageTests(
else: else:
self.fail("expect long user field to raise error") self.fail("expect long user field to raise error")
def check_use_fsIndex(self):
from ZODB.fsIndex import fsIndex
self.assertEqual(self._storage._index.__class__, fsIndex)
# XXX We could really use some tests for sanity checking
def check_conversion_to_fsIndex_not_if_readonly(self):
self.tearDown()
class OldFileStorage(ZODB.FileStorage.FileStorage):
def _newIndexes(self):
return {}, {}, {}, {}
from ZODB.fsIndex import fsIndex
# Hack FileStorage to create dictionary indexes
self._storage = OldFileStorage('FileStorageTests.fs')
self.assertEqual(type(self._storage._index), type({}))
for i in range(10):
self._dostore()
# Should save the index
self._storage.close()
self._storage = ZODB.FileStorage.FileStorage(
'FileStorageTests.fs', read_only=1)
self.assertEqual(type(self._storage._index), type({}))
def check_conversion_to_fsIndex(self):
self.tearDown()
class OldFileStorage(ZODB.FileStorage.FileStorage):
def _newIndexes(self):
return {}, {}, {}, {}
from ZODB.fsIndex import fsIndex
# Hack FileStorage to create dictionary indexes
self._storage = OldFileStorage('FileStorageTests.fs')
self.assertEqual(type(self._storage._index), type({}))
for i in range(10):
self._dostore()
oldindex = self._storage._index.copy()
# Should save the index
self._storage.close()
self._storage = ZODB.FileStorage.FileStorage('FileStorageTests.fs')
self.assertEqual(self._storage._index.__class__, fsIndex)
self.failUnless(self._storage._used_index)
index = {}
for k, v in self._storage._index.items():
index[k] = v
self.assertEqual(index, oldindex)
def check_save_after_load_with_no_index(self):
for i in range(10):
self._dostore()
self._storage.close()
os.remove('FileStorageTests.fs.index')
self.open()
self.assertEqual(self._storage._saved, 1)
# This would make the unit tests too slow
# check_save_after_load_that_worked_hard(self)
def check_periodic_save_index(self):
# Check the basic algorithm
oldsaved = self._storage._saved
self._storage._records_before_save = 10
for i in range(4):
self._dostore()
self.assertEqual(self._storage._saved, oldsaved)
self._dostore()
self.assertEqual(self._storage._saved, oldsaved+1)
# Now make sure the parameter changes as we get bigger
for i in range(20):
self._dostore()
self.failUnless(self._storage._records_before_save > 20)
class FileStorageRecoveryTest( class FileStorageRecoveryTest(
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase,
RecoveryStorage.RecoveryStorage, RecoveryStorage.RecoveryStorage,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment