Commit 4d86e4e0 authored by Jeremy Hylton's avatar Jeremy Hylton

Merge ZODB 3.1 changes to the trunk.

XXX Not sure if berkeley still works.
parent e650766b
...@@ -28,9 +28,11 @@ ClientDisconnected -- exception raised by ClientStorage ...@@ -28,9 +28,11 @@ ClientDisconnected -- exception raised by ClientStorage
import cPickle import cPickle
import os import os
import socket
import tempfile import tempfile
import threading import threading
import time import time
import types
from ZEO import ClientCache, ServerStub from ZEO import ClientCache, ServerStub
from ZEO.TransactionBuffer import TransactionBuffer from ZEO.TransactionBuffer import TransactionBuffer
...@@ -204,6 +206,8 @@ class ClientStorage: ...@@ -204,6 +206,8 @@ class ClientStorage:
self._storage = storage self._storage = storage
self._read_only_fallback = read_only_fallback self._read_only_fallback = read_only_fallback
self._connection = None self._connection = None
# _server_addr is used by sortKey()
self._server_addr = None
self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client', self._info = {'length': 0, 'size': 0, 'name': 'ZEO Client',
'supportsUndo':0, 'supportsVersions': 0, 'supportsUndo':0, 'supportsVersions': 0,
...@@ -339,6 +343,7 @@ class ClientStorage: ...@@ -339,6 +343,7 @@ class ClientStorage:
log2(INFO, "Reconnected to storage") log2(INFO, "Reconnected to storage")
else: else:
log2(INFO, "Connected to storage") log2(INFO, "Connected to storage")
self.set_server_addr(conn.get_addr())
stub = self.StorageServerStubClass(conn) stub = self.StorageServerStubClass(conn)
self._oids = [] self._oids = []
self._info.update(stub.get_info()) self._info.update(stub.get_info())
...@@ -350,6 +355,33 @@ class ClientStorage: ...@@ -350,6 +355,33 @@ class ClientStorage:
self._connection = conn self._connection = conn
self._server = stub self._server = stub
def set_server_addr(self, addr):
# Normalize server address and convert to string
if isinstance(addr, types.StringType):
self._server_addr = addr
else:
assert isinstance(addr, types.TupleType)
# If the server is on a remote host, we need to guarantee
# that all clients used the same name for the server. If
# they don't, the sortKey() may be different for each client.
# The best solution seems to be the official name reported
# by gethostbyaddr().
host = addr[0]
try:
canonical, aliases, addrs = socket.gethostbyaddr(host)
except socket.error, err:
log2(BLATHER, "Error resoving host: %s (%s)" % (host, err))
canonical = host
self._server_addr = str((canonical, addr[1]))
def sortKey(self):
# If the client isn't connected to anything, it can't have a
# valid sortKey(). Raise an error to stop the transaction early.
if self._server_addr is None:
raise ClientDisconnected
else:
return self._server_addr
def verify_cache(self, server): def verify_cache(self, server):
"""Internal routine called to verify the cache.""" """Internal routine called to verify the cache."""
# XXX beginZeoVerify ends up calling back to beginVerify() below. # XXX beginZeoVerify ends up calling back to beginVerify() below.
...@@ -622,11 +654,15 @@ class ClientStorage: ...@@ -622,11 +654,15 @@ class ClientStorage:
"""Internal helper to end a transaction.""" """Internal helper to end a transaction."""
# the right way to set self._transaction to None # the right way to set self._transaction to None
# calls notify() on _tpc_cond in case there are waiting threads # calls notify() on _tpc_cond in case there are waiting threads
self._ltid = self._serial
self._tpc_cond.acquire() self._tpc_cond.acquire()
self._transaction = None self._transaction = None
self._tpc_cond.notify() self._tpc_cond.notify()
self._tpc_cond.release() self._tpc_cond.release()
def lastTransaction(self):
return self._ltid
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Storage API: abort a transaction.""" """Storage API: abort a transaction."""
if transaction is not self._transaction: if transaction is not self._transaction:
......
...@@ -206,17 +206,16 @@ class ZEOStorage: ...@@ -206,17 +206,16 @@ class ZEOStorage:
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.connection = None
self.client = None self.client = None
self.storage = None self.storage = None
self.storage_id = "uninitialized" self.storage_id = "uninitialized"
self.transaction = None self.transaction = None
self.read_only = read_only self.read_only = read_only
self.timeout = TimeoutThread()
self.timeout.start()
def notifyConnected(self, conn): def notifyConnected(self, conn):
self.connection = conn # For restart_other() below
self.client = self.ClientStorageStubClass(conn) self.client = self.ClientStorageStubClass(conn)
self.timeout.notifyConnected(conn)
def notifyDisconnected(self): def notifyDisconnected(self):
# When this storage closes, we must ensure that it aborts # When this storage closes, we must ensure that it aborts
...@@ -226,7 +225,6 @@ class ZEOStorage: ...@@ -226,7 +225,6 @@ class ZEOStorage:
self.abort() self.abort()
else: else:
self.log("disconnected") self.log("disconnected")
self.timeout.notifyDisconnected()
def __repr__(self): def __repr__(self):
tid = self.transaction and repr(self.transaction.id) tid = self.transaction and repr(self.transaction.id)
...@@ -416,11 +414,6 @@ class ZEOStorage: ...@@ -416,11 +414,6 @@ class ZEOStorage:
" requests from one client.") " requests from one client.")
# (This doesn't require a lock because we're using asyncore) # (This doesn't require a lock because we're using asyncore)
if self.storage._transaction is None:
self.strategy = self.ImmediateCommitStrategyClass(self.storage,
self.client)
self.timeout.begin()
else:
self.strategy = self.DelayedCommitStrategyClass(self.storage, self.strategy = self.DelayedCommitStrategyClass(self.storage,
self.wait) self.wait)
...@@ -436,7 +429,6 @@ class ZEOStorage: ...@@ -436,7 +429,6 @@ class ZEOStorage:
def tpc_finish(self, id): def tpc_finish(self, id):
if not self.check_tid(id): if not self.check_tid(id):
return return
self.timeout.end()
invalidated = self.strategy.tpc_finish() invalidated = self.strategy.tpc_finish()
if invalidated: if invalidated:
self.server.invalidate(self, self.storage_id, self.server.invalidate(self, self.storage_id,
...@@ -448,7 +440,6 @@ class ZEOStorage: ...@@ -448,7 +440,6 @@ class ZEOStorage:
def tpc_abort(self, id): def tpc_abort(self, id):
if not self.check_tid(id): if not self.check_tid(id):
return return
self.timeout.end()
strategy = self.strategy strategy = self.strategy
strategy.tpc_abort() strategy.tpc_abort()
self.transaction = None self.transaction = None
...@@ -469,9 +460,7 @@ class ZEOStorage: ...@@ -469,9 +460,7 @@ class ZEOStorage:
def vote(self, id): def vote(self, id):
self.check_tid(id, exc=StorageTransactionError) self.check_tid(id, exc=StorageTransactionError)
r = self.strategy.tpc_vote() return self.strategy.tpc_vote()
self.timeout.begin()
return r
def abortVersion(self, src, id): def abortVersion(self, src, id):
self.check_tid(id, exc=StorageTransactionError) self.check_tid(id, exc=StorageTransactionError)
...@@ -503,8 +492,10 @@ class ZEOStorage: ...@@ -503,8 +492,10 @@ class ZEOStorage:
"Clients waiting: %d." % len(self.storage._waiting)) "Clients waiting: %d." % len(self.storage._waiting))
return d return d
else: else:
self.restart() return self.restart()
return None
def dontwait(self):
return self.restart()
def handle_waiting(self): def handle_waiting(self):
while self.storage._waiting: while self.storage._waiting:
...@@ -526,7 +517,7 @@ class ZEOStorage: ...@@ -526,7 +517,7 @@ class ZEOStorage:
except: except:
self.log("Unexpected error handling waiting transaction", self.log("Unexpected error handling waiting transaction",
level=zLOG.WARNING, error=sys.exc_info()) level=zLOG.WARNING, error=sys.exc_info())
zeo_storage._conn.close() zeo_storage.connection.close()
return 0 return 0
else: else:
return 1 return 1
...@@ -539,6 +530,8 @@ class ZEOStorage: ...@@ -539,6 +530,8 @@ class ZEOStorage:
resp = old_strategy.restart(self.strategy) resp = old_strategy.restart(self.strategy)
if delay is not None: if delay is not None:
delay.reply(resp) delay.reply(resp)
else:
return resp
# A ZEOStorage instance can use different strategies to commit a # A ZEOStorage instance can use different strategies to commit a
# transaction. The current implementation uses different strategies # transaction. The current implementation uses different strategies
...@@ -768,79 +761,6 @@ class SlowMethodThread(threading.Thread): ...@@ -768,79 +761,6 @@ class SlowMethodThread(threading.Thread):
else: else:
self.delay.reply(result) self.delay.reply(result)
class TimeoutThread(threading.Thread):
# A TimeoutThread is associated with a ZEOStorage. It trackes
# how long transactions take to commit. If a transaction takes
# too long, it will close the connection.
TIMEOUT = 30
def __init__(self):
threading.Thread.__init__(self)
self._lock = threading.Lock()
self._timestamp = None
self._conn = None
def begin(self):
self._lock.acquire()
try:
self._timestamp = time.time()
finally:
self._lock.release()
def end(self):
self._lock.acquire()
try:
self._timestamp = None
finally:
self._lock.release()
# There's a race here, but I hope it is harmless.
def notifyConnected(self, conn):
self._conn = conn
def notifyDisconnected(self):
self._conn = None
def run(self):
timeout = self.TIMEOUT
while self._conn is not None:
time.sleep(timeout)
self._lock.acquire()
try:
if self._timestamp is not None:
deadline = self._timestamp + self.TIMEOUT
else:
log("TimeoutThread no current transaction",
zLOG.BLATHER)
timeout = self.TIMEOUT
continue
finally:
self._lock.release()
timeout = deadline - time.time()
if deadline < time.time():
self._abort()
break
else:
elapsed = self.TIMEOUT - timeout
log("TimeoutThread transaction has %0.2f sec to complete"
" (%.2f elapsed)" % (timeout, elapsed), zLOG.BLATHER)
log("TimeoutThread exiting. Connection closed.", zLOG.BLATHER)
def _abort(self):
# It's possible for notifyDisconnected to remove the connection
# just before we use it. I think that's harmless, since it means
# the connection was closed.
log("TimeoutThread aborting transaction", zLOG.WARNING)
try:
self._conn.close()
except AttributeError, msg:
log(msg)
# Patch up class references # Patch up class references
StorageServer.ZEOStorageClass = ZEOStorage StorageServer.ZEOStorageClass = ZEOStorage
ZEOStorage.DelayedCommitStrategyClass = DelayedCommitStrategy ZEOStorage.DelayedCommitStrategyClass = DelayedCommitStrategy
......
...@@ -17,6 +17,8 @@ from __future__ import nested_scopes ...@@ -17,6 +17,8 @@ from __future__ import nested_scopes
import sys, os, getopt import sys, os, getopt
import types import types
import errno
import socket
def directory(p, n=1): def directory(p, n=1):
d = p d = p
......
...@@ -124,9 +124,12 @@ class CommitLockTests: ...@@ -124,9 +124,12 @@ class CommitLockTests:
# started, but before it finishes. The dowork() function # started, but before it finishes. The dowork() function
# executes after the first transaction has completed. # executes after the first transaction has completed.
# Start on transaction normally. # Start on transaction normally and get the lock.
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oid = self._storage.new_oid()
self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', t)
self._storage.tpc_vote(t)
# Start a second transaction on a different connection without # Start a second transaction on a different connection without
# blocking the test thread. # blocking the test thread.
...@@ -141,9 +144,6 @@ class CommitLockTests: ...@@ -141,9 +144,6 @@ class CommitLockTests:
else: else:
self._storages.append((storage2, t2)) self._storages.append((storage2, t2))
oid = self._storage.new_oid()
self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', t)
self._storage.tpc_vote(t)
if method_name == "tpc_finish": if method_name == "tpc_finish":
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
self._storage.load(oid, '') self._storage.load(oid, '')
......
...@@ -348,13 +348,22 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -348,13 +348,22 @@ class Connection(smac.SizedMessageAsyncConnection):
else: else:
return 0 return 0
def _pull_trigger(self, tryagain=10):
try:
self.trigger.pull_trigger()
except OSError, e:
self.trigger.close()
self.trigger = trigger()
if tryagain > 0:
self._pull_trigger(tryagain=tryagain-1)
def wait(self, msgid): def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply.""" """Invoke asyncore mainloop and wait for reply."""
if __debug__: if __debug__:
log("wait(%d), async=%d" % (msgid, self.is_async()), log("wait(%d), async=%d" % (msgid, self.is_async()),
level=zLOG.TRACE) level=zLOG.TRACE)
if self.is_async(): if self.is_async():
self.trigger.pull_trigger() self._pull_trigger()
# Delay used when we call asyncore.poll() directly. # Delay used when we call asyncore.poll() directly.
# Start with a 1 msec delay, double until 1 sec. # Start with a 1 msec delay, double until 1 sec.
...@@ -398,7 +407,7 @@ class Connection(smac.SizedMessageAsyncConnection): ...@@ -398,7 +407,7 @@ class Connection(smac.SizedMessageAsyncConnection):
if __debug__: if __debug__:
log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE) log("poll(), async=%d" % self.is_async(), level=zLOG.TRACE)
if self.is_async(): if self.is_async():
self.trigger.pull_trigger() self._pull_trigger()
else: else:
asyncore.poll(0.0, self._map) asyncore.poll(0.0, self._map)
......
...@@ -77,6 +77,9 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -77,6 +77,9 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
self.__closed = 0 self.__closed = 0
self.__super_init(sock, map) self.__super_init(sock, map)
def get_addr(self):
return self.addr
# XXX avoid expensive getattr calls? Can't remember exactly what # XXX avoid expensive getattr calls? Can't remember exactly what
# this comment was supposed to mean, but it has something to do # this comment was supposed to mean, but it has something to do
# with the way asyncore uses getattr and uses if sock: # with the way asyncore uses getattr and uses if sock:
......
...@@ -16,6 +16,7 @@ import asyncore ...@@ -16,6 +16,7 @@ import asyncore
import os import os
import socket import socket
import thread import thread
import errno
if os.name == 'posix': if os.name == 'posix':
...@@ -71,6 +72,7 @@ if os.name == 'posix': ...@@ -71,6 +72,7 @@ if os.name == 'posix':
self.del_channel() self.del_channel()
for fd in self._fds: for fd in self._fds:
os.close(fd) os.close(fd)
self._fds = []
def __repr__(self): def __repr__(self):
return '<select-trigger (pipe) at %x>' % id(self) return '<select-trigger (pipe) at %x>' % id(self)
...@@ -84,6 +86,9 @@ if os.name == 'posix': ...@@ -84,6 +86,9 @@ if os.name == 'posix':
def handle_connect(self): def handle_connect(self):
pass pass
def handle_close(self):
self.close()
def pull_trigger(self, thunk=None): def pull_trigger(self, thunk=None):
if thunk: if thunk:
self.lock.acquire() self.lock.acquire()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" """
# Do this portably in the face of checking out with -kv # Do this portably in the face of checking out with -kv
import string import string
__version__ = string.split('$Revision: 1.27 $')[-2:][0] __version__ = string.split('$Revision: 1.28 $')[-2:][0]
import cPickle import cPickle
import ThreadLock, bpthread import ThreadLock, bpthread
...@@ -63,6 +63,15 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible): ...@@ -63,6 +63,15 @@ class BaseStorage(UndoLogCompatible.UndoLogCompatible):
def close(self): def close(self):
pass pass
def sortKey(self):
"""Return a string that can be used to sort storage instances.
The key must uniquely identify a storage and must be the same
across multiple instantiations of the same storage.
"""
# name may not be sufficient, e.g. ZEO has a user-definable name.
return self.__name__
def getName(self): def getName(self):
return self.__name__ return self.__name__
......
...@@ -78,7 +78,7 @@ def load_class(class_tuple): ...@@ -78,7 +78,7 @@ def load_class(class_tuple):
except (ImportError, AttributeError): except (ImportError, AttributeError):
zLOG.LOG("Conflict Resolution", zLOG.BLATHER, zLOG.LOG("Conflict Resolution", zLOG.BLATHER,
"Unable to load class", error=sys.exc_info()) "Unable to load class", error=sys.exc_info())
bad_class[class_tuple] = 1 bad_classes[class_tuple] = 1
return None return None
return klass return klass
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.78 2002/10/23 19:18:35 jeremy Exp $""" $Id: Connection.py,v 1.79 2002/11/18 23:17:40 jeremy Exp $"""
from cPickleCache import PickleCache from cPickleCache import PickleCache
from POSException import ConflictError, ReadConflictError from POSException import ConflictError, ReadConflictError
...@@ -184,6 +184,14 @@ class Connection(ExportImport.ExportImport): ...@@ -184,6 +184,14 @@ class Connection(ExportImport.ExportImport):
return obj return obj
return self[oid] return self[oid]
def sortKey(self):
# XXX will raise an exception if the DB hasn't been set
storage_key = self._sortKey()
# If two connections use the same storage, give them a
# consistent order using id(). This is unique for the
# lifetime of a connection, which is good enough.
return "%s:%s" % (storage_key, id(self))
def _setDB(self, odb): def _setDB(self, odb):
"""Begin a new transaction. """Begin a new transaction.
...@@ -191,6 +199,7 @@ class Connection(ExportImport.ExportImport): ...@@ -191,6 +199,7 @@ class Connection(ExportImport.ExportImport):
""" """
self._db=odb self._db=odb
self._storage=s=odb._storage self._storage=s=odb._storage
self._sortKey = odb._storage.sortKey
self.new_oid=s.new_oid self.new_oid=s.new_oid
if self._code_timestamp != global_code_timestamp: if self._code_timestamp != global_code_timestamp:
# New code is in place. Start a new cache. # New code is in place. Start a new cache.
...@@ -261,27 +270,8 @@ class Connection(ExportImport.ExportImport): ...@@ -261,27 +270,8 @@ class Connection(ExportImport.ExportImport):
self.__onCommitActions.append((method_name, args, kw)) self.__onCommitActions.append((method_name, args, kw))
get_transaction().register(self) get_transaction().register(self)
# NB: commit() is responsible for calling tpc_begin() on the storage.
# It uses self._begun to track whether it has been called. When
# self._begun is 0, it has not been called.
# This arrangement allows us to handle the special case of a
# transaction with no modified objects. It is possible for
# registration to be occur unintentionally and for a persistent
# object to compensate by making itself as unchanged. When this
# happens, it's possible to commit a transaction with no modified
# objects.
# Since tpc_begin() may raise a ReadOnlyError, don't call it if there
# are no objects. This avoids spurious (?) errors when working with
# a read-only storage.
def commit(self, object, transaction): def commit(self, object, transaction):
if object is self: if object is self:
if not self._begun:
self._storage.tpc_begin(transaction)
self._begun = 1
# We registered ourself. Execute a commit action, if any. # We registered ourself. Execute a commit action, if any.
if self.__onCommitActions is not None: if self.__onCommitActions is not None:
method_name, args, kw = self.__onCommitActions.pop(0) method_name, args, kw = self.__onCommitActions.pop(0)
...@@ -306,10 +296,6 @@ class Connection(ExportImport.ExportImport): ...@@ -306,10 +296,6 @@ class Connection(ExportImport.ExportImport):
# Nothing to do # Nothing to do
return return
if not self._begun:
self._storage.tpc_begin(transaction)
self._begun = 1
stack = [object] stack = [object]
# Create a special persistent_id that passes T and the subobject # Create a special persistent_id that passes T and the subobject
...@@ -616,8 +602,6 @@ class Connection(ExportImport.ExportImport): ...@@ -616,8 +602,6 @@ class Connection(ExportImport.ExportImport):
def tpc_begin(self, transaction, sub=None): def tpc_begin(self, transaction, sub=None):
self._invalidating = [] self._invalidating = []
self._creating = [] self._creating = []
self._begun = 0
if sub: if sub:
# Sub-transaction! # Sub-transaction!
if self._tmp is None: if self._tmp is None:
...@@ -626,10 +610,7 @@ class Connection(ExportImport.ExportImport): ...@@ -626,10 +610,7 @@ class Connection(ExportImport.ExportImport):
self._storage = _tmp self._storage = _tmp
_tmp.registerDB(self._db, 0) _tmp.registerDB(self._db, 0)
# It's okay to always call tpc_begin() for a sub-transaction
# because this isn't the real storage.
self._storage.tpc_begin(transaction) self._storage.tpc_begin(transaction)
self._begun = 1
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
if self.__onCommitActions is not None: if self.__onCommitActions is not None:
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
############################################################################## ##############################################################################
"""Database objects """Database objects
$Id: DB.py,v 1.44 2002/10/23 19:08:36 jeremy Exp $""" $Id: DB.py,v 1.45 2002/11/18 23:17:40 jeremy Exp $"""
__version__='$Revision: 1.44 $'[11:-2] __version__='$Revision: 1.45 $'[11:-2]
import cPickle, cStringIO, sys, POSException, UndoLogCompatible import cPickle, cStringIO, sys, POSException, UndoLogCompatible
from Connection import Connection from Connection import Connection
...@@ -578,8 +578,12 @@ class CommitVersion: ...@@ -578,8 +578,12 @@ class CommitVersion:
self.tpc_begin=s.tpc_begin self.tpc_begin=s.tpc_begin
self.tpc_vote=s.tpc_vote self.tpc_vote=s.tpc_vote
self.tpc_finish=s.tpc_finish self.tpc_finish=s.tpc_finish
self._sortKey=s.sortKey
get_transaction().register(self) get_transaction().register(self)
def sortKey(self):
return "%s:%s" % (self._sortKey(), id(self))
def abort(self, reallyme, t): pass def abort(self, reallyme, t): pass
def commit(self, reallyme, t): def commit(self, reallyme, t):
......
...@@ -115,7 +115,7 @@ ...@@ -115,7 +115,7 @@
# may have a back pointer to a version record or to a non-version # may have a back pointer to a version record or to a non-version
# record. # record.
# #
__version__='$Revision: 1.115 $'[11:-2] __version__='$Revision: 1.116 $'[11:-2]
import base64 import base64
from cPickle import Pickler, Unpickler, loads from cPickle import Pickler, Unpickler, loads
...@@ -316,9 +316,6 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -316,9 +316,6 @@ class FileStorage(BaseStorage.BaseStorage,
# hook to use something other than builtin dict # hook to use something other than builtin dict
return {}, {}, {}, {} return {}, {}, {}, {}
def abortVersion(self, src, transaction):
return self.commitVersion(src, '', transaction, abort=1)
def _save_index(self): def _save_index(self):
"""Write the database index to a file to support quick startup """Write the database index to a file to support quick startup
""" """
...@@ -446,6 +443,9 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -446,6 +443,9 @@ class FileStorage(BaseStorage.BaseStorage,
# XXX should log the error, though # XXX should log the error, though
pass # We don't care if this fails. pass # We don't care if this fails.
def abortVersion(self, src, transaction):
return self.commitVersion(src, '', transaction, abort=1)
def commitVersion(self, src, dest, transaction, abort=None): def commitVersion(self, src, dest, transaction, abort=None):
# We are going to commit by simply storing back pointers. # We are going to commit by simply storing back pointers.
if self._is_read_only: if self._is_read_only:
...@@ -526,6 +526,9 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -526,6 +526,9 @@ class FileStorage(BaseStorage.BaseStorage,
here += heredelta here += heredelta
current_oids[oid] = 1 current_oids[oid] = 1
# Once we've found the data we are looking for,
# we can stop chasing backpointers.
break
else: else:
# Hm. This is a non-current record. Is there a # Hm. This is a non-current record. Is there a
...@@ -768,9 +771,13 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -768,9 +771,13 @@ class FileStorage(BaseStorage.BaseStorage,
if vl: if vl:
self._file.read(vl + 16) self._file.read(vl + 16)
# Make sure this looks like the right data record # Make sure this looks like the right data record
if dl == 0:
# This is also a backpointer. Gotta trust it.
return pos
if dl != len(data): if dl != len(data):
# XXX what if this data record also has a backpointer? # The expected data doesn't match what's in the
# I don't think that's possible, but I'm not sure. # backpointer. Something is wrong.
error("Mismatch between data and backpointer at %d", pos)
return 0 return 0
_data = self._file.read(dl) _data = self._file.read(dl)
if data != _data: if data != _data:
...@@ -828,20 +835,7 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -828,20 +835,7 @@ class FileStorage(BaseStorage.BaseStorage,
# We need to write some version information if this revision is # We need to write some version information if this revision is
# happening in a version. # happening in a version.
if version: if version:
pnv = None pnv = self._restore_pnv(oid, old, version, prev_pos)
# We need to write the position of the non-version data.
# If the previous revision of the object was in a version,
# then it will contain a pnv record. Otherwise, the
# previous record is the non-version data.
if old:
self._file.seek(old)
h = self._file.read(42)
doid, x, y, z, vlen, w = unpack(DATA_HDR, h)
if doid != oid:
raise CorruptedDataError, h
# XXX assert versions match?
if vlen > 0:
pnv = self._file.read(8)
if pnv: if pnv:
self._tfile.write(pnv) self._tfile.write(pnv)
else: else:
...@@ -853,20 +847,65 @@ class FileStorage(BaseStorage.BaseStorage, ...@@ -853,20 +847,65 @@ class FileStorage(BaseStorage.BaseStorage,
self._tfile.write(p64(pv)) self._tfile.write(p64(pv))
self._tvindex[version] = here self._tvindex[version] = here
self._tfile.write(version) self._tfile.write(version)
# And finally, write the data # And finally, write the data or a backpointer
if data is None: if data is None:
if prev_pos: if prev_pos:
self._tfile.write(p64(prev_pos)) self._tfile.write(p64(prev_pos))
else: else:
# Write a zero backpointer, which indicates an # Write a zero backpointer, which indicates an
# un-creation transaction. # un-creation transaction.
# write a backpointer instead of data
self._tfile.write(z64) self._tfile.write(z64)
else: else:
self._tfile.write(data) self._tfile.write(data)
finally: finally:
self._lock_release() self._lock_release()
def _restore_pnv(self, oid, prev, version, bp):
# Find a valid pnv (previous non-version) pointer for this version.
# If there is no previous record, there can't be a pnv.
if not prev:
return None
pnv = None
# Load the record pointed to be prev
self._file.seek(prev)
h = self._file.read(DATA_HDR_LEN)
doid, x, y, z, vlen, w = unpack(DATA_HDR, h)
if doid != oid:
raise CorruptedDataError, h
# If the previous record is for a version, it must have
# a valid pnv.
if vlen > 0:
pnv = self._file.read(8)
pv = self._file.read(8)
v = self._file.read(vlen)
elif bp:
# XXX Not sure the following is always true:
# The previous record is not for this version, yet we
# have a backpointer to it. The current record must
# be an undo of an abort or commit, so the backpointer
# must be to a version record with a pnv.
self._file.seek(bp)
h2 = self._file.read(DATA_HDR_LEN)
doid2, x, y, z, vlen2, sdl = unpack(DATA_HDR, h2)
dl = U64(sdl)
if oid != doid2:
raise CorruptedDataError, h2
if vlen2 > 0:
pnv = self._file.read(8)
pv = self._file.read(8)
v = self._file.read(8)
else:
warn("restore could not find previous non-version data "
"at %d or %d" % (prev, bp))
return pnv
def supportsUndo(self):
return 1
def supportsVersions(self): def supportsVersions(self):
return 1 return 1
...@@ -2097,7 +2136,8 @@ def _loadBack_impl(file, oid, back): ...@@ -2097,7 +2136,8 @@ def _loadBack_impl(file, oid, back):
doid, serial, prev, tloc, vlen, plen = unpack(DATA_HDR, h) doid, serial, prev, tloc, vlen, plen = unpack(DATA_HDR, h)
if vlen: if vlen:
file.seek(vlen + 16, 1) file.read(16)
version = file.read(vlen)
if plen != z64: if plen != z64:
return file.read(U64(plen)), serial, old, tloc return file.read(U64(plen)), serial, old, tloc
back = file.read(8) # We got a back pointer! back = file.read(8) # We got a back pointer!
...@@ -2120,6 +2160,17 @@ def _loadBackTxn(file, oid, back): ...@@ -2120,6 +2160,17 @@ def _loadBackTxn(file, oid, back):
tid = h[:8] tid = h[:8]
return data, serial, tid return data, serial, tid
def getTxnFromData(file, oid, back):
"""Return transaction id for data at back."""
file.seek(U64(back))
h = file.read(DATA_HDR_LEN)
doid, serial, prev, stloc, vlen, plen = unpack(DATA_HDR, h)
assert oid == doid
tloc = U64(stloc)
file.seek(tloc)
# seek to transaction header, where tid is first 8 bytes
return file.read(8)
def _truncate(file, name, pos): def _truncate(file, name, pos):
seek=file.seek seek=file.seek
seek(0,2) seek(0,2)
...@@ -2336,40 +2387,48 @@ class RecordIterator(Iterator, BaseStorage.TransactionRecord): ...@@ -2336,40 +2387,48 @@ class RecordIterator(Iterator, BaseStorage.TransactionRecord):
self._file.seek(pos) self._file.seek(pos)
h = self._file.read(DATA_HDR_LEN) h = self._file.read(DATA_HDR_LEN)
oid, serial, sprev, stloc, vlen, splen = unpack(DATA_HDR, h) oid, serial, sprev, stloc, vlen, splen = unpack(DATA_HDR, h)
prev = U64(sprev)
tloc = U64(stloc) tloc = U64(stloc)
plen = U64(splen) plen = U64(splen)
dlen = DATA_HDR_LEN + (plen or 8) dlen = DATA_HDR_LEN + (plen or 8)
if vlen: if vlen:
dlen += (16 + vlen) dlen += (16 + vlen)
self._file.read(16) # move to the right location tmp = self._file.read(16)
pv = U64(tmp[8:16])
version = self._file.read(vlen) version = self._file.read(vlen)
else: else:
version = '' version = ''
datapos = pos + DATA_HDR_LEN
if vlen:
datapos += 16 + vlen
assert self._file.tell() == datapos, (self._file.tell(), datapos)
if pos + dlen > self._tend or tloc != self._tpos: if pos + dlen > self._tend or tloc != self._tpos:
warn("%s data record exceeds transaction record at %s", warn("%s data record exceeds transaction record at %s",
file.name, pos) file.name, pos)
break break
self._pos = pos + dlen self._pos = pos + dlen
tid = None prev_txn = None
if plen: if plen:
p = self._file.read(plen) data = self._file.read(plen)
else: else:
p = self._file.read(8) bp = self._file.read(8)
if p == z64: if bp == z64:
# If the backpointer is 0 (encoded as z64), then # If the backpointer is 0 (encoded as z64), then
# this transaction undoes the object creation. It # this transaction undoes the object creation. It
# either aborts the version that created the # either aborts the version that created the
# object or undid the transaction that created it. # object or undid the transaction that created it.
# Return None instead of a pickle to indicate # Return None instead of a pickle to indicate
# this. # this.
p = None data = None
else: else:
p, _s, tid = _loadBackTxn(self._file, oid, p) data, _s, tid = _loadBackTxn(self._file, oid, bp)
prev_txn = getTxnFromData(self._file, oid, bp)
r = Record(oid, serial, version, p, tid) r = Record(oid, serial, version, data, prev_txn)
return r return r
......
This diff is collapsed.
...@@ -76,3 +76,80 @@ def fsdump(path, file=None, with_offset=1): ...@@ -76,3 +76,80 @@ def fsdump(path, file=None, with_offset=1):
print >> file print >> file
i += 1 i += 1
iter.close() iter.close()
import struct
from ZODB.FileStorage import TRANS_HDR, TRANS_HDR_LEN
from ZODB.FileStorage import DATA_HDR, DATA_HDR_LEN
def fmt(p64):
# Return a nicely formatted string for a packaged 64-bit value
return "%016x" % U64(p64)
class Dumper:
"""A very verbose dumper for debuggin FileStorage problems."""
def __init__(self, path, dest=None):
self.file = open(path, "rb")
self.dest = dest
def dump(self):
fid = self.file.read(4)
print >> self.dest, "*" * 60
print >> self.dest, "file identifier: %r" % fid
while self.dump_txn():
pass
def dump_txn(self):
pos = self.file.tell()
h = self.file.read(TRANS_HDR_LEN)
if not h:
return False
tid, stlen, status, ul, dl, el = struct.unpack(TRANS_HDR, h)
end = pos + U64(stlen)
print >> self.dest, "=" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "end pos: %d" % end
print >> self.dest, "transaction id: %s" % fmt(tid)
print >> self.dest, "trec len: %d" % U64(stlen)
print >> self.dest, "status: %r" % status
user = descr = extra = ""
if ul:
user = self.file.read(ul)
if dl:
descr = self.file.read(dl)
if el:
extra = self.file.read(el)
print >> self.dest, "user: %r" % user
print >> self.dest, "description: %r" % descr
print >> self.dest, "len(extra): %d" % el
while self.file.tell() < end:
self.dump_data(pos)
stlen2 = self.file.read(8)
print >> self.dest, "redundant trec len: %d" % U64(stlen2)
return True
def dump_data(self, tloc):
pos = self.file.tell()
h = self.file.read(DATA_HDR_LEN)
assert len(h) == DATA_HDR_LEN
oid, revid, sprev, stloc, vlen, sdlen = struct.unpack(DATA_HDR, h)
dlen = U64(sdlen)
print >> self.dest, "-" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "oid: %s" % fmt(oid)
print >> self.dest, "revid: %s" % fmt(revid)
print >> self.dest, "previous record offset: %d" % U64(sprev)
print >> self.dest, "transaction offset: %d" % U64(stloc)
if vlen:
pnv = self.file.read(8)
sprevdata = self.file.read(8)
version = self.file.read(vlen)
print >> self.dest, "version: %r" % version
print >> self.dest, "non-version data offset: %d" % U64(pnv)
print >> self.dest, \
"previous version data offset: %d" % U64(sprevdata)
print >> self.dest, "len(data): %d" % dlen
self.file.read(dlen)
if not dlen:
sbp = self.file.read(8)
print >> self.dest, "backpointer: %d" % U64(sbp)
...@@ -180,7 +180,6 @@ class StorageTestBase(unittest.TestCase): ...@@ -180,7 +180,6 @@ class StorageTestBase(unittest.TestCase):
def _dostoreNP(self, oid=None, revid=None, data=None, version=None, def _dostoreNP(self, oid=None, revid=None, data=None, version=None,
user=None, description=None): user=None, description=None):
return self._dostore(oid, revid, data, version, already_pickled=1) return self._dostore(oid, revid, data, version, already_pickled=1)
# The following methods depend on optional storage features. # The following methods depend on optional storage features.
def _undo(self, tid, oid): def _undo(self, tid, oid):
......
...@@ -29,9 +29,25 @@ class TransactionalUndoVersionStorage: ...@@ -29,9 +29,25 @@ class TransactionalUndoVersionStorage:
pass # not expected pass # not expected
return self._dostore(*args, **kwargs) return self._dostore(*args, **kwargs)
def _undo(self, tid, oid):
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
self.assertEqual(len(oids), 1)
self.assertEqual(oids[0], oid)
def checkUndoInVersion(self): def checkUndoInVersion(self):
eq = self.assertEqual eq = self.assertEqual
unless = self.failUnless unless = self.failUnless
def check_objects(nonversiondata, versiondata):
data, revid = self._storage.load(oid, version)
self.assertEqual(zodb_unpickle(data), MinPO(versiondata))
data, revid = self._storage.load(oid, '')
self.assertEqual(zodb_unpickle(data), MinPO(nonversiondata))
oid = self._storage.new_oid() oid = self._storage.new_oid()
version = 'one' version = 'one'
revid_a = self._dostore(oid, data=MinPO(91)) revid_a = self._dostore(oid, data=MinPO(91))
...@@ -39,21 +55,17 @@ class TransactionalUndoVersionStorage: ...@@ -39,21 +55,17 @@ class TransactionalUndoVersionStorage:
version=version) version=version)
revid_c = self._dostore(oid, revid=revid_b, data=MinPO(93), revid_c = self._dostore(oid, revid=revid_b, data=MinPO(93),
version=version) version=version)
info=self._storage.undoInfo()
tid=info[0]['id'] info = self._storage.undoInfo()
t = Transaction() self._undo(info[0]['id'], oid)
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(revid, revid_a) eq(revid, revid_a)
eq(zodb_unpickle(data), MinPO(91)) eq(zodb_unpickle(data), MinPO(91))
data, revid = self._storage.load(oid, version) data, revid = self._storage.load(oid, version)
unless(revid > revid_b and revid > revid_c) unless(revid > revid_b and revid > revid_c)
eq(zodb_unpickle(data), MinPO(92)) eq(zodb_unpickle(data), MinPO(92))
# Now commit the version... # Now commit the version...
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
...@@ -63,61 +75,25 @@ class TransactionalUndoVersionStorage: ...@@ -63,61 +75,25 @@ class TransactionalUndoVersionStorage:
eq(len(oids), 1) eq(len(oids), 1)
eq(oids[0], oid) eq(oids[0], oid)
#JF# No, because we fall back to non-version data. check_objects(92, 92)
#JF# self.assertRaises(POSException.VersionError,
#JF# self._storage.load,
#JF# oid, version)
data, revid = self._storage.load(oid, version)
eq(zodb_unpickle(data), MinPO(92))
data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(92))
# ...and undo the commit # ...and undo the commit
info=self._storage.undoInfo() info = self._storage.undoInfo()
tid=info[0]['id'] self._undo(info[0]['id'], oid)
t = Transaction()
self._storage.tpc_begin(t) check_objects(91, 92)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t) oids = self._abortVersion(version)
self._storage.tpc_finish(t) assert len(oids) == 1
eq(len(oids), 1) assert oids[0] == oid
eq(oids[0], oid)
data, revid = self._storage.load(oid, version) check_objects(91, 91)
eq(zodb_unpickle(data), MinPO(92))
data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(91))
# Now abort the version
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.abortVersion(version, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
# The object should not exist in the version now, but it should exist
# in the non-version
#JF# No, because we fall back
#JF# self.assertRaises(POSException.VersionError,
#JF# self._storage.load,
#JF# oid, version)
data, revid = self._storage.load(oid, version)
eq(zodb_unpickle(data), MinPO(91))
data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(91))
# Now undo the abort # Now undo the abort
info=self._storage.undoInfo() info=self._storage.undoInfo()
tid=info[0]['id'] self._undo(info[0]['id'], oid)
t = Transaction()
self._storage.tpc_begin(t) check_objects(91, 92)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
# And the object should be back in versions 'one' and ''
data, revid = self._storage.load(oid, version)
eq(zodb_unpickle(data), MinPO(92))
data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(91))
def checkUndoCommitVersion(self): def checkUndoCommitVersion(self):
def load_value(oid, version=''): def load_value(oid, version=''):
......
...@@ -14,22 +14,6 @@ from ZODB.tests.StorageTestBase import zodb_unpickle ...@@ -14,22 +14,6 @@ from ZODB.tests.StorageTestBase import zodb_unpickle
class VersionStorage: class VersionStorage:
def _commitVersion(self, src, dst):
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.commitVersion(src, dst, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
return oids
def _abortVersion(self, ver):
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.abortVersion(ver, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
return oids
def checkCommitVersionSerialno(self): def checkCommitVersionSerialno(self):
oid = self._storage.new_oid() oid = self._storage.new_oid()
revid1 = self._dostore(oid, data=MinPO(12)) revid1 = self._dostore(oid, data=MinPO(12))
......
...@@ -12,6 +12,7 @@ from ZODB.tests import StorageTestBase, BasicStorage, \ ...@@ -12,6 +12,7 @@ from ZODB.tests import StorageTestBase, BasicStorage, \
Synchronization, ConflictResolution, HistoryStorage, \ Synchronization, ConflictResolution, HistoryStorage, \
IteratorStorage, Corruption, RevisionStorage, PersistentStorage, \ IteratorStorage, Corruption, RevisionStorage, PersistentStorage, \
MTStorage, ReadOnlyStorage, RecoveryStorage MTStorage, ReadOnlyStorage, RecoveryStorage
from ZODB.tests.StorageTestBase import MinPO, zodb_unpickle
class FileStorageTests( class FileStorageTests(
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase,
...@@ -63,6 +64,8 @@ class FileStorageRecoveryTest( ...@@ -63,6 +64,8 @@ class FileStorageRecoveryTest(
): ):
def setUp(self): def setUp(self):
StorageTestBase.removefs("Source.fs")
StorageTestBase.removefs("Dest.fs")
self._storage = ZODB.FileStorage.FileStorage('Source.fs') self._storage = ZODB.FileStorage.FileStorage('Source.fs')
self._dst = ZODB.FileStorage.FileStorage('Dest.fs') self._dst = ZODB.FileStorage.FileStorage('Dest.fs')
...@@ -76,6 +79,63 @@ class FileStorageRecoveryTest( ...@@ -76,6 +79,63 @@ class FileStorageRecoveryTest(
StorageTestBase.removefs('Dest.fs') StorageTestBase.removefs('Dest.fs')
return ZODB.FileStorage.FileStorage('Dest.fs') return ZODB.FileStorage.FileStorage('Dest.fs')
def checkRecoverUndoInVersion(self):
oid = self._storage.new_oid()
version = "aVersion"
revid_a = self._dostore(oid, data=MinPO(91))
revid_b = self._dostore(oid, revid=revid_a, version=version,
data=MinPO(92))
revid_c = self._dostore(oid, revid=revid_b, version=version,
data=MinPO(93))
self._undo(self._storage.undoInfo()[0]['id'], oid)
self._commitVersion(version, '')
self._undo(self._storage.undoInfo()[0]['id'], oid)
# now copy the records to a new storage
self._dst.copyTransactionsFrom(self._storage)
self.compare(self._storage, self._dst)
# The last two transactions were applied directly rather than
# copied. So we can't use compare() to verify that they new
# transactions are applied correctly. (The new transactions
# will have different timestamps for each storage.)
self._abortVersion(version)
self.assert_(self._storage.versionEmpty(version))
self._undo(self._storage.undoInfo()[0]['id'], oid)
self.assert_(not self._storage.versionEmpty(version))
# check the data is what we expect it to be
data, revid = self._storage.load(oid, version)
self.assertEqual(zodb_unpickle(data), MinPO(92))
data, revid = self._storage.load(oid, '')
self.assertEqual(zodb_unpickle(data), MinPO(91))
# and swap the storages
tmp = self._storage
self._storage = self._dst
self._abortVersion(version)
self.assert_(self._storage.versionEmpty(version))
self._undo(self._storage.undoInfo()[0]['id'], oid)
self.assert_(not self._storage.versionEmpty(version))
# check the data is what we expect it to be
data, revid = self._storage.load(oid, version)
self.assertEqual(zodb_unpickle(data), MinPO(92))
data, revid = self._storage.load(oid, '')
self.assertEqual(zodb_unpickle(data), MinPO(91))
# swap them back
self._storage = tmp
# Now remove _dst and copy all the transactions a second time.
# This time we will be able to confirm via compare().
self._dst.close()
StorageTestBase.removefs("Dest.fs")
self._dst = ZODB.FileStorage.FileStorage('Dest.fs')
self._dst.copyTransactionsFrom(self._storage)
self.compare(self._storage, self._dst)
def test_suite(): def test_suite():
suite = unittest.makeSuite(FileStorageTests, 'check') suite = unittest.makeSuite(FileStorageTests, 'check')
......
...@@ -94,28 +94,6 @@ class ZODBTests(unittest.TestCase, ExportImportTests): ...@@ -94,28 +94,6 @@ class ZODBTests(unittest.TestCase, ExportImportTests):
self._storage.close() self._storage.close()
removefs("ZODBTests.fs") removefs("ZODBTests.fs")
def checkUnmodifiedObject(self):
# Test that a transaction with only unmodified objects works
# correctly. The specific sequence of events is:
# - an object is modified
# - it is registered with the transaction
# - the object is explicitly "unmodified"
# - the transaction commits, but now has no modified objects
# We'd like to avoid doing anything with the storage.
ltid = self._storage.lastTransaction()
_objects = get_transaction()._objects
self.assertEqual(len(_objects), 0)
r = self._db.open().root()
obj = r["test"][0]
obj[1] = 1
self.assertEqual(obj._p_changed, 1)
self.assertEqual(len(_objects), 1)
del obj._p_changed
self.assertEqual(obj._p_changed, None)
self.assertEqual(len(_objects), 1)
get_transaction().commit()
self.assertEqual(ltid, self._storage.lastTransaction())
def checkVersionOnly(self): def checkVersionOnly(self):
# Make sure the changes to make empty transactions a no-op # Make sure the changes to make empty transactions a no-op
# still allow things like abortVersion(). This should work # still allow things like abortVersion(). This should work
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment