Commit c316b563 authored by Jeremy Hylton's avatar Jeremy Hylton

Merge ZODB3-fast-restart-branch to the trunk

parent a45a0786
......@@ -34,8 +34,16 @@ temporary directory as determined by the tempfile module.
The ClientStorage overrides the client name default to the value of
the environment variable ZEO_CLIENT, if it exists.
Each cache file has a 4-byte magic number followed by a sequence of
records of the form:
Each cache file has a 12-byte header followed by a sequence of
records. The header format is as follows:
offset in header: name -- description
0: magic -- 4-byte magic number, identifying this as a ZEO cache file
4: lasttid -- 8-byte last transaction id
Each record has the following form:
offset in record: name -- description
......@@ -111,7 +119,8 @@ from ZODB.utils import U64
import zLOG
from ZEO.ICache import ICache
magic='ZEC0'
magic = 'ZEC1'
headersize = 12
class ClientCache:
......@@ -126,6 +135,8 @@ class ClientCache:
self._storage = storage
self._limit = size / 2
self._client = client
self._ltid = None # For getLastTid()
# Allocate locks:
L = allocate_lock()
......@@ -154,9 +165,9 @@ class ClientCache:
fi = open(p[i],'r+b')
if fi.read(4) == magic: # Minimal sanity
fi.seek(0, 2)
if fi.tell() > 30:
# First serial is at offset 19 + 4 for magic
fi.seek(23)
if fi.tell() > headersize:
# Read serial at offset 19 of first record
fi.seek(headersize + 19)
s[i] = fi.read(8)
# If we found a non-zero serial, then use the file
if s[i] != '\0\0\0\0\0\0\0\0':
......@@ -172,14 +183,14 @@ class ClientCache:
if f[0] is None:
# We started, open the first cache file
f[0] = open(p[0], 'w+b')
f[0].write(magic)
f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0
f[1] = None
else:
self._f = f = [tempfile.TemporaryFile(suffix='.zec'), None]
# self._p file name 'None' signifies an unnamed temp file.
self._p = p = [None, None]
f[0].write(magic)
f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0
self.log("%s: storage=%r, size=%r; file[%r]=%r" %
......@@ -219,6 +230,57 @@ class ClientCache:
except OSError:
pass
def getLastTid(self):
"""Get the last transaction id stored by setLastTid().
If the cache is persistent, it is read from the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
return self._ltid
else:
self._acquire()
try:
return self._getLastTid()
finally:
self._release()
def _getLastTid(self):
f = self._f[self._current]
f.seek(4)
tid = f.read(8)
if len(tid) < 8 or tid == '\0\0\0\0\0\0\0\0':
return None
else:
return tid
def setLastTid(self, tid):
"""Store the last transaction id.
If the cache is persistent, it is written to the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
if tid == '\0\0\0\0\0\0\0\0':
tid = None
self._ltid = tid
else:
self._acquire()
try:
self._setLastTid(tid)
finally:
self._release()
def _setLastTid(self, tid):
if tid is None:
tid = '\0\0\0\0\0\0\0\0'
else:
tid = str(tid)
assert len(tid) == 8
f = self._f[self._current]
f.seek(4)
f.write(tid)
def verify(self, verifyFunc):
"""Call the verifyFunc on every object in the cache.
......@@ -477,6 +539,7 @@ class ClientCache:
self._acquire()
try:
if self._pos + size > self._limit:
ltid = self._getLastTid()
current = not self._current
self._current = current
self._trace(0x70)
......@@ -500,8 +563,12 @@ class ClientCache:
else:
# Temporary cache file:
self._f[current] = tempfile.TemporaryFile(suffix='.zec')
self._f[current].write(magic)
self._pos = 4
header = magic
if ltid:
header += ltid
self._f[current].write(header +
'\0' * (headersize - len(header)))
self._pos = headersize
finally:
self._release()
......@@ -593,7 +660,7 @@ class ClientCache:
f = self._f[fileindex]
seek = f.seek
read = f.read
pos = 4
pos = headersize
count = 0
while 1:
......@@ -652,7 +719,6 @@ class ClientCache:
del serial[oid]
del index[oid]
pos = pos + tlen
count += 1
......
......@@ -22,7 +22,6 @@ ClientDisconnected -- exception raised by ClientStorage
"""
# XXX TO DO
# get rid of beginVerify, set up _tfile in verify_cache
# set self._storage = stub later, in endVerify
# if wait is given, wait until verify is complete
......@@ -60,6 +59,9 @@ class UnrecognizedResult(ClientStorageError):
class ClientDisconnected(ClientStorageError, Disconnected):
"""The database storage is disconnected from the storage."""
def tid2time(tid):
return str(TimeStamp(tid))
def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance.
......@@ -208,6 +210,8 @@ class ClientStorage:
self._connection = None
# _server_addr is used by sortKey()
self._server_addr = None
self._tfile = None
self._pickler = None
self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client',
'supportsUndo':0, 'supportsVersions': 0,
......@@ -337,12 +341,14 @@ class ClientStorage:
This is called by ConnectionManager after it has decided which
connection should be used.
"""
# XXX would like to report whether we get a read-only connection
if self._connection is not None:
log2(INFO, "Reconnected to storage")
reconnect = 1
else:
log2(INFO, "Connected to storage")
reconnect = 0
self.set_server_addr(conn.get_addr())
stub = self.StorageServerStubClass(conn)
stub = self.StorageServerStubClass(conn)
self._oids = []
self._info.update(stub.get_info())
self.verify_cache(stub)
......@@ -353,6 +359,11 @@ class ClientStorage:
self._connection = conn
self._server = stub
if reconnect:
log2(INFO, "Reconnected to storage: %s" % self._server_addr)
else:
log2(INFO, "Connected to storage: %s" % self._server_addr)
def set_server_addr(self, addr):
# Normalize server address and convert to string
if isinstance(addr, types.StringType):
......@@ -381,12 +392,42 @@ class ClientStorage:
return self._server_addr
def verify_cache(self, server):
"""Internal routine called to verify the cache."""
# XXX beginZeoVerify ends up calling back to beginVerify() below.
# That whole exchange is rather unnecessary.
server.beginZeoVerify()
"""Internal routine called to verify the cache.
The return value (indicating which path we took) is used by
the test suite.
"""
last_inval_tid = self._cache.getLastTid()
if last_inval_tid is not None:
ltid = server.lastTransaction()
if ltid == last_inval_tid:
log2(INFO, "No verification necessary "
"(last_inval_tid up-to-date)")
self._cache.open()
return "no verification"
# log some hints about last transaction
log2(INFO, "last inval tid: %r %s"
% (last_inval_tid, tid2time(last_inval_tid)))
log2(INFO, "last transaction: %r %s" %
(ltid, ltid and tid2time(ltid)))
pair = server.getInvalidations(last_inval_tid)
if pair is not None:
log2(INFO, "Recovering %d invalidations" % len(pair[1]))
self._cache.open()
self.invalidateTransaction(*pair)
return "quick verification"
log2(INFO, "Verifying cache")
# setup tempfile to hold zeoVerify results
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
self._cache.verify(server.zeoVerify)
server.endZeoVerify()
return "full verification"
### Is there a race condition between notifyConnected and
### notifyDisconnected? In Particular, what if we get
......@@ -402,7 +443,8 @@ class ClientStorage:
This is called by ConnectionManager when the connection is
closed or when certain problems with the connection occur.
"""
log2(PROBLEM, "Disconnected from storage")
log2(PROBLEM, "Disconnected from storage: %s"
% repr(self._server_addr))
self._connection = None
self._server = disconnected_stub
......@@ -644,6 +686,7 @@ class ClientStorage:
self._serial = id
self._seriald.clear()
del self._serials[:]
self._tbuf.clear()
def end_transaction(self):
"""Internal helper to end a transaction."""
......@@ -678,12 +721,13 @@ class ClientStorage:
if f is not None:
f()
self._server.tpc_finish(self._serial)
tid = self._server.tpc_finish(self._serial)
r = self._check_serials()
assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
self._update_cache()
self._cache.setLastTid(tid)
finally:
self.end_transaction()
......@@ -779,12 +823,6 @@ class ClientStorage:
"""Server callback to update the info dictionary."""
self._info.update(dict)
def beginVerify(self):
"""Server callback to signal start of cache validation."""
self._tfile = tempfile.TemporaryFile(suffix=".inv")
self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo
def invalidateVerify(self, args):
"""Server callback to invalidate an (oid, version) pair.
......@@ -802,6 +840,7 @@ class ClientStorage:
if self._pickler is None:
return
self._pickler.dump((0,0))
self._pickler = None
self._tfile.seek(0)
unpick = cPickle.Unpickler(self._tfile)
f = self._tfile
......@@ -815,29 +854,26 @@ class ClientStorage:
self._db.invalidate(oid, version=version)
f.close()
def invalidateTrans(self, args):
"""Server callback to invalidate a list of (oid, version) pairs.
This is called as the result of a transaction.
"""
def invalidateTransaction(self, tid, args):
"""Invalidate objects modified by tid."""
self._cache.setLastTid(tid)
if self._pickler is not None:
self.log("Transactional invalidation during cache verification",
level=zLOG.BLATHER)
for t in args:
self.self._pickler.dump(t)
return
db = self._db
for oid, version in args:
self._cache.invalidate(oid, version=version)
try:
self._db.invalidate(oid, version=version)
except AttributeError, msg:
log2(PROBLEM,
"Invalidate(%s, %s) failed for _db: %s" % (repr(oid),
repr(version),
msg))
# Unfortunately, the ZEO 2 wire protocol uses different names for
# several of the callback methods invoked by the StorageServer.
# We can't change the wire protocol at this point because that
# would require synchronized updates of clients and servers and we
# don't want that. So here we alias the old names to their new
# implementations.
begin = beginVerify
if db is not None:
db.invalidate(oid, version=version)
# The following are for compatibility with protocol version 2.0.0
def invalidateTrans(self, args):
return self.invalidateTransaction(None, args)
invalidate = invalidateVerify
end = endVerify
Invalidate = invalidateTrans
......
......@@ -44,16 +44,16 @@ class ClientStorage:
self.rpc = rpc
def beginVerify(self):
self.rpc.callAsync('begin')
self.rpc.callAsync('beginVerify')
def invalidateVerify(self, args):
self.rpc.callAsync('invalidate', args)
self.rpc.callAsync('invalidateVerify', args)
def endVerify(self):
self.rpc.callAsync('end')
self.rpc.callAsync('endVerify')
def invalidateTrans(self, args):
self.rpc.callAsync('Invalidate', args)
def invalidateTransaction(self, tid, args):
self.rpc.callAsync('invalidateTransaction', tid, args)
def serialnos(self, arg):
self.rpc.callAsync('serialnos', arg)
......
......@@ -32,6 +32,9 @@ class StorageServer:
zrpc.connection.Connection class.
"""
self.rpc = rpc
if self.rpc.peer_protocol_version == 'Z200':
self.lastTransaction = lambda: None
self.getInvalidations = lambda tid: None
def extensionMethod(self, name):
return ExtensionMethodWrapper(self.rpc, name).call
......@@ -51,8 +54,13 @@ class StorageServer:
def get_info(self):
return self.rpc.call('get_info')
def beginZeoVerify(self):
self.rpc.callAsync('beginZeoVerify')
def lastTransaction(self):
# Not in protocol version 2.0.0; see __init__()
return self.rpc.call('lastTransaction')
def getInvalidations(self, tid):
# Not in protocol version 2.0.0; see __init__()
return self.rpc.call('getInvalidations', tid)
def zeoVerify(self, oid, s, sv):
self.rpc.callAsync('zeoVerify', oid, s, sv)
......
......@@ -37,6 +37,7 @@ from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError
from ZODB.referencesf import referencesf
from ZODB.Transaction import Transaction
from ZODB.utils import u64
_label = "ZSS" # Default label used for logging.
......@@ -68,8 +69,8 @@ class StorageServer:
ZEOStorageClass = None # patched up later
ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages, read_only=0):
def __init__(self, addr, storages, read_only=0,
invalidation_queue_size=100):
"""StorageServer constructor.
This is typically invoked from the start.py script.
......@@ -102,13 +103,17 @@ class StorageServer:
self.storages = storages
set_label()
msg = ", ".join(
["%s:%s" % (name, storage.isReadOnly() and "RO" or "RW")
["%s:%s:%s" % (name, storage.isReadOnly() and "RO" or "RW",
storage.getName())
for name, storage in storages.items()])
log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg))
for s in storages.values():
s._waiting = []
self.read_only = read_only
# A list of at most invalidation_queue_size invalidations
self.invq = []
self.invq_bound = invalidation_queue_size
self.connections = {}
self.dispatcher = self.DispatcherClass(addr,
factory=self.new_connection,
......@@ -141,7 +146,7 @@ class StorageServer:
l = self.connections[storage_id] = []
l.append(conn)
def invalidate(self, conn, storage_id, invalidated=(), info=None):
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients.
This is called from several ZEOStorage methods.
......@@ -149,7 +154,7 @@ class StorageServer:
This can do three different things:
- If the invalidated argument is non-empty, it broadcasts
invalidateTrans() messages to all clients of the given
invalidateTransaction() messages to all clients of the given
storage except the current client (the conn argument).
- If the invalidated argument is empty and the info argument
......@@ -158,17 +163,47 @@ class StorageServer:
client.
- If both the invalidated argument and the info argument are
non-empty, it broadcasts invalidateTrans() messages to all
non-empty, it broadcasts invalidateTransaction() messages to all
clients except the current, and sends an info() message to
the current client.
"""
if invalidated:
if len(self.invq) >= self.invq_bound:
del self.invq[0]
self.invq.append((tid, invalidated))
for p in self.connections.get(storage_id, ()):
if invalidated and p is not conn:
p.client.invalidateTrans(invalidated)
p.client.invalidateTransaction(tid, invalidated)
elif info is not None:
p.client.info(info)
def get_invalidations(self, tid):
"""Return a tid and list of all objects invalidation since tid.
The tid is the most recent transaction id committed by the server.
Returns None if it is unable to provide a complete list
of invalidations for tid. In this case, client should
do full cache verification.
"""
if not self.invq:
log("invq empty")
return None, []
earliest_tid = self.invq[0][0]
if earliest_tid > tid:
log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid)))
return None, []
oids = {}
for tid, L in self.invq:
for key in L:
oids[key] = 1
latest_tid = self.invq[-1][0]
return latest_tid, oids.keys()
def close_server(self):
"""Close the dispatcher so that there are no new connections.
......@@ -212,10 +247,18 @@ class ZEOStorage:
self.storage_id = "uninitialized"
self.transaction = None
self.read_only = read_only
self.log_label = _label
def notifyConnected(self, conn):
self.connection = conn # For restart_other() below
self.client = self.ClientStorageStubClass(conn)
addr = conn.addr
if isinstance(addr, type("")):
label = addr
else:
host, port = addr
label = str(host) + ":" + str(port)
self.log_label = _label + "/" + label
def notifyDisconnected(self):
# When this storage closes, we must ensure that it aborts
......@@ -237,7 +280,7 @@ class ZEOStorage:
return "<%s %X trans=%s s_trans=%s>" % (name, id(self), tid, stid)
def log(self, msg, level=zLOG.INFO, error=None):
zLOG.LOG("%s:%s" % (_label, self.storage_id), level, msg, error=error)
zLOG.LOG(self.log_label, level, msg, error=error)
def setup_delegation(self):
"""Delegate several methods to the storage"""
......@@ -259,6 +302,7 @@ class ZEOStorage:
for name in fn().keys():
if not hasattr(self,name):
setattr(self, name, getattr(self.storage, name))
self.lastTransaction = self.storage.lastTransaction
def check_tid(self, tid, exc=None):
if self.read_only:
......@@ -286,7 +330,7 @@ class ZEOStorage:
This method must be the first one called by the client.
"""
if self.storage is not None:
log("duplicate register() call")
self.log("duplicate register() call")
raise ValueError, "duplicate register() call"
storage = self.server.storages.get(storage_id)
if storage is None:
......@@ -342,8 +386,13 @@ class ZEOStorage:
raise
return p, s, v, pv, sv
def beginZeoVerify(self):
self.client.beginVerify()
def getInvalidations(self, tid):
invtid, invlist = self.server.get_invalidations(tid)
if invtid is None:
return None
self.log("Return %d invalidations up to tid %s"
% (len(invlist), u64(invtid)))
return invtid, invlist
def zeoVerify(self, oid, s, sv):
try:
......@@ -394,7 +443,8 @@ class ZEOStorage:
self.storage.pack(time, referencesf)
self.log("pack(time=%s) complete" % repr(time))
# Broadcast new size statistics
self.server.invalidate(0, self.storage_id, (), self.get_size_info())
self.server.invalidate(0, self.storage_id, None,
(), self.get_size_info())
def new_oids(self, n=100):
"""Return a sequence of n new oids, where n defaults to 100"""
......@@ -409,7 +459,7 @@ class ZEOStorage:
raise ReadOnlyError()
oids = self.storage.undo(transaction_id)
if oids:
self.server.invalidate(self, self.storage_id,
self.server.invalidate(self, self.storage_id, None,
map(lambda oid: (oid, ''), oids))
return oids
return ()
......@@ -450,12 +500,15 @@ class ZEOStorage:
if not self.check_tid(id):
return
invalidated = self.strategy.tpc_finish()
tid = self.storage.lastTransaction()
if invalidated:
self.server.invalidate(self, self.storage_id,
self.server.invalidate(self, self.storage_id, tid,
invalidated, self.get_size_info())
self.transaction = None
self.strategy = None
# Return the tid, for cache invalidation optimization
self.handle_waiting()
return tid
def tpc_abort(self, id):
if not self.check_tid(id):
......@@ -546,7 +599,8 @@ class ZEOStorage:
old_strategy = self.strategy
assert isinstance(old_strategy, DelayedCommitStrategy)
self.strategy = ImmediateCommitStrategy(self.storage,
self.client)
self.client,
self.log)
resp = old_strategy.restart(self.strategy)
if delay is not None:
delay.reply(resp)
......@@ -602,11 +656,12 @@ class ICommitStrategy:
class ImmediateCommitStrategy:
"""The storage is available so do a normal commit."""
def __init__(self, storage, client):
def __init__(self, storage, client, logmethod):
self.storage = storage
self.client = client
self.invalidated = []
self.serials = []
self.log = logmethod
def tpc_begin(self, txn, tid, status):
self.txn = txn
......@@ -628,11 +683,13 @@ class ImmediateCommitStrategy:
try:
newserial = self.storage.store(oid, serial, data, version,
self.txn)
except (SystemExit, KeyboardInterrupt):
raise
except Exception, err:
if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client
exc_info = sys.exc_info()
log("store error: %s, %s" % exc_info[:2],
self.log("store error: %s, %s" % exc_info[:2],
zLOG.ERROR, error=exc_info)
del exc_info
# Try to pickle the exception. If it can't be pickled,
......@@ -643,7 +700,7 @@ class ImmediateCommitStrategy:
pickler.dump(err, 1)
except:
msg = "Couldn't pickle storage exception: %s" % repr(err)
log(msg, zLOG.ERROR)
self.log(msg, zLOG.ERROR)
err = StorageServerError(msg)
# The exception is reported back as newserial for this oid
newserial = err
......@@ -776,6 +833,8 @@ class SlowMethodThread(threading.Thread):
def run(self):
try:
result = self._method(*self._args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
self.delay.error(sys.exc_info())
else:
......
......@@ -117,18 +117,12 @@ def main():
# Must be a misaligned record caused by a crash
##print "Skipping 8 bytes at offset", offset-8
continue
oid = f_read(8)
if len(oid) < 8:
r = f_read(16)
if len(r) < 16:
break
if heuristic and oid[:4] != '\0\0\0\0':
f.seek(-8, 1)
continue
offset += 8
serial = f_read(8)
if len(serial) < 8:
break
offset += 8
offset += 16
records += 1
oid, serial = struct_unpack(">8s8s", r)
# Decode the code
dlen, version, code, current = (code & 0x7fffff00,
code & 0x80,
......
......@@ -153,24 +153,14 @@ def main():
if ts == 0:
# Must be a misaligned record caused by a crash
if not quiet:
print "Skipping 8 bytes at offset", offset-8,
print repr(r)
print "Skipping 8 bytes at offset", offset-8
continue
oid = f_read(8)
if len(oid) < 8:
r = f_read(16)
if len(r) < 16:
break
if heuristic and oid[:4] != '\0\0\0\0':
# Heuristic for severe data corruption
print "Seeking back over bad oid at offset", offset,
print repr(r)
f.seek(-8, 1)
continue
offset += 8
serial = f_read(8)
if len(serial) < 8:
break
offset += 8
offset += 16
records += 1
oid, serial = struct_unpack(">8s8s", r)
if t0 is None:
t0 = ts
thisinterval = t0 / interval
......
......@@ -20,7 +20,9 @@ import select
import socket
import asyncore
import tempfile
import thread # XXX do we really need to catch thread.error
import threading
import time
import zLOG
......@@ -36,9 +38,18 @@ from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
from ZODB.tests.StorageTestBase import handle_all_serials, ZERO
class TestClientStorage(ClientStorage):
def verify_cache(self, stub):
self.end_verify = threading.Event()
self.verify_result = ClientStorage.verify_cache(self, stub)
def endVerify(self):
ClientStorage.endVerify(self)
self.end_verify.set()
class DummyDB:
def invalidate(self, *args, **kws):
def invalidate(self, *args, **kwargs):
pass
......@@ -48,6 +59,7 @@ class CommonSetupTearDown(StorageTestBase):
__super_setUp = StorageTestBase.setUp
__super_tearDown = StorageTestBase.tearDown
keep = 0
invq = None
def setUp(self):
"""Test setup for connection tests.
......@@ -99,17 +111,15 @@ class CommonSetupTearDown(StorageTestBase):
raise NotImplementedError
def openClientStorage(self, cache='', cache_size=200000, wait=1,
read_only=0, read_only_fallback=0,
addr=None):
if addr is None:
addr = self.addr
storage = ClientStorage(addr,
read_only=0, read_only_fallback=0):
base = TestClientStorage(self.addr,
client=cache,
cache_size=cache_size,
wait=wait,
min_disconnect_poll=0.1,
read_only=read_only,
read_only_fallback=read_only_fallback)
storage = base
storage.registerDB(DummyDB(), None)
return storage
......@@ -121,7 +131,7 @@ class CommonSetupTearDown(StorageTestBase):
path = "%s.%d" % (self.file, index)
conf = self.getConfig(path, create, read_only)
zeoport, adminaddr, pid = forker.start_zeo_server(
conf, addr, ro_svr, self.keep)
conf, addr, ro_svr, self.keep, self.invq)
self._pids.append(pid)
self._servers.append(adminaddr)
......@@ -420,9 +430,9 @@ class ConnectionTests(CommonSetupTearDown):
for t in threads:
t.closeclients()
class ReconnectionTests(CommonSetupTearDown):
keep = 1
invq = 2
def checkReadOnlyStorage(self):
# Open a read-only client to a read-only *storage*; stores fail
......@@ -557,6 +567,113 @@ class ReconnectionTests(CommonSetupTearDown):
else:
self.fail("Couldn't store after starting a read-write server")
def checkNoVerificationOnServerRestart(self):
self._storage = self.openClientStorage()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
self._dostore()
self.shutdownServer()
self.pollDown()
self._storage.verify_result = None
self.startServer(create=0)
self.pollUp()
# There were no transactions committed, so no verification
# should be needed.
self.assertEqual(self._storage.verify_result, "no verification")
def checkNoVerificationOnServerRestartWith2Clients(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
self._dostore(oid, revid)
perstorage.load(oid, '')
self.shutdownServer()
self.pollDown()
self._storage.verify_result = None
perstorage.verify_result = None
self.startServer(create=0)
self.pollUp()
# There were no transactions committed, so no verification
# should be needed.
self.assertEqual(self._storage.verify_result, "no verification")
perstorage.close()
self.assertEqual(perstorage.verify_result, "no verification")
def checkQuickVerificationWith2Clients(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
revid = self._dostore(oid, revid)
perstorage.load(oid, '')
perstorage.close()
revid = self._dostore(oid, revid)
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "quick verification")
self.assertEqual(perstorage.load(oid, ''),
self._storage.load(oid, ''))
def checkVerificationWith2ClientsInvqOverflow(self):
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
self._storage = self.openClientStorage()
oid = self._storage.new_oid()
# When we create a new storage, it should always do a full
# verification
self.assertEqual(self._storage.verify_result, "full verification")
# do two storages of the object to make sure an invalidation
# message is generated
revid = self._dostore(oid)
revid = self._dostore(oid, revid)
perstorage.load(oid, '')
perstorage.close()
# the test code sets invq bound to 2
for i in range(5):
revid = self._dostore(oid, revid)
perstorage = self.openClientStorage(cache="test")
self.assertEqual(perstorage.verify_result, "full verification")
t = time.time() + 30
while not perstorage.end_verify.isSet():
perstorage.sync()
if time.time() > t:
self.fail("timed out waiting for endVerify")
self.assertEqual(self._storage.load(oid, '')[1], revid)
self.assertEqual(perstorage.load(oid, ''),
self._storage.load(oid, ''))
perstorage.close()
class MSTThread(threading.Thread):
......
......@@ -51,7 +51,7 @@ def get_port():
raise RuntimeError, "Can't find port"
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0):
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None):
"""Start a ZEO server in a separate process.
Returns the ZEO port, the test server port, and the pid.
......@@ -77,6 +77,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0):
args.append('-r')
if keep:
args.append('-k')
if invq:
args += ['-Q', str(invq)]
args.append(str(port))
d = os.environ.copy()
d['PYTHONPATH'] = os.pathsep.join(sys.path)
......
......@@ -261,6 +261,19 @@ class ClientCacheTests(unittest.TestCase):
self.assert_(None is not cache._index.get(oid1) < 0)
self.assert_(None is not cache._index.get(oid2) < 0)
def testLastTid(self):
cache = self.cache
self.failUnless(cache.getLastTid() is None)
ltid = 'pqrstuvw'
cache.setLastTid(ltid)
self.assertEqual(cache.getLastTid(), ltid)
cache.checkSize(10*self.cachesize) # Force a file flip
self.assertEqual(cache.getLastTid(), ltid)
cache.setLastTid(None)
self.failUnless(cache.getLastTid() is None)
cache.checkSize(10*self.cachesize) # Force a file flip
self.failUnless(cache.getLastTid() is None)
class PersistentClientCacheTests(unittest.TestCase):
def setUp(self):
......@@ -348,6 +361,26 @@ class PersistentClientCacheTests(unittest.TestCase):
self.fail("invalidated data resurrected, size %d, was %d" %
(len(loaded[0]), len(data)))
def testPersistentLastTid(self):
cache = self.cache
self.failUnless(cache.getLastTid() is None)
ltid = 'pqrstuvw'
cache.setLastTid(ltid)
self.assertEqual(cache.getLastTid(), ltid)
oid = 'abcdefgh'
data = '1234'
serial = 'ABCDEFGH'
cache.store(oid, data, serial, '', '', '')
self.assertEqual(cache.getLastTid(), ltid)
cache.checkSize(10*self.cachesize) # Force a file flip
self.assertEqual(cache.getLastTid(), ltid)
cache = self.reopenCache()
self.assertEqual(cache.getLastTid(), ltid)
cache.setLastTid(None)
self.failUnless(cache.getLastTid() is None)
cache.checkSize(10*self.cachesize) # Force a file flip
self.failUnless(cache.getLastTid() is None)
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ClientCacheTests))
......
......@@ -116,8 +116,9 @@ def main():
ro_svr = 0
keep = 0
configfile = None
invalidation_queue_size = 100
# Parse the arguments and let getopt.error percolate
opts, args = getopt.getopt(sys.argv[1:], 'rkC:')
opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:')
for opt, arg in opts:
if opt == '-r':
ro_svr = 1
......@@ -125,6 +126,8 @@ def main():
keep = 1
elif opt == '-C':
configfile = arg
elif opt == '-Q':
invalidation_queue_size = int(arg)
# Open the config file and let ZConfig parse the data there. Then remove
# the config file, otherwise we'll leave turds.
fp = open(configfile, 'r')
......@@ -145,7 +148,9 @@ def main():
sys.exit(2)
addr = ('', zeo_port)
log(label, 'creating the storage server')
serv = ZEO.StorageServer.StorageServer(addr, {'1': storage}, ro_svr)
serv = ZEO.StorageServer.StorageServer(
addr, {'1': storage}, ro_svr,
invalidation_queue_size=invalidation_queue_size)
log(label, 'entering ThreadedAsync loop')
ThreadedAsync.LoopCallback.loop()
......
......@@ -119,7 +119,7 @@ class ConnectionManager:
# XXX need each connection started with async==0 to have a
# callback
log("CM.set_async(%s)" % repr(map))
log("CM.set_async(%s)" % repr(map), level=zLOG.DEBUG)
if not self.closed and self.trigger is None:
log("CM.set_async(): first call")
self.trigger = trigger()
......@@ -294,6 +294,9 @@ class ConnectThread(threading.Thread):
if success > 0:
break
time.sleep(delay)
if self.mgr.is_connected():
log("CT: still trying to replace fallback connection",
level=zLOG.INFO)
delay = min(delay*2, self.tmax)
log("CT: exiting thread: %s" % self.getName())
......
......@@ -21,7 +21,7 @@ import types
import ThreadedAsync
from ZEO.zrpc import smac
from ZEO.zrpc.error import ZRPCError, DisconnectedError
from ZEO.zrpc.log import log, short_repr
from ZEO.zrpc.log import short_repr, log
from ZEO.zrpc.marshal import Marshaller
from ZEO.zrpc.trigger import trigger
import zLOG
......@@ -115,13 +115,32 @@ class Connection(smac.SizedMessageAsyncConnection):
__super_init = smac.SizedMessageAsyncConnection.__init__
__super_close = smac.SizedMessageAsyncConnection.close
protocol_version = "Z200"
# Protocol variables:
#
# oldest_protocol_version -- the oldest protocol version we support
# protocol_version -- the newest protocol version we support; preferred
oldest_protocol_version = "Z200"
protocol_version = "Z201"
# Protocol history:
#
# Z200 -- original ZEO 2.0 protocol
#
# Z201 -- added invalidateTransaction() to client;
# renamed several client methods;
# added lastTransaction() to server
def __init__(self, sock, addr, obj=None):
self.obj = None
self.marshal = Marshaller()
self.closed = 0
self.msgid = 0
self.peer_protocol_version = None # Set in recv_handshake()
if isinstance(addr, types.TupleType):
self.log_label = "zrpc-conn:%s:%d" % addr
else:
self.log_label = "zrpc-conn:%s" % addr
self.__super_init(sock, addr)
# A Connection either uses asyncore directly or relies on an
# asyncore mainloop running in a separate thread. If
......@@ -147,6 +166,9 @@ class Connection(smac.SizedMessageAsyncConnection):
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=zLOG.BLATHER, error=None):
zLOG.LOG(self.log_label, level, message, error=error)
def close(self):
if self.closed:
return
......@@ -156,7 +178,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.__super_close()
def close_trigger(self):
# overridden by ManagedConnection
# Overridden by ManagedConnection
if self.trigger is not None:
self.trigger.close()
......@@ -164,7 +186,9 @@ class Connection(smac.SizedMessageAsyncConnection):
"""Register obj as the true object to invoke methods on"""
self.obj = obj
def handshake(self):
def handshake(self, proto=None):
# Overridden by ManagedConnection
# When a connection is created the first message sent is a
# 4-byte protocol version. This mechanism should allow the
# protocol to evolve over time, and let servers handle clients
......@@ -174,17 +198,18 @@ class Connection(smac.SizedMessageAsyncConnection):
# first message received.
# The client sends the protocol version it is using.
self._message_input = self.message_input
self.message_input = self.recv_handshake
self.message_output(self.protocol_version)
def recv_handshake(self, message):
if message == self.protocol_version:
self.message_input = self._message_input
self.message_output(proto or self.protocol_version)
def recv_handshake(self, proto):
# Extended by ManagedConnection
del self.message_input
self.peer_protocol_version = proto
if self.oldest_protocol_version <= proto <= self.protocol_version:
self.log("received handshake %r" % proto, level=zLOG.INFO)
else:
log("recv_handshake: bad handshake %s" % short_repr(message),
level=zLOG.ERROR)
# otherwise do something else...
self.log("bad handshake %s" % short_repr(proto), level=zLOG.ERROR)
raise ZRPCError("bad handshake %r" % proto)
def message_input(self, message):
"""Decoding an incoming message and dispatch it"""
......@@ -195,7 +220,7 @@ class Connection(smac.SizedMessageAsyncConnection):
msgid, flags, name, args = self.marshal.decode(message)
if __debug__:
log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
self.log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
short_repr(args)),
level=zLOG.TRACE)
if name == REPLY:
......@@ -205,8 +230,8 @@ class Connection(smac.SizedMessageAsyncConnection):
def handle_reply(self, msgid, flags, args):
if __debug__:
log("recv reply: %s, %s, %s" % (msgid, flags, short_repr(args)),
level=zLOG.DEBUG)
self.log("recv reply: %s, %s, %s"
% (msgid, flags, short_repr(args)), level=zLOG.DEBUG)
self.replies_cond.acquire()
try:
self.replies[msgid] = flags, args
......@@ -219,7 +244,8 @@ class Connection(smac.SizedMessageAsyncConnection):
msg = "Invalid method name: %s on %s" % (name, repr(self.obj))
raise ZRPCError(msg)
if __debug__:
log("calling %s%s" % (name, short_repr(args)), level=zLOG.BLATHER)
self.log("calling %s%s" % (name, short_repr(args)),
level=zLOG.BLATHER)
meth = getattr(self.obj, name)
try:
......@@ -228,7 +254,7 @@ class Connection(smac.SizedMessageAsyncConnection):
raise
except Exception, msg:
error = sys.exc_info()
log("%s() raised exception: %s" % (name, msg), zLOG.INFO,
self.log("%s() raised exception: %s" % (name, msg), zLOG.INFO,
error=error)
error = error[:2]
return self.return_error(msgid, flags, *error)
......@@ -239,7 +265,7 @@ class Connection(smac.SizedMessageAsyncConnection):
(name, short_repr(ret)))
else:
if __debug__:
log("%s returns %s" % (name, short_repr(ret)), zLOG.DEBUG)
self.log("%s returns %s" % (name, short_repr(ret)), zLOG.DEBUG)
if isinstance(ret, Delay):
ret.set_sender(msgid, self.send_reply, self.return_error)
else:
......@@ -252,7 +278,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.close()
def log_error(self, msg="No error message supplied"):
log(msg, zLOG.ERROR, error=sys.exc_info())
self.log(msg, zLOG.ERROR, error=sys.exc_info())
def check_method(self, name):
# XXX Is this sufficient "security" for now?
......@@ -304,7 +330,7 @@ class Connection(smac.SizedMessageAsyncConnection):
finally:
self.msgid_lock.release()
if __debug__:
log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
zLOG.TRACE)
buf = self.marshal.encode(msgid, flags, method, args)
self.message_output(buf)
......@@ -342,7 +368,7 @@ class Connection(smac.SizedMessageAsyncConnection):
self.thr_async = 1
def is_async(self):
# overridden for ManagedConnection
# Overridden by ManagedConnection
if self.thr_async:
return 1
else:
......@@ -360,7 +386,7 @@ class Connection(smac.SizedMessageAsyncConnection):
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if __debug__:
log("wait(%d), async=%d" % (msgid, self.is_async()),
self.log("wait(%d), async=%d" % (msgid, self.is_async()),
level=zLOG.TRACE)
if self.is_async():
self._pull_trigger()
......@@ -378,8 +404,8 @@ class Connection(smac.SizedMessageAsyncConnection):
if reply is not None:
del self.replies[msgid]
if __debug__:
log("wait(%d): reply=%s" % (msgid, short_repr(reply)),
level=zLOG.DEBUG)
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=zLOG.DEBUG)
return reply
if self.is_async():
self.replies_cond.wait(10.0)
......@@ -388,14 +414,14 @@ class Connection(smac.SizedMessageAsyncConnection):
try:
try:
if __debug__:
log("wait(%d): asyncore.poll(%s)" %
self.log("wait(%d): asyncore.poll(%s)" %
(msgid, delay), level=zLOG.TRACE)
asyncore.poll(delay, self._map)
if delay < 1.0:
delay += delay
except select.error, err:
log("Closing. asyncore.poll() raised %s." % err,
level=zLOG.BLATHER)
self.log("Closing. asyncore.poll() raised %s."
% err, level=zLOG.BLATHER)
self.close()
finally:
self.replies_cond.acquire()
......@@ -405,7 +431,7 @@ class Connection(smac.SizedMessageAsyncConnection):
def poll(self):
"""Invoke asyncore mainloop to get pending message out."""
if __debug__:
log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE)
self.log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE)
if self.is_async():
self._pull_trigger()
else:
......@@ -414,7 +440,7 @@ class Connection(smac.SizedMessageAsyncConnection):
def pending(self):
"""Invoke mainloop until any pending messages are handled."""
if __debug__:
log("pending(), async=%d" % self.is_async(), level=zLOG.TRACE)
self.log("pending(), async=%d" % self.is_async(), level=zLOG.TRACE)
if self.is_async():
return
# Inline the asyncore poll() function to know whether any input
......@@ -465,6 +491,64 @@ class ManagedConnection(Connection):
self.__super_init(sock, addr, obj)
self.check_mgr_async()
# PROTOCOL NEGOTIATION:
#
# The code implementing protocol version 2.0.0 (which is deployed
# in the field and cannot be changed) *only* talks to peers that
# send a handshake indicating protocol version 2.0.0. In that
# version, both the client and the server immediately send out
# their protocol handshake when a connection is established,
# without waiting for their peer, and disconnect when a different
# handshake is receive.
#
# The new protocol uses this to enable new clients to talk to
# 2.0.0 servers: in the new protocol, the client waits until it
# receives the server's protocol handshake before sending its own
# handshake. The client sends the lower of its own protocol
# version and the server protocol version, allowing it to talk to
# servers using later protocol versions (2.0.2 and higher) as
# well: the effective protocol used will be the lower of the
# client and server protocol.
#
# The ZEO modules ClientStorage and ServerStub have backwards
# compatibility code for dealing with the previous version of the
# protocol. The client accept the old version of some messages,
# and will not send new messages when talking to an old server.
#
# As long as the client hasn't sent its handshake, it can't send
# anything else; output messages are queued during this time.
# (Output can happen because the connection testing machinery can
# start sending requests before the handshake is received.)
#
# UPGRADING FROM ZEO 2.0.0 TO NEWER VERSIONS:
#
# Because a new client can talk to an old server, but not vice
# versa, all clients should be upgraded before upgrading any
# servers. Protocol upgrades beyond 2.0.1 will not have this
# restriction, because clients using protocol 2.0.1 or later can
# talk to both older and newer servers.
#
# No compatibility with protocol version 1 is provided.
def handshake(self):
self.message_input = self.recv_handshake
self.message_output = self.queue_output
self.output_queue = []
# The handshake is sent by recv_handshake() below
def queue_output(self, message):
self.output_queue.append(message)
def recv_handshake(self, proto):
del self.message_output
proto = min(proto, self.protocol_version)
Connection.recv_handshake(self, proto) # Raise error if wrong proto
self.message_output(proto)
queue = self.output_queue
del self.output_queue
for message in queue:
self.message_output(message)
# Defer the ThreadedAsync work to the manager.
def close_trigger(self):
......
......@@ -16,6 +16,7 @@ import asyncore
import os
import socket
import thread
import errno
if os.name == 'posix':
......
......@@ -12,11 +12,9 @@
#
##############################################################################
"""Handy standard storage machinery
"""
# Do this portably in the face of checking out with -kv
import string
__version__ = string.split('$Revision: 1.30 $')[-2:][0]
$Id: BaseStorage.py,v 1.31 2003/01/03 22:07:43 jeremy Exp $
"""
import cPickle
import ThreadLock, bpthread
import time, UndoLogCompatible
......@@ -277,8 +275,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible):
restoring = 1
else:
restoring = 0
for transaction in other.iterator():
fiter = other.iterator()
for transaction in fiter:
tid=transaction.tid
if _ts is None:
_ts=TimeStamp(tid)
......@@ -313,6 +311,8 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible):
self.tpc_vote(transaction)
self.tpc_finish(transaction)
fiter.close()
class TransactionRecord:
"""Abstract base class for iterator protocol"""
......
......@@ -115,7 +115,7 @@
# may have a back pointer to a version record or to a non-version
# record.
#
__version__='$Revision: 1.123 $'[11:-2]
__version__='$Revision: 1.124 $'[11:-2]
import base64
from cPickle import Pickler, Unpickler, loads
......@@ -124,7 +124,7 @@ import os
import struct
import sys
import time
from types import StringType
from types import StringType, DictType
from struct import pack, unpack
try:
......@@ -137,7 +137,12 @@ from ZODB.POSException import UndoError, POSKeyError, MultipleUndoErrors
from ZODB.TimeStamp import TimeStamp
from ZODB.lock_file import lock_file
from ZODB.utils import p64, u64, cp, z64
from ZODB.fsIndex import fsIndex
try:
from ZODB.fsIndex import fsIndex
except ImportError:
def fsIndex():
return {}
from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC
......@@ -203,6 +208,8 @@ class FileStorage(BaseStorage.BaseStorage,
# default pack time is 0
_packt = z64
_records_before_save = 10000
def __init__(self, file_name, create=0, read_only=0, stop=None,
quota=None):
......@@ -270,7 +277,9 @@ class FileStorage(BaseStorage.BaseStorage,
r = self._restore_index()
if r is not None:
self._used_index = 1 # Marker for testing
index, vindex, start, maxoid, ltid = r
self._initIndex(index, vindex, tindex, tvindex)
self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop,
......@@ -278,10 +287,15 @@ class FileStorage(BaseStorage.BaseStorage,
read_only=read_only,
)
else:
self._used_index = 0 # Marker for testing
self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop,
read_only=read_only,
)
self._save_index()
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid
# self._pos should always point just past the last
......@@ -314,6 +328,7 @@ class FileStorage(BaseStorage.BaseStorage,
# hook to use something other than builtin dict
return fsIndex(), {}, {}, {}
_saved = 0
def _save_index(self):
"""Write the database index to a file to support quick startup."""
......@@ -329,6 +344,7 @@ class FileStorage(BaseStorage.BaseStorage,
p.dump(info)
f.flush()
f.close()
try:
try:
os.remove(index_name)
......@@ -337,6 +353,8 @@ class FileStorage(BaseStorage.BaseStorage,
os.rename(tmp_name, index_name)
except: pass
self._saved += 1
def _clear_index(self):
index_name = self.__name__ + '.index'
if os.path.exists(index_name):
......@@ -354,58 +372,77 @@ class FileStorage(BaseStorage.BaseStorage,
object positions cause zero to be returned.
"""
if pos < 100: return 0
file=self._file
seek=file.seek
read=file.read
if pos < 100:
return 0 # insane
file = self._file
seek = file.seek
read = file.read
seek(0,2)
if file.tell() < pos: return 0
ltid=None
if file.tell() < pos:
return 0 # insane
ltid = None
while 1:
max_checked = 5
checked = 0
while checked < max_checked:
seek(pos-8)
rstl=read(8)
tl=u64(rstl)
pos=pos-tl-8
if pos < 4: return 0
rstl = read(8)
tl = u64(rstl)
pos = pos-tl-8
if pos < 4:
return 0 # insane
seek(pos)
s = read(TRANS_HDR_LEN)
tid, stl, status, ul, dl, el = unpack(TRANS_HDR, s)
if not ltid: ltid=tid
if stl != rstl: return 0 # inconsistent lengths
if status == 'u': continue # undone trans, search back
if status not in ' p': return 0
if tl < (TRANS_HDR_LEN + ul + dl + el): return 0
tend=pos+tl
opos=pos+(TRANS_HDR_LEN + ul + dl + el)
if opos==tend: continue # empty trans
while opos < tend:
if not ltid:
ltid = tid
if stl != rstl:
return 0 # inconsistent lengths
if status == 'u':
continue # undone trans, search back
if status not in ' p':
return 0 # insane
if tl < (TRANS_HDR_LEN + ul + dl + el):
return 0 # insane
tend = pos+tl
opos = pos+(TRANS_HDR_LEN + ul + dl + el)
if opos == tend:
continue # empty trans
while opos < tend and checked < max_checked:
# Read the data records for this transaction
seek(opos)
h=read(DATA_HDR_LEN)
oid,serial,sprev,stloc,vlen,splen = unpack(DATA_HDR, h)
tloc=u64(stloc)
plen=u64(splen)
h = read(DATA_HDR_LEN)
oid, serial, sprev, stloc, vlen, splen = unpack(DATA_HDR, h)
tloc = u64(stloc)
plen = u64(splen)
dlen=DATA_HDR_LEN+(plen or 8)
if vlen: dlen=dlen+(16+vlen)
dlen = DATA_HDR_LEN+(plen or 8)
if vlen:
dlen = dlen+(16+vlen)
if opos+dlen > tend or tloc != pos: return 0
if opos+dlen > tend or tloc != pos:
return 0 # insane
if index.get(oid, 0) != opos: return 0
if index.get(oid, 0) != opos:
return 0 # insane
opos=opos+dlen
checked += 1
opos = opos+dlen
return ltid
def _restore_index(self):
"""Load database index to support quick startup."""
try:
f = open("%s.index" % self.__name__, 'rb')
except:
return None
p = Unpickler(f)
file_name=self.__name__
index_name=file_name+'.index'
try: f=open(index_name,'rb')
except: return None
p=Unpickler(f)
try:
info=p.load()
......@@ -422,6 +459,23 @@ class FileStorage(BaseStorage.BaseStorage,
return None
pos = long(pos)
if isinstance(index, DictType) and not self._is_read_only:
# Convert to fsIndex
newindex = fsIndex()
if type(newindex) is not type(index):
# And we have fsIndex
newindex.update(index)
# Now save the index
f = open(index_name, 'wb')
p = Pickler(f, 1)
info['index'] = newindex
p.dump(info)
f.close()
# Now call this method again to get the new data
return self._restore_index()
tid = self._sane(index, pos)
if not tid:
return None
......@@ -955,6 +1009,9 @@ class FileStorage(BaseStorage.BaseStorage,
finally:
self._lock_release()
# Keep track of the number of records that we've written
_records_written = 0
def _finish(self, tid, u, d, e):
nextpos=self._nextpos
if nextpos:
......@@ -967,10 +1024,20 @@ class FileStorage(BaseStorage.BaseStorage,
if fsync is not None: fsync(file.fileno())
self._pos=nextpos
self._pos = nextpos
self._index.update(self._tindex)
self._vindex.update(self._tvindex)
# Update the number of records that we've written
# +1 for the transaction record
self._records_written += len(self._tindex) + 1
if self._records_written >= self._records_before_save:
self._save_index()
self._records_written = 0
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid
def _abort(self):
......@@ -1210,7 +1277,8 @@ class FileStorage(BaseStorage.BaseStorage,
while pos < tend:
self._file.seek(pos)
h = self._file.read(DATA_HDR_LEN)
oid, serial, sprev, stloc, vlen, splen = struct.unpack(DATA_HDR, h)
oid, serial, sprev, stloc, vlen, splen = \
struct.unpack(DATA_HDR, h)
if failed(oid):
del failures[oid] # second chance!
plen = u64(splen)
......@@ -1966,6 +2034,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
id in the data. The transaction id is the tid of the last
transaction.
"""
read = file.read
seek = file.seek
seek(0, 2)
......@@ -2001,7 +2070,7 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
if tid <= ltid:
warn("%s time-stamp reduction at %s", name, pos)
ltid=tid
ltid = tid
tl=u64(stl)
......@@ -2074,7 +2143,12 @@ def read_index(file, name, index, vindex, tindex, stop='\377'*8,
if vlen:
dlen=dlen+(16+vlen)
read(16)
pv=u64(read(8))
version=read(vlen)
# Jim says: "It's just not worth the bother."
#if vndexpos(version, 0) != pv:
# panic("%s incorrect previous version pointer at %s",
# name, pos)
vindex[version]=pos
if pos+dlen > tend or tloc != tpos:
......@@ -2223,8 +2297,11 @@ class FileIterator(Iterator):
self._stop = stop
def __len__(self):
# This is a lie. It's here only for Python 2.1 support for
# list()-ifying these objects.
# Define a bogus __len__() to make the iterator work
# with code like builtin list() and tuple() in Python 2.1.
# There's a lot of C code that expects a sequence to have
# an __len__() but can cope with any sort of mistake in its
# implementation. So just return 0.
return 0
def close(self):
......@@ -2362,7 +2439,6 @@ class FileIterator(Iterator):
class RecordIterator(Iterator, BaseStorage.TransactionRecord):
"""Iterate over the transactions in a FileStorage file."""
def __init__(self, tid, status, user, desc, ext, pos, tend, file, tpos):
self.tid = tid
self.status = status
......
......@@ -38,6 +38,7 @@ class IteratorCompare:
eq(zodb_unpickle(rec.data), MinPO(val))
val = val + 1
eq(val, val0 + len(revids))
txniter.close()
class IteratorStorage(IteratorCompare):
......@@ -191,3 +192,5 @@ class IteratorDeepCompare:
# they were the same length
self.assertRaises(IndexError, iter1.next)
self.assertRaises(IndexError, iter2.next)
iter1.close()
iter2.close()
......@@ -72,6 +72,101 @@ class FileStorageTests(
else:
self.fail("expect long user field to raise error")
def check_use_fsIndex(self):
from ZODB.fsIndex import fsIndex
self.assertEqual(self._storage._index.__class__, fsIndex)
# XXX We could really use some tests for sanity checking
def check_conversion_to_fsIndex_not_if_readonly(self):
self.tearDown()
class OldFileStorage(ZODB.FileStorage.FileStorage):
def _newIndexes(self):
return {}, {}, {}, {}
from ZODB.fsIndex import fsIndex
# Hack FileStorage to create dictionary indexes
self._storage = OldFileStorage('FileStorageTests.fs')
self.assertEqual(type(self._storage._index), type({}))
for i in range(10):
self._dostore()
# Should save the index
self._storage.close()
self._storage = ZODB.FileStorage.FileStorage(
'FileStorageTests.fs', read_only=1)
self.assertEqual(type(self._storage._index), type({}))
def check_conversion_to_fsIndex(self):
self.tearDown()
class OldFileStorage(ZODB.FileStorage.FileStorage):
def _newIndexes(self):
return {}, {}, {}, {}
from ZODB.fsIndex import fsIndex
# Hack FileStorage to create dictionary indexes
self._storage = OldFileStorage('FileStorageTests.fs')
self.assertEqual(type(self._storage._index), type({}))
for i in range(10):
self._dostore()
oldindex = self._storage._index.copy()
# Should save the index
self._storage.close()
self._storage = ZODB.FileStorage.FileStorage('FileStorageTests.fs')
self.assertEqual(self._storage._index.__class__, fsIndex)
self.failUnless(self._storage._used_index)
index = {}
for k, v in self._storage._index.items():
index[k] = v
self.assertEqual(index, oldindex)
def check_save_after_load_with_no_index(self):
for i in range(10):
self._dostore()
self._storage.close()
os.remove('FileStorageTests.fs.index')
self.open()
self.assertEqual(self._storage._saved, 1)
# This would make the unit tests too slow
# check_save_after_load_that_worked_hard(self)
def check_periodic_save_index(self):
# Check the basic algorithm
oldsaved = self._storage._saved
self._storage._records_before_save = 10
for i in range(4):
self._dostore()
self.assertEqual(self._storage._saved, oldsaved)
self._dostore()
self.assertEqual(self._storage._saved, oldsaved+1)
# Now make sure the parameter changes as we get bigger
for i in range(20):
self._dostore()
self.failUnless(self._storage._records_before_save > 20)
class FileStorageRecoveryTest(
StorageTestBase.StorageTestBase,
RecoveryStorage.RecoveryStorage,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment