Commit b7e88bab authored by Jeremy Hylton's avatar Jeremy Hylton

Merge MVCC branch to the HEAD.

parent 3e29b5b6
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
# XXX TO DO
# use two indices rather than the sign bit of the index??????
# add a shared routine to read + verify a record???
# redesign header to include vdlen???
# rewrite the cache using a different algorithm???
"""Implement a client cache
The cache is managed as two files.
The cache can be persistent (meaning it is survives a process restart)
or temporary. It is persistent if the client argument is not None.
Persistent cache files live in the var directory and are named
'c<storage>-<client>-<digit>.zec' where <storage> is the storage
argument (default '1'), <client> is the client argument, and <digit> is
0 or 1. Temporary cache files are unnamed files in the standard
temporary directory as determined by the tempfile module.
Each cache file has a 12-byte header followed by a sequence of
records. The header format is as follows:
offset in header: name -- description
0: magic -- 4-byte magic number, identifying this as a ZEO cache file
4: lasttid -- 8-byte last transaction id
Each record has the following form:
offset in record: name -- description
0: oidlen -- 2-byte unsigned object id length
2: reserved (6 bytes)
8: status -- 1-byte status 'v': valid, 'n': non-version valid, 'i': invalid
('n' means only the non-version data in the record is valid)
9: tlen -- 4-byte (unsigned) record length
13: vlen -- 2-byte (unsigned) version length
15: dlen -- 4-byte length of non-version data
19: serial -- 8-byte non-version serial (timestamp)
27: oid -- object id
27+oidlen: data -- non-version data
27+oidlen+dlen: version -- Version string (if vlen > 0)
27+oidlen+dlen+vlen: vdlen -- 4-byte length of version data (if vlen > 0)
31+oidlen+dlen+vlen: vdata -- version data (if vlen > 0)
31+oidlen+dlen+vlen+vdlen: vserial -- 8-byte version serial (timestamp)
(if vlen > 0)
27+oidlen+dlen (if vlen == 0) **or**
39+oidlen+dlen+vlen+vdlen: tlen -- 4-byte (unsigned) record length (for
redundancy and backward traversal)
31+oidlen+dlen (if vlen == 0) **or**
43+oidlen+dlen+vlen+vdlen: -- total record length (equal to tlen)
There is a cache size limit.
The cache is managed as follows:
- Data are written to file 0 until file 0 exceeds limit/2 in size.
- Data are written to file 1 until file 1 exceeds limit/2 in size.
- File 0 is truncated to size 0 (or deleted and recreated).
- Data are written to file 0 until file 0 exceeds limit/2 in size.
- File 1 is truncated to size 0 (or deleted and recreated).
- Data are written to file 1 until file 1 exceeds limit/2 in size.
and so on.
On startup, index information is read from file 0 and file 1.
Current serial numbers are sent to the server for verification.
If any serial numbers are not valid, then the server will send back
invalidation messages and the cache entries will be invalidated.
When a cache record is invalidated, the data length is overwritten
with '\0\0\0\0'.
If var is not writable, then temporary files are used for
file 0 and file 1.
"""
import os
import time
import tempfile
from struct import pack, unpack
from thread import allocate_lock
from ZODB.utils import oid_repr, u64, z64
import zLOG
from ZEO.ICache import ICache
magic = 'ZEC2'
headersize = 12
MB = 1024**2
class ClientCache:
__implements__ = ICache
def __init__(self, storage='1', size=20*MB, client=None, var=None):
# Arguments:
# storage -- storage name (used in filenames and log messages)
# size -- size limit in bytes of both files together
# client -- if not None, use a persistent cache file and use this name
# var -- directory where to create persistent cache files; default cwd
self._storage = storage
self._limit = size / 2
self._client = client
self._ltid = None # For getLastTid()
# Allocate locks:
L = allocate_lock()
self._acquire = L.acquire
self._release = L.release
if client is not None:
# Create a persistent cache
if var is None:
var = os.getcwd()
fmt = os.path.join(var, "c%s-%s-%%s.zec" % (storage, client))
# Initialize pairs of filenames, file objects, and serialnos.
self._p = p = [fmt % 0, fmt % 1]
self._f = f = [None, None]
self._current = 0
s = [z64, z64]
for i in 0, 1:
if os.path.exists(p[i]):
fi = open(p[i],'r+b')
if fi.read(4) == magic: # Minimal sanity
# Read the ltid for this file. If it never
# saw a transaction commit, it will get tossed,
# even if it has valid data.
s[i] = fi.read(8)
# If we found a non-zero serial, then use the file
if s[i] != z64:
f[i] = fi
# Whoever has the larger serial is the current
if s[1] > s[0]:
current = 1
elif s[0] > s[1]:
current = 0
else:
if f[0] is None:
# We started, open the first cache file
f[0] = open(p[0], 'w+b')
f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0
f[1] = None
else:
self._f = f = [tempfile.TemporaryFile(suffix='.zec'), None]
# self._p file name 'None' signifies an unnamed temp file.
self._p = p = [None, None]
f[0].write(magic + '\0' * (headersize - len(magic)))
current = 0
self._current = current
if self._ltid:
ts = "; last txn=%x" % u64(self._ltid)
else:
ts = ""
self.log("%s: storage=%r, size=%r; file[%r]=%r%s" %
(self.__class__.__name__, storage, size, current, p[current],
ts))
self._setup_trace()
def open(self):
# Two tasks:
# - Set self._index, self._get, and self._pos.
# - Read and validate both cache files, returning a list of
# serials to be used by verify().
# This may be called more than once (by the cache verification code).
self._acquire()
try:
self._index = index = {}
self._get = index.get
serial = {}
f = self._f
current = self._current
if f[not current] is not None:
self.read_index(serial, not current)
self._pos = self.read_index(serial, current)
return serial.items()
finally:
self._release()
def close(self):
for f in self._f:
if f is not None:
# In 2.1 on Windows, the TemporaryFileWrapper doesn't allow
# closing a file more than once.
try:
f.close()
except OSError:
pass
def getLastTid(self):
"""Get the last transaction id stored by setLastTid().
If the cache is persistent, it is read from the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
return self._ltid
else:
self._acquire()
try:
return self._getLastTid()
finally:
self._release()
def _getLastTid(self):
f = self._f[self._current]
f.seek(4)
tid = f.read(8)
if len(tid) < 8 or tid == z64:
return None
else:
return tid
def setLastTid(self, tid):
"""Store the last transaction id.
If the cache is persistent, it is written to the current
cache file; otherwise it's an instance variable.
"""
if self._client is None:
if tid == z64:
tid = None
self._ltid = tid
else:
self._acquire()
try:
self._setLastTid(tid)
finally:
self._release()
def _setLastTid(self, tid):
if tid is None:
tid = z64
else:
tid = str(tid)
assert len(tid) == 8
f = self._f[self._current]
f.seek(4)
f.write(tid)
def verify(self, verifyFunc):
"""Call the verifyFunc on every object in the cache.
verifyFunc(oid, serialno, version)
"""
for oid, (s, vs) in self.open():
verifyFunc(oid, s, vs)
def invalidate(self, oid, version):
self._acquire()
try:
p = self._get(oid, None)
if p is None:
self._trace(0x10, oid, version)
return None
f = self._f[p < 0]
ap = abs(p)
f.seek(ap)
h = f.read(27)
if len(h) != 27:
self.log("invalidate: short record for oid %s "
"at position %d in cache file %d"
% (oid_repr(oid), ap, p < 0))
del self._index[oid]
return None
oidlen = unpack(">H", h[:2])[0]
rec_oid = f.read(oidlen)
if rec_oid != oid:
self.log("invalidate: oid mismatch: expected %s read %s "
"at position %d in cache file %d"
% (oid_repr(oid), oid_repr(rec_oid), ap, p < 0))
del self._index[oid]
return None
f.seek(ap+8) # Switch from reading to writing
if version and h[15:19] != '\0\0\0\0':
self._trace(0x1A, oid, version)
# There's still relevant non-version data in the cache record
f.write('n')
else:
self._trace(0x1C, oid, version)
del self._index[oid]
f.write('i')
finally:
self._release()
def load(self, oid, version):
self._acquire()
try:
p = self._get(oid, None)
if p is None:
self._trace(0x20, oid, version)
return None
f = self._f[p < 0]
ap = abs(p)
seek = f.seek
read = f.read
seek(ap)
h = read(27)
oidlen = unpack(">H", h[:2])[0]
rec_oid = read(oidlen)
if len(h)==27 and h[8] in 'nv' and rec_oid == oid:
tlen, vlen, dlen = unpack(">iHi", h[9:19])
else:
tlen = -1
if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen:
self.log("load: bad record for oid %s "
"at position %d in cache file %d"
% (oid_repr(oid), ap, p < 0))
del self._index[oid]
return None
if h[8]=='n':
if version:
self._trace(0x22, oid, version)
return None
if not dlen:
# XXX This shouldn't actually happen
self._trace(0x24, oid, version)
del self._index[oid]
return None
if not vlen or not version:
if dlen:
data = read(dlen)
self._trace(0x2A, oid, version, h[19:], dlen)
if (p < 0) != self._current:
# If the cache read we are copying has version info,
# we need to pass the header to copytocurrent().
if vlen:
vheader = read(vlen + 4)
else:
vheader = None
self._copytocurrent(ap, oidlen, tlen, dlen, vlen, h,
oid, data, vheader)
return data, h[19:]
else:
self._trace(0x26, oid, version)
return None
if dlen:
seek(dlen, 1)
vheader = read(vlen+4)
v = vheader[:-4]
if version != v:
if dlen:
seek(ap+27+oidlen)
data = read(dlen)
self._trace(0x2C, oid, version, h[19:], dlen)
if (p < 0) != self._current:
self._copytocurrent(ap, oidlen, tlen, dlen, vlen, h,
oid, data, vheader)
return data, h[19:]
else:
self._trace(0x28, oid, version)
return None
vdlen = unpack(">i", vheader[-4:])[0]
vdata = read(vdlen)
vserial = read(8)
self._trace(0x2E, oid, version, vserial, vdlen)
if (p < 0) != self._current:
self._copytocurrent(ap, oidlen, tlen, dlen, vlen, h,
oid, None, vheader, vdata, vserial)
return vdata, vserial
finally:
self._release()
def _copytocurrent(self, pos, oidlen, tlen, dlen, vlen, header, oid,
data=None, vheader=None, vdata=None, vserial=None):
"""Copy a cache hit from the non-current file to the current file.
Arguments are the file position in the non-current file,
record length, data length, version string length, header, and
optionally parts of the record that have already been read.
"""
if self._pos + tlen > self._limit:
return # Don't let this cause a cache flip
assert len(header) == 27, len(header)
if header[8] == 'n':
# Rewrite the header to drop the version data.
# This shortens the record.
tlen = 31 + oidlen + dlen
vlen = 0
vheader = None
# (oidlen:2, reserved:6, status:1, tlen:4,
# vlen:2, dlen:4, serial:8)
header = header[:9] + pack(">IHI", tlen, vlen, dlen) + header[-8:]
else:
assert header[8] == 'v'
f = self._f[not self._current]
if data is None:
f.seek(pos+27+oidlen)
data = f.read(dlen)
if len(data) != dlen:
return
l = [header, oid, data]
if vlen:
assert vheader is not None
l.append(vheader)
assert (vdata is None) == (vserial is None)
if vdata is None:
vdlen = unpack(">I", vheader[-4:])[0]
f.seek(pos+27+oidlen+dlen+vlen+4)
vdata = f.read(vdlen)
if len(vdata) != vdlen:
return
vserial = f.read(8)
if len(vserial) != 8:
return
l.append(vdata)
l.append(vserial)
else:
assert None is vheader is vdata is vserial, (
vlen, vheader, vdata, vserial)
l.append(header[9:13]) # copy of tlen
g = self._f[self._current]
g.seek(self._pos)
g.writelines(l)
assert g.tell() == self._pos + tlen
if self._current:
self._index[oid] = - self._pos
else:
self._index[oid] = self._pos
self._pos += tlen
self._trace(0x6A, oid, vlen and vheader[:-4] or '',
vlen and vserial or header[-8:], dlen)
def update(self, oid, serial, version, data):
self._acquire()
try:
self._trace(0x3A, oid, version, serial, len(data))
if version:
# We need to find and include non-version data
p = self._get(oid, None)
if p is None:
return self._store(oid, '', '', version, data, serial)
f = self._f[p < 0]
ap = abs(p)
seek = f.seek
read = f.read
seek(ap)
h = read(27)
oidlen = unpack(">H", h[:2])[0]
rec_oid = read(oidlen)
if len(h) == 27 and h[8] in 'nv' and rec_oid == oid:
tlen, vlen, dlen = unpack(">iHi", h[9:19])
else:
return self._store(oid, '', '', version, data, serial)
if tlen <= 0 or vlen < 0 or dlen <= 0 or vlen+dlen > tlen:
return self._store(oid, '', '', version, data, serial)
if dlen:
nvdata = read(dlen)
nvserial = h[19:]
else:
return self._store(oid, '', '', version, data, serial)
self._store(oid, nvdata, nvserial, version, data, serial)
else:
# Simple case, just store new data:
self._store(oid, data, serial, '', None, None)
finally:
self._release()
def modifiedInVersion(self, oid):
# This should return:
# - The version from the record for oid, if there is one.
# - '', if there is no version in the record and its status is 'v'.
# - None, if we don't know: no valid record or status is 'n'.
self._acquire()
try:
p = self._get(oid, None)
if p is None:
self._trace(0x40, oid)
return None
f = self._f[p < 0]
ap = abs(p)
seek = f.seek
read = f.read
seek(ap)
h = read(27)
oidlen = unpack(">H", h[:2])[0]
rec_oid = read(oidlen)
if len(h) == 27 and h[8] in 'nv' and rec_oid == oid:
tlen, vlen, dlen = unpack(">iHi", h[9:19])
else:
tlen = -1
if tlen <= 0 or vlen < 0 or dlen < 0 or vlen+dlen > tlen:
self.log("modifiedInVersion: bad record for oid %s "
"at position %d in cache file %d"
% (oid_repr(oid), ap, p < 0))
del self._index[oid]
return None
if h[8] == 'n':
self._trace(0x4A, oid)
return None
if not vlen:
self._trace(0x4C, oid)
return ''
seek(dlen, 1)
version = read(vlen)
self._trace(0x4E, oid, version)
return version
finally:
self._release()
def checkSize(self, size):
# Make sure we aren't going to exceed the target size.
# If we are, then flip the cache.
self._acquire()
try:
if self._pos + size > self._limit:
ltid = self._getLastTid()
current = not self._current
self._current = current
self._trace(0x70)
self.log("flipping cache files. new current = %d" % current)
# Delete the half of the index that's no longer valid
index = self._index
for oid in index.keys():
if (index[oid] < 0) == current:
del index[oid]
if self._p[current] is not None:
# Persistent cache file: remove the old file
# before opening the new one, because the old file
# may be owned by root (created before setuid()).
if self._f[current] is not None:
self._f[current].close()
try:
os.remove(self._p[current])
except:
pass
self._f[current] = open(self._p[current],'w+b')
else:
# Temporary cache file:
self._f[current] = tempfile.TemporaryFile(suffix='.zec')
header = magic
if ltid:
header += ltid
self._f[current].write(header +
'\0' * (headersize - len(header)))
self._pos = headersize
finally:
self._release()
def store(self, oid, p, s, version, pv, sv):
self._acquire()
if s:
self._trace(0x5A, oid, version, s, len(p))
else:
self._trace(0x5C, oid, version, sv, len(pv))
try:
self._store(oid, p, s, version, pv, sv)
finally:
self._release()
def _store(self, oid, p, s, version, pv, sv):
if not s:
p = ''
s = z64
tlen = 31 + len(oid) + len(p)
if version:
tlen = tlen + len(version) + 12 + len(pv)
vlen = len(version)
else:
vlen = 0
stlen = pack(">I", tlen)
# accumulate various data to write into a list
assert len(oid) < 2**16
assert vlen < 2**16
assert tlen < 2L**32
l = [pack(">H6x", len(oid)), 'v', stlen,
pack(">HI", vlen, len(p)), s, oid]
if p:
l.append(p)
if version:
l.extend([version,
pack(">I", len(pv)),
pv, sv])
l.append(stlen)
f = self._f[self._current]
f.seek(self._pos)
f.writelines(l) # write all list elements
if self._current:
self._index[oid] = - self._pos
else:
self._index[oid] = self._pos
self._pos += tlen
def _setup_trace(self):
# See if cache tracing is requested through $ZEO_CACHE_TRACE.
# A dash and the storage name are appended to get the filename.
# If not, or if we can't write to the trace file,
# disable tracing by setting self._trace to a dummy function.
self._tracefile = None
tfn = os.environ.get("ZEO_CACHE_TRACE")
if tfn:
tfn = tfn + "-" + self._storage
try:
self._tracefile = open(tfn, "ab")
self._trace(0x00)
except IOError, msg:
self._tracefile = None
self.log("cannot write tracefile %s (%s)" % (tfn, msg))
else:
self.log("opened tracefile %s" % tfn)
if self._tracefile is None:
def notrace(*args):
pass
self._trace = notrace
def _trace(self, code, oid='', version='', serial='', dlen=0,
# Remaining arguments are speed hacks
time_time=time.time, struct_pack=pack):
# The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'.
# If it is in "ACE" then it is a 'hit'.
# This method has been carefully tuned to be as fast as possible.
# Note: when tracing is disabled, this method is hidden by a dummy.
if version:
code |= 0x80
self._tracefile.write(
struct_pack(">iiH8s",
time_time(),
(dlen+255) & 0x7fffff00 | code | self._current,
len(oid),
serial) + oid)
def read_index(self, serial, fileindex):
index = self._index
f = self._f[fileindex]
seek = f.seek
read = f.read
pos = headersize
count = 0
while 1:
f.seek(pos)
h = read(27)
if len(h) != 27:
# An empty read is expected, anything else is suspect
if h:
self.rilog("truncated header", pos, fileindex)
break
if h[8] in 'vni':
tlen, vlen, dlen = unpack(">iHi", h[9:19])
else:
tlen = -1
if tlen <= 0 or vlen < 0 or dlen < 0 or vlen + dlen > tlen:
self.rilog("invalid header data", pos, fileindex)
break
oidlen = unpack(">H", h[:2])[0]
oid = read(oidlen)
if h[8] == 'v' and vlen:
seek(dlen+vlen, 1)
vdlen = read(4)
if len(vdlen) != 4:
self.rilog("truncated record", pos, fileindex)
break
vdlen = unpack(">i", vdlen)[0]
if vlen + oidlen + dlen + 43 + vdlen != tlen:
self.rilog("inconsistent lengths", pos, fileindex)
break
seek(vdlen, 1)
vs = read(8)
if read(4) != h[9:13]:
self.rilog("inconsistent tlen", pos, fileindex)
break
else:
if h[8] in 'vn' and vlen == 0:
if oidlen + dlen + 31 != tlen:
self.rilog("inconsistent nv lengths", pos, fileindex)
seek(dlen, 1)
if read(4) != h[9:13]:
self.rilog("inconsistent nv tlen", pos, fileindex)
break
vs = None
if h[8] in 'vn':
if fileindex:
index[oid] = -pos
else:
index[oid] = pos
serial[oid] = h[-8:], vs
else:
if serial.has_key(oid):
# We have a record for this oid, but it was invalidated!
del serial[oid]
del index[oid]
pos = pos + tlen
count += 1
f.seek(pos)
try:
f.truncate()
except:
pass
if count:
self.log("read_index: cache file %d has %d records and %d bytes"
% (fileindex, count, pos))
return pos
def rilog(self, msg, pos, fileindex):
# Helper to log messages from read_index
self.log("read_index: %s at position %d in cache file %d"
% (msg, pos, fileindex))
def log(self, msg, level=zLOG.INFO):
# XXX Using the path of the current file means the tags
# won't match after a cache flip. But they'll be very similar.
zLOG.LOG("ZEC:%s" % self._p[self._current], level, msg)
...@@ -26,7 +26,8 @@ import threading ...@@ -26,7 +26,8 @@ import threading
import time import time
import types import types
from ZEO import ClientCache, ServerStub from ZEO import ServerStub
from ZEO.cache import ClientCache
from ZEO.TransactionBuffer import TransactionBuffer from ZEO.TransactionBuffer import TransactionBuffer
from ZEO.Exceptions import ClientStorageError, UnrecognizedResult, \ from ZEO.Exceptions import ClientStorageError, UnrecognizedResult, \
ClientDisconnected, AuthError ClientDisconnected, AuthError
...@@ -91,7 +92,7 @@ class ClientStorage(object): ...@@ -91,7 +92,7 @@ class ClientStorage(object):
# Classes we instantiate. A subclass might override. # Classes we instantiate. A subclass might override.
TransactionBufferClass = TransactionBuffer TransactionBufferClass = TransactionBuffer
ClientCacheClass = ClientCache.ClientCache ClientCacheClass = ClientCache
ConnectionManagerClass = ConnectionManager ConnectionManagerClass = ConnectionManager
StorageServerStubClass = ServerStub.StorageServer StorageServerStubClass = ServerStub.StorageServer
...@@ -252,10 +253,17 @@ class ClientStorage(object): ...@@ -252,10 +253,17 @@ class ClientStorage(object):
self._tbuf = self.TransactionBufferClass() self._tbuf = self.TransactionBufferClass()
self._db = None self._db = None
self._ltid = None # the last committed transaction
# _serials: stores (oid, serialno) as returned by server # _serials: stores (oid, serialno) as returned by server
# _seriald: _check_serials() moves from _serials to _seriald, # _seriald: _check_serials() moves from _serials to _seriald,
# which maps oid to serialno # which maps oid to serialno
# XXX If serial number matches transaction id, then there is
# no need to have all this extra infrastructure for handling
# serial numbers. The vote call can just return the tid.
# If there is a conflict error, we can't have a special method
# called just to propagate the error.
self._serials = [] self._serials = []
self._seriald = {} self._seriald = {}
...@@ -292,13 +300,15 @@ class ClientStorage(object): ...@@ -292,13 +300,15 @@ class ClientStorage(object):
# is executing. # is executing.
self._lock = threading.Lock() self._lock = threading.Lock()
t = self._ts = get_timestamp()
self._serial = `t`
self._oid = '\0\0\0\0\0\0\0\0'
# Decide whether to use non-temporary files # Decide whether to use non-temporary files
self._cache = self.ClientCacheClass(storage, cache_size, if client is not None:
client=client, var=var) dir = var or os.getcwd()
cache_path = os.path.join(dir, "%s-%s.zec" % (client, storage))
else:
cache_path = None
self._cache = self.ClientCacheClass(cache_path)
# XXX When should it be opened?
self._cache.open()
self._rpc_mgr = self.ConnectionManagerClass(addr, self, self._rpc_mgr = self.ConnectionManagerClass(addr, self,
tmin=min_disconnect_poll, tmin=min_disconnect_poll,
...@@ -312,9 +322,6 @@ class ClientStorage(object): ...@@ -312,9 +322,6 @@ class ClientStorage(object):
# doesn't succeed, call connect() to start a thread. # doesn't succeed, call connect() to start a thread.
if not self._rpc_mgr.attempt_connect(): if not self._rpc_mgr.attempt_connect():
self._rpc_mgr.connect() self._rpc_mgr.connect()
# If the connect hasn't occurred, run with cached data.
if not self._ready.isSet():
self._cache.open()
def _wait(self, timeout=None): def _wait(self, timeout=None):
if timeout is not None: if timeout is not None:
...@@ -555,7 +562,6 @@ class ClientStorage(object): ...@@ -555,7 +562,6 @@ class ClientStorage(object):
if ltid == last_inval_tid: if ltid == last_inval_tid:
log2(INFO, "No verification necessary " log2(INFO, "No verification necessary "
"(last_inval_tid up-to-date)") "(last_inval_tid up-to-date)")
self._cache.open()
self._server = server self._server = server
self._ready.set() self._ready.set()
return "no verification" return "no verification"
...@@ -569,7 +575,6 @@ class ClientStorage(object): ...@@ -569,7 +575,6 @@ class ClientStorage(object):
pair = server.getInvalidations(last_inval_tid) pair = server.getInvalidations(last_inval_tid)
if pair is not None: if pair is not None:
log2(INFO, "Recovering %d invalidations" % len(pair[1])) log2(INFO, "Recovering %d invalidations" % len(pair[1]))
self._cache.open()
self.invalidateTransaction(*pair) self.invalidateTransaction(*pair)
self._server = server self._server = server
self._ready.set() self._ready.set()
...@@ -581,7 +586,9 @@ class ClientStorage(object): ...@@ -581,7 +586,9 @@ class ClientStorage(object):
self._pickler = cPickle.Pickler(self._tfile, 1) self._pickler = cPickle.Pickler(self._tfile, 1)
self._pickler.fast = 1 # Don't use the memo self._pickler.fast = 1 # Don't use the memo
self._cache.verify(server.zeoVerify) # XXX should batch these operations for efficiency
for oid, tid, version in self._cache.contents():
server.verify(oid, version, tid)
self._pending_server = server self._pending_server = server
server.endZeoVerify() server.endZeoVerify()
return "full verification" return "full verification"
...@@ -600,8 +607,7 @@ class ClientStorage(object): ...@@ -600,8 +607,7 @@ class ClientStorage(object):
This is called by ConnectionManager when the connection is This is called by ConnectionManager when the connection is
closed or when certain problems with the connection occur. closed or when certain problems with the connection occur.
""" """
log2(PROBLEM, "Disconnected from storage: %s" log2(INFO, "Disconnected from storage: %s" % repr(self._server_addr))
% repr(self._server_addr))
self._connection = None self._connection = None
self._ready.clear() self._ready.clear()
self._server = disconnected_stub self._server = disconnected_stub
...@@ -671,10 +677,10 @@ class ClientStorage(object): ...@@ -671,10 +677,10 @@ class ClientStorage(object):
raise POSException.StorageTransactionError(self._transaction, raise POSException.StorageTransactionError(self._transaction,
trans) trans)
def abortVersion(self, version, transaction): def abortVersion(self, version, txn):
"""Storage API: clear any changes made by the given version.""" """Storage API: clear any changes made by the given version."""
self._check_trans(transaction) self._check_trans(txn)
oids = self._server.abortVersion(version, self._serial) tid, oids = self._server.abortVersion(version, id(txn))
# When a version aborts, invalidate the version and # When a version aborts, invalidate the version and
# non-version data. The non-version data should still be # non-version data. The non-version data should still be
# valid, but older versions of ZODB will change the # valid, but older versions of ZODB will change the
...@@ -686,28 +692,31 @@ class ClientStorage(object): ...@@ -686,28 +692,31 @@ class ClientStorage(object):
# we could just invalidate the version data. # we could just invalidate the version data.
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, '') self._tbuf.invalidate(oid, '')
return oids return tid, oids
def commitVersion(self, source, destination, transaction): def commitVersion(self, source, destination, txn):
"""Storage API: commit the source version in the destination.""" """Storage API: commit the source version in the destination."""
self._check_trans(transaction) self._check_trans(txn)
oids = self._server.commitVersion(source, destination, self._serial) tid, oids = self._server.commitVersion(source, destination, id(txn))
if destination: if destination:
# just invalidate our version data # just invalidate our version data
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, source) self._tbuf.invalidate(oid, source)
else: else:
# destination is '', so invalidate version and non-version # destination is "", so invalidate version and non-version
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, destination) self._tbuf.invalidate(oid, "")
return oids return tid, oids
def history(self, oid, version, length=1): def history(self, oid, version, length=1, filter=None):
"""Storage API: return a sequence of HistoryEntry objects. """Storage API: return a sequence of HistoryEntry objects.
This does not support the optional filter argument defined by This does not support the optional filter argument defined by
the Storage API. the Storage API.
""" """
if filter is not None:
log2(WARNING, "filter argument to history() ignored")
# XXX should I run filter on the results?
return self._server.history(oid, version, length) return self._server.history(oid, version, length)
def getSerial(self, oid): def getSerial(self, oid):
...@@ -725,11 +734,14 @@ class ClientStorage(object): ...@@ -725,11 +734,14 @@ class ClientStorage(object):
specified by the given object id and version, if they exist; specified by the given object id and version, if they exist;
otherwise a KeyError is raised. otherwise a KeyError is raised.
""" """
return self.loadEx(oid, version)[:2]
def loadEx(self, oid, version):
self._lock.acquire() # for atomic processing of invalidations self._lock.acquire() # for atomic processing of invalidations
try: try:
pair = self._cache.load(oid, version) t = self._cache.load(oid, version)
if pair: if t:
return pair return t
finally: finally:
self._lock.release() self._lock.release()
...@@ -745,25 +757,55 @@ class ClientStorage(object): ...@@ -745,25 +757,55 @@ class ClientStorage(object):
finally: finally:
self._lock.release() self._lock.release()
p, s, v, pv, sv = self._server.zeoLoad(oid) data, tid, ver = self._server.loadEx(oid, version)
self._lock.acquire() # for atomic processing of invalidations self._lock.acquire() # for atomic processing of invalidations
try: try:
if self._load_status: if self._load_status:
self._cache.checkSize(0) self._cache.store(oid, ver, tid, None, data)
self._cache.store(oid, p, s, v, pv, sv)
self._load_oid = None self._load_oid = None
finally: finally:
self._lock.release() self._lock.release()
finally: finally:
self._load_lock.release() self._load_lock.release()
if v and version and v == version: return data, tid, ver
return pv, sv
else: def loadBefore(self, oid, tid):
if s: self._lock.acquire()
return p, s try:
raise KeyError, oid # no non-version data for this t = self._cache.loadBefore(oid, tid)
if t is not None:
return t
finally:
self._lock.release()
t = self._server.loadBefore(oid, tid)
if t is None:
return None
data, start, end = t
if end is None:
# This method should not be used to get current data. It
# doesn't use the _load_lock, so it is possble to overlap
# this load with an invalidation for the same object.
# XXX If we call again, we're guaranteed to get the
# post-invalidation data. But if the data is still
# current, we'll still get end == None.
# Maybe the best thing to do is to re-run the test with
# the load lock in the case. That's slow performance, but
# I don't think real application code will ever care about
# it.
return data, start, end
self._lock.acquire()
try:
self._cache.store(oid, "", start, end, data)
finally:
self._lock.release()
return data, start, end
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
"""Storage API: return the version, if any, that modfied an object. """Storage API: return the version, if any, that modfied an object.
...@@ -815,6 +857,8 @@ class ClientStorage(object): ...@@ -815,6 +857,8 @@ class ClientStorage(object):
def _check_serials(self): def _check_serials(self):
"""Internal helper to move data from _serials to _seriald.""" """Internal helper to move data from _serials to _seriald."""
# XXX serials are always going to be the same, the only
# question is whether an exception has been raised.
if self._serials: if self._serials:
l = len(self._serials) l = len(self._serials)
r = self._serials[:l] r = self._serials[:l]
...@@ -825,18 +869,18 @@ class ClientStorage(object): ...@@ -825,18 +869,18 @@ class ClientStorage(object):
self._seriald[oid] = s self._seriald[oid] = s
return r return r
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, txn):
"""Storage API: store data for an object.""" """Storage API: store data for an object."""
self._check_trans(transaction) self._check_trans(txn)
self._server.storea(oid, serial, data, version, self._serial) self._server.storea(oid, serial, data, version, id(txn))
self._tbuf.store(oid, version, data) self._tbuf.store(oid, version, data)
return self._check_serials() return self._check_serials()
def tpc_vote(self, transaction): def tpc_vote(self, txn):
"""Storage API: vote on a transaction.""" """Storage API: vote on a transaction."""
if transaction is not self._transaction: if txn is not self._transaction:
return return
self._server.vote(self._serial) self._server.vote(id(txn))
return self._check_serials() return self._check_serials()
def tpc_begin(self, txn, tid=None, status=' '): def tpc_begin(self, txn, tid=None, status=' '):
...@@ -856,15 +900,8 @@ class ClientStorage(object): ...@@ -856,15 +900,8 @@ class ClientStorage(object):
self._transaction = txn self._transaction = txn
self._tpc_cond.release() self._tpc_cond.release()
if tid is None:
self._ts = get_timestamp(self._ts)
id = `self._ts`
else:
self._ts = TimeStamp(tid)
id = tid
try: try:
self._server.tpc_begin(id, txn.user, txn.description, self._server.tpc_begin(id(txn), txn.user, txn.description,
txn._extension, tid, status) txn._extension, tid, status)
except: except:
# Client may have disconnected during the tpc_begin(). # Client may have disconnected during the tpc_begin().
...@@ -872,7 +909,6 @@ class ClientStorage(object): ...@@ -872,7 +909,6 @@ class ClientStorage(object):
self.end_transaction() self.end_transaction()
raise raise
self._serial = id
self._tbuf.clear() self._tbuf.clear()
self._seriald.clear() self._seriald.clear()
del self._serials[:] del self._serials[:]
...@@ -881,18 +917,17 @@ class ClientStorage(object): ...@@ -881,18 +917,17 @@ class ClientStorage(object):
"""Internal helper to end a transaction.""" """Internal helper to end a transaction."""
# the right way to set self._transaction to None # the right way to set self._transaction to None
# calls notify() on _tpc_cond in case there are waiting threads # calls notify() on _tpc_cond in case there are waiting threads
self._ltid = self._serial
self._tpc_cond.acquire() self._tpc_cond.acquire()
self._transaction = None self._transaction = None
self._tpc_cond.notify() self._tpc_cond.notify()
self._tpc_cond.release() self._tpc_cond.release()
def lastTransaction(self): def lastTransaction(self):
return self._ltid return self._cache.getLastTid()
def tpc_abort(self, transaction): def tpc_abort(self, txn):
"""Storage API: abort a transaction.""" """Storage API: abort a transaction."""
if transaction is not self._transaction: if txn is not self._transaction:
return return
try: try:
# XXX Are there any transactions that should prevent an # XXX Are there any transactions that should prevent an
...@@ -900,7 +935,7 @@ class ClientStorage(object): ...@@ -900,7 +935,7 @@ class ClientStorage(object):
# all, yet you want to be sure that other abort logic is # all, yet you want to be sure that other abort logic is
# executed regardless. # executed regardless.
try: try:
self._server.tpc_abort(self._serial) self._server.tpc_abort(id(txn))
except ClientDisconnected: except ClientDisconnected:
log2(BLATHER, 'ClientDisconnected in tpc_abort() ignored') log2(BLATHER, 'ClientDisconnected in tpc_abort() ignored')
finally: finally:
...@@ -909,9 +944,9 @@ class ClientStorage(object): ...@@ -909,9 +944,9 @@ class ClientStorage(object):
del self._serials[:] del self._serials[:]
self.end_transaction() self.end_transaction()
def tpc_finish(self, transaction, f=None): def tpc_finish(self, txn, f=None):
"""Storage API: finish a transaction.""" """Storage API: finish a transaction."""
if transaction is not self._transaction: if txn is not self._transaction:
return return
self._load_lock.acquire() self._load_lock.acquire()
try: try:
...@@ -919,15 +954,16 @@ class ClientStorage(object): ...@@ -919,15 +954,16 @@ class ClientStorage(object):
raise ClientDisconnected( raise ClientDisconnected(
'Calling tpc_finish() on a disconnected transaction') 'Calling tpc_finish() on a disconnected transaction')
tid = self._server.tpc_finish(self._serial) tid = self._server.tpc_finish(id(txn))
self._lock.acquire() # for atomic processing of invalidations self._lock.acquire() # for atomic processing of invalidations
try: try:
self._update_cache() self._update_cache(tid)
if f is not None: if f is not None:
f() f(tid)
finally: finally:
self._lock.release() self._lock.release()
# XXX Shouldn't this cache call be made while holding the lock?
self._cache.setLastTid(tid) self._cache.setLastTid(tid)
r = self._check_serials() r = self._check_serials()
...@@ -936,7 +972,7 @@ class ClientStorage(object): ...@@ -936,7 +972,7 @@ class ClientStorage(object):
self._load_lock.release() self._load_lock.release()
self.end_transaction() self.end_transaction()
def _update_cache(self): def _update_cache(self, tid):
"""Internal helper to handle objects modified by a transaction. """Internal helper to handle objects modified by a transaction.
This iterates over the objects in the transaction buffer and This iterates over the objects in the transaction buffer and
...@@ -949,7 +985,6 @@ class ClientStorage(object): ...@@ -949,7 +985,6 @@ class ClientStorage(object):
if self._cache is None: if self._cache is None:
return return
self._cache.checkSize(self._tbuf.get_size())
try: try:
self._tbuf.begin_iterate() self._tbuf.begin_iterate()
except ValueError, msg: except ValueError, msg:
...@@ -965,18 +1000,17 @@ class ClientStorage(object): ...@@ -965,18 +1000,17 @@ class ClientStorage(object):
"client storage: %s" % msg) "client storage: %s" % msg)
if t is None: if t is None:
break break
oid, v, p = t oid, version, data = t
if p is None: # an invalidation self._cache.invalidate(oid, version, tid)
s = None # If data is None, we just invalidate.
else: if data is not None:
s = self._seriald[oid] s = self._seriald[oid]
if s == ResolvedSerial or s is None: if s != ResolvedSerial:
self._cache.invalidate(oid, v) assert s == tid, (s, tid)
else: self._cache.store(oid, version, s, None, data)
self._cache.update(oid, s, v, p)
self._tbuf.clear() self._tbuf.clear()
def transactionalUndo(self, trans_id, trans): def transactionalUndo(self, trans_id, txn):
"""Storage API: undo a transaction. """Storage API: undo a transaction.
This is executed in a transactional context. It has no effect This is executed in a transactional context. It has no effect
...@@ -985,24 +1019,11 @@ class ClientStorage(object): ...@@ -985,24 +1019,11 @@ class ClientStorage(object):
Zope uses this to implement undo unless it is not supported by Zope uses this to implement undo unless it is not supported by
a storage. a storage.
""" """
self._check_trans(trans) self._check_trans(txn)
oids = self._server.transactionalUndo(trans_id, self._serial) tid, oids = self._server.transactionalUndo(trans_id, id(txn))
for oid in oids: for oid in oids:
self._tbuf.invalidate(oid, '') self._tbuf.invalidate(oid, '')
return oids return tid, oids
def undo(self, transaction_id):
"""Storage API: undo a transaction, writing directly to the storage."""
if self._is_read_only:
raise POSException.ReadOnlyError()
oids = self._server.undo(transaction_id)
self._lock.acquire()
try:
for oid in oids:
self._cache.invalidate(oid, '')
finally:
self._lock.release()
return oids
def undoInfo(self, first=0, last=-20, specification=None): def undoInfo(self, first=0, last=-20, specification=None):
"""Storage API: return undo information.""" """Storage API: return undo information."""
...@@ -1059,15 +1080,15 @@ class ClientStorage(object): ...@@ -1059,15 +1080,15 @@ class ClientStorage(object):
try: try:
# versions maps version names to dictionary of invalidations # versions maps version names to dictionary of invalidations
versions = {} versions = {}
for oid, version in invs: for oid, version, tid in invs:
if oid == self._load_oid: if oid == self._load_oid:
self._load_status = 0 self._load_status = 0
self._cache.invalidate(oid, version=version) self._cache.invalidate(oid, version, tid)
versions.setdefault(version, {})[oid] = 1 versions.setdefault((version, tid), {})[oid] = tid
if self._db is not None: if self._db is not None:
for v, d in versions.items(): for (version, tid), d in versions.items():
self._db.invalidate(d, version=v) self._db.invalidate(tid, d, version=version)
finally: finally:
self._lock.release() self._lock.release()
...@@ -1099,7 +1120,8 @@ class ClientStorage(object): ...@@ -1099,7 +1120,8 @@ class ClientStorage(object):
for t in args: for t in args:
self._pickler.dump(t) self._pickler.dump(t)
return return
self._process_invalidations(args) self._process_invalidations([(oid, version, tid)
for oid, version in args])
# The following are for compatibility with protocol version 2.0.0 # The following are for compatibility with protocol version 2.0.0
...@@ -1110,36 +1132,10 @@ class ClientStorage(object): ...@@ -1110,36 +1132,10 @@ class ClientStorage(object):
end = endVerify end = endVerify
Invalidate = invalidateTrans Invalidate = invalidateTrans
try: def InvalidationLogIterator(fileobj):
StopIteration unpickler = cPickle.Unpickler(fileobj)
except NameError: while 1:
class StopIteration(Exception): oid, version = unpickler.load()
pass
class InvalidationLogIterator:
"""Helper class for reading invalidations in endVerify."""
def __init__(self, fileobj):
self._unpickler = cPickle.Unpickler(fileobj)
self.getitem_i = 0
def __iter__(self):
return self
def next(self):
oid, version = self._unpickler.load()
if oid is None: if oid is None:
raise StopIteration break
return oid, version yield oid, version, None
# The __getitem__() method is needed to support iteration
# in Python 2.1.
def __getitem__(self, i):
assert i == self.getitem_i
try:
obj = self.next()
except StopIteration:
raise IndexError, i
self.getitem_i += 1
return obj
try:
from Interface import Base
except ImportError:
class Base:
# a dummy interface for use when Zope's is unavailable
pass
class ICache(Base):
"""ZEO client cache.
__init__(storage, size, client, var)
All arguments optional.
storage -- name of storage
size -- max size of cache in bytes
client -- a string; if specified, cache is persistent.
var -- var directory to store cache files in
"""
def open():
"""Returns a sequence of object info tuples.
An object info tuple is a pair containing an object id and a
pair of serialnos, a non-version serialno and a version serialno:
oid, (serial, ver_serial)
This method builds an index of the cache and returns a
sequence used for cache validation.
"""
def close():
"""Closes the cache."""
def verify(func):
"""Call func on every object in cache.
func is called with three arguments
func(oid, serial, ver_serial)
"""
def invalidate(oid, version):
"""Remove object from cache."""
def load(oid, version):
"""Load object from cache.
Return None if object not in cache.
Return data, serialno if object is in cache.
"""
def store(oid, p, s, version, pv, sv):
"""Store a new object in the cache."""
def update(oid, serial, version, data):
"""Update an object already in the cache.
XXX This method is called to update objects that were modified by
a transaction. It's likely that it is already in the cache,
and it may be possible for the implementation to operate more
efficiently.
"""
def modifiedInVersion(oid):
"""Return the version an object is modified in.
'' signifies the trunk.
Returns None if the object is not in the cache.
"""
def checkSize(size):
"""Check if adding size bytes would exceed cache limit.
This method is often called just before store or update. The
size is a hint about the amount of data that is about to be
stored. The cache may want to evict some data to make space.
"""
...@@ -13,6 +13,18 @@ ...@@ -13,6 +13,18 @@
############################################################################## ##############################################################################
"""RPC stubs for interface exported by StorageServer.""" """RPC stubs for interface exported by StorageServer."""
##
# ZEO storage server.
# <p>
# Remote method calls can be synchronous or asynchronous. If the call
# is synchronous, the client thread blocks until the call returns. A
# single client can only have one synchronous request outstanding. If
# several threads share a single client, threads other than the caller
# will block only if the attempt to make another synchronous call.
# An asynchronous call does not cause the client thread to block. An
# exception raised by an asynchronous method is logged on the server,
# but is not returned to the client.
class StorageServer: class StorageServer:
"""An RPC stub class for the interface exported by ClientStorage. """An RPC stub class for the interface exported by ClientStorage.
...@@ -43,47 +55,174 @@ class StorageServer: ...@@ -43,47 +55,174 @@ class StorageServer:
def extensionMethod(self, name): def extensionMethod(self, name):
return ExtensionMethodWrapper(self.rpc, name).call return ExtensionMethodWrapper(self.rpc, name).call
##
# Register current connection with a storage and a mode.
# In effect, it is like an open call.
# @param storage_name a string naming the storage. This argument
# is primarily for backwards compatibility with servers
# that supported multiple storages.
# @param read_only boolean
# @exception ValueError unknown storage_name or already registered
# @exception ReadOnlyError storage is read-only and a read-write
# connectio was requested
def register(self, storage_name, read_only): def register(self, storage_name, read_only):
self.rpc.call('register', storage_name, read_only) self.rpc.call('register', storage_name, read_only)
##
# Return dictionary of meta-data about the storage.
# @defreturn dict
def get_info(self): def get_info(self):
return self.rpc.call('get_info') return self.rpc.call('get_info')
##
# Check whether the server requires authentication. Returns
# the name of the protocol.
# @defreturn string
def getAuthProtocol(self): def getAuthProtocol(self):
return self.rpc.call('getAuthProtocol') return self.rpc.call('getAuthProtocol')
##
# Return id of the last committed transaction
# @defreturn string
def lastTransaction(self): def lastTransaction(self):
# Not in protocol version 2.0.0; see __init__() # Not in protocol version 2.0.0; see __init__()
return self.rpc.call('lastTransaction') return self.rpc.call('lastTransaction')
##
# Return invalidations for all transactions after tid.
# @param tid transaction id
# @defreturn 2-tuple, (tid, list)
# @return tuple containing the last committed transaction
# and a list of oids that were invalidated. Returns
# None and an empty list if the server does not have
# the list of oids available.
def getInvalidations(self, tid): def getInvalidations(self, tid):
# Not in protocol version 2.0.0; see __init__() # Not in protocol version 2.0.0; see __init__()
return self.rpc.call('getInvalidations', tid) return self.rpc.call('getInvalidations', tid)
##
# Check whether serial numbers s and sv are current for oid.
# If one or both of the serial numbers are not current, the
# server will make an asynchronous invalidateVerify() call.
# @param oid object id
# @param s serial number on non-version data
# @param sv serial number of version data or None
# @defreturn async
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
self.rpc.callAsync('zeoVerify', oid, s, sv) self.rpc.callAsync('zeoVerify', oid, s, sv)
##
# Check whether current serial number is valid for oid and version.
# If the serial number is not current, the server will make an
# asynchronous invalidateVerify() call.
# @param oid object id
# @param version name of version for oid
# @param serial client's current serial number
# @defreturn async
def verify(self, oid, version, serial):
self.rpc.callAsync('verify', oid, version, serial)
##
# Signal to the server that cache verification is done.
# @defreturn async
def endZeoVerify(self): def endZeoVerify(self):
self.rpc.callAsync('endZeoVerify') self.rpc.callAsync('endZeoVerify')
##
# Generate a new set of oids.
# @param n number of new oids to return
# @defreturn list
# @return list of oids
def new_oids(self, n=None): def new_oids(self, n=None):
if n is None: if n is None:
return self.rpc.call('new_oids') return self.rpc.call('new_oids')
else: else:
return self.rpc.call('new_oids', n) return self.rpc.call('new_oids', n)
##
# Pack the storage.
# @param t pack time
# @param wait optional, boolean. If true, the call will not
# return until the pack is complete.
def pack(self, t, wait=None): def pack(self, t, wait=None):
if wait is None: if wait is None:
self.rpc.call('pack', t) self.rpc.call('pack', t)
else: else:
self.rpc.call('pack', t, wait) self.rpc.call('pack', t, wait)
##
# Return current data for oid. Version data is returned if
# present.
# @param oid object id
# @defreturn 5-tuple
# @return 5-tuple, current non-version data, serial number,
# version name, version data, version data serial number
# @exception KeyError if oid is not found
def zeoLoad(self, oid): def zeoLoad(self, oid):
return self.rpc.call('zeoLoad', oid) return self.rpc.call('zeoLoad', oid)
##
# Return current data for oid along with tid if transaction that
# wrote the date.
# @param oid object id
# @param version string, name of version
# @defreturn 4-tuple
# @return data, serial number, transaction id, version,
# where version is the name of the version the data came
# from or "" for non-version data
# @exception KeyError if oid is not found
def loadEx(self, oid, version):
return self.rpc.call("loadEx", oid, version)
##
# Return non-current data along with transaction ids that identify
# the lifetime of the specific revision.
# @param oid object id
# @param tid a transaction id that provides an upper bound on
# the lifetime of the revision. That is, loadBefore
# returns the revision that was current before tid committed.
# @defreturn 4-tuple
# @return data, serial numbr, start transaction id, end transaction id
def loadBefore(self, oid, tid):
return self.rpc.call("loadBefore", oid, tid)
##
# Storage new revision of oid.
# @param oid object id
# @param serial serial number that this transaction read
# @param data new data record for oid
# @param version name of version or ""
# @param id id of current transaction
# @defreturn async
def storea(self, oid, serial, data, version, id): def storea(self, oid, serial, data, version, id):
self.rpc.callAsync('storea', oid, serial, data, version, id) self.rpc.callAsync('storea', oid, serial, data, version, id)
##
# Start two-phase commit for a transaction
# @param id id used by client to identify current transaction. The
# only purpose of this argument is to distinguish among multiple
# threads using a single ClientStorage.
# @param user name of user committing transaction (can be "")
# @param description string containing transaction metadata (can be "")
# @param ext dictionary of extended metadata (?)
# @param tid optional explicit tid to pass to underlying storage
# @param status optional status character, e.g "p" for pack
# @defreturn async
def tpc_begin(self, id, user, descr, ext, tid, status): def tpc_begin(self, id, user, descr, ext, tid, status):
return self.rpc.call('tpc_begin', id, user, descr, ext, tid, status) return self.rpc.call('tpc_begin', id, user, descr, ext, tid, status)
......
...@@ -235,6 +235,14 @@ class ZEOStorage: ...@@ -235,6 +235,14 @@ class ZEOStorage:
def getExtensionMethods(self): def getExtensionMethods(self):
return self._extensions return self._extensions
def loadEx(self, oid, version):
self.stats.loads += 1
return self.storage.loadEx(oid, version)
def loadBefore(self, oid, tid):
self.stats.loads += 1
return self.storage.loadBefore(oid, tid)
def zeoLoad(self, oid): def zeoLoad(self, oid):
self.stats.loads += 1 self.stats.loads += 1
v = self.storage.modifiedInVersion(oid) v = self.storage.modifiedInVersion(oid)
...@@ -260,12 +268,26 @@ class ZEOStorage: ...@@ -260,12 +268,26 @@ class ZEOStorage:
% (len(invlist), u64(invtid))) % (len(invlist), u64(invtid)))
return invtid, invlist return invtid, invlist
def verify(self, oid, version, tid):
try:
t = self.storage.getTid(oid)
except KeyError:
self.client.invalidateVerify((oid, ""))
else:
if tid != t:
# This will invalidate non-version data when the
# client only has invalid version data. Since this is
# an uncommon case, we avoid the cost of checking
# whether the serial number matches the current
# non-version data.
self.client.invalidateVerify((oid, version))
def zeoVerify(self, oid, s, sv): def zeoVerify(self, oid, s, sv):
if not self.verifying: if not self.verifying:
self.verifying = 1 self.verifying = 1
self.stats.verifying_clients += 1 self.stats.verifying_clients += 1
try: try:
os = self.storage.getSerial(oid) os = self.storage.getTid(oid)
except KeyError: except KeyError:
self.client.invalidateVerify((oid, '')) self.client.invalidateVerify((oid, ''))
# XXX It's not clear what we should do now. The KeyError # XXX It's not clear what we should do now. The KeyError
...@@ -344,7 +366,7 @@ class ZEOStorage: ...@@ -344,7 +366,7 @@ class ZEOStorage:
def undoLog(self, first, last): def undoLog(self, first, last):
return run_in_thread(self.storage.undoLog, first, last) return run_in_thread(self.storage.undoLog, first, last)
def tpc_begin(self, id, user, description, ext, tid, status): def tpc_begin(self, id, user, description, ext, tid=None, status=" "):
if self.read_only: if self.read_only:
raise ReadOnlyError() raise ReadOnlyError()
if self.transaction is not None: if self.transaction is not None:
...@@ -521,25 +543,25 @@ class ZEOStorage: ...@@ -521,25 +543,25 @@ class ZEOStorage:
return self.storage.tpc_vote(self.transaction) return self.storage.tpc_vote(self.transaction)
def _abortVersion(self, src): def _abortVersion(self, src):
oids = self.storage.abortVersion(src, self.transaction) tid, oids = self.storage.abortVersion(src, self.transaction)
inv = [(oid, src) for oid in oids] inv = [(oid, src) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
def _commitVersion(self, src, dest): def _commitVersion(self, src, dest):
oids = self.storage.commitVersion(src, dest, self.transaction) tid, oids = self.storage.commitVersion(src, dest, self.transaction)
inv = [(oid, dest) for oid in oids] inv = [(oid, dest) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
if dest: if dest:
inv = [(oid, src) for oid in oids] inv = [(oid, src) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
def _transactionalUndo(self, trans_id): def _transactionalUndo(self, trans_id):
oids = self.storage.transactionalUndo(trans_id, self.transaction) tid, oids = self.storage.transactionalUndo(trans_id, self.transaction)
inv = [(oid, None) for oid in oids] inv = [(oid, None) for oid in oids]
self.invalidated.extend(inv) self.invalidated.extend(inv)
return oids return tid, oids
# When a delayed transaction is restarted, the dance is # When a delayed transaction is restarted, the dance is
# complicated. The restart occurs when one ZEOStorage instance # complicated. The restart occurs when one ZEOStorage instance
...@@ -854,6 +876,9 @@ class StorageServer: ...@@ -854,6 +876,9 @@ class StorageServer:
log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid))) log("tid to old for invq %s < %s" % (u64(tid), u64(earliest_tid)))
return None, [] return None, []
# XXX this is wrong! must check against tid or we invalidate
# too much.
oids = {} oids = {}
for tid, L in self.invq: for tid, L in self.invq:
for key in L: for key in L:
......
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Disk-based client cache for ZEO.
ClientCache exposes an API used by the ZEO client storage. FileCache
stores objects one disk using a 2-tuple of oid and tid as key.
The upper cache's API is similar to a storage API with methods like
load(), store(), and invalidate(). It manages in-memory data
structures that allow it to map this richer API onto the simple
key-based API of the lower-level cache.
"""
import bisect
import logging
import os
import struct
import tempfile
import time
from sets import Set
from ZODB.utils import z64, u64
##
# A disk-based cache for ZEO clients.
# <p>
# This class provides an interface to a persistent, disk-based cache
# used by ZEO clients to store copies of database records from the
# server.
# <p>
# The details of the constructor as unspecified at this point.
# <p>
# Each entry in the cache is valid for a particular range of transaction
# ids. The lower bound is the transaction that wrote the data. The
# upper bound is the next transaction that wrote a revision of the
# object. If the data is current, the upper bound is stored as None;
# the data is considered current until an invalidate() call is made.
# <p>
# It is an error to call store() twice with the same object without an
# intervening invalidate() to set the upper bound on the first cache
# entry. <em>Perhaps it will be necessary to have a call the removes
# something from the cache outright, without keeping a non-current
# entry.</em>
# <h3>Cache verification</h3>
# <p>
# When the client is connected to the server, it receives
# invalidations every time an object is modified. Whe the client is
# disconnected, it must perform cache verification to make sure its
# cached data is synchronized with the storage's current state.
# <p>
# quick verification
# full verification
# <p>
class ClientCache:
"""A simple in-memory cache."""
##
# Do we put the constructor here?
# @param path path of persistent snapshot of cache state
# @param size maximum size of object data, in bytes
def __init__(self, path=None, size=None, trace=True):
self.path = path
self.size = size
self.log = logging.getLogger("zeo.cache")
if trace and path:
self._setup_trace()
else:
self._trace = self._notrace
# Last transaction seen by the cache, either via setLastTid()
# or by invalidate().
self.tid = None
# The cache stores objects in a dict mapping (oid, tid) pairs
# to Object() records (see below). The tid is the transaction
# id that wrote the object. An object record includes data,
# serialno, and end tid. It has auxillary data structures to
# compute the appropriate tid, given the oid and a transaction id
# representing an arbitrary point in history.
#
# The serialized form of the cache just stores the Object()
# records. The in-memory form can be reconstructed from these
# records.
# Maps oid to current tid. Used to find compute key for objects.
self.current = {}
# Maps oid to list of (start_tid, end_tid) pairs in sorted order.
# Used to find matching key for load of non-current data.
self.noncurrent = {}
# Map oid to version, tid pair. If there is no entry, the object
# is not modified in a version.
self.version = {}
# A double-linked list is used to manage the cache. It makes
# decisions about which objects to keep and which to evict.
self.fc = FileCache(size or 10**6, self.path, self)
def open(self):
self.fc.scan(self.install)
def install(self, f, ent):
# Called by cache storage layer to insert object
o = Object.fromFile(f, ent.key, header_only=True)
if o is None:
return
oid = o.key[0]
if o.version:
self.version[oid] = o.version, o.start_tid
elif o.end_tid is None:
self.current[oid] = o.start_tid
else:
L = self.noncurrent.setdefault(oid, [])
bisect.insort_left(L, (o.start_tid, o.end_tid))
def close(self):
self.fc.close()
##
# Set the last transaction seen by the cache.
# @param tid a transaction id
# @exception ValueError attempt to set a new tid less than the current tid
def setLastTid(self, tid):
self.fc.settid(tid)
##
# Return the last transaction seen by the cache.
# @return a transaction id
# @defreturn string
def getLastTid(self):
if self.fc.tid == z64:
return None
else:
return self.fc.tid
##
# Return the current data record for oid and version.
# @param oid object id
# @param version a version string
# @return data record, serial number, tid or None if the object is not
# in the cache
# @defreturn 3-tuple: (string, string, string)
def load(self, oid, version=""):
tid = None
if version:
p = self.version.get(oid)
if p is None:
return None
elif p[0] == version:
tid = p[1]
# Otherwise, we know the cache has version data but not
# for the requested version. Thus, we know it is safe
# to return the non-version data from the cache.
if tid is None:
tid = self.current.get(oid)
if tid is None:
self._trace(0x20, oid, version)
return None
o = self.fc.access((oid, tid))
if o is None:
return None
self._trace(0x22, oid, version, o.start_tid, o.end_tid, len(o.data))
return o.data, tid, o.version
##
# Return a non-current revision of oid that was current before tid.
# @param oid object id
# @param tid id of transaction that wrote next revision of oid
# @return data record, serial number, start tid, and end tid
# @defreturn 4-tuple: (string, string, string, string)
def loadBefore(self, oid, tid):
L = self.noncurrent.get(oid)
if L is None:
self._trace(0x24, oid, tid)
return None
# A pair with None as the second element will always be less
# than any pair with the same first tid.
i = bisect.bisect_left(L, (tid, None))
# The least element left of tid was written before tid. If
# there is no element, the cache doesn't have old enough data.
if i == 0:
self._trace(0x24, oid, tid)
return
lo, hi = L[i-1]
# XXX lo should always be less than tid
if not lo < tid <= hi:
self._trace(0x24, oid, tid)
return None
o = self.fc.access((oid, lo))
self._trace(0x26, oid, tid)
return o.data, o.start_tid, o.end_tid
##
# Return the version an object is modified in or None for an
# object that is not modified in a version.
# @param oid object id
# @return name of version in which the object is modified
# @defreturn string or None
def modifiedInVersion(self, oid):
p = self.version.get(oid)
if p is None:
return None
version, tid = p
return version
##
# Store a new data record in the cache.
# @param oid object id
# @param version name of version that oid was modified in. The cache
# only stores current version data, so end_tid should
# be None.
# @param start_tid the id of the transaction that wrote this revision
# @param end_tid the id of the transaction that created the next
# revision of oid. If end_tid is None, the data is
# current.
# @param data the actual data
# @exception ValueError tried to store non-current version data
def store(self, oid, version, start_tid, end_tid, data):
# It's hard for the client to avoid storing the same object
# more than once. One case is whether the client requests
# version data that doesn't exist. It checks the cache for
# the requested version, doesn't find it, then asks the server
# for that data. The server returns the non-version data,
# which may already by in the cache.
if (oid, start_tid) in self.fc:
return
o = Object((oid, start_tid), version, data, start_tid, end_tid)
if version:
if end_tid is not None:
raise ValueError("cache only stores current version data")
if oid in self.version:
if self.version[oid] != (version, start_tid):
raise ValueError("data already exists for version %r"
% self.version[oid][0])
self.version[oid] = version, start_tid
self._trace(0x50, oid, version, start_tid, dlen=len(data))
else:
if end_tid is None:
_cur_start = self.current.get(oid)
if _cur_start:
if _cur_start != start_tid:
raise ValueError(
"already have current data for oid")
else:
return
self.current[oid] = start_tid
self._trace(0x52, oid, version, start_tid, dlen=len(data))
else:
L = self.noncurrent.setdefault(oid, [])
p = start_tid, end_tid
if p in L:
return # duplicate store
bisect.insort_left(L, (start_tid, end_tid))
self._trace(0x54, oid, version, start_tid, end_tid,
dlen=len(data))
self.fc.add(o)
##
# Mark the current data for oid as non-current. If there is no
# current data for oid, do nothing.
# @param oid object id
# @param version name of version to invalidate.
# @param tid the id of the transaction that wrote a new revision of oid
def invalidate(self, oid, version, tid):
if tid > self.fc.tid:
self.fc.settid(tid)
if oid in self.version:
self._trace(0x1A, oid, version, tid)
dllversion, dlltid = self.version[oid]
assert not version or version == dllversion, (version, dllversion)
# remove() will call unlink() to delete from self.version
self.fc.remove((oid, dlltid))
# And continue on, we must also remove any non-version data
# from the cache. This is a bit of a failure of the current
# cache consistency approach as the new tid of the version
# data gets confused with the old tid of the non-version data.
# I could sort this out, but it seems simpler to punt and
# have the cache invalidation too much for versions.
if oid not in self.current:
self._trace(0x10, oid, version, tid)
return
cur_tid = self.current.pop(oid)
# XXX Want to fetch object without marking it as accessed
o = self.fc.access((oid, cur_tid))
if o is None:
# XXX is this possible?
return None
o.end_tid = tid
self.fc.update(o)
self._trace(0x1C, oid, version, tid)
L = self.noncurrent.setdefault(oid, [])
bisect.insort_left(L, (cur_tid, tid))
##
# Return the number of object revisions in the cache.
# XXX just return len(self.cache)?
def __len__(self):
n = len(self.current) + len(self.version)
if self.noncurrent:
n += sum(map(len, self.noncurrent))
return n
##
# Generates over, version, serial triples for all objects in the
# cache. This generator is used by cache verification.
def contents(self):
# XXX May need to materialize list instead of iterating,
# depends on whether the caller may change the cache.
for o in self.fc:
oid, tid = o.key
if oid in self.version:
obj = self.fc.access(o.key)
yield oid, tid, obj.version
else:
yield oid, tid, ""
def dump(self):
from ZODB.utils import oid_repr
print "cache size", len(self)
L = list(self.contents())
L.sort()
for oid, tid, version in L:
print oid_repr(oid), oid_repr(tid), repr(version)
print "dll contents"
L = list(self.fc)
L.sort(lambda x,y:cmp(x.key, y.key))
for x in L:
end_tid = x.end_tid or z64
print oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid)
print
def _evicted(self, o):
# Called by Object o to signal its eviction
oid, tid = o.key
if o.end_tid is None:
if o.version:
del self.version[oid]
else:
del self.current[oid]
else:
# XXX Although we use bisect to keep the list sorted,
# we never expect the list to be very long. So the
# brute force approach should normally be fine.
L = self.noncurrent[oid]
L.remove((o.start_tid, o.end_tid))
def _setup_trace(self):
tfn = self.path + ".trace"
self.tracefile = None
try:
self.tracefile = open(tfn, "ab")
self._trace(0x00)
except IOError, msg:
self.tracefile = None
self.log.warning("Could not write to trace file %s: %s",
tfn, msg)
def _notrace(self, *arg, **kwargs):
pass
def _trace(self,
code, oid="", version="", tid="", end_tid=z64, dlen=0,
# The next two are just speed hacks.
time_time=time.time, struct_pack=struct.pack):
# The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'.
# If it is in "ACE" then it is a 'hit'.
# This method has been carefully tuned to be as fast as possible.
# Note: when tracing is disabled, this method is hidden by a dummy.
if version:
code |= 0x80
encoded = (dlen + 255) & 0x7fffff00 | code
if tid is None:
tid = z64
if end_tid is None:
end_tid = z64
try:
self.tracefile.write(
struct_pack(">iiH8s8s",
time_time(),
encoded,
len(oid),
tid, end_tid) + oid)
except:
print `tid`, `end_tid`
raise
##
# An Object stores the cached data for a single object.
# <p>
# The cached data includes the actual object data, the key, and three
# data fields that describe the validity period and version of the
# object. The key contains the oid and a redundant start_tid. The
# actual size of an object is variable, depending on the size of the
# data and whether it is in a version.
# <p>
# The serialized format does not include the key, because it is stored
# in the header used by the cache's storage format.
class Object(object):
__slots__ = (# pair, object id, txn id -- something usable as a dict key
# the second part of the part is equal to start_tid below
"key",
"start_tid", # string, id of txn that wrote the data
"end_tid", # string, id of txn that wrote next revision
# or None
"version", # string, name of version
"data", # string, the actual data record for the object
"size", # total size of serialized object
)
def __init__(self, key, version, data, start_tid, end_tid):
self.key = key
self.version = version
self.data = data
self.start_tid = start_tid
self.end_tid = end_tid
# The size of a the serialized object on disk, include the
# 14-byte header, the length of data and version, and a
# copy of the 8-byte oid.
if data is not None:
self.size = 22 + len(data) + len(version)
# The serialization format uses an end tid of "\0" * 8, the least
# 8-byte string, to represent None. It isn't possible for an
# end_tid to be 0, because it must always be strictly greater
# than the start_tid.
fmt = ">8shi"
def serialize(self, f):
# Write standard form of Object to file, f.
self.serialize_header(f)
f.write(self.data)
f.write(struct.pack(">8s", self.key[0]))
def serialize_header(self, f):
s = struct.pack(self.fmt, self.end_tid or "\0" * 8,
len(self.version), len(self.data))
f.write(s)
f.write(self.version)
def fromFile(cls, f, key, header_only=False):
s = f.read(struct.calcsize(cls.fmt))
if not s:
return None
oid, start_tid = key
end_tid, vlen, dlen = struct.unpack(cls.fmt, s)
if end_tid == z64:
end_tid = None
version = f.read(vlen)
if vlen != len(version):
raise ValueError("corrupted record, version")
if header_only:
data = None
else:
data = f.read(dlen)
if dlen != len(data):
raise ValueError("corrupted record, data")
s = f.read(8)
if struct.pack(">8s", s) != oid:
raise ValueError("corrupted record, oid")
return cls((oid, start_tid), version, data, start_tid, end_tid)
fromFile = classmethod(fromFile)
def sync(f):
f.flush()
if hasattr(os, 'fsync'):
os.fsync(f.fileno())
class Entry(object):
__slots__ = (# object key -- something usable as a dict key.
'key',
# Offset from start of file to the object's data
# record; this includes all overhead bytes (status
# byte, size bytes, etc). The size of the data
# record is stored in the file near the start of the
# record, but for efficiency we also keep size in a
# dict (filemap; see later).
'offset',
)
def __init__(self, key=None, offset=None):
self.key = key
self.offset = offset
magic = "ZEC3"
OBJECT_HEADER_SIZE = 1 + 4 + 16
##
# FileCache stores a cache in a single on-disk file.
#
# On-disk cache structure
#
# The file begins with a 12-byte header. The first four bytes are the
# file's magic number - ZEC3 - indicating zeo cache version 3. The
# next eight bytes are the last transaction id.
#
# The file is a contiguous sequence of blocks. All blocks begin with
# a one-byte status indicator:
#
# 'a'
# Allocated. The block holds an object; the next 4 bytes are >I
# format total block size.
#
# 'f'
# Free. The block is free; the next 4 bytes are >I format total
# block size.
#
# '1', '2', '3', '4'
# The block is free, and consists of 1, 2, 3 or 4 bytes total.
#
# 'Z'
# File header. The file starts with a magic number, currently
# 'ZEC3' and an 8-byte transaction id.
#
# "Total" includes the status byte, and size bytes. There are no
# empty (size 0) blocks.
# XXX This needs a lot more hair.
# The structure of an allocated block is more complicated:
#
# 1 byte allocation status ('a').
# 4 bytes block size, >I format.
# 16 bytes oid + tid, string.
# size-OBJECT_HEADER_SIZE bytes, the object pickle.
# The cache's currentofs goes around the file, circularly, forever.
# It's always the starting offset of some block.
#
# When a new object is added to the cache, it's stored beginning at
# currentofs, and currentofs moves just beyond it. As many contiguous
# blocks needed to make enough room for the new object are evicted,
# starting at currentofs. Exception: if currentofs is close enough
# to the end of the file that the new object can't fit in one
# contiguous chunk, currentofs is reset to 0 first.
# Do all possible to ensure that the bytes we wrote are really on
# disk.
class FileCache(object):
def __init__(self, maxsize, fpath, parent, reuse=True):
# Maximum total of object sizes we keep in cache.
self.maxsize = maxsize
# Current total of object sizes in cache.
self.currentsize = 0
self.parent = parent
self.tid = None
# Map offset in file to pair (data record size, Entry).
# Entry is None iff the block starting at offset is free.
# filemap always contains a complete account of what's in the
# file -- study method _verify_filemap for executable checking
# of the relevant invariants. An offset is at the start of a
# block iff it's a key in filemap.
self.filemap = {}
# Map key to Entry. There's one entry for each object in the
# cache file. After
# obj = key2entry[key]
# then
# obj.key == key
# is true.
self.key2entry = {}
# Always the offset into the file of the start of a block.
# New and relocated objects are always written starting at
# currentofs.
self.currentofs = 12
self.fpath = fpath
if not reuse or not fpath or not os.path.exists(fpath):
self.new = True
if fpath:
self.f = file(fpath, 'wb+')
else:
self.f = tempfile.TemporaryFile()
# Make sure the OS really saves enough bytes for the file.
self.f.seek(self.maxsize - 1)
self.f.write('x')
self.f.truncate()
# Start with one magic header block
self.f.seek(0)
self.f.write(magic)
self.f.write(z64)
# and one free block.
self.f.write('f' + struct.pack(">I", self.maxsize - 12))
self.sync()
self.filemap[12] = self.maxsize - 12, None
else:
self.new = False
self.f = None
# Statistics: _n_adds, _n_added_bytes,
# _n_evicts, _n_evicted_bytes
self.clearStats()
# Scan the current contents of the cache file, calling install
# for each object found in the cache. This method should only
# be called once to initialize the cache from disk.
def scan(self, install):
if self.new:
return
fsize = os.path.getsize(self.fpath)
self.f = file(self.fpath, 'rb+')
_magic = self.f.read(4)
if _magic != magic:
raise ValueError("unexpected magic number: %r" % _magic)
self.tid = self.f.read(8)
# Remember the largest free block. That seems a
# decent place to start currentofs.
max_free_size = max_free_offset = 0
ofs = 12
while ofs < fsize:
self.f.seek(ofs)
ent = None
status = self.f.read(1)
if status == 'a':
size, rawkey = struct.unpack(">I16s", self.f.read(20))
key = rawkey[:8], rawkey[8:]
assert key not in self.key2entry
self.key2entry[key] = ent = Entry(key, ofs)
install(self.f, ent)
elif status == 'f':
size, = struct.unpack(">I", self.f.read(4))
elif status in '1234':
size = int(status)
else:
assert 0, status
self.filemap[ofs] = size, ent
if ent is None and size > max_free_size:
max_free_size, max_free_offset = size, ofs
ofs += size
assert ofs == fsize
if __debug__:
self._verify_filemap()
self.currentofs = max_free_offset
def clearStats(self):
self._n_adds = self._n_added_bytes = 0
self._n_evicts = self._n_evicted_bytes = 0
self._n_removes = self._n_removed_bytes = 0
self._n_accesses = 0
def getStats(self):
return (self._n_adds, self._n_added_bytes,
self._n_evicts, self._n_evicted_bytes,
self._n_removes, self._n_removed_bytes,
self._n_accesses
)
def __len__(self):
return len(self.key2entry)
def __iter__(self):
return self.key2entry.itervalues()
def __contains__(self, key):
return key in self.key2entry
def sync(self):
sync(self.f)
def close(self):
if self.f:
self.sync()
self.f.close()
self.f = None
# Evict objects as necessary to free up at least nbytes bytes,
# starting at currentofs. If currentofs is closer than nbytes to
# the end of the file, currentofs is reset to 0. The number of
# bytes actually freed may be (and probably will be) greater than
# nbytes, and is _makeroom's return value. The file is not
# altered by _makeroom. filemap is updated to reflect the
# evictions, and it's the caller's responsibilty both to fiddle
# the file, and to update filemap, to account for all the space
# freed (starting at currentofs when _makeroom returns, and
# spanning the number of bytes retured by _makeroom).
def _makeroom(self, nbytes):
assert 0 < nbytes <= self.maxsize
if self.currentofs + nbytes > self.maxsize:
self.currentofs = 12
ofs = self.currentofs
while nbytes > 0:
size, e = self.filemap.pop(ofs)
if e is not None:
self._evictobj(e, size)
ofs += size
nbytes -= size
return ofs - self.currentofs
# Write Object obj, with data, to file starting at currentofs.
# nfreebytes are already available for overwriting, and it's
# guranteed that's enough. obj.offset is changed to reflect the
# new data record position, and filemap is updated to match.
def _writeobj(self, obj, nfreebytes):
size = OBJECT_HEADER_SIZE + obj.size
assert size <= nfreebytes
excess = nfreebytes - size
# If there's any excess (which is likely), we need to record a
# free block following the end of the data record. That isn't
# expensive -- it's all a contiguous write.
if excess == 0:
extra = ''
elif excess < 5:
extra = "01234"[excess]
else:
extra = 'f' + struct.pack(">I", excess)
self.f.seek(self.currentofs)
self.f.writelines(('a',
struct.pack(">I8s8s", size,
obj.key[0], obj.key[1])))
obj.serialize(self.f)
self.f.write(extra)
e = Entry(obj.key, self.currentofs)
self.key2entry[obj.key] = e
self.filemap[self.currentofs] = size, e
self.currentofs += size
if excess:
# We need to record the free block in filemap, but there's
# no need to advance currentofs beyond it. Instead it
# gives some breathing room for the next object to get
# written.
self.filemap[self.currentofs] = excess, None
def add(self, object):
size = OBJECT_HEADER_SIZE + object.size
if size > self.maxsize:
return
assert size <= self.maxsize
assert object.key not in self.key2entry
assert len(object.key[0]) == 8
assert len(object.key[1]) == 8
self._n_adds += 1
self._n_added_bytes += size
available = self._makeroom(size)
self._writeobj(object, available)
def _verify_filemap(self, display=False):
a = 12
f = self.f
while a < self.maxsize:
f.seek(a)
status = f.read(1)
if status in 'af':
size, = struct.unpack(">I", f.read(4))
else:
size = int(status)
if display:
if a == self.currentofs:
print '*****',
print "%c%d" % (status, size),
size2, obj = self.filemap[a]
assert size == size2
assert (obj is not None) == (status == 'a')
if obj is not None:
assert obj.offset == a
assert self.key2entry[obj.key] is obj
a += size
if display:
print
assert a == self.maxsize
def _evictobj(self, e, size):
self._n_evicts += 1
self._n_evicted_bytes += size
# Load the object header into memory so we know how to
# update the parent's in-memory data structures.
self.f.seek(e.offset + OBJECT_HEADER_SIZE)
o = Object.fromFile(self.f, e.key, header_only=True)
self.parent._evicted(o)
##
# Return object for key or None if not in cache.
def access(self, key):
self._n_accesses += 1
e = self.key2entry.get(key)
if e is None:
return None
offset = e.offset
size, e2 = self.filemap[offset]
assert e is e2
self.f.seek(offset + OBJECT_HEADER_SIZE)
return Object.fromFile(self.f, key)
##
# Remove object for key from cache, if present.
def remove(self, key):
# If an object is being explicitly removed, we need to load
# its header into memory and write a free block marker to the
# disk where the object was stored. We need to load the
# header to update the in-memory data structures held by
# ClientCache.
# XXX Or we could just keep the header in memory at all times.
e = self.key2entry.get(key)
if e is None:
return
offset = e.offset
size, e2 = self.filemap[offset]
self.f.seek(offset + OBJECT_HEADER_SIZE)
o = Object.fromFile(self.f, key, header_only=True)
self.f.seek(offset + OBJECT_HEADER_SIZE)
self.f.write('f')
self.f.flush()
self.parent._evicted(o)
self.filemap[offset] = size, None
##
# Update on-disk representation of obj.
#
# This method should be called when the object header is modified.
def update(self, obj):
e = self.key2entry[obj.key]
self.f.seek(e.offset + OBJECT_HEADER_SIZE)
obj.serialize_header(self.f)
def settid(self, tid):
if self.tid is not None:
if tid < self.tid:
raise ValueError(
"new last tid must be greater that previous one")
self.tid = tid
self.f.seek(4)
self.f.write(tid)
self.f.flush()
...@@ -128,15 +128,21 @@ def main(): ...@@ -128,15 +128,21 @@ def main():
# Read file, gathering statistics, and printing each record if verbose # Read file, gathering statistics, and printing each record if verbose
rt0 = time.time() rt0 = time.time()
# bycode -- map code to count of occurrences
bycode = {} bycode = {}
# records -- number of records
records = 0 records = 0
# version -- number of records with versions
versions = 0 versions = 0
t0 = te = None t0 = te = None
# datarecords -- number of records with dlen set
datarecords = 0 datarecords = 0
datasize = 0L datasize = 0L
file0 = file1 = 0 # oids -- maps oid to number of times it was loaded
oids = {} oids = {}
# bysize -- maps data size to number of loads
bysize = {} bysize = {}
# bysize -- maps data size to number of writes
bysizew = {} bysizew = {}
total_loads = 0 total_loads = 0
byinterval = {} byinterval = {}
...@@ -157,12 +163,12 @@ def main(): ...@@ -157,12 +163,12 @@ def main():
if not quiet: if not quiet:
print "Skipping 8 bytes at offset", offset-8 print "Skipping 8 bytes at offset", offset-8
continue continue
r = f_read(10) r = f_read(18)
if len(r) < 10: if len(r) < 10:
break break
offset += 10 offset += 10
records += 1 records += 1
oidlen, serial = struct_unpack(">H8s", r) oidlen, start_tid, end_tid = struct_unpack(">H8s8s", r)
oid = f_read(oidlen) oid = f_read(oidlen)
if len(oid) != oidlen: if len(oid) != oidlen:
break break
...@@ -187,11 +193,6 @@ def main(): ...@@ -187,11 +193,6 @@ def main():
if code & 0x80: if code & 0x80:
version = 'V' version = 'V'
versions += 1 versions += 1
current = code & 1
if current:
file1 += 1
else:
file0 += 1
code = code & 0x7e code = code & 0x7e
bycode[code] = bycode.get(code, 0) + 1 bycode[code] = bycode.get(code, 0) + 1
byinterval[code] = byinterval.get(code, 0) + 1 byinterval[code] = byinterval.get(code, 0) + 1
...@@ -199,22 +200,23 @@ def main(): ...@@ -199,22 +200,23 @@ def main():
if code & 0x70 == 0x20: # All loads if code & 0x70 == 0x20: # All loads
bysize[dlen] = d = bysize.get(dlen) or {} bysize[dlen] = d = bysize.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
elif code == 0x3A: # Update elif code & 0x70 == 0x50: # All stores
bysizew[dlen] = d = bysizew.get(dlen) or {} bysizew[dlen] = d = bysizew.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
if verbose: if verbose:
print "%s %d %02x %s %016x %1s %s" % ( print "%s %d %02x %s %016x %016x %1s %s" % (
time.ctime(ts)[4:-5], time.ctime(ts)[4:-5],
current, current,
code, code,
oid_repr(oid), oid_repr(oid),
U64(serial), U64(start_tid),
U64(end_tid),
version, version,
dlen and str(dlen) or "") dlen and str(dlen) or "")
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
oids[oid] = oids.get(oid, 0) + 1 oids[oid] = oids.get(oid, 0) + 1
total_loads += 1 total_loads += 1
if code in (0x00, 0x70): if code == 0x00:
if not quiet: if not quiet:
dumpbyinterval(byinterval, h0, he) dumpbyinterval(byinterval, h0, he)
byinterval = {} byinterval = {}
...@@ -222,10 +224,7 @@ def main(): ...@@ -222,10 +224,7 @@ def main():
h0 = he = ts h0 = he = ts
if not quiet: if not quiet:
print time.ctime(ts)[4:-5], print time.ctime(ts)[4:-5],
if code == 0x00: print '='*20, "Restart", '='*20
print '='*20, "Restart", '='*20
else:
print '-'*20, "Flip->%d" % current, '-'*20
except KeyboardInterrupt: except KeyboardInterrupt:
print "\nInterrupted. Stats so far:\n" print "\nInterrupted. Stats so far:\n"
...@@ -248,8 +247,6 @@ def main(): ...@@ -248,8 +247,6 @@ def main():
print "First time: %s" % time.ctime(t0) print "First time: %s" % time.ctime(t0)
print "Last time: %s" % time.ctime(te) print "Last time: %s" % time.ctime(te)
print "Duration: %s seconds" % addcommas(te-t0) print "Duration: %s seconds" % addcommas(te-t0)
print "File stats: %s in file 0; %s in file 1" % (
addcommas(file0), addcommas(file1))
print "Data recs: %s (%.1f%%), average size %.1f KB" % ( print "Data recs: %s (%.1f%%), average size %.1f KB" % (
addcommas(datarecords), addcommas(datarecords),
100.0 * datarecords / records, 100.0 * datarecords / records,
...@@ -314,7 +311,7 @@ def dumpbyinterval(byinterval, h0, he): ...@@ -314,7 +311,7 @@ def dumpbyinterval(byinterval, h0, he):
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
n = byinterval[code] n = byinterval[code]
loads += n loads += n
if code in (0x2A, 0x2C, 0x2E): if code in (0x22, 0x26):
hits += n hits += n
if not loads: if not loads:
return return
...@@ -333,7 +330,7 @@ def hitrate(bycode): ...@@ -333,7 +330,7 @@ def hitrate(bycode):
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
n = bycode[code] n = bycode[code]
loads += n loads += n
if code in (0x2A, 0x2C, 0x2E): if code in (0x22, 0x26):
hits += n hits += n
if loads: if loads:
return 100.0 * hits / loads return 100.0 * hits / loads
...@@ -376,31 +373,18 @@ explain = { ...@@ -376,31 +373,18 @@ explain = {
0x00: "_setup_trace (initialization)", 0x00: "_setup_trace (initialization)",
0x10: "invalidate (miss)", 0x10: "invalidate (miss)",
0x1A: "invalidate (hit, version, writing 'n')", 0x1A: "invalidate (hit, version)",
0x1C: "invalidate (hit, writing 'i')", 0x1C: "invalidate (hit, saving non-current)",
0x20: "load (miss)", 0x20: "load (miss)",
0x22: "load (miss, version, status 'n')", 0x22: "load (hit)",
0x24: "load (miss, deleting index entry)", 0x24: "load (non-current, miss)",
0x26: "load (miss, no non-version data)", 0x26: "load (non-current, hit)",
0x28: "load (miss, version mismatch, no non-version data)",
0x2A: "load (hit, returning non-version data)",
0x2C: "load (hit, version mismatch, returning non-version data)",
0x2E: "load (hit, returning version data)",
0x3A: "update", 0x50: "store (version)",
0x52: "store (current, non-version)",
0x54: "store (non-current)",
0x40: "modifiedInVersion (miss)",
0x4A: "modifiedInVersion (hit, return None, status 'n')",
0x4C: "modifiedInVersion (hit, return '')",
0x4E: "modifiedInVersion (hit, return version)",
0x5A: "store (non-version data present)",
0x5C: "store (only version data present)",
0x6A: "_copytocurrent",
0x70: "checkSize (cache flip)",
} }
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -42,7 +42,7 @@ class TransUndoStorageWithCache: ...@@ -42,7 +42,7 @@ class TransUndoStorageWithCache:
t.note('undo1') t.note('undo1')
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
# Make sure this doesn't load invalid data into the cache # Make sure this doesn't load invalid data into the cache
self._storage.load(oid, '') self._storage.load(oid, '')
......
...@@ -71,7 +71,7 @@ class WorkerThread(TestThread): ...@@ -71,7 +71,7 @@ class WorkerThread(TestThread):
# self.storage.tpc_vote(self.trans) # self.storage.tpc_vote(self.trans)
rpc = self.storage._server.rpc rpc = self.storage._server.rpc
msgid = rpc._deferred_call('vote', self.storage._serial) msgid = rpc._deferred_call('vote', id(self.trans))
self.ready.set() self.ready.set()
rpc._deferred_wait(msgid) rpc._deferred_wait(msgid)
self.storage._check_serials() self.storage._check_serials()
...@@ -103,6 +103,51 @@ class CommitLockTests: ...@@ -103,6 +103,51 @@ class CommitLockTests:
self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', txn) self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', txn)
return oid, txn return oid, txn
def _begin_threads(self):
# Start a second transaction on a different connection without
# blocking the test thread. Returns only after each thread has
# set it's ready event.
self._storages = []
self._threads = []
for i in range(self.NUM_CLIENTS):
storage = self._duplicate_client()
txn = Transaction()
tid = self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
t.start()
t.ready.wait()
# Close on the connections abnormally to test server response
if i == 0:
storage.close()
else:
self._storages.append((storage, txn))
def _finish_threads(self):
for t in self._threads:
t.cleanup()
def _duplicate_client(self):
"Open another ClientStorage to the same server."
# XXX argh it's hard to find the actual address
# The rpc mgr addr attribute is a list. Each element in the
# list is a socket domain (AF_INET, AF_UNIX, etc.) and an
# address.
addr = self._storage._addr
new = ZEO.ClientStorage.ClientStorage(addr, wait=1)
new.registerDB(DummyDB(), None)
return new
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
return `t`
class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self): def checkCommitLockVoteFinish(self):
oid, txn = self._start_txn() oid, txn = self._start_txn()
self._storage.tpc_vote(txn) self._storage.tpc_vote(txn)
...@@ -141,15 +186,16 @@ class CommitLockTests: ...@@ -141,15 +186,16 @@ class CommitLockTests:
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
class CommitLockUndoTests(CommitLockTests):
def _get_trans_id(self): def _get_trans_id(self):
self._dostore() self._dostore()
L = self._storage.undoInfo() L = self._storage.undoInfo()
return L[0]['id'] return L[0]['id']
def _begin_undo(self, trans_id): def _begin_undo(self, trans_id, txn):
rpc = self._storage._server.rpc rpc = self._storage._server.rpc
return rpc._deferred_call('transactionalUndo', trans_id, return rpc._deferred_call('transactionalUndo', trans_id, id(txn))
self._storage._serial)
def _finish_undo(self, msgid): def _finish_undo(self, msgid):
return self._storage._server.rpc._deferred_wait(msgid) return self._storage._server.rpc._deferred_wait(msgid)
...@@ -157,7 +203,7 @@ class CommitLockTests: ...@@ -157,7 +203,7 @@ class CommitLockTests:
def checkCommitLockUndoFinish(self): def checkCommitLockUndoFinish(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -174,7 +220,7 @@ class CommitLockTests: ...@@ -174,7 +220,7 @@ class CommitLockTests:
def checkCommitLockUndoAbort(self): def checkCommitLockUndoAbort(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -190,7 +236,7 @@ class CommitLockTests: ...@@ -190,7 +236,7 @@ class CommitLockTests:
def checkCommitLockUndoClose(self): def checkCommitLockUndoClose(self):
trans_id = self._get_trans_id() trans_id = self._get_trans_id()
oid, txn = self._start_txn() oid, txn = self._start_txn()
msgid = self._begin_undo(trans_id) msgid = self._begin_undo(trans_id, txn)
self._begin_threads() self._begin_threads()
...@@ -201,46 +247,3 @@ class CommitLockTests: ...@@ -201,46 +247,3 @@ class CommitLockTests:
self._finish_threads() self._finish_threads()
self._cleanup() self._cleanup()
def _begin_threads(self):
# Start a second transaction on a different connection without
# blocking the test thread. Returns only after each thread has
# set it's ready event.
self._storages = []
self._threads = []
for i in range(self.NUM_CLIENTS):
storage = self._duplicate_client()
txn = Transaction()
tid = self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
t.start()
t.ready.wait()
# Close on the connections abnormally to test server response
if i == 0:
storage.close()
else:
self._storages.append((storage, txn))
def _finish_threads(self):
for t in self._threads:
t.cleanup()
def _duplicate_client(self):
"Open another ClientStorage to the same server."
# XXX argh it's hard to find the actual address
# The rpc mgr addr attribute is a list. Each element in the
# list is a socket domain (AF_INET, AF_UNIX, etc.) and an
# address.
addr = self._storage._addr
new = ZEO.ClientStorage.ClientStorage(addr, wait=1)
new.registerDB(DummyDB(), None)
return new
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
return `t`
...@@ -109,7 +109,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -109,7 +109,7 @@ class CommonSetupTearDown(StorageTestBase):
os.waitpid(pid, 0) os.waitpid(pid, 0)
for c in self.caches: for c in self.caches:
for i in 0, 1: for i in 0, 1:
path = "c1-%s-%d.zec" % (c, i) path = "%s-%s.zec" % (c, "1")
# On Windows before 2.3, we don't have a way to wait for # On Windows before 2.3, we don't have a way to wait for
# the spawned server(s) to close, and they inherited # the spawned server(s) to close, and they inherited
# file descriptors for our open files. So long as those # file descriptors for our open files. So long as those
...@@ -584,6 +584,9 @@ class InvqTests(CommonSetupTearDown): ...@@ -584,6 +584,9 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid) revid = self._dostore(oid)
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
# sync() is needed to prevent invalidation for oid from arriving
# in the middle of the load() call.
perstorage.sync()
perstorage.load(oid, '') perstorage.load(oid, '')
perstorage.close() perstorage.close()
...@@ -853,7 +856,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -853,7 +856,7 @@ class TimeoutTests(CommonSetupTearDown):
unless = self.failUnless unless = self.failUnless
self._storage = storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
# Assert that the zeo cache is empty # Assert that the zeo cache is empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Create the object # Create the object
oid = storage.new_oid() oid = storage.new_oid()
obj = MinPO(7) obj = MinPO(7)
...@@ -872,7 +875,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -872,7 +875,7 @@ class TimeoutTests(CommonSetupTearDown):
# We expect finish to fail # We expect finish to fail
raises(ClientDisconnected, storage.tpc_finish, t) raises(ClientDisconnected, storage.tpc_finish, t)
# The cache should still be empty # The cache should still be empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Load should fail since the object should not be in either the cache # Load should fail since the object should not be in either the cache
# or the server. # or the server.
raises(KeyError, storage.load, oid, '') raises(KeyError, storage.load, oid, '')
...@@ -883,7 +886,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -883,7 +886,7 @@ class TimeoutTests(CommonSetupTearDown):
unless = self.failUnless unless = self.failUnless
self._storage = storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
# Assert that the zeo cache is empty # Assert that the zeo cache is empty
unless(not storage._cache._index) unless(not list(storage._cache.contents()))
# Create the object # Create the object
oid = storage.new_oid() oid = storage.new_oid()
obj = MinPO(7) obj = MinPO(7)
......
...@@ -39,7 +39,23 @@ from ZODB.POSException \ ...@@ -39,7 +39,23 @@ from ZODB.POSException \
# thought they added (i.e., the keys for which get_transaction().commit() # thought they added (i.e., the keys for which get_transaction().commit()
# did not raise any exception). # did not raise any exception).
class StressThread(TestThread): class FailableThread(TestThread):
# mixin class
# subclass must provide
# - self.stop attribute (an event)
# - self._testrun() method
def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
class StressThread(FailableThread):
# Append integers startnum, startnum + step, startnum + 2*step, ... # Append integers startnum, startnum + step, startnum + 2*step, ...
# to 'tree' until Event stop is set. If sleep is given, sleep # to 'tree' until Event stop is set. If sleep is given, sleep
...@@ -57,7 +73,7 @@ class StressThread(TestThread): ...@@ -57,7 +73,7 @@ class StressThread(TestThread):
self.added_keys = [] self.added_keys = []
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def _testrun(self):
cn = self.db.open() cn = self.db.open()
while not self.stop.isSet(): while not self.stop.isSet():
try: try:
...@@ -87,7 +103,7 @@ class StressThread(TestThread): ...@@ -87,7 +103,7 @@ class StressThread(TestThread):
key += self.step key += self.step
cn.close() cn.close()
class LargeUpdatesThread(TestThread): class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify # A thread that performs a lot of updates. It attempts to modify
# more than 25 objects so that it can test code that runs vote # more than 25 objects so that it can test code that runs vote
...@@ -106,6 +122,15 @@ class LargeUpdatesThread(TestThread): ...@@ -106,6 +122,15 @@ class LargeUpdatesThread(TestThread):
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
def _testrun(self):
cn = self.db.open() cn = self.db.open()
while not self.stop.isSet(): while not self.stop.isSet():
try: try:
...@@ -162,7 +187,7 @@ class LargeUpdatesThread(TestThread): ...@@ -162,7 +187,7 @@ class LargeUpdatesThread(TestThread):
self.added_keys = keys_added.keys() self.added_keys = keys_added.keys()
cn.close() cn.close()
class VersionStressThread(TestThread): class VersionStressThread(FailableThread):
def __init__(self, testcase, db, stop, threadnum, commitdict, startnum, def __init__(self, testcase, db, stop, threadnum, commitdict, startnum,
step=2, sleep=None): step=2, sleep=None):
...@@ -177,6 +202,15 @@ class VersionStressThread(TestThread): ...@@ -177,6 +202,15 @@ class VersionStressThread(TestThread):
self.commitdict = commitdict self.commitdict = commitdict
def testrun(self): def testrun(self):
try:
self._testrun()
except:
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
raise
def _testrun(self):
commit = 0 commit = 0
key = self.startnum key = self.startnum
while not self.stop.isSet(): while not self.stop.isSet():
...@@ -302,7 +336,10 @@ class InvalidationTests: ...@@ -302,7 +336,10 @@ class InvalidationTests:
delay = self.MINTIME delay = self.MINTIME
start = time.time() start = time.time()
while time.time() - start <= self.MAXTIME: while time.time() - start <= self.MAXTIME:
time.sleep(delay) stop.wait(delay)
if stop.isSet():
# Some thread failed. Stop right now.
break
delay = 2.0 delay = 2.0
if len(commitdict) >= len(threads): if len(commitdict) >= len(threads):
break break
...@@ -406,6 +443,7 @@ class InvalidationTests: ...@@ -406,6 +443,7 @@ class InvalidationTests:
t1 = VersionStressThread(self, db1, stop, 1, cd, 1, 3) t1 = VersionStressThread(self, db1, stop, 1, cd, 1, 3)
t2 = VersionStressThread(self, db2, stop, 2, cd, 2, 3, 0.01) t2 = VersionStressThread(self, db2, stop, 2, cd, 2, 3, 0.01)
t3 = VersionStressThread(self, db2, stop, 3, cd, 3, 3, 0.01) t3 = VersionStressThread(self, db2, stop, 3, cd, 3, 3, 0.01)
## t1 = VersionStressThread(self, db2, stop, 3, cd, 1, 3, 0.01)
self.go(stop, cd, t1, t2, t3) self.go(stop, cd, t1, t2, t3)
cn.sync() cn.sync()
......
...@@ -28,7 +28,6 @@ from ZEO.StorageServer import StorageServer ...@@ -28,7 +28,6 @@ from ZEO.StorageServer import StorageServer
from ZEO.tests.ConnectionTests import CommonSetupTearDown from ZEO.tests.ConnectionTests import CommonSetupTearDown
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
from ZODB.tests.StorageTestBase import removefs
class AuthTest(CommonSetupTearDown): class AuthTest(CommonSetupTearDown):
__super_getServerConfig = CommonSetupTearDown.getServerConfig __super_getServerConfig = CommonSetupTearDown.getServerConfig
......
...@@ -101,20 +101,12 @@ class GenericTests( ...@@ -101,20 +101,12 @@ class GenericTests(
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase,
# ZODB test mixin classes (in the same order as imported) # ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage, BasicStorage.BasicStorage,
VersionStorage.VersionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
TransactionalUndoVersionStorage.TransactionalUndoVersionStorage,
PackableStorage.PackableStorage, PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage, Synchronization.SynchronizedStorage,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
RevisionStorage.RevisionStorage,
MTStorage.MTStorage, MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage, ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported) # ZEO test mixin classes (in the same order as imported)
Cache.StorageWithCache, CommitLockTests.CommitLockVoteTests,
Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockTests,
ThreadTests.ThreadTests, ThreadTests.ThreadTests,
# Locally defined (see above) # Locally defined (see above)
MiscZEOTests MiscZEOTests
...@@ -167,8 +159,22 @@ class GenericTests( ...@@ -167,8 +159,22 @@ class GenericTests(
key = '%s:%s' % (self._storage._storage, self._storage._server_addr) key = '%s:%s' % (self._storage._storage, self._storage._server_addr)
self.assertEqual(self._storage.sortKey(), key) self.assertEqual(self._storage.sortKey(), key)
class FullGenericTests(
GenericTests,
Cache.StorageWithCache,
Cache.TransUndoStorageWithCache,
CommitLockTests.CommitLockUndoTests,
ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage,
PackableStorage.PackableUndoStorage,
RevisionStorage.RevisionStorage,
TransactionalUndoStorage.TransactionalUndoStorage,
TransactionalUndoVersionStorage.TransactionalUndoVersionStorage,
VersionStorage.VersionStorage,
):
"""Extend GenericTests with tests that MappingStorage can't pass."""
class FileStorageTests(GenericTests): class FileStorageTests(FullGenericTests):
"""Test ZEO backed by a FileStorage.""" """Test ZEO backed by a FileStorage."""
level = 2 level = 2
...@@ -180,7 +186,7 @@ class FileStorageTests(GenericTests): ...@@ -180,7 +186,7 @@ class FileStorageTests(GenericTests):
</filestorage> </filestorage>
""" % filename """ % filename
class BDBTests(FileStorageTests): class BDBTests(FullGenericTests):
"""ZEO backed by a Berkeley full storage.""" """ZEO backed by a Berkeley full storage."""
level = 2 level = 2
...@@ -192,67 +198,14 @@ class BDBTests(FileStorageTests): ...@@ -192,67 +198,14 @@ class BDBTests(FileStorageTests):
</fullstorage> </fullstorage>
""" % self._envdir """ % self._envdir
class MappingStorageTests(FileStorageTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
def getConfig(self): def getConfig(self):
return """<mappingstorage 1/>""" return """<mappingstorage 1/>"""
# Tests which MappingStorage can't possibly pass, because it doesn't # XXX There are still a bunch of tests that fail. Are there
# support versions or undo. # still test classes in GenericTests that shouldn't be there?
def checkVersions(self): pass
def checkVersionedStoreAndLoad(self): pass
def checkVersionedLoadErrors(self): pass
def checkVersionLock(self): pass
def checkVersionEmpty(self): pass
def checkUndoUnresolvable(self): pass
def checkUndoInvalidation(self): pass
def checkUndoInVersion(self): pass
def checkUndoCreationBranch2(self): pass
def checkUndoCreationBranch1(self): pass
def checkUndoConflictResolution(self): pass
def checkUndoCommitVersion(self): pass
def checkUndoAbortVersion(self): pass
def checkPackUndoLog(self): pass
def checkUndoLogMetadata(self): pass
def checkTwoObjectUndoAtOnce(self): pass
def checkTwoObjectUndoAgain(self): pass
def checkTwoObjectUndo(self): pass
def checkTransactionalUndoAfterPackWithObjectUnlinkFromRoot(self): pass
def checkTransactionalUndoAfterPack(self): pass
def checkSimpleTransactionalUndo(self): pass
def checkReadMethods(self): pass
def checkPackAfterUndoDeletion(self): pass
def checkPackAfterUndoManyTimes(self): pass
def checkPackVersions(self): pass
def checkPackUnlinkedFromRoot(self): pass
def checkPackOnlyOneObject(self): pass
def checkPackJustOldRevisions(self): pass
def checkPackEmptyStorage(self): pass
def checkPackAllRevisions(self): pass
def checkPackVersionsInPast(self): pass
def checkPackVersionReachable(self): pass
def checkNotUndoable(self): pass
def checkNewSerialOnCommitVersionToVersion(self): pass
def checkModifyAfterAbortVersion(self): pass
def checkLoadSerial(self): pass
def checkCreateObjectInVersionWithAbort(self): pass
def checkCommitVersionSerialno(self): pass
def checkCommitVersionInvalidation(self): pass
def checkCommitToOtherVersion(self): pass
def checkCommitToNonVersion(self): pass
def checkCommitLockUndoFinish(self): pass
def checkCommitLockUndoClose(self): pass
def checkCommitLockUndoAbort(self): pass
def checkCommitEmptyVersionInvalidation(self): pass
def checkCreationUndoneGetSerial(self): pass
def checkAbortVersionSerialno(self): pass
def checkAbortVersionInvalidation(self): pass
def checkAbortVersionErrors(self): pass
def checkAbortVersion(self): pass
def checkAbortOneVersionCommitTheOther(self): pass
def checkResolve(self): pass
def check4ExtStorageThread(self): pass
test_classes = [FileStorageTests, MappingStorageTests] test_classes = [FileStorageTests, MappingStorageTests]
......
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Basic unit tests for a multi-version client cache."""
import os
import tempfile
import unittest
import ZEO.cache
from ZODB.utils import p64
n1 = p64(1)
n2 = p64(2)
n3 = p64(3)
n4 = p64(4)
n5 = p64(5)
class CacheTests(unittest.TestCase):
def setUp(self):
self.cache = ZEO.cache.ClientCache()
self.cache.open()
def tearDown(self):
if self.cache.path:
os.remove(self.cache.path)
def testLastTid(self):
self.assertEqual(self.cache.getLastTid(), None)
self.cache.setLastTid(n2)
self.assertEqual(self.cache.getLastTid(), n2)
self.cache.invalidate(None, "", n1)
self.assertEqual(self.cache.getLastTid(), n2)
self.cache.invalidate(None, "", n3)
self.assertEqual(self.cache.getLastTid(), n3)
self.assertRaises(ValueError, self.cache.setLastTid, n2)
def testLoad(self):
data1 = "data for n1"
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.load(n1, "version"), None)
self.cache.store(n1, "", n3, None, data1)
self.assertEqual(self.cache.load(n1, ""), (data1, n3, ""))
# The cache doesn't know whether version exists, because it
# only has non-version data.
self.assertEqual(self.cache.load(n1, "version"), None)
self.assertEqual(self.cache.modifiedInVersion(n1), None)
def testInvalidate(self):
data1 = "data for n1"
self.cache.store(n1, "", n3, None, data1)
self.cache.invalidate(n1, "", n4)
self.cache.invalidate(n2, "", n2)
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.loadBefore(n1, n4),
(data1, n3, n4))
def testVersion(self):
data1 = "data for n1"
data1v = "data for n1 in version"
self.cache.store(n1, "version", n3, None, data1v)
self.assertEqual(self.cache.load(n1, ""), None)
self.assertEqual(self.cache.load(n1, "version"),
(data1v, n3, "version"))
self.assertEqual(self.cache.load(n1, "random"), None)
self.assertEqual(self.cache.modifiedInVersion(n1), "version")
self.cache.invalidate(n1, "version", n4)
self.assertEqual(self.cache.load(n1, "version"), None)
def testNonCurrent(self):
data1 = "data for n1"
data2 = "data for n2"
self.cache.store(n1, "", n4, None, data1)
self.cache.store(n1, "", n2, n3, data2)
# can't say anything about state before n2
self.assertEqual(self.cache.loadBefore(n1, n2), None)
# n3 is the upper bound of non-current record n2
self.assertEqual(self.cache.loadBefore(n1, n3), (data2, n2, n3))
# no data for between n2 and n3
self.assertEqual(self.cache.loadBefore(n1, n4), None)
self.cache.invalidate(n1, "", n5)
self.assertEqual(self.cache.loadBefore(n1, n5), (data1, n4, n5))
self.assertEqual(self.cache.loadBefore(n2, n4), None)
def testException(self):
self.assertRaises(ValueError,
self.cache.store,
n1, "version", n2, n3, "data")
self.cache.store(n1, "", n2, None, "data")
self.assertRaises(ValueError,
self.cache.store,
n1, "", n3, None, "data")
def testEviction(self):
# Manually override the current maxsize
maxsize = self.cache.size = self.cache.fc.maxsize = 3395 # 1245
self.cache.fc = ZEO.cache.FileCache(3395, None, self.cache)
# Trivial test of eviction code. Doesn't test non-current
# eviction.
data = ["z" * i for i in range(100)]
for i in range(50):
n = p64(i)
self.cache.store(n, "", n, None, data[i])
self.assertEquals(len(self.cache), i + 1)
self.assert_(self.cache.fc.currentsize < maxsize)
# The cache now uses 1225 bytes. The next insert
# should delete some objects.
n = p64(50)
self.cache.store(n, "", n, None, data[51])
self.assert_(len(self.cache) < 51)
self.assert_(self.cache.fc.currentsize <= maxsize)
# XXX Need to make sure eviction of non-current data
# and of version data are handled correctly.
def testSerialization(self):
self.cache.store(n1, "", n2, None, "data for n1")
self.cache.store(n2, "version", n2, None, "version data for n2")
self.cache.store(n3, "", n3, n4, "non-current data for n3")
self.cache.store(n3, "", n4, n5, "more non-current data for n3")
path = tempfile.mktemp()
# Copy data from self.cache into path, reaching into the cache
# guts to make the copy.
dst = open(path, "wb+")
src = self.cache.fc.f
src.seek(0)
dst.write(src.read(self.cache.fc.maxsize))
dst.close()
copy = ZEO.cache.ClientCache(path)
copy.open()
# Verify that internals of both objects are the same.
# Could also test that external API produces the same results.
eq = self.assertEqual
eq(copy.tid, self.cache.tid)
eq(len(copy), len(self.cache))
eq(copy.version, self.cache.version)
eq(copy.current, self.cache.current)
eq(copy.noncurrent, self.cache.noncurrent)
def test_suite():
return unittest.makeSuite(CacheTests)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
############################################################################## ##############################################################################
"""Handy standard storage machinery """Handy standard storage machinery
$Id: BaseStorage.py,v 1.38 2003/12/23 14:37:13 jeremy Exp $ $Id: BaseStorage.py,v 1.39 2003/12/24 16:02:00 jeremy Exp $
""" """
import cPickle import cPickle
import threading import threading
...@@ -32,7 +32,6 @@ from ZODB.utils import z64 ...@@ -32,7 +32,6 @@ from ZODB.utils import z64
class BaseStorage(UndoLogCompatible): class BaseStorage(UndoLogCompatible):
_transaction=None # Transaction that is being committed _transaction=None # Transaction that is being committed
_serial=z64 # Transaction serial number
_tstatus=' ' # Transaction status, used for copying data _tstatus=' ' # Transaction status, used for copying data
_is_read_only = 0 _is_read_only = 0
...@@ -51,7 +50,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -51,7 +50,7 @@ class BaseStorage(UndoLogCompatible):
t=time.time() t=time.time()
t=self._ts=apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,))) t=self._ts=apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,)))
self._serial=`t` self._tid = `t`
if base is None: if base is None:
self._oid='\0\0\0\0\0\0\0\0' self._oid='\0\0\0\0\0\0\0\0'
else: else:
...@@ -60,16 +59,19 @@ class BaseStorage(UndoLogCompatible): ...@@ -60,16 +59,19 @@ class BaseStorage(UndoLogCompatible):
def abortVersion(self, src, transaction): def abortVersion(self, src, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
return [] return self._tid, []
def commitVersion(self, src, dest, transaction): def commitVersion(self, src, dest, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
return [] return self._tid, []
def close(self): def close(self):
pass pass
def cleanup(self):
pass
def sortKey(self): def sortKey(self):
"""Return a string that can be used to sort storage instances. """Return a string that can be used to sort storage instances.
...@@ -85,7 +87,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -85,7 +87,7 @@ class BaseStorage(UndoLogCompatible):
def getSize(self): def getSize(self):
return len(self)*300 # WAG! return len(self)*300 # WAG!
def history(self, oid, version, length=1): def history(self, oid, version, length=1, filter=None):
pass pass
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
...@@ -167,13 +169,13 @@ class BaseStorage(UndoLogCompatible): ...@@ -167,13 +169,13 @@ class BaseStorage(UndoLogCompatible):
now = time.time() now = time.time()
t = TimeStamp(*(time.gmtime(now)[:5] + (now % 60,))) t = TimeStamp(*(time.gmtime(now)[:5] + (now % 60,)))
self._ts = t = t.laterThan(self._ts) self._ts = t = t.laterThan(self._ts)
self._serial = `t` self._tid = `t`
else: else:
self._ts = TimeStamp(tid) self._ts = TimeStamp(tid)
self._serial = tid self._tid = tid
self._tstatus = status self._tstatus = status
self._begin(self._serial, user, desc, ext) self._begin(self._tid, user, desc, ext)
finally: finally:
self._lock_release() self._lock_release()
...@@ -203,10 +205,11 @@ class BaseStorage(UndoLogCompatible): ...@@ -203,10 +205,11 @@ class BaseStorage(UndoLogCompatible):
return return
try: try:
if f is not None: if f is not None:
f() f(self._tid)
u, d, e = self._ude u, d, e = self._ude
self._finish(self._serial, u, d, e) self._finish(self._tid, u, d, e)
self._clear_temp() self._clear_temp()
return self._tid
finally: finally:
self._ude = None self._ude = None
self._transaction = None self._transaction = None
...@@ -250,6 +253,48 @@ class BaseStorage(UndoLogCompatible): ...@@ -250,6 +253,48 @@ class BaseStorage(UndoLogCompatible):
raise POSException.Unsupported, ( raise POSException.Unsupported, (
"Retrieval of historical revisions is not supported") "Retrieval of historical revisions is not supported")
def loadBefore(self, oid, tid):
"""Return most recent revision of oid before tid committed."""
# XXX Is it okay for loadBefore() to return current data?
# There doesn't seem to be a good reason to forbid it, even
# though the typical use of this method will never find
# current data. But maybe we should call it loadByTid()?
n = 2
start_time = None
end_time = None
while start_time is None:
# The history() approach is a hack, because the dict
# returned by history() doesn't contain a tid. It
# contains a serialno, which is often the same, but isn't
# required to be. We'll pretend it is for now.
# A second problem is that history() doesn't say anything
# about whether the transaction status. If it falls before
# the pack time, we can't honor the MVCC request.
# Note: history() returns the most recent record first.
# XXX The filter argument to history() only appears to be
# supported by FileStorage. Perhaps it shouldn't be used.
L = self.history(oid, "", n, lambda d: not d["version"])
if not L:
return
for d in L:
if d["serial"] < tid:
start_time = d["serial"]
break
else:
end_time = d["serial"]
if len(L) < n:
break
n *= 2
if start_time is None:
return None
data = self.loadSerial(oid, start_time)
return data, start_time, end_time
def getExtensionMethods(self): def getExtensionMethods(self):
"""getExtensionMethods """getExtensionMethods
...@@ -314,7 +359,7 @@ class BaseStorage(UndoLogCompatible): ...@@ -314,7 +359,7 @@ class BaseStorage(UndoLogCompatible):
oid=r.oid oid=r.oid
if verbose: print oid_repr(oid), r.version, len(r.data) if verbose: print oid_repr(oid), r.version, len(r.data)
if restoring: if restoring:
self.restore(oid, r.serial, r.data, r.version, self.restore(oid, r.tid, r.data, r.version,
r.data_txn, transaction) r.data_txn, transaction)
else: else:
pre=preget(oid, None) pre=preget(oid, None)
......
...@@ -13,11 +13,19 @@ ...@@ -13,11 +13,19 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.104 2003/12/10 20:02:15 shane Exp $""" $Id: Connection.py,v 1.105 2003/12/24 16:02:00 jeremy Exp $"""
import logging
import sys import sys
import threading import threading
from time import time from time import time
from types import ClassType
_marker = object()
def myhasattr(obj, attr):
# builtin hasattr() swallows exceptions
return getattr(obj, attr, _marker) is not _marker
from persistent import PickleCache from persistent import PickleCache
from zLOG import LOG, ERROR, BLATHER, WARNING from zLOG import LOG, ERROR, BLATHER, WARNING
...@@ -56,16 +64,19 @@ class Connection(ExportImport, object): ...@@ -56,16 +64,19 @@ class Connection(ExportImport, object):
The Connection manages movement of objects in and out of object storage. The Connection manages movement of objects in and out of object storage.
""" """
_tmp=None _tmp = None
_debug_info=() _debug_info = ()
_opened=None _opened = None
_reset_counter = 0 _code_timestamp = 0
_transaction = None _transaction = None
def __init__(self, version='', cache_size=400, def __init__(self, version='', cache_size=400,
cache_deactivate_after=60): cache_deactivate_after=60, mvcc=True):
"""Create a new Connection""" """Create a new Connection"""
self._version=version
self._log = logging.getLogger("zodb.conn")
self._version = version
self._cache = cache = PickleCache(self, cache_size) self._cache = cache = PickleCache(self, cache_size)
if version: if version:
# Caches for versions end up empty if the version # Caches for versions end up empty if the version
...@@ -97,6 +108,16 @@ class Connection(ExportImport, object): ...@@ -97,6 +108,16 @@ class Connection(ExportImport, object):
self._invalidated = d = {} self._invalidated = d = {}
self._invalid = d.has_key self._invalid = d.has_key
self._conflicts = {} self._conflicts = {}
self._noncurrent = {}
# If MVCC is enabled, then _mvcc is True and _txn_time stores
# the upper bound on transactions visible to this connection.
# That is, all object revisions must be written before _txn_time.
# If it is None, then the current revisions are acceptable.
# If the connection is in a version, mvcc will be disabled, because
# loadBefore() only returns non-version data.
self._mvcc = mvcc and not version
self._txn_time = None
def getTransaction(self): def getTransaction(self):
t = self._transaction t = self._transaction
...@@ -216,11 +237,12 @@ class Connection(ExportImport, object): ...@@ -216,11 +237,12 @@ class Connection(ExportImport, object):
# Call the close callbacks. # Call the close callbacks.
if self.__onCloseCallbacks is not None: if self.__onCloseCallbacks is not None:
for f in self.__onCloseCallbacks: for f in self.__onCloseCallbacks:
try: f() try:
except: f()
f=getattr(f, 'im_self', f) except: # except what?
LOG('ZODB',ERROR, 'Close callback failed for %s' % f, f = getattr(f, 'im_self', f)
error=sys.exc_info()) self._log.error("Close callback failed for %s", f,
sys.exc_info())
self.__onCloseCallbacks = None self.__onCloseCallbacks = None
self._storage = self._tmp = self.new_oid = self._opened = None self._storage = self._tmp = self.new_oid = self._opened = None
self._debug_info = () self._debug_info = ()
...@@ -303,8 +325,8 @@ class Connection(ExportImport, object): ...@@ -303,8 +325,8 @@ class Connection(ExportImport, object):
if tmp is None: return if tmp is None: return
src=self._storage src=self._storage
LOG('ZODB', BLATHER, self._log.debug("Commiting subtransaction of size %s",
'Commiting subtransaction of size %s' % src.getSize()) src.getSize())
self._storage=tmp self._storage=tmp
self._tmp=None self._tmp=None
...@@ -363,7 +385,7 @@ class Connection(ExportImport, object): ...@@ -363,7 +385,7 @@ class Connection(ExportImport, object):
def isReadOnly(self): def isReadOnly(self):
return self._storage.isReadOnly() return self._storage.isReadOnly()
def invalidate(self, oids): def invalidate(self, tid, oids):
"""Invalidate a set of oids. """Invalidate a set of oids.
This marks the oid as invalid, but doesn't actually invalidate This marks the oid as invalid, but doesn't actually invalidate
...@@ -372,6 +394,8 @@ class Connection(ExportImport, object): ...@@ -372,6 +394,8 @@ class Connection(ExportImport, object):
""" """
self._inv_lock.acquire() self._inv_lock.acquire()
try: try:
if self._txn_time is None:
self._txn_time = tid
self._invalidated.update(oids) self._invalidated.update(oids)
finally: finally:
self._inv_lock.release() self._inv_lock.release()
...@@ -381,13 +405,15 @@ class Connection(ExportImport, object): ...@@ -381,13 +405,15 @@ class Connection(ExportImport, object):
try: try:
self._cache.invalidate(self._invalidated) self._cache.invalidate(self._invalidated)
self._invalidated.clear() self._invalidated.clear()
self._txn_time = None
finally: finally:
self._inv_lock.release() self._inv_lock.release()
# Now is a good time to collect some garbage # Now is a good time to collect some garbage
self._cache.incrgc() self._cache.incrgc()
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
try: return self._db.modifiedInVersion(oid) try:
return self._db.modifiedInVersion(oid)
except KeyError: except KeyError:
return self._version return self._version
...@@ -411,54 +437,94 @@ class Connection(ExportImport, object): ...@@ -411,54 +437,94 @@ class Connection(ExportImport, object):
if self._storage is None: if self._storage is None:
msg = ("Shouldn't load state for %s " msg = ("Shouldn't load state for %s "
"when the connection is closed" % oid_repr(oid)) "when the connection is closed" % oid_repr(oid))
LOG('ZODB', ERROR, msg) self._log.error(msg)
raise RuntimeError(msg) raise RuntimeError(msg)
try: try:
# Avoid reading data from a transaction that committed self._setstate(obj)
# after the current transaction started, as that might
# lead to mixing of cached data from earlier transactions
# and new inconsistent data.
#
# Wait for check until after data is loaded from storage
# to avoid time-of-check to time-of-use race.
p, serial = self._storage.load(oid, self._version)
self._load_count = self._load_count + 1
invalid = self._is_invalidated(obj)
self._reader.setGhostState(obj, p)
obj._p_serial = serial
if invalid:
self._handle_independent(obj)
except ConflictError: except ConflictError:
raise raise
except: except:
LOG('ZODB', ERROR, self._log.error("Couldn't load state for %s", oid_repr(oid),
"Couldn't load state for %s" % oid_repr(oid), exc_info=sys.exc_info())
error=sys.exc_info())
raise raise
def _is_invalidated(self, obj): def _setstate(self, obj):
# Helper method for setstate() covers three cases: # Helper for setstate(), which provides logging of failures.
# returns false if obj is valid
# returns true if obj was invalidation, but is independent # The control flow is complicated here to avoid loading an
# otherwise, raises ConflictError for invalidated objects # object revision that we are sure we aren't going to use. As
# a result, invalidation tests occur before and after the
# load. We can only be sure about invalidations after the
# load.
# If an object has been invalidated, there are several cases
# to consider:
# 1. Check _p_independent()
# 2. Try MVCC
# 3. Raise ConflictError.
# Does anything actually use _p_independent()? It would simplify
# the code if we could drop support for it.
# There is a harmless data race with self._invalidated. A
# dict update could go on in another thread, but we don't care
# because we have to check again after the load anyway.
if (obj._p_oid in self._invalidated
and not myhasattr(obj, "_p_independent")):
# If the object has _p_independent(), we will handle it below.
if not (self._mvcc and self._setstate_noncurrent(obj)):
self.getTransaction().register(obj)
self._conflicts[obj._p_oid] = 1
raise ReadConflictError(object=obj)
p, serial = self._storage.load(obj._p_oid, self._version)
self._load_count += 1
self._inv_lock.acquire() self._inv_lock.acquire()
try: try:
if self._invalidated.has_key(obj._p_oid): invalid = obj._p_oid in self._invalidated
# Defer _p_independent() call until state is loaded.
ind = getattr(obj, "_p_independent", None)
if ind is not None:
# Defer _p_independent() call until state is loaded.
return 1
else:
self.getTransaction().register(obj)
self._conflicts[obj._p_oid] = 1
raise ReadConflictError(object=obj)
else:
return 0
finally: finally:
self._inv_lock.release() self._inv_lock.release()
if invalid:
if myhasattr(obj, "_p_independent"):
# This call will raise a ReadConflictError if something
# goes wrong
self._handle_independent(obj)
elif not (self._mvcc and self._setstate_noncurrent(obj)):
self.getTransaction().register(obj)
self._conflicts[obj._p_oid] = 1
raise ReadConflictError(object=obj)
self._reader.setGhostState(obj, p)
obj._p_serial = serial
def _setstate_noncurrent(self, obj):
"""Set state using non-current data.
Return True if state was available, False if not.
"""
try:
# Load data that was current before the commit at txn_time.
t = self._storage.loadBefore(obj._p_oid, self._txn_time)
except KeyError:
return False
if t is None:
return False
data, start, end = t
# The non-current transaction must have been written before
# txn_time. It must be current at txn_time, but could have
# been modified at txn_time.
# It's possible that end is None, if, e.g., the most recent
# invalidation was for version data.
assert start < self._txn_time <= end, \
(U64(start), U64(self._txn_time), U64(end))
self._noncurrent[obj._p_oid] = True
self._reader.setGhostState(obj, data)
obj._p_serial = start
def _handle_independent(self, obj): def _handle_independent(self, obj):
# Helper method for setstate() handles possibly independent objects # Helper method for setstate() handles possibly independent objects
# Call _p_independent(), if it returns True, setstate() wins. # Call _p_independent(), if it returns True, setstate() wins.
...@@ -499,7 +565,7 @@ class Connection(ExportImport, object): ...@@ -499,7 +565,7 @@ class Connection(ExportImport, object):
obj._p_changed = 0 obj._p_changed = 0
obj._p_serial = serial obj._p_serial = serial
except: except:
LOG('ZODB',ERROR, 'setklassstate failed', error=sys.exc_info()) self._log.error("setklassstate failed", exc_info=sys.exc_info())
raise raise
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
...@@ -590,11 +656,11 @@ class Connection(ExportImport, object): ...@@ -590,11 +656,11 @@ class Connection(ExportImport, object):
self._storage._creating[:0]=self._creating self._storage._creating[:0]=self._creating
del self._creating[:] del self._creating[:]
else: else:
def callback(): def callback(tid):
d = {} d = {}
for oid in self._modified: for oid in self._modified:
d[oid] = 1 d[oid] = 1
self._db.invalidate(d, self) self._db.invalidate(tid, d, self)
self._storage.tpc_finish(transaction, callback) self._storage.tpc_finish(transaction, callback)
self._conflicts.clear() self._conflicts.clear()
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
############################################################################## ##############################################################################
"""Database objects """Database objects
$Id: DB.py,v 1.57 2003/11/28 16:44:49 jim Exp $""" $Id: DB.py,v 1.58 2003/12/24 16:02:00 jeremy Exp $"""
__version__='$Revision: 1.57 $'[11:-2] __version__='$Revision: 1.58 $'[11:-2]
import cPickle, cStringIO, sys, POSException, UndoLogCompatible import cPickle, cStringIO, sys, POSException, UndoLogCompatible
from Connection import Connection from Connection import Connection
...@@ -74,7 +74,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -74,7 +74,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
self._version_cache_size=version_cache_size self._version_cache_size=version_cache_size
self._version_cache_deactivate_after = version_cache_deactivate_after self._version_cache_deactivate_after = version_cache_deactivate_after
self._miv_cache={} self._miv_cache = {}
# Setup storage # Setup storage
self._storage=storage self._storage=storage
...@@ -300,8 +300,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -300,8 +300,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
def importFile(self, file): def importFile(self, file):
raise NotImplementedError raise NotImplementedError
def invalidate(self, oids, connection=None, version='', def invalidate(self, tid, oids, connection=None, version=''):
rc=sys.getrefcount):
"""Invalidate references to a given oid. """Invalidate references to a given oid.
This is used to indicate that one of the connections has committed a This is used to indicate that one of the connections has committed a
...@@ -323,21 +322,21 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -323,21 +322,21 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
for cc in allocated: for cc in allocated:
if (cc is not connection and if (cc is not connection and
(not version or cc._version==version)): (not version or cc._version==version)):
if rc(cc) <= 3: if sys.getrefcount(cc) <= 3:
cc.close() cc.close()
cc.invalidate(oids) cc.invalidate(tid, oids)
temps=self._temps if self._temps:
if temps:
t=[] t=[]
for cc in temps: for cc in self._temps:
if rc(cc) > 3: if sys.getrefcount(cc) > 3:
if (cc is not connection and if (cc is not connection and
(not version or cc._version==version)): (not version or cc._version == version)):
cc.invalidate(oids) cc.invalidate(tid, oids)
t.append(cc) t.append(cc)
else: cc.close() else:
self._temps=t cc.close()
self._temps = t
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
h=hash(oid)%131 h=hash(oid)%131
...@@ -353,7 +352,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -353,7 +352,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
return len(self._storage) return len(self._storage)
def open(self, version='', transaction=None, temporary=0, force=None, def open(self, version='', transaction=None, temporary=0, force=None,
waitflag=1): waitflag=1, mvcc=True):
"""Return a object space (AKA connection) to work in """Return a object space (AKA connection) to work in
The optional version argument can be used to specify that a The optional version argument can be used to specify that a
...@@ -371,25 +370,25 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -371,25 +370,25 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
try: try:
if transaction is not None: if transaction is not None:
connections=transaction._connections connections = transaction._connections
if connections: if connections:
if connections.has_key(version) and not temporary: if connections.has_key(version) and not temporary:
return connections[version] return connections[version]
else: else:
transaction._connections=connections={} transaction._connections = connections = {}
transaction=transaction._connections transaction = transaction._connections
if temporary: if temporary:
# This is a temporary connection. # This is a temporary connection.
# We won't bother with the pools. This will be # We won't bother with the pools. This will be
# a one-use connection. # a one-use connection.
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._version_cache_size,
cache_size=self._version_cache_size) mvcc=mvcc)
c._setDB(self) c._setDB(self)
self._temps.append(c) self._temps.append(c)
if transaction is not None: transaction[id(c)]=c if transaction is not None:
transaction[id(c)] = c
return c return c
...@@ -430,18 +429,18 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -430,18 +429,18 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
if not pool: if not pool:
c=None c = None
if version: if version:
if self._version_pool_size > len(allocated) or force: if self._version_pool_size > len(allocated) or force:
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._version_cache_size,
cache_size=self._version_cache_size) mvcc=mvcc)
allocated.append(c) allocated.append(c)
pool.append(c) pool.append(c)
elif self._pool_size > len(allocated) or force: elif self._pool_size > len(allocated) or force:
c=self.klass( c = self.klass(version=version,
version=version, cache_size=self._cache_size,
cache_size=self._cache_size) mvcc=mvcc)
allocated.append(c) allocated.append(c)
pool.append(c) pool.append(c)
...@@ -456,7 +455,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -456,7 +455,7 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
pool_lock.release() pool_lock.release()
else: return else: return
elif len(pool)==1: elif len(pool) == 1:
# Taking last one, lock the pool # Taking last one, lock the pool
# Note that another thread might grab the lock # Note that another thread might grab the lock
# before us, so we might actually block, however, # before us, so we might actually block, however,
...@@ -470,14 +469,15 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -470,14 +469,15 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
# but it could be higher due to a race condition. # but it could be higher due to a race condition.
pool_lock.release() pool_lock.release()
c=pool[-1] c = pool[-1]
del pool[-1] del pool[-1]
c._setDB(self) c._setDB(self)
for pool, allocated in pooll: for pool, allocated in pooll:
for cc in pool: for cc in pool:
cc._incrgc() cc._incrgc()
if transaction is not None: transaction[version]=c if transaction is not None:
transaction[version] = c
return c return c
finally: self._r() finally: self._r()
...@@ -588,7 +588,8 @@ class DB(UndoLogCompatible.UndoLogCompatible, object): ...@@ -588,7 +588,8 @@ class DB(UndoLogCompatible.UndoLogCompatible, object):
d = {} d = {}
for oid in storage.undo(id): for oid in storage.undo(id):
d[oid] = 1 d[oid] = 1
self.invalidate(d) # XXX I think we need to remove old undo to use mvcc
self.invalidate(None, d)
def versionEmpty(self, version): def versionEmpty(self, version):
return self._storage.versionEmpty(version) return self._storage.versionEmpty(version)
...@@ -616,13 +617,13 @@ class CommitVersion: ...@@ -616,13 +617,13 @@ class CommitVersion:
def commit(self, reallyme, t): def commit(self, reallyme, t):
dest=self._dest dest=self._dest
oids = self._db._storage.commitVersion(self._version, dest, t) tid, oids = self._db._storage.commitVersion(self._version, dest, t)
oids = list2dict(oids) oids = list2dict(oids)
self._db.invalidate(oids, version=dest) self._db.invalidate(tid, oids, version=dest)
if dest: if dest:
# the code above just invalidated the dest version. # the code above just invalidated the dest version.
# now we need to invalidate the source! # now we need to invalidate the source!
self._db.invalidate(oids, version=self._version) self._db.invalidate(tid, oids, version=self._version)
class AbortVersion(CommitVersion): class AbortVersion(CommitVersion):
"""An object that will see to version abortion """An object that will see to version abortion
...@@ -631,9 +632,9 @@ class AbortVersion(CommitVersion): ...@@ -631,9 +632,9 @@ class AbortVersion(CommitVersion):
""" """
def commit(self, reallyme, t): def commit(self, reallyme, t):
version=self._version version = self._version
oids = self._db._storage.abortVersion(version, t) tid, oids = self._db._storage.abortVersion(version, t)
self._db.invalidate(list2dict(oids), version=version) self._db.invalidate(tid, list2dict(oids), version=version)
class TransactionalUndo(CommitVersion): class TransactionalUndo(CommitVersion):
...@@ -647,5 +648,5 @@ class TransactionalUndo(CommitVersion): ...@@ -647,5 +648,5 @@ class TransactionalUndo(CommitVersion):
# similarity of rhythm that I think it's justified. # similarity of rhythm that I think it's justified.
def commit(self, reallyme, t): def commit(self, reallyme, t):
oids = self._db._storage.transactionalUndo(self._version, t) tid, oids = self._db._storage.transactionalUndo(self._version, t)
self._db.invalidate(list2dict(oids)) self._db.invalidate(tid, list2dict(oids))
...@@ -45,14 +45,12 @@ There are three main data structures: ...@@ -45,14 +45,12 @@ There are three main data structures:
A record is a tuple: A record is a tuple:
oid, serial, pre, vdata, p, oid, pre, vdata, p, tid
where: where:
oid -- object id oid -- object id
serial -- object serial number
pre -- The previous record for this object (or None) pre -- The previous record for this object (or None)
vdata -- version data vdata -- version data
...@@ -62,6 +60,8 @@ where: ...@@ -62,6 +60,8 @@ where:
p -- the pickle data or None p -- the pickle data or None
tid -- the transaction id that wrote the record
The pickle data will be None for a record for an object created in The pickle data will be None for a record for an object created in
an aborted version. an aborted version.
...@@ -79,7 +79,7 @@ method:: ...@@ -79,7 +79,7 @@ method::
and call it to monitor the storage. and call it to monitor the storage.
""" """
__version__='$Revision: 1.22 $'[11:-2] __version__='$Revision: 1.23 $'[11:-2]
import base64, time, string import base64, time, string
from ZODB import POSException, BaseStorage, utils from ZODB import POSException, BaseStorage, utils
...@@ -93,12 +93,13 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -93,12 +93,13 @@ class DemoStorage(BaseStorage.BaseStorage):
BaseStorage.BaseStorage.__init__(self, name, base) BaseStorage.BaseStorage.__init__(self, name, base)
# We use a BTree because the items are sorted! # We use a BTree because the items are sorted!
self._data=OOBTree.OOBTree() self._data = OOBTree.OOBTree()
self._index={} self._index = {}
self._vindex={} self._vindex = {}
self._base=base self._base = base
self._size=0 self._size = 0
self._quota=quota self._quota = quota
self._ltid = None
self._clear_temp() self._clear_temp()
if base is not None and base.versions(): if base is not None and base.versions():
raise POSException.StorageError, ( raise POSException.StorageError, (
...@@ -113,7 +114,7 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -113,7 +114,7 @@ class DemoStorage(BaseStorage.BaseStorage):
s=100 s=100
for tid, (p, u, d, e, t) in self._data.items(): for tid, (p, u, d, e, t) in self._data.items():
s=s+16+24+12+4+16+len(u)+16+len(d)+16+len(e)+16 s=s+16+24+12+4+16+len(u)+16+len(d)+16+len(e)+16
for oid, serial, pre, vdata, p in t: for oid, pre, vdata, p, tid in t:
s=s+16+24+24+4+4+(p and (16+len(p)) or 4) s=s+16+24+24+4+4+(p and (16+len(p)) or 4)
if vdata: s=s+12+16+len(vdata[0])+4 if vdata: s=s+12+16+len(vdata[0])+4
...@@ -139,16 +140,16 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -139,16 +140,16 @@ class DemoStorage(BaseStorage.BaseStorage):
oids = [] oids = []
for r in v.values(): for r in v.values():
oid, serial, pre, (version, nv), p = r oid, pre, (version, nv), p, tid = r
oids.append(oid) oids.append(oid)
if nv: if nv:
oid, serial, pre, vdata, p = nv oid, pre, vdata, p, tid = nv
self._tindex.append([oid, serial, r, None, p]) self._tindex.append([oid, r, None, p, self._tid])
else: else:
# effectively, delete the thing # effectively, delete the thing
self._tindex.append([oid, None, r, None, None]) self._tindex.append([oid, r, None, None, self._tid])
return oids return self._tid, oids
finally: self._lock_release() finally: self._lock_release()
...@@ -168,53 +169,60 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -168,53 +169,60 @@ class DemoStorage(BaseStorage.BaseStorage):
if v is None: if v is None:
return return
newserial = self._serial newserial = self._tid
tindex = self._tindex tindex = self._tindex
oids = [] oids = []
for r in v.values(): for r in v.values():
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
assert vdata is not None assert vdata is not None
oids.append(oid) oids.append(oid)
if dest: if dest:
new_vdata = dest, vdata[1] new_vdata = dest, vdata[1]
else: else:
new_vdata = None new_vdata = None
tindex.append([oid, newserial, r, new_vdata, p]) tindex.append([oid, r, new_vdata, p, self._tid])
return oids return self._tid, oids
finally: finally:
self._lock_release() self._lock_release()
def load(self, oid, version): def loadEx(self, oid, version):
self._lock_acquire() self._lock_acquire()
try: try:
try: try:
oid, serial, pre, vdata, p = self._index[oid] oid, pre, vdata, p, tid = self._index[oid]
except KeyError: except KeyError:
if self._base: if self._base:
return self._base.load(oid, '') return self._base.load(oid, '')
raise KeyError, oid raise KeyError, oid
ver = ""
if vdata: if vdata:
oversion, nv = vdata oversion, nv = vdata
if oversion != version: if oversion != version:
if nv: if nv:
oid, serial, pre, vdata, p = nv # Return the current txn's tid with the non-version
# data.
oid, pre, vdata, p, skiptid = nv
else: else:
raise KeyError, oid raise KeyError, oid
ver = oversion
if p is None: if p is None:
raise KeyError, oid raise KeyError, oid
return p, serial return p, tid, ver
finally: self._lock_release() finally: self._lock_release()
def load(self, oid, version):
return self.loadEx(oid, version)[:2]
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
self._lock_acquire() self._lock_acquire()
try: try:
try: try:
oid, serial, pre, vdata, p = self._index[oid] oid, pre, vdata, p, tid = self._index[oid]
if vdata: return vdata[0] if vdata: return vdata[0]
return '' return ''
except: return '' except: return ''
...@@ -231,15 +239,15 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -231,15 +239,15 @@ class DemoStorage(BaseStorage.BaseStorage):
# Hm, nothing here, check the base version: # Hm, nothing here, check the base version:
if self._base: if self._base:
try: try:
p, oserial = self._base.load(oid, '') p, tid = self._base.load(oid, '')
except KeyError: except KeyError:
pass pass
else: else:
old = oid, oserial, None, None, p old = oid, None, None, p, tid
nv=None nv=None
if old: if old:
oid, oserial, pre, vdata, p = old oid, pre, vdata, p, tid = old
if vdata: if vdata:
if vdata[0] != version: if vdata[0] != version:
...@@ -249,12 +257,11 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -249,12 +257,11 @@ class DemoStorage(BaseStorage.BaseStorage):
else: else:
nv=old nv=old
if serial != oserial: if serial != tid:
raise POSException.ConflictError( raise POSException.ConflictError(
oid=oid, serials=(oserial, serial), data=data) oid=oid, serials=(tid, serial), data=data)
serial=self._serial r = [oid, old, version and (version, nv) or None, data, self._tid]
r=[oid, serial, old, version and (version, nv) or None, data]
self._tindex.append(r) self._tindex.append(r)
s=self._tsize s=self._tsize
...@@ -268,15 +275,21 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -268,15 +275,21 @@ class DemoStorage(BaseStorage.BaseStorage):
has been exceeded.<br>Have a nice day.''') has been exceeded.<br>Have a nice day.''')
finally: self._lock_release() finally: self._lock_release()
return serial return self._tid
def supportsUndo(self): return 1 def supportsUndo(self):
def supportsVersions(self): return 1 return 1
def supportsVersions(self):
return 1
def _clear_temp(self): def _clear_temp(self):
self._tindex = [] self._tindex = []
self._tsize = self._size + 160 self._tsize = self._size + 160
def lastTransaction(self):
return self._ltid
def _begin(self, tid, u, d, e): def _begin(self, tid, u, d, e):
self._tsize = self._size + 120 + len(u) + len(d) + len(e) self._tsize = self._size + 120 + len(u) + len(d) + len(e)
...@@ -285,11 +298,11 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -285,11 +298,11 @@ class DemoStorage(BaseStorage.BaseStorage):
self._data[tid] = None, user, desc, ext, tuple(self._tindex) self._data[tid] = None, user, desc, ext, tuple(self._tindex)
for r in self._tindex: for r in self._tindex:
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
old = self._index.get(oid) old = self._index.get(oid)
# If the object had version data, remove the version data. # If the object had version data, remove the version data.
if old is not None: if old is not None:
oldvdata = old[3] oldvdata = old[2]
if oldvdata: if oldvdata:
v = self._vindex[oldvdata[0]] v = self._vindex[oldvdata[0]]
del v[oid] del v[oid]
...@@ -306,6 +319,7 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -306,6 +319,7 @@ class DemoStorage(BaseStorage.BaseStorage):
if v is None: if v is None:
v = self._vindex[version] = {} v = self._vindex[version] = {}
v[oid] = r v[oid] = r
self._ltid = self._tid
def undo(self, transaction_id): def undo(self, transaction_id):
self._lock_acquire() self._lock_acquire()
...@@ -324,7 +338,7 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -324,7 +338,7 @@ class DemoStorage(BaseStorage.BaseStorage):
oids=[] oids=[]
for r in t: for r in t:
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
if pre: if pre:
index[oid] = pre index[oid] = pre
...@@ -337,7 +351,7 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -337,7 +351,7 @@ class DemoStorage(BaseStorage.BaseStorage):
if v: del v[oid] if v: del v[oid]
# Add new version data (from pre): # Add new version data (from pre):
oid, serial, prepre, vdata, p = pre oid, prepre, vdata, p, tid = pre
if vdata: if vdata:
version=vdata[0] version=vdata[0]
v=vindex.get(version, None) v=vindex.get(version, None)
...@@ -404,17 +418,17 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -404,17 +418,17 @@ class DemoStorage(BaseStorage.BaseStorage):
def _build_indexes(self, stop='\377\377\377\377\377\377\377\377'): def _build_indexes(self, stop='\377\377\377\377\377\377\377\377'):
# Rebuild index structures from transaction data # Rebuild index structures from transaction data
index={} index = {}
vindex={} vindex = {}
_data=self._data for tid, (p, u, d, e, t) in self._data.items():
for tid, (p, u, d, e, t) in _data.items(): if tid >= stop:
if tid >= stop: break break
for r in t: for r in t:
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
old=index.get(oid, None) old=index.get(oid, None)
if old is not None: if old is not None:
oldvdata=old[3] oldvdata=old[2]
if oldvdata: if oldvdata:
v=vindex[oldvdata[0]] v=vindex[oldvdata[0]]
del v[oid] del v[oid]
...@@ -439,54 +453,56 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -439,54 +453,56 @@ class DemoStorage(BaseStorage.BaseStorage):
try: try:
stop=`TimeStamp(*time.gmtime(t)[:5]+(t%60,))` stop=`TimeStamp(*time.gmtime(t)[:5]+(t%60,))`
_data=self._data
# Build indexes up to the pack time: # Build indexes up to the pack time:
index, vindex = self._build_indexes(stop) index, vindex = self._build_indexes(stop)
# Now build an index of *only* those objects reachable # Now build an index of *only* those objects reachable
# from the root. # from the root.
rootl=['\0\0\0\0\0\0\0\0'] rootl = ['\0\0\0\0\0\0\0\0']
pop=rootl.pop pindex = {}
pindex={}
referenced=pindex.has_key
while rootl: while rootl:
oid=pop() oid = rootl.pop()
if referenced(oid): continue if oid in pindex:
continue
# Scan non-version pickle for references # Scan non-version pickle for references
r=index.get(oid, None) r = index.get(oid, None)
if r is None: if r is None:
if self._base: if self._base:
p, s = self._base.load(oid, '') p, s = self._base.load(oid, '')
referencesf(p, rootl) referencesf(p, rootl)
else: else:
pindex[oid]=r pindex[oid] = r
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
referencesf(p, rootl) referencesf(p, rootl)
if vdata: if vdata:
nv=vdata[1] nv = vdata[1]
if nv: if nv:
oid, serial, pre, vdata, p = nv oid, pre, vdata, p, tid = nv
referencesf(p, rootl) referencesf(p, rootl)
# Now we're ready to do the actual packing. # Now we're ready to do the actual packing.
# We'll simply edit the transaction data in place. # We'll simply edit the transaction data in place.
# We'll defer deleting transactions till the end # We'll defer deleting transactions till the end
# to avoid messing up the BTree items. # to avoid messing up the BTree items.
deleted=[] deleted = []
for tid, (p, u, d, e, t) in _data.items(): for tid, (p, u, d, e, records) in self._data.items():
if tid >= stop: break if tid >= stop:
o=[] break
for r in t: o = []
c=pindex.get(r[0]) for r in records:
c = pindex.get(r[0])
if c is None: if c is None:
# GC this record, no longer referenced # GC this record, no longer referenced
continue continue
elif c is not r: if c == r:
# This is the most recent revision.
o.append(r)
else:
# This record is not the indexed record, # This record is not the indexed record,
# so it may not be current. Let's see. # so it may not be current. Let's see.
oid, serial, pre, vdata, p = r vdata = r[3]
if vdata: if vdata:
# Version record are current *only* if they # Version record are current *only* if they
# are indexed # are indexed
...@@ -494,7 +510,7 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -494,7 +510,7 @@ class DemoStorage(BaseStorage.BaseStorage):
else: else:
# OK, this isn't a version record, so it may be the # OK, this isn't a version record, so it may be the
# non-version record for the indexed record. # non-version record for the indexed record.
oid, serial, pre, vdata, p = c vdata = c[3]
if vdata: if vdata:
if vdata[1] != r: if vdata[1] != r:
# This record is not the non-version # This record is not the non-version
...@@ -505,25 +521,25 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -505,25 +521,25 @@ class DemoStorage(BaseStorage.BaseStorage):
# so this record can not be the non-version # so this record can not be the non-version
# record for it. # record for it.
continue continue
o.append(r) o.append(r)
if o: if o:
if len(o) != len(t): if len(o) != len(records):
_data[tid] = 1, u, d, e, tuple(o) # Reset data self._data[tid] = 1, u, d, e, tuple(o) # Reset data
else: else:
deleted.append(tid) deleted.append(tid)
# Now delete empty transactions # Now delete empty transactions
for tid in deleted: for tid in deleted:
del _data[tid] del self._data[tid]
# Now reset previous pointers for "current" records: # Now reset previous pointers for "current" records:
for r in pindex.values(): for r in pindex.values():
r[2] = None # Previous record r[1] = None # Previous record
if r[3] and r[3][1]: # vdata if r[2] and r[2][1]: # vdata
# If this record contains version data and # If this record contains version data and
# non-version data, then clear it out. # non-version data, then clear it out.
r[3][1][2] = None r[2][1][2] = None
# Finally, rebuild indexes from transaction data: # Finally, rebuild indexes from transaction data:
self._index, self._vindex = self._build_indexes() self._index, self._vindex = self._build_indexes()
...@@ -541,21 +557,22 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -541,21 +557,22 @@ class DemoStorage(BaseStorage.BaseStorage):
for tid, (p, u, d, e, t) in self._data.items(): for tid, (p, u, d, e, t) in self._data.items():
o.append(" %s %s" % (TimeStamp(tid), p)) o.append(" %s %s" % (TimeStamp(tid), p))
for r in t: for r in t:
oid, serial, pre, vdata, p = r oid, pre, vdata, p, tid = r
oid=utils.u64(oid) oid = utils.oid_repr(oid)
if serial is not None: serial=str(TimeStamp(serial)) tid = utils.oid_repr(tid)
## if serial is not None: serial=str(TimeStamp(serial))
pre=id(pre) pre=id(pre)
if vdata and vdata[1]: vdata=vdata[0], id(vdata[1]) if vdata and vdata[1]: vdata=vdata[0], id(vdata[1])
if p: p='' if p: p=''
o.append(' %s: %s' % o.append(' %s: %s' %
(id(r), `(oid, serial, pre, vdata, p)`)) (id(r), `(oid, pre, vdata, p, tid)`))
o.append('\nIndex:') o.append('\nIndex:')
items=self._index.items() items=self._index.items()
items.sort() items.sort()
for oid, r in items: for oid, r in items:
if r: r=id(r) if r: r=id(r)
o.append(' %s: %s' % (utils.u64(oid), r)) o.append(' %s: %s' % (utils.oid_repr(oid), r))
o.append('\nVersion Index:') o.append('\nVersion Index:')
items=self._vindex.items() items=self._vindex.items()
...@@ -566,7 +583,6 @@ class DemoStorage(BaseStorage.BaseStorage): ...@@ -566,7 +583,6 @@ class DemoStorage(BaseStorage.BaseStorage):
vitems.sort() vitems.sort()
for oid, r in vitems: for oid, r in vitems:
if r: r=id(r) if r: r=id(r)
o.append(' %s: %s' % (utils.u64(oid), r)) o.append(' %s: %s' % (utils.oid_repr(oid), r))
return string.join(o,'\n') return string.join(o,'\n')
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Storage implementation using a log written to a single file.
$Revision: 1.2 $
"""
import base64
from cPickle import Pickler, Unpickler, loads
import errno
import os
import struct
import sys
import time
from types import StringType, DictType
from struct import pack, unpack
# Not all platforms have fsync
fsync = getattr(os, "fsync", None)
from ZODB import BaseStorage, ConflictResolution, POSException
from ZODB.POSException \
import UndoError, POSKeyError, MultipleUndoErrors, VersionLockError
from persistent.TimeStamp import TimeStamp
from ZODB.lock_file import LockFile
from ZODB.utils import p64, u64, cp, z64
from ZODB.FileStorage.fspack import FileStoragePacker
from ZODB.FileStorage.format \
import FileStorageFormatter, DataHeader, TxnHeader, DATA_HDR, \
DATA_HDR_LEN, TRANS_HDR, TRANS_HDR_LEN, CorruptedDataError, \
DATA_VERSION_HDR_LEN
try:
from ZODB.fsIndex import fsIndex
except ImportError:
def fsIndex():
return {}
from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC
t32 = 1L << 32
packed_version = "FS21"
def blather(message, *data):
LOG('ZODB FS', BLATHER, "%s blather: %s\n" % (packed_version,
message % data))
def warn(message, *data):
LOG('ZODB FS', WARNING, "%s warn: %s\n" % (packed_version,
message % data))
def error(message, *data, **kwargs):
LOG('ZODB FS', ERROR, "%s ERROR: %s\n" % (packed_version,
message % data), **kwargs)
def nearPanic(message, *data):
LOG('ZODB FS', PANIC, "%s ERROR: %s\n" % (packed_version,
message % data))
def panic(message, *data):
message = message % data
LOG('ZODB FS', PANIC, "%s ERROR: %s\n" % (packed_version, message))
raise CorruptedTransactionError(message)
class FileStorageError(POSException.StorageError):
pass
class PackError(FileStorageError):
pass
class FileStorageFormatError(FileStorageError):
"""Invalid file format
The format of the given file is not valid.
"""
class CorruptedFileStorageError(FileStorageError,
POSException.StorageSystemError):
"""Corrupted file storage."""
class CorruptedTransactionError(CorruptedFileStorageError):
pass
class FileStorageQuotaError(FileStorageError,
POSException.StorageSystemError):
"""File storage quota exceeded."""
class TempFormatter(FileStorageFormatter):
"""Helper class used to read formatted FileStorage data."""
def __init__(self, afile):
self._file = afile
class FileStorage(BaseStorage.BaseStorage,
ConflictResolution.ConflictResolvingStorage,
FileStorageFormatter):
# default pack time is 0
_packt = z64
_records_before_save = 10000
def __init__(self, file_name, create=False, read_only=False, stop=None,
quota=None):
if read_only:
self._is_read_only = 1
if create:
raise ValueError("can't create a read-only file")
elif stop is not None:
raise ValueError("time-travel only supported in read-only mode")
if stop is None:
stop='\377'*8
# Lock the database and set up the temp file.
if not read_only:
# Create the lock file
self._lock_file = LockFile(file_name + '.lock')
self._tfile = open(file_name + '.tmp', 'w+b')
self._tfmt = TempFormatter(self._tfile)
else:
self._tfile = None
self._file_name = file_name
BaseStorage.BaseStorage.__init__(self, file_name)
(index, vindex, tindex, tvindex,
oid2tid, toid2tid, toid2tid_delete) = self._newIndexes()
self._initIndex(index, vindex, tindex, tvindex,
oid2tid, toid2tid, toid2tid_delete)
# Now open the file
self._file = None
if not create:
try:
self._file = open(file_name, read_only and 'rb' or 'r+b')
except IOError, exc:
if exc.errno == errno.EFBIG:
# The file is too big to open. Fail visibly.
raise
if exc.errno == errno.ENOENT:
# The file doesn't exist. Create it.
create = 1
# If something else went wrong, it's hard to guess
# what the problem was. If the file does not exist,
# create it. Otherwise, fail.
if os.path.exists(file_name):
raise
else:
create = 1
if self._file is None and create:
if os.path.exists(file_name):
os.remove(file_name)
self._file = open(file_name, 'w+b')
self._file.write(packed_version)
r = self._restore_index()
if r is not None:
self._used_index = 1 # Marker for testing
index, vindex, start, maxoid, ltid = r
self._initIndex(index, vindex, tindex, tvindex,
oid2tid, toid2tid, toid2tid_delete)
self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop,
ltid=ltid, start=start, maxoid=maxoid,
read_only=read_only,
)
else:
self._used_index = 0 # Marker for testing
self._pos, self._oid, tid = read_index(
self._file, file_name, index, vindex, tindex, stop,
read_only=read_only,
)
self._save_index()
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid
# self._pos should always point just past the last
# transaction. During 2PC, data is written after _pos.
# invariant is restored at tpc_abort() or tpc_finish().
self._ts = tid = TimeStamp(tid)
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5] + (t % 60,))
if tid > t:
warn("%s Database records in the future", file_name);
if tid.timeTime() - t.timeTime() > 86400*30:
# a month in the future? This is bogus, use current time
self._ts = t
self._quota = quota
# tid cache statistics.
self._oid2tid_nlookups = self._oid2tid_nhits = 0
def _initIndex(self, index, vindex, tindex, tvindex,
oid2tid, toid2tid, toid2tid_delete):
self._index=index
self._vindex=vindex
self._tindex=tindex
self._tvindex=tvindex
self._index_get=index.get
self._vindex_get=vindex.get
# .store() needs to compare the passed-in serial to the
# current tid in the database. _oid2tid caches the oid ->
# current tid mapping for non-version data (if the current
# record for oid is version data, the oid is not a key in
# _oid2tid). The point is that otherwise seeking into the
# storage is needed to extract the current tid, and that's
# an expensive operation. For example, if a transaction
# stores 4000 objects, and each random seek + read takes 7ms
# (that was approximately true on Linux and Windows tests in
# mid-2003), that's 28 seconds just to find the old tids.
# XXX Probably better to junk this and redefine _index as mapping
# XXX oid to (offset, tid) pair, via a new memory-efficient
# XXX BTree type.
self._oid2tid = oid2tid
# oid->tid map to transactionally add to _oid2tid.
self._toid2tid = toid2tid
# Set of oids to transactionally delete from _oid2tid (e.g.,
# oids reverted by undo, or for which the most recent record
# becomes version data).
self._toid2tid_delete = toid2tid_delete
def __len__(self):
return len(self._index)
def _newIndexes(self):
# hook to use something other than builtin dict
return fsIndex(), {}, {}, {}, {}, {}, {}
_saved = 0
def _save_index(self):
"""Write the database index to a file to support quick startup."""
index_name = self.__name__ + '.index'
tmp_name = index_name + '.index_tmp'
f=open(tmp_name,'wb')
p=Pickler(f,1)
info={'index': self._index, 'pos': self._pos,
'oid': self._oid, 'vindex': self._vindex}
p.dump(info)
f.flush()
f.close()
try:
try:
os.remove(index_name)
except OSError:
pass
os.rename(tmp_name, index_name)
except: pass
self._saved += 1
def _clear_index(self):
index_name = self.__name__ + '.index'
if os.path.exists(index_name):
try:
os.remove(index_name)
except OSError:
pass
def _sane(self, index, pos):
"""Sanity check saved index data by reading the last undone trans
Basically, we read the last not undone transaction and
check to see that the included records are consistent
with the index. Any invalid record records or inconsistent
object positions cause zero to be returned.
"""
r = self._check_sanity(index, pos)
if not r:
warn("Ignoring index for %s", self._file_name)
return r
def _check_sanity(self, index, pos):
if pos < 100:
return 0 # insane
self._file.seek(0, 2)
if self._file.tell() < pos:
return 0 # insane
ltid = None
max_checked = 5
checked = 0
while checked < max_checked:
self._file.seek(pos - 8)
rstl = self._file.read(8)
tl = u64(rstl)
pos = pos - tl - 8
if pos < 4:
return 0 # insane
h = self._read_txn_header(pos)
if not ltid:
ltid = h.tid
if h.tlen != tl:
return 0 # inconsistent lengths
if h.status == 'u':
continue # undone trans, search back
if h.status not in ' p':
return 0 # insane
if tl < h.headerlen():
return 0 # insane
tend = pos + tl
opos = pos + h.headerlen()
if opos == tend:
continue # empty trans
while opos < tend and checked < max_checked:
# Read the data records for this transaction
h = self._read_data_header(opos)
if opos + h.recordlen() > tend or h.tloc != pos:
return 0
if index.get(h.oid, 0) != opos:
return 0 # insane
checked += 1
opos = opos + h.recordlen()
return ltid
def _restore_index(self):
"""Load database index to support quick startup."""
file_name=self.__name__
index_name=file_name+'.index'
try: f=open(index_name,'rb')
except: return None
p=Unpickler(f)
try:
info=p.load()
except:
exc, err = sys.exc_info()[:2]
warn("Failed to load database index: %s: %s" %
(exc, err))
return None
index = info.get('index')
pos = info.get('pos')
oid = info.get('oid')
vindex = info.get('vindex')
if index is None or pos is None or oid is None or vindex is None:
return None
pos = long(pos)
if isinstance(index, DictType) and not self._is_read_only:
# Convert to fsIndex
newindex = fsIndex()
if type(newindex) is not type(index):
# And we have fsIndex
newindex.update(index)
# Now save the index
f = open(index_name, 'wb')
p = Pickler(f, 1)
info['index'] = newindex
p.dump(info)
f.close()
# Now call this method again to get the new data
return self._restore_index()
tid = self._sane(index, pos)
if not tid:
return None
return index, vindex, pos, oid, tid
def close(self):
self._file.close()
if hasattr(self,'_lock_file'):
self._lock_file.close()
if self._tfile:
self._tfile.close()
try:
self._save_index()
except:
# Log the error and continue
LOG("ZODB FS", ERROR, "Error saving index on close()",
error=sys.exc_info())
# Return tid of most recent record for oid if that's in the
# _oid2tid cache. Else return None. It's important to use this
# instead of indexing _oid2tid directly so that cache statistics
# can be logged.
def _get_cached_tid(self, oid):
self._oid2tid_nlookups += 1
result = self._oid2tid.get(oid)
if result is not None:
self._oid2tid_nhits += 1
# Log a msg every ~8000 tries, and prevent overflow.
if self._oid2tid_nlookups & 0x1fff == 0:
if self._oid2tid_nlookups >> 30:
# In older Pythons, we may overflow if we keep it an int.
self._oid2tid_nlookups = long(self._oid2tid_nlookups)
self._oid2tid_nhits = long(self._oid2tid_nhits)
blather("_oid2tid size %s lookups %s hits %s rate %.1f%%",
len(self._oid2tid),
self._oid2tid_nlookups,
self._oid2tid_nhits,
100.0 * self._oid2tid_nhits /
self._oid2tid_nlookups)
return result
def abortVersion(self, src, transaction):
return self.commitVersion(src, '', transaction, abort=True)
def commitVersion(self, src, dest, transaction, abort=False):
# We are going to commit by simply storing back pointers.
if self._is_read_only:
raise POSException.ReadOnlyError()
if not (src and isinstance(src, StringType)
and isinstance(dest, StringType)):
raise POSException.VersionCommitError('Invalid source version')
if src == dest:
raise POSException.VersionCommitError(
"Can't commit to same version: %s" % repr(src))
if dest and abort:
raise POSException.VersionCommitError(
"Internal error, can't abort to a version")
if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction)
self._lock_acquire()
try:
return self._commitVersion(src, dest, transaction, abort)
finally:
self._lock_release()
def _commitVersion(self, src, dest, transaction, abort=False):
# call after checking arguments and acquiring lock
srcpos = self._vindex_get(src, 0)
spos = p64(srcpos)
# middle holds bytes 16:34 of a data record:
# pos of transaction, len of version name, data length
# commit version never writes data, so data length is always 0
middle = struct.pack(">8sH8s", p64(self._pos), len(dest), z64)
if dest:
sd = p64(self._vindex_get(dest, 0))
heredelta = 66 + len(dest)
else:
sd = ''
heredelta = 50
here = self._pos + (self._tfile.tell() + self._thl)
oids = []
current_oids = {}
while srcpos:
h = self._read_data_header(srcpos)
if self._index.get(h.oid) == srcpos:
# This is a current record!
self._tindex[h.oid] = here
oids.append(h.oid)
self._tfile.write(h.oid + self._tid + spos + middle)
if dest:
self._tvindex[dest] = here
self._tfile.write(p64(h.pnv) + sd + dest)
sd = p64(here)
self._tfile.write(abort and p64(h.pnv) or spos)
# data backpointer to src data
here += heredelta
current_oids[h.oid] = 1
else:
# Hm. This is a non-current record. Is there a
# current record for this oid?
if not current_oids.has_key(h.oid):
break
srcpos = h.vprev
spos = p64(srcpos)
self._toid2tid_delete.update(current_oids)
return self._tid, oids
def getSize(self):
return self._pos
def _lookup_pos(self, oid):
try:
return self._index[oid]
except KeyError:
raise POSKeyError(oid)
except TypeError:
raise TypeError("invalid oid %r" % (oid,))
def loadEx(self, oid, version):
# A variant of load() that also returns a transaction id.
# ZEO wants this for managing its cache.
self._lock_acquire()
try:
pos = self._lookup_pos(oid)
h = self._read_data_header(pos, oid)
if h.version and h.version != version:
# Return data and tid from pnv (non-version data).
# If we return the old record's transaction id, then
# it will look to the cache like old data is current.
# The tid for the current data must always be greater
# than any non-current data.
data = self._loadBack_impl(oid, h.pnv)[0]
return data, h.tid, ""
if h.plen:
data = self._file.read(h.plen)
return data, h.tid, h.version
else:
# Get the data from the backpointer, but tid from
# currnt txn.
data, _, _, _ = self._loadBack_impl(oid, h.back)
th = self._read_txn_header(h.tloc)
return data, h.tid, h.version
finally:
self._lock_release()
def load(self, oid, version):
self._lock_acquire()
try:
pos = self._lookup_pos(oid)
h = self._read_data_header(pos, oid)
if h.version and h.version != version:
data = self._loadBack_impl(oid, h.pnv)[0]
return data, h.tid
if h.plen:
return self._file.read(h.plen), h.tid
else:
data = self._loadBack_impl(oid, h.back)[0]
return data, h.tid
finally:
self._lock_release()
def loadSerial(self, oid, serial):
# loadSerial must always return non-version data, because it
# is used by conflict resolution.
self._lock_acquire()
try:
pos = self._lookup_pos(oid)
while 1:
h = self._read_data_header(pos, oid)
if h.tid == serial:
break
pos = h.prev
if not pos:
raise POSKeyError(oid)
if h.version:
return self._loadBack_impl(oid, h.pnv)[0]
if h.plen:
return self._file.read(h.plen)
else:
return self._loadBack_impl(oid, h.back)[0]
finally:
self._lock_release()
def loadBefore(self, oid, tid):
pos = self._lookup_pos(oid)
end_tid = None
while True:
h = self._read_data_header(pos, oid)
if h.version:
# Just follow the pnv pointer to the previous
# non-version data.
if not h.pnv:
# Object was created in version. There is no
# before data to find.
return None
pos = h.pnv
# The end_tid for the non-version data is not affected
# by versioned data records.
continue
if h.tid < tid:
break
pos = h.prev
end_tid = h.tid
if not pos:
return None
if h.back:
data, _, _, _ = self._loadBack_impl(oid, h.back)
return data, h.tid, end_tid
else:
return self._file.read(h.plen), h.tid, end_tid
def modifiedInVersion(self, oid):
self._lock_acquire()
try:
pos = self._lookup_pos(oid)
h = self._read_data_header(pos, oid)
return h.version
finally:
self._lock_release()
def store(self, oid, serial, data, version, transaction):
if self._is_read_only:
raise POSException.ReadOnlyError()
if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction)
self._lock_acquire()
try:
old = self._index_get(oid, 0)
cached_tid = None
pnv = None
if old:
cached_tid = self._get_cached_tid(oid)
if cached_tid is None:
h = self._read_data_header(old, oid)
if h.version:
if h.version != version:
raise VersionLockError(oid, h.version)
pnv = h.pnv
cached_tid = h.tid
if serial != cached_tid:
rdata = self.tryToResolveConflict(oid, cached_tid,
serial, data)
if rdata is None:
raise POSException.ConflictError(
oid=oid, serials=(cached_tid, serial), data=data)
else:
data = rdata
pos = self._pos
here = pos + self._tfile.tell() + self._thl
self._tindex[oid] = here
new = DataHeader(oid, self._tid, old, pos, len(version),
len(data))
if version:
# Link to last record for this version:
pv = (self._tvindex.get(version, 0)
or self._vindex.get(version, 0))
if pnv is None:
pnv = old
new.setVersion(version, pnv, pv)
self._tvindex[version] = here
self._toid2tid_delete[oid] = 1
else:
self._toid2tid[oid] = self._tid
self._tfile.write(new.asString())
self._tfile.write(data)
# Check quota
if self._quota is not None and here > self._quota:
raise FileStorageQuotaError(
"The storage quota has been exceeded.")
if old and serial != cached_tid:
return ConflictResolution.ResolvedSerial
else:
return self._tid
finally:
self._lock_release()
def _data_find(self, tpos, oid, data):
# Return backpointer to oid in data record for in transaction at tpos.
# It should contain a pickle identical to data. Returns 0 on failure.
# Must call with lock held.
self._file.seek(tpos)
h = self._file.read(TRANS_HDR_LEN)
tid, tl, status, ul, dl, el = struct.unpack(TRANS_HDR, h)
self._file.read(ul + dl + el)
tend = tpos + tl + 8
pos = self._file.tell()
while pos < tend:
h = self._read_data_header(pos)
if h.oid == oid:
# Make sure this looks like the right data record
if h.plen == 0:
# This is also a backpointer. Gotta trust it.
return pos
if h.plen != len(data):
# The expected data doesn't match what's in the
# backpointer. Something is wrong.
error("Mismatch between data and backpointer at %d", pos)
return 0
_data = self._file.read(h.plen)
if data != _data:
return 0
return pos
pos += h.recordlen()
self._file.seek(pos)
return 0
def restore(self, oid, serial, data, version, prev_txn, transaction):
# A lot like store() but without all the consistency checks. This
# should only be used when we /know/ the data is good, hence the
# method name. While the signature looks like store() there are some
# differences:
#
# - serial is the serial number of /this/ revision, not of the
# previous revision. It is used instead of self._tid, which is
# ignored.
#
# - Nothing is returned
#
# - data can be None, which indicates a George Bailey object
# (i.e. one who's creation has been transactionally undone).
#
# prev_txn is a backpointer. In the original database, it's possible
# that the data was actually living in a previous transaction. This
# can happen for transactional undo and other operations, and is used
# as a space saving optimization. Under some circumstances the
# prev_txn may not actually exist in the target database (i.e. self)
# for example, if it's been packed away. In that case, the prev_txn
# should be considered just a hint, and is ignored if the transaction
# doesn't exist.
if self._is_read_only:
raise POSException.ReadOnlyError()
if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction)
self._lock_acquire()
try:
prev_pos = 0
if prev_txn is not None:
prev_txn_pos = self._txn_find(prev_txn, 0)
if prev_txn_pos:
prev_pos = self._data_find(prev_txn_pos, oid, data)
old = self._index_get(oid, 0)
# Calculate the file position in the temporary file
here = self._pos + self._tfile.tell() + self._thl
# And update the temp file index
self._tindex[oid] = here
if prev_pos:
# If there is a valid prev_pos, don't write data.
data = None
if data is None:
dlen = 0
else:
dlen = len(data)
# Write the recovery data record
new = DataHeader(oid, serial, old, self._pos, len(version), dlen)
if version:
pnv = self._restore_pnv(oid, old, version, prev_pos) or old
vprev = self._tvindex.get(version, 0)
if not vprev:
vprev = self._vindex.get(version, 0)
new.setVersion(version, pnv, vprev)
self._tvindex[version] = here
self._toid2tid_delete[oid] = 1
else:
self._toid2tid[oid] = serial
self._tfile.write(new.asString())
# Finally, write the data or a backpointer.
if data is None:
if prev_pos:
self._tfile.write(p64(prev_pos))
else:
# Write a zero backpointer, which indicates an
# un-creation transaction.
self._tfile.write(z64)
else:
self._tfile.write(data)
finally:
self._lock_release()
def _restore_pnv(self, oid, prev, version, bp):
# Find a valid pnv (previous non-version) pointer for this version.
# If there is no previous record, there can't be a pnv.
if not prev:
return None
# Load the record pointed to be prev
h = self._read_data_header(prev, oid)
if h.version:
return h.pnv
if h.back:
# XXX Not sure the following is always true:
# The previous record is not for this version, yet we
# have a backpointer to it. The current record must
# be an undo of an abort or commit, so the backpointer
# must be to a version record with a pnv.
h2 = self._read_data_header(h.back, oid)
if h2.version:
return h2.pnv
return None
def supportsUndo(self):
return 1
def supportsVersions(self):
return 1
def _clear_temp(self):
self._tindex.clear()
self._tvindex.clear()
self._toid2tid.clear()
self._toid2tid_delete.clear()
if self._tfile is not None:
self._tfile.seek(0)
def _begin(self, tid, u, d, e):
self._nextpos = 0
self._thl = TRANS_HDR_LEN + len(u) + len(d) + len(e)
if self._thl > 65535:
# one of u, d, or e may be > 65535
# We have to check lengths here because struct.pack
# doesn't raise an exception on overflow!
if len(u) > 65535:
raise FileStorageError('user name too long')
if len(d) > 65535:
raise FileStorageError('description too long')
if len(e) > 65535:
raise FileStorageError('too much extension data')
def tpc_vote(self, transaction):
self._lock_acquire()
try:
if transaction is not self._transaction:
return
dlen = self._tfile.tell()
if not dlen:
return # No data in this trans
self._tfile.seek(0)
user, descr, ext = self._ude
self._file.seek(self._pos)
tl = self._thl + dlen
try:
h = TxnHeader(self._tid, tl, "c", len(user),
len(descr), len(ext))
h.user = user
h.descr = descr
h.ext = ext
self._file.write(h.asString())
cp(self._tfile, self._file, dlen)
self._file.write(p64(tl))
self._file.flush()
except:
# Hm, an error occured writing out the data. Maybe the
# disk is full. We don't want any turd at the end.
self._file.truncate(self._pos)
raise
self._nextpos = self._pos + (tl + 8)
finally:
self._lock_release()
# Keep track of the number of records that we've written
_records_written = 0
def _finish(self, tid, u, d, e):
nextpos=self._nextpos
if nextpos:
file=self._file
# Clear the checkpoint flag
file.seek(self._pos+16)
file.write(self._tstatus)
file.flush()
if fsync is not None: fsync(file.fileno())
self._pos = nextpos
self._index.update(self._tindex)
self._vindex.update(self._tvindex)
self._oid2tid.update(self._toid2tid)
for oid in self._toid2tid_delete.keys():
try:
del self._oid2tid[oid]
except KeyError:
pass
# Update the number of records that we've written
# +1 for the transaction record
self._records_written += len(self._tindex) + 1
if self._records_written >= self._records_before_save:
self._save_index()
self._records_written = 0
self._records_before_save = max(self._records_before_save,
len(self._index))
self._ltid = tid
def _abort(self):
if self._nextpos:
self._file.truncate(self._pos)
self._nextpos=0
def supportsTransactionalUndo(self):
return 1
def _undoDataInfo(self, oid, pos, tpos):
"""Return the tid, data pointer, data, and version for the oid
record at pos"""
if tpos:
pos = tpos - self._pos - self._thl
tpos = self._tfile.tell()
h = self._tfmt._read_data_header(pos, oid)
afile = self._tfile
else:
h = self._read_data_header(pos, oid)
afile = self._file
if h.oid != oid:
raise UndoError("Invalid undo transaction id", oid)
if h.plen:
data = afile.read(h.plen)
else:
data = ''
pos = h.back
if tpos:
self._tfile.seek(tpos) # Restore temp file to end
return h.tid, pos, data, h.version
def getTid(self, oid):
self._lock_acquire()
try:
result = self._get_cached_tid(oid)
if result is None:
pos = self._lookup_pos(oid)
result = self._getTid(oid, pos)
return result
finally:
self._lock_release()
def _getTid(self, oid, pos):
self._file.seek(pos)
h = self._file.read(16)
assert oid == h[:8]
return h[8:]
def _getVersion(self, oid, pos):
h = self._read_data_header(pos, oid)
if h.version:
return h.version, h.pnv
else:
return "", None
def _transactionalUndoRecord(self, oid, pos, tid, pre, version):
"""Get the indo information for a data record
Return a 5-tuple consisting of a pickle, data pointer,
version, packed non-version data pointer, and current
position. If the pickle is true, then the data pointer must
be 0, but the pickle can be empty *and* the pointer 0.
"""
copy = 1 # Can we just copy a data pointer
# First check if it is possible to undo this record.
tpos = self._tindex.get(oid, 0)
ipos = self._index.get(oid, 0)
tipos = tpos or ipos
if tipos != pos:
# Eek, a later transaction modified the data, but,
# maybe it is pointing at the same data we are.
ctid, cdataptr, cdata, cver = self._undoDataInfo(oid, ipos, tpos)
# Versions of undone record and current record *must* match!
if cver != version:
raise UndoError('Current and undone versions differ', oid)
if cdataptr != pos:
# We aren't sure if we are talking about the same data
try:
if (
# The current record wrote a new pickle
cdataptr == tipos
or
# Backpointers are different
self._loadBackPOS(oid, pos) !=
self._loadBackPOS(oid, cdataptr)
):
if pre and not tpos:
copy = 0 # we'll try to do conflict resolution
else:
# We bail if:
# - We don't have a previous record, which should
# be impossible.
raise UndoError("no previous record", oid)
except KeyError:
# LoadBack gave us a key error. Bail.
raise UndoError("_loadBack() failed", oid)
# Return the data that should be written in the undo record.
if not pre:
# There is no previous revision, because the object creation
# is being undone.
return "", 0, "", "", ipos
version, snv = self._getVersion(oid, pre)
if copy:
# we can just copy our previous-record pointer forward
return "", pre, version, snv, ipos
try:
bdata = self._loadBack_impl(oid, pre)[0]
except KeyError:
# couldn't find oid; what's the real explanation for this?
raise UndoError("_loadBack() failed for %s", oid)
data = self.tryToResolveConflict(oid, ctid, tid, bdata, cdata)
if data:
return data, 0, version, snv, ipos
raise UndoError("Some data were modified by a later transaction", oid)
# undoLog() returns a description dict that includes an id entry.
# The id is opaque to the client, but contains the transaction id.
# The transactionalUndo() implementation does a simple linear
# search through the file (from the end) to find the transaction.
def undoLog(self, first=0, last=-20, filter=None):
if last < 0:
last = first - last + 1
self._lock_acquire()
try:
if self._packt is None:
raise UndoError(
'Undo is currently disabled for database maintenance.<p>')
us = UndoSearch(self._file, self._pos, self._packt,
first, last, filter)
while not us.finished():
# Hold lock for batches of 20 searches, so default search
# parameters will finish without letting another thread run.
for i in range(20):
if us.finished():
break
us.search()
# Give another thread a chance, so that a long undoLog()
# operation doesn't block all other activity.
self._lock_release()
self._lock_acquire()
return us.results
finally:
self._lock_release()
def transactionalUndo(self, transaction_id, transaction):
"""Undo a transaction, given by transaction_id.
Do so by writing new data that reverses the action taken by
the transaction.
Usually, we can get by with just copying a data pointer, by
writing a file position rather than a pickle. Sometimes, we
may do conflict resolution, in which case we actually copy
new data that results from resolution.
"""
if self._is_read_only:
raise POSException.ReadOnlyError()
if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction)
self._lock_acquire()
try:
return self._txn_undo(transaction_id)
finally:
self._lock_release()
def _txn_undo(self, transaction_id):
# Find the right transaction to undo and call _txn_undo_write().
tid = base64.decodestring(transaction_id + '\n')
assert len(tid) == 8
tpos = self._txn_find(tid, 1)
tindex = self._txn_undo_write(tpos)
self._tindex.update(tindex)
# Arrange to clear the affected oids from the oid2tid cache.
# It's too painful to try to update them to correct current
# values instead.
self._toid2tid_delete.update(tindex)
return self._tid, tindex.keys()
def _txn_find(self, tid, stop_at_pack):
pos = self._pos
while pos > 39:
self._file.seek(pos - 8)
pos = pos - u64(self._file.read(8)) - 8
self._file.seek(pos)
h = self._file.read(TRANS_HDR_LEN)
_tid = h[:8]
if _tid == tid:
return pos
if stop_at_pack:
# check the status field of the transaction header
if h[16] == 'p' or _tid < self._packt:
break
raise UndoError("Invalid transaction id")
def _txn_undo_write(self, tpos):
# a helper function to write the data records for transactional undo
otloc = self._pos
here = self._pos + self._tfile.tell() + self._thl
base = here - self._tfile.tell()
# Let's move the file pointer back to the start of the txn record.
th = self._read_txn_header(tpos)
if th.status != " ":
raise UndoError('non-undoable transaction')
tend = tpos + th.tlen
pos = tpos + th.headerlen()
tindex = {}
# keep track of failures, cause we may succeed later
failures = {}
# Read the data records for this transaction
while pos < tend:
h = self._read_data_header(pos)
if h.oid in failures:
del failures[h.oid] # second chance!
assert base + self._tfile.tell() == here, (here, base,
self._tfile.tell())
try:
p, prev, v, snv, ipos = self._transactionalUndoRecord(
h.oid, pos, h.tid, h.prev, h.version)
except UndoError, v:
# Don't fail right away. We may be redeemed later!
failures[h.oid] = v
else:
new = DataHeader(h.oid, self._tid, ipos, otloc, len(v),
len(p))
if v:
vprev = self._tvindex.get(v, 0) or self._vindex.get(v, 0)
new.setVersion(v, snv, vprev)
self._tvindex[v] = here
# XXX This seek shouldn't be necessary, but some other
# bit of code is messig with the file pointer.
assert self._tfile.tell() == here - base, (here, base,
self._tfile.tell())
self._tfile.write(new.asString())
if p:
self._tfile.write(p)
else:
self._tfile.write(p64(prev))
tindex[h.oid] = here
here += new.recordlen()
pos += h.recordlen()
if pos > tend:
raise UndoError("non-undoable transaction")
if failures:
raise MultipleUndoErrors(failures.items())
return tindex
def versionEmpty(self, version):
if not version:
# The interface is silent on this case. I think that this should
# be an error, but Barry thinks this should return 1 if we have
# any non-version data. This would be excruciatingly painful to
# test, so I must be right. ;)
raise POSException.VersionError(
'The version must be an non-empty string')
self._lock_acquire()
try:
index=self._index
file=self._file
seek=file.seek
read=file.read
srcpos=self._vindex_get(version, 0)
t=tstatus=None
while srcpos:
seek(srcpos)
oid=read(8)
if index[oid]==srcpos: return 0
h=read(50) # serial, prev(oid), tloc, vlen, plen, pnv, pv
tloc=h[16:24]
if t != tloc:
# We haven't checked this transaction before,
# get its status.
t=tloc
seek(u64(t)+16)
tstatus=read(1)
if tstatus != 'u': return 1
spos=h[-8:]
srcpos=u64(spos)
return 1
finally: self._lock_release()
def versions(self, max=None):
r=[]
a=r.append
keys=self._vindex.keys()
if max is not None: keys=keys[:max]
for version in keys:
if self.versionEmpty(version): continue
a(version)
if max and len(r) >= max: return r
return r
def history(self, oid, version=None, size=1, filter=None):
self._lock_acquire()
try:
r = []
pos = self._lookup_pos(oid)
wantver = version
while 1:
if len(r) >= size: return r
h = self._read_data_header(pos)
if h.version:
if wantver is not None and h.version != wantver:
if h.prev:
pos = h.prev
continue
else:
return r
else:
version = ""
wantver = None
th = self._read_txn_header(h.tloc)
user_name = self._file.read(th.ulen)
description = self._file.read(th.dlen)
if th.elen:
d = loads(self._file.read(th.elen))
else:
d = {}
d.update({"time": TimeStamp(h.tid).timeTime(),
"user_name": user_name,
"description": description,
"tid": h.tid,
"version": h.version,
"size": h.plen,
})
if filter is None or filter(d):
r.append(d)
if h.prev:
pos = h.prev
else:
return r
finally:
self._lock_release()
def _redundant_pack(self, file, pos):
assert pos > 8, pos
file.seek(pos - 8)
p = u64(file.read(8))
file.seek(pos - p + 8)
return file.read(1) not in ' u'
def pack(self, t, referencesf):
"""Copy data from the current database file to a packed file
Non-current records from transactions with time-stamp strings less
than packtss are ommitted. As are all undone records.
Also, data back pointers that point before packtss are resolved and
the associated data are copied, since the old records are not copied.
"""
if self._is_read_only:
raise POSException.ReadOnlyError()
stop=`TimeStamp(*time.gmtime(t)[:5]+(t%60,))`
if stop==z64: raise FileStorageError, 'Invalid pack time'
# If the storage is empty, there's nothing to do.
if not self._index:
return
# Record pack time so we don't undo while packing
self._lock_acquire()
try:
if self._packt != z64:
# Already packing.
raise FileStorageError, 'Already packing'
self._packt = None
finally:
self._lock_release()
p = FileStoragePacker(self._file_name, stop,
self._lock_acquire, self._lock_release,
self._commit_lock_acquire,
self._commit_lock_release)
try:
opos = p.pack()
if opos is None:
return
oldpath = self._file_name + ".old"
self._lock_acquire()
try:
self._file.close()
try:
if os.path.exists(oldpath):
os.remove(oldpath)
os.rename(self._file_name, oldpath)
except Exception:
self._file = open(self._file_name, 'r+b')
raise
# OK, we're beyond the point of no return
os.rename(self._file_name + '.pack', self._file_name)
self._file = open(self._file_name, 'r+b')
self._initIndex(p.index, p.vindex, p.tindex, p.tvindex,
p.oid2tid, p.toid2tid,
p.toid2tid_delete)
self._pos = opos
self._save_index()
finally:
self._lock_release()
finally:
if p.locked:
self._commit_lock_release()
self._lock_acquire()
self._packt = z64
self._lock_release()
def iterator(self, start=None, stop=None):
return FileIterator(self._file_name, start, stop)
def lastTransaction(self):
"""Return transaction id for last committed transaction"""
return self._ltid
def lastTid(self, oid):
"""Return last serialno committed for object oid.
If there is no serialno for this oid -- which can only occur
if it is a new object -- return None.
"""
try:
return self.getTid(oid)
except KeyError:
return None
def cleanup(self):
"""Remove all files created by this storage."""
for ext in '', '.old', '.tmp', '.lock', '.index', '.pack':
try:
os.remove(self._file_name + ext)
except OSError, e:
if e.errno != errno.ENOENT:
raise
def shift_transactions_forward(index, vindex, tindex, file, pos, opos):
"""Copy transactions forward in the data file
This might be done as part of a recovery effort
"""
# Cache a bunch of methods
seek=file.seek
read=file.read
write=file.write
index_get=index.get
vindex_get=vindex.get
# Initialize,
pv=z64
p1=opos
p2=pos
offset=p2-p1
# Copy the data in two stages. In the packing stage,
# we skip records that are non-current or that are for
# unreferenced objects. We also skip undone transactions.
#
# After the packing stage, we copy everything but undone
# transactions, however, we have to update various back pointers.
# We have to have the storage lock in the second phase to keep
# data from being changed while we're copying.
pnv=None
while 1:
# Read the transaction record
seek(pos)
h=read(TRANS_HDR_LEN)
if len(h) < TRANS_HDR_LEN: break
tid, stl, status, ul, dl, el = unpack(TRANS_HDR,h)
if status=='c': break # Oops. we found a checkpoint flag.
tl=u64(stl)
tpos=pos
tend=tpos+tl
otpos=opos # start pos of output trans
thl=ul+dl+el
h2=read(thl)
if len(h2) != thl:
raise PackError(opos)
# write out the transaction record
seek(opos)
write(h)
write(h2)
thl=TRANS_HDR_LEN+thl
pos=tpos+thl
opos=otpos+thl
while pos < tend:
# Read the data records for this transaction
seek(pos)
h=read(DATA_HDR_LEN)
oid,serial,sprev,stloc,vlen,splen = unpack(DATA_HDR, h)
plen=u64(splen)
dlen=DATA_HDR_LEN+(plen or 8)
if vlen:
dlen=dlen+(16+vlen)
pnv=u64(read(8))
# skip position of previous version record
seek(8,1)
version=read(vlen)
pv=p64(vindex_get(version, 0))
if status != 'u': vindex[version]=opos
tindex[oid]=opos
if plen: p=read(plen)
else:
p=read(8)
p=u64(p)
if p >= p2: p=p-offset
elif p >= p1:
# Ick, we're in trouble. Let's bail
# to the index and hope for the best
p=index_get(oid, 0)
p=p64(p)
# WRITE
seek(opos)
sprev=p64(index_get(oid, 0))
write(pack(DATA_HDR,
oid,serial,sprev,p64(otpos),vlen,splen))
if vlen:
if not pnv: write(z64)
else:
if pnv >= p2: pnv=pnv-offset
elif pnv >= p1:
pnv=index_get(oid, 0)
write(p64(pnv))
write(pv)
write(version)
write(p)
opos=opos+dlen
pos=pos+dlen
# skip the (intentionally redundant) transaction length
pos=pos+8
if status != 'u':
index.update(tindex) # Record the position
tindex.clear()
write(stl)
opos=opos+8
return opos
def search_back(file, pos):
seek=file.seek
read=file.read
seek(0,2)
s=p=file.tell()
while p > pos:
seek(p-8)
l=u64(read(8))
if l <= 0: break
p=p-l-8
return p, s
def recover(file_name):
file=open(file_name, 'r+b')
index={}
vindex={}
tindex={}
pos, oid, tid = read_index(
file, file_name, index, vindex, tindex, recover=1)
if oid is not None:
print "Nothing to recover"
return
opos=pos
pos, sz = search_back(file, pos)
if pos < sz:
npos = shift_transactions_forward(
index, vindex, tindex, file, pos, opos,
)
file.truncate(npos)
print "Recovered file, lost %s, ended up with %s bytes" % (
pos-opos, npos)
def read_index(file, name, index, vindex, tindex, stop='\377'*8,
ltid=z64, start=4L, maxoid=z64, recover=0, read_only=0):
"""Scan the entire file storage and recreate the index.
Returns file position, max oid, and last transaction id. It also
stores index information in the three dictionary arguments.
Arguments:
file -- a file object (the Data.fs)
name -- the name of the file (presumably file.name)
index -- dictionary, oid -> data record
vindex -- dictionary, oid -> data record for version data
tindex -- dictionary, oid -> data record
XXX tindex is cleared before return, so it will be empty
There are several default arguments that affect the scan or the
return values. XXX should document them.
The file position returned is the position just after the last
valid transaction record. The oid returned is the maximum object
id in the data. The transaction id is the tid of the last
transaction.
"""
read = file.read
seek = file.seek
seek(0, 2)
file_size=file.tell()
fmt = TempFormatter(file)
if file_size:
if file_size < start: raise FileStorageFormatError, file.name
seek(0)
if read(4) != packed_version:
raise FileStorageFormatError, name
else:
if not read_only:
file.write(packed_version)
return 4L, maxoid, ltid
index_get=index.get
pos=start
seek(start)
tid='\0'*7+'\1'
while 1:
# Read the transaction record
h=read(TRANS_HDR_LEN)
if not h: break
if len(h) != TRANS_HDR_LEN:
if not read_only:
warn('%s truncated at %s', name, pos)
seek(pos)
file.truncate()
break
tid, tl, status, ul, dl, el = unpack(TRANS_HDR,h)
if el < 0: el=t32-el
if tid <= ltid:
warn("%s time-stamp reduction at %s", name, pos)
ltid = tid
if pos+(tl+8) > file_size or status=='c':
# Hm, the data were truncated or the checkpoint flag wasn't
# cleared. They may also be corrupted,
# in which case, we don't want to totally lose the data.
if not read_only:
warn("%s truncated, possibly due to damaged records at %s",
name, pos)
_truncate(file, name, pos)
break
if status not in ' up':
warn('%s has invalid status, %s, at %s', name, status, pos)
if tl < (TRANS_HDR_LEN+ul+dl+el):
# We're in trouble. Find out if this is bad data in the
# middle of the file, or just a turd that Win 9x dropped
# at the end when the system crashed.
# Skip to the end and read what should be the transaction length
# of the last transaction.
seek(-8, 2)
rtl=u64(read(8))
# Now check to see if the redundant transaction length is
# reasonable:
if file_size - rtl < pos or rtl < TRANS_HDR_LEN:
nearPanic('%s has invalid transaction header at %s', name, pos)
if not read_only:
warn("It appears that there is invalid data at the end of "
"the file, possibly due to a system crash. %s "
"truncated to recover from bad data at end."
% name)
_truncate(file, name, pos)
break
else:
if recover: return pos, None, None
panic('%s has invalid transaction header at %s', name, pos)
if tid >= stop:
break
tpos=pos
tend=tpos+tl
if status=='u':
# Undone transaction, skip it
seek(tend)
h=u64(read(8))
if h != tl:
if recover: return tpos, None, None
panic('%s has inconsistent transaction length at %s',
name, pos)
pos=tend+8
continue
pos = tpos+ TRANS_HDR_LEN + ul + dl + el
while pos < tend:
# Read the data records for this transaction
h = fmt._read_data_header(pos)
dlen = h.recordlen()
tindex[h.oid] = pos
if h.version:
vindex[h.version] = pos
if pos + dlen > tend or h.tloc != tpos:
if recover:
return tpos, None, None
panic("%s data record exceeds transaction record at %s",
name, pos)
if index_get(h.oid, 0) != h.prev:
if prev:
if recover: return tpos, None, None
error("%s incorrect previous pointer at %s", name, pos)
else:
warn("%s incorrect previous pointer at %s", name, pos)
pos=pos+dlen
if pos != tend:
if recover: return tpos, None, None
panic("%s data records don't add up at %s",name,tpos)
# Read the (intentionally redundant) transaction length
seek(pos)
h = u64(read(8))
if h != tl:
if recover: return tpos, None, None
panic("%s redundant transaction length check failed at %s",
name, pos)
pos=pos+8
if tindex: # avoid the pathological empty transaction case
_maxoid = max(tindex.keys()) # in 2.2, just max(tindex)
maxoid = max(_maxoid, maxoid)
index.update(tindex)
tindex.clear()
return pos, maxoid, ltid
def _truncate(file, name, pos):
file.seek(0, 2)
file_size = file.tell()
try:
i = 0
while 1:
oname='%s.tr%s' % (name, i)
if os.path.exists(oname):
i += 1
else:
warn("Writing truncated data from %s to %s", name, oname)
o = open(oname,'wb')
file.seek(pos)
cp(file, o, file_size-pos)
o.close()
break
except:
error("couldn\'t write truncated data for %s", name,
error=sys.exc_info())
raise POSException.StorageSystemError, (
"Couldn't save truncated data")
file.seek(pos)
file.truncate()
class Iterator:
"""A General simple iterator that uses the Python for-loop index protocol
"""
__index=-1
__current=None
def __getitem__(self, i):
__index=self.__index
while i > __index:
__index=__index+1
self.__current=self.next(__index)
self.__index=__index
return self.__current
class FileIterator(Iterator, FileStorageFormatter):
"""Iterate over the transactions in a FileStorage file.
"""
_ltid = z64
_file = None
def __init__(self, file, start=None, stop=None):
if isinstance(file, str):
file = open(file, 'rb')
self._file = file
if file.read(4) != packed_version:
raise FileStorageFormatError, file.name
file.seek(0,2)
self._file_size = file.tell()
self._pos = 4L
assert start is None or isinstance(start, str)
assert stop is None or isinstance(stop, str)
if start:
self._skip_to_start(start)
self._stop = stop
def __len__(self):
# Define a bogus __len__() to make the iterator work
# with code like builtin list() and tuple() in Python 2.1.
# There's a lot of C code that expects a sequence to have
# an __len__() but can cope with any sort of mistake in its
# implementation. So just return 0.
return 0
# This allows us to pass an iterator as the `other' argument to
# copyTransactionsFrom() in BaseStorage. The advantage here is that we
# can create the iterator manually, e.g. setting start and stop, and then
# just let copyTransactionsFrom() do its thing.
def iterator(self):
return self
def close(self):
file = self._file
if file is not None:
self._file = None
file.close()
def _skip_to_start(self, start):
# Scan through the transaction records doing almost no sanity
# checks.
while 1:
self._file.seek(self._pos)
h = self._file.read(16)
if len(h) < 16:
return
tid, stl = unpack(">8s8s", h)
if tid >= start:
return
tl = u64(stl)
try:
self._pos += tl + 8
except OverflowError:
self._pos = long(self._pos) + tl + 8
if __debug__:
# Sanity check
self._file.seek(self._pos - 8, 0)
rtl = self._file.read(8)
if rtl != stl:
pos = self._file.tell() - 8
panic("%s has inconsistent transaction length at %s "
"(%s != %s)",
self._file.name, pos, u64(rtl), u64(stl))
def next(self, index=0):
if self._file is None:
# A closed iterator. XXX: Is IOError the best we can do? For
# now, mimic a read on a closed file.
raise IOError, 'iterator is closed'
pos = self._pos
while 1:
# Read the transaction record
try:
h = self._read_txn_header(pos)
except CorruptedDataError, err:
# If buf is empty, we've reached EOF.
if not err.buf:
break
raise
if h.tid <= self._ltid:
warn("%s time-stamp reduction at %s", self._file.name, pos)
self._ltid = h.tid
if self._stop is not None and h.tid > self._stop:
raise IndexError, index
if h.status == "c":
# Assume we've hit the last, in-progress transaction
raise IndexError, index
if pos + h.tlen + 8 > self._file_size:
# Hm, the data were truncated or the checkpoint flag wasn't
# cleared. They may also be corrupted,
# in which case, we don't want to totally lose the data.
warn("%s truncated, possibly due to damaged records at %s",
self._file.name, pos)
break
if h.status not in " up":
warn('%s has invalid status, %s, at %s', self._file.name,
h.status, pos)
if h.tlen < h.headerlen():
# We're in trouble. Find out if this is bad data in
# the middle of the file, or just a turd that Win 9x
# dropped at the end when the system crashed. Skip to
# the end and read what should be the transaction
# length of the last transaction.
self._file.seek(-8, 2)
rtl = u64(self._file.read(8))
# Now check to see if the redundant transaction length is
# reasonable:
if self._file_size - rtl < pos or rtl < TRANS_HDR_LEN:
nearPanic("%s has invalid transaction header at %s",
self._file.name, pos)
warn("It appears that there is invalid data at the end of "
"the file, possibly due to a system crash. %s "
"truncated to recover from bad data at end."
% self._file.name)
break
else:
warn("%s has invalid transaction header at %s",
self._file.name, pos)
break
tpos = pos
tend = tpos + h.tlen
if h.status != "u":
pos = tpos + h.headerlen()
user = self._file.read(h.ulen)
description = self._file.read(h.dlen)
e = {}
if h.elen:
try:
e = loads(self._file.read(h.elen))
except:
pass
result = RecordIterator(h.tid, h.status, user, description,
e, pos, tend, self._file, tpos)
# Read the (intentionally redundant) transaction length
self._file.seek(tend)
rtl = u64(self._file.read(8))
if rtl != h.tlen:
warn("%s redundant transaction length check failed at %s",
self._file.name, tend)
break
self._pos = tend + 8
return result
raise IndexError, index
class RecordIterator(Iterator, BaseStorage.TransactionRecord,
FileStorageFormatter):
"""Iterate over the transactions in a FileStorage file."""
def __init__(self, tid, status, user, desc, ext, pos, tend, file, tpos):
self.tid = tid
self.status = status
self.user = user
self.description = desc
self._extension = ext
self._pos = pos
self._tend = tend
self._file = file
self._tpos = tpos
def next(self, index=0):
pos = self._pos
while pos < self._tend:
# Read the data records for this transaction
h = self._read_data_header(pos)
dlen = h.recordlen()
if pos + dlen > self._tend or h.tloc != self._tpos:
warn("%s data record exceeds transaction record at %s",
file.name, pos)
break
self._pos = pos + dlen
prev_txn = None
if h.plen:
data = self._file.read(h.plen)
else:
if h.back == 0:
# If the backpointer is 0, then this transaction
# undoes the object creation. It either aborts
# the version that created the object or undid the
# transaction that created it. Return None
# instead of a pickle to indicate this.
data = None
else:
data, tid = self._loadBackTxn(h.oid, h.back, False)
# XXX looks like this only goes one link back, should
# it go to the original data like BDBFullStorage?
prev_txn = self.getTxnFromData(h.oid, h.back)
r = Record(h.oid, h.tid, h.version, data, prev_txn)
return r
raise IndexError, index
class Record(BaseStorage.DataRecord):
"""An abstract database record."""
def __init__(self, *args):
self.oid, self.tid, self.version, self.data, self.data_txn = args
class UndoSearch:
def __init__(self, file, pos, packt, first, last, filter=None):
self.file = file
self.pos = pos
self.packt = packt
self.first = first
self.last = last
self.filter = filter
self.i = 0
self.results = []
self.stop = 0
def finished(self):
"""Return True if UndoSearch has found enough records."""
# BAW: Why 39 please? This makes no sense (see also below).
return self.i >= self.last or self.pos < 39 or self.stop
def search(self):
"""Search for another record."""
dict = self._readnext()
if dict is not None and (self.filter is None or self.filter(dict)):
if self.i >= self.first:
self.results.append(dict)
self.i += 1
def _readnext(self):
"""Read the next record from the storage."""
self.file.seek(self.pos - 8)
self.pos -= u64(self.file.read(8)) + 8
self.file.seek(self.pos)
h = self.file.read(TRANS_HDR_LEN)
tid, tl, status, ul, dl, el = struct.unpack(TRANS_HDR, h)
if tid < self.packt or status == 'p':
self.stop = 1
return None
if status != ' ':
return None
d = u = ''
if ul:
u = self.file.read(ul)
if dl:
d = self.file.read(dl)
e = {}
if el:
try:
e = loads(self.file.read(el))
except:
pass
d = {'id': base64.encodestring(tid).rstrip(),
'time': TimeStamp(tid).timeTime(),
'user_name': u,
'description': d}
d.update(e)
return d
# this is a package
from ZODB.FileStorage.FileStorage \
import FileStorage, RecordIterator, FileIterator, packed_version
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
#
# File-based ZODB storage
#
# Files are arranged as follows.
#
# - The first 4 bytes are a file identifier.
#
# - The rest of the file consists of a sequence of transaction
# "records".
#
# A transaction record consists of:
#
# - 8-byte transaction id, which is also a time stamp.
#
# - 8-byte transaction record length - 8.
#
# - 1-byte status code
#
# - 2-byte length of user name
#
# - 2-byte length of description
#
# - 2-byte length of extension attributes
#
# - user name
#
# - description
#
# - extension attributes
#
# * A sequence of data records
#
# - 8-byte redundant transaction length -8
#
# A data record consists of
#
# - 8-byte oid.
#
# - 8-byte tid, which matches the transaction id in the transaction record.
#
# - 8-byte previous-record file-position.
#
# - 8-byte beginning of transaction record file position.
#
# - 2-byte version length
#
# - 8-byte data length
#
# ? 8-byte position of non-version data
# (if version length > 0)
#
# ? 8-byte position of previous record in this version
# (if version length > 0)
#
# ? version string
# (if version length > 0)
#
# ? data
# (data length > 0)
#
# ? 8-byte position of data record containing data
# (data length == 0)
#
# Note that the lengths and positions are all big-endian.
# Also, the object ids time stamps are big-endian, so comparisons
# are meaningful.
#
# Version handling
#
# There isn't a separate store for versions. Each record has a
# version field, indicating what version it is in. The records in a
# version form a linked list. Each record that has a non-empty
# version string has a pointer to the previous record in the version.
# Version back pointers are retained *even* when versions are
# committed or aborted or when transactions are undone.
#
# There is a notion of "current" version records, which are the
# records in a version that are the current records for their
# respective objects. When a version is comitted, the current records
# are committed to the destination version. When a version is
# aborted, the current records are aborted.
#
# When committing or aborting, we search backward through the linked
# list until we find a record for an object that does not have a
# current record in the version. If we find a record for which the
# non-version pointer is the same as the previous pointer, then we
# forget that the corresponding object had a current record in the
# version. This strategy allows us to avoid searching backward through
# previously committed or aborted version records.
#
# Of course, we ignore records in undone transactions when committing
# or aborting.
#
# Backpointers
#
# When we commit or abort a version, we don't copy (or delete)
# and data. Instead, we write records with back pointers.
#
# A version record *never* has a back pointer to a non-version
# record, because we never abort to a version. A non-version record
# may have a back pointer to a version record or to a non-version
# record.
import struct
from ZODB.POSException import POSKeyError
from ZODB.referencesf import referencesf
from ZODB.utils import p64, u64, z64, oid_repr, t32
from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC
class CorruptedError(Exception):
pass
class CorruptedDataError(CorruptedError):
def __init__(self, oid=None, buf=None, pos=None):
self.oid = oid
self.buf = buf
self.pos = pos
def __str__(self):
if self.oid:
msg = "Error reading oid %s. Found %r" % (oid_repr(self.oid),
self.buf)
else:
msg = "Error reading unknown oid. Found %r" % self.buf
if self.pos:
msg += " at %d" % self.pos
return msg
# the struct formats for the headers
TRANS_HDR = ">8sQcHHH"
DATA_HDR = ">8s8sQQHQ"
# constants to support various header sizes
TRANS_HDR_LEN = 23
DATA_HDR_LEN = 42
DATA_VERSION_HDR_LEN = 58
assert struct.calcsize(TRANS_HDR) == TRANS_HDR_LEN
assert struct.calcsize(DATA_HDR) == DATA_HDR_LEN
class FileStorageFormatter(object):
"""Mixin class that can read and write the low-level format."""
# subclasses must provide _file
_metadata_size = 4L
_format_version = "21"
def _read_num(self, pos):
"""Read an 8-byte number."""
self._file.seek(pos)
return u64(self._file.read(8))
def _read_data_header(self, pos, oid=None):
"""Return a DataHeader object for data record at pos.
If ois is not None, raise CorruptedDataError if oid passed
does not match oid in file.
If there is version data, reads the version part of the header.
If there is no pickle data, reads the back pointer.
"""
self._file.seek(pos)
s = self._file.read(DATA_HDR_LEN)
if len(s) != DATA_HDR_LEN:
raise CorruptedDataError(oid, s, pos)
h = DataHeaderFromString(s)
if oid is not None and oid != h.oid:
raise CorruptedDataError(oid, s, pos)
if h.vlen:
s = self._file.read(16 + h.vlen)
h.parseVersion(s)
if not h.plen:
h.back = u64(self._file.read(8))
return h
def _write_version_header(self, file, pnv, vprev, version):
s = struct.pack(">8s8s", pnv, vprev)
file.write(s + version)
def _read_txn_header(self, pos, tid=None):
self._file.seek(pos)
s = self._file.read(TRANS_HDR_LEN)
if len(s) != TRANS_HDR_LEN:
raise CorruptedDataError(tid, s, pos)
h = TxnHeaderFromString(s)
if tid is not None and tid != h.tid:
raise CorruptedDataError(tid, s, pos)
h.user = self._file.read(h.ulen)
h.descr = self._file.read(h.dlen)
h.ext = self._file.read(h.elen)
return h
def _loadBack_impl(self, oid, back, fail=True):
# shared implementation used by various _loadBack methods
#
# If the backpointer ultimately resolves to 0:
# If fail is True, raise KeyError for zero backpointer.
# If fail is False, return the empty data from the record
# with no backpointer.
while 1:
if not back:
# If backpointer is 0, object does not currently exist.
raise POSKeyError(oid)
h = self._read_data_header(back)
if h.plen:
return self._file.read(h.plen), h.tid, back, h.tloc
if h.back == 0 and not fail:
return None, h.tid, back, h.tloc
back = h.back
def _loadBackTxn(self, oid, back, fail=True):
"""Return data and txn id for backpointer."""
return self._loadBack_impl(oid, back, fail)[:2]
def _loadBackPOS(self, oid, back):
return self._loadBack_impl(oid, back)[2]
def getTxnFromData(self, oid, back):
"""Return transaction id for data at back."""
h = self._read_data_header(back, oid)
return h.tid
def fail(self, pos, msg, *args):
s = ("%s:%s:" + msg) % ((self._name, pos) + args)
LOG("FS pack", ERROR, s)
raise CorruptedError(s)
def checkTxn(self, th, pos):
if th.tid <= self.ltid:
self.fail(pos, "time-stamp reduction: %s <= %s",
oid_repr(th.tid), oid_repr(self.ltid))
self.ltid = th.tid
if th.status == "c":
self.fail(pos, "transaction with checkpoint flag set")
if not th.status in " pu": # recognize " ", "p", and "u" as valid
self.fail(pos, "invalid transaction status: %r", th.status)
if th.tlen < th.headerlen():
self.fail(pos, "invalid transaction header: "
"txnlen (%d) < headerlen(%d)", th.tlen, th.headerlen())
def checkData(self, th, tpos, dh, pos):
if dh.tloc != tpos:
self.fail(pos, "data record does not point to transaction header"
": %d != %d", dh.tloc, tpos)
if pos + dh.recordlen() > tpos + th.tlen:
self.fail(pos, "data record size exceeds transaction size: "
"%d > %d", pos + dh.recordlen(), tpos + th.tlen)
if dh.prev >= pos:
self.fail(pos, "invalid previous pointer: %d", dh.prev)
if dh.back:
if dh.back >= pos:
self.fail(pos, "invalid back pointer: %d", dh.prev)
if dh.plen:
self.fail(pos, "data record has back pointer and data")
def DataHeaderFromString(s):
return DataHeader(*struct.unpack(DATA_HDR, s))
class DataHeader(object):
"""Header for a data record."""
__slots__ = (
"oid", "tid", "prev", "tloc", "vlen", "plen", "back",
# These three attributes are only defined when vlen > 0
"pnv", "vprev", "version")
def __init__(self, oid, tid, prev, tloc, vlen, plen):
self.back = 0 # default
self.version = "" # default
self.oid = oid
self.tid = tid
self.prev = prev
self.tloc = tloc
self.vlen = vlen
self.plen = plen
def asString(self):
s = struct.pack(DATA_HDR, self.oid, self.tid, self.prev,
self.tloc, self.vlen, self.plen)
if self.version:
v = struct.pack(">QQ", self.pnv, self.vprev)
return s + v + self.version
else:
return s
def setVersion(self, version, pnv, vprev):
self.version = version
self.vlen = len(version)
self.pnv = pnv
self.vprev = vprev
def parseVersion(self, buf):
pnv, vprev = struct.unpack(">QQ", buf[:16])
self.pnv = pnv
self.vprev = vprev
self.version = buf[16:]
def recordlen(self):
rlen = DATA_HDR_LEN + (self.plen or 8)
if self.version:
rlen += 16 + self.vlen
return rlen
def TxnHeaderFromString(s):
return TxnHeader(*struct.unpack(TRANS_HDR, s))
class TxnHeader(object):
"""Header for a transaction record."""
__slots__ = ("tid", "tlen", "status", "user", "descr", "ext",
"ulen", "dlen", "elen")
def __init__(self, tid, tlen, status, ulen, dlen, elen):
self.tid = tid
self.tlen = tlen
self.status = status
self.ulen = ulen
self.dlen = dlen
self.elen = elen
if elen < 0:
self.elen = t32 - elen
def asString(self):
s = struct.pack(TRANS_HDR, self.tid, self.tlen, self.status,
self.ulen, self.dlen, self.elen)
return "".join([s, self.user, self.descr, self.ext])
def headerlen(self):
return TRANS_HDR_LEN + self.ulen + self.dlen + self.elen
from ZODB.FileStorage import FileIterator
from ZODB.FileStorage.format \
import TRANS_HDR, TRANS_HDR_LEN, DATA_HDR, DATA_HDR_LEN
from ZODB.TimeStamp import TimeStamp
from ZODB.utils import u64
from ZODB.tests.StorageTestBase import zodb_unpickle
from cPickle import Unpickler
from cStringIO import StringIO
import md5
import struct
import types
def get_pickle_metadata(data):
# ZODB's data records contain two pickles. The first is the class
# of the object, the second is the object.
if data.startswith('(c'):
# Don't actually unpickle a class, because it will attempt to
# load the class. Just break open the pickle and get the
# module and class from it.
modname, classname, rest = data.split('\n', 2)
modname = modname[2:]
return modname, classname
f = StringIO(data)
u = Unpickler(f)
try:
class_info = u.load()
except Exception, err:
print "Error", err
return '', ''
if isinstance(class_info, types.TupleType):
if isinstance(class_info[0], types.TupleType):
modname, classname = class_info[0]
else:
modname, classname = class_info
else:
# XXX not sure what to do here
modname = repr(class_info)
classname = ''
return modname, classname
def fsdump(path, file=None, with_offset=1):
i = 0
iter = FileIterator(path)
for trans in iter:
if with_offset:
print >> file, "Trans #%05d tid=%016x time=%s offset=%d" % \
(i, u64(trans.tid), str(TimeStamp(trans.tid)), trans._pos)
else:
print >> file, "Trans #%05d tid=%016x time=%s" % \
(i, u64(trans.tid), str(TimeStamp(trans.tid)))
print >> file, "\tstatus=%s user=%s description=%s" % \
(`trans.status`, trans.user, trans.description)
j = 0
for rec in trans:
if rec.data is None:
fullclass = "undo or abort of object creation"
else:
modname, classname = get_pickle_metadata(rec.data)
dig = md5.new(rec.data).hexdigest()
fullclass = "%s.%s" % (modname, classname)
# special case for testing purposes
if fullclass == "ZODB.tests.MinPO.MinPO":
obj = zodb_unpickle(rec.data)
fullclass = "%s %s" % (fullclass, obj.value)
if rec.version:
version = "version=%s " % rec.version
else:
version = ''
if rec.data_txn:
# XXX It would be nice to print the transaction number
# (i) but it would be too expensive to keep track of.
bp = "bp=%016x" % u64(rec.data_txn)
else:
bp = ""
print >> file, " data #%05d oid=%016x %sclass=%s %s" % \
(j, u64(rec.oid), version, fullclass, bp)
j += 1
print >> file
i += 1
iter.close()
def fmt(p64):
# Return a nicely formatted string for a packaged 64-bit value
return "%016x" % u64(p64)
class Dumper:
"""A very verbose dumper for debuggin FileStorage problems."""
# XXX Should revise this class to use FileStorageFormatter.
def __init__(self, path, dest=None):
self.file = open(path, "rb")
self.dest = dest
def dump(self):
fid = self.file.read(4)
print >> self.dest, "*" * 60
print >> self.dest, "file identifier: %r" % fid
while self.dump_txn():
pass
def dump_txn(self):
pos = self.file.tell()
h = self.file.read(TRANS_HDR_LEN)
if not h:
return False
tid, tlen, status, ul, dl, el = struct.unpack(TRANS_HDR, h)
end = pos + tlen
print >> self.dest, "=" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "end pos: %d" % end
print >> self.dest, "transaction id: %s" % fmt(tid)
print >> self.dest, "trec len: %d" % tlen
print >> self.dest, "status: %r" % status
user = descr = extra = ""
if ul:
user = self.file.read(ul)
if dl:
descr = self.file.read(dl)
if el:
extra = self.file.read(el)
print >> self.dest, "user: %r" % user
print >> self.dest, "description: %r" % descr
print >> self.dest, "len(extra): %d" % el
while self.file.tell() < end:
self.dump_data(pos)
stlen = self.file.read(8)
print >> self.dest, "redundant trec len: %d" % u64(stlen)
return 1
def dump_data(self, tloc):
pos = self.file.tell()
h = self.file.read(DATA_HDR_LEN)
assert len(h) == DATA_HDR_LEN
oid, revid, prev, tloc, vlen, dlen = struct.unpack(DATA_HDR, h)
print >> self.dest, "-" * 60
print >> self.dest, "offset: %d" % pos
print >> self.dest, "oid: %s" % fmt(oid)
print >> self.dest, "revid: %s" % fmt(revid)
print >> self.dest, "previous record offset: %d" % prev
print >> self.dest, "transaction offset: %d" % tloc
if vlen:
pnv = self.file.read(8)
sprevdata = self.file.read(8)
version = self.file.read(vlen)
print >> self.dest, "version: %r" % version
print >> self.dest, "non-version data offset: %d" % u64(pnv)
print >> self.dest, \
"previous version data offset: %d" % u64(sprevdata)
print >> self.dest, "len(data): %d" % dlen
self.file.read(dlen)
if not dlen:
sbp = self.file.read(8)
print >> self.dest, "backpointer: %d" % u64(sbp)
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""FileStorage helper to perform pack.
A storage contains an ordered set of object revisions. When a storage
is packed, object revisions that are not reachable as of the pack time
are deleted. The notion of reachability is complicated by
backpointers -- object revisions that point to earlier revisions of
the same object.
An object revisions is reachable at a certain time if it is reachable
from the revision of the root at that time or if it is reachable from
a backpointer after that time.
"""
# This module contains code backported from ZODB4 from the
# zodb.storage.file package. It's been edited heavily to work with
# ZODB3 code and storage layout.
import os
import struct
from types import StringType
from ZODB.referencesf import referencesf
from ZODB.utils import p64, u64, z64, oid_repr
from zLOG import LOG, BLATHER, WARNING, ERROR, PANIC
from ZODB.fsIndex import fsIndex
from ZODB.FileStorage.format \
import FileStorageFormatter, CorruptedDataError, DataHeader, \
TRANS_HDR_LEN
class DataCopier(FileStorageFormatter):
"""Mixin class for copying transactions into a storage.
The restore() and pack() methods share a need to copy data records
and update pointers to data in earlier transaction records. This
class provides the shared logic.
The mixin extends the FileStorageFormatter with a copy() method.
It also requires that the concrete class provides the following
attributes:
_file -- file with earlier destination data
_tfile -- destination file for copied data
_packt -- p64() representation of latest pack time
_pos -- file pos of destination transaction
_tindex -- maps oid to data record file pos
_tvindex -- maps version name to data record file pos
_tindex and _tvindex are updated by copy().
The copy() method does not do any locking.
"""
def _txn_find(self, tid, stop_at_pack):
# _pos always points just past the last transaction
pos = self._pos
while pos > 4:
self._file.seek(pos - 8)
pos = pos - u64(self._file.read(8)) - 8
self._file.seek(pos)
h = self._file.read(TRANS_HDR_LEN)
_tid = h[:8]
if _tid == tid:
return pos
if stop_at_pack:
if h[16] == 'p':
break
raise UndoError(None, "Invalid transaction id")
def _data_find(self, tpos, oid, data):
# Return backpointer to oid in data record for in transaction at tpos.
# It should contain a pickle identical to data. Returns 0 on failure.
# Must call with lock held.
h = self._read_txn_header(tpos)
tend = tpos + h.tlen
pos = self._file.tell()
while pos < tend:
h = self._read_data_header(pos)
if h.oid == oid:
# Make sure this looks like the right data record
if h.plen == 0:
# This is also a backpointer. Gotta trust it.
return pos
if h.plen != len(data):
# The expected data doesn't match what's in the
# backpointer. Something is wrong.
error("Mismatch between data and backpointer at %d", pos)
return 0
_data = self._file.read(h.plen)
if data != _data:
return 0
return pos
pos += h.recordlen()
return 0
def _restore_pnv(self, oid, prev, version, bp):
# Find a valid pnv (previous non-version) pointer for this version.
# If there is no previous record, there can't be a pnv.
if not prev:
return None
pnv = None
h = self._read_data_header(prev, oid)
# If the previous record is for a version, it must have
# a valid pnv.
if h.version:
return h.pnv
elif bp:
# XXX Not sure the following is always true:
# The previous record is not for this version, yet we
# have a backpointer to it. The current record must
# be an undo of an abort or commit, so the backpointer
# must be to a version record with a pnv.
h2 = self._read_data_header(bp, oid)
if h2.version:
return h2.pnv
else:
warn("restore could not find previous non-version data "
"at %d or %d", prev, bp)
return None
def _resolve_backpointer(self, prev_txn, oid, data):
prev_pos = 0
if prev_txn is not None:
prev_txn_pos = self._txn_find(prev_txn, 0)
if prev_txn_pos:
prev_pos = self._data_find(prev_txn_pos, oid, data)
return prev_pos
def copy(self, oid, serial, data, version, prev_txn,
txnpos, datapos):
prev_pos = self._resolve_backpointer(prev_txn, oid, data)
old = self._index.get(oid, 0)
# Calculate the pos the record will have in the storage.
here = datapos
# And update the temp file index
self._tindex[oid] = here
if prev_pos:
# If there is a valid prev_pos, don't write data.
data = None
if data is None:
dlen = 0
else:
dlen = len(data)
# Write the recovery data record
h = DataHeader(oid, serial, old, txnpos, len(version), dlen)
if version:
h.version = version
pnv = self._restore_pnv(oid, old, version, prev_pos)
if pnv is not None:
h.pnv = pnv
else:
h.pnv = old
# Link to the last record for this version
h.vprev = self._tvindex.get(version, 0)
if not h.vprev:
h.vprev = self._vindex.get(version, 0)
self._tvindex[version] = here
self._tfile.write(h.asString())
# Write the data or a backpointer
if data is None:
if prev_pos:
self._tfile.write(p64(prev_pos))
else:
# Write a zero backpointer, which indicates an
# un-creation transaction.
self._tfile.write(z64)
else:
self._tfile.write(data)
class GC(FileStorageFormatter):
def __init__(self, file, eof, packtime):
self._file = file
self._name = file.name
self.eof = eof
self.packtime = packtime
# packpos: position of first txn header after pack time
self.packpos = None
self.oid2curpos = fsIndex() # maps oid to current data record position
self.oid2verpos = fsIndex() # maps oid to current version data
# The set of reachable revisions of each object.
#
# This set as managed using two data structures. The first is
# an fsIndex mapping oids to one data record pos. Since only
# a few objects will have more than one revision, we use this
# efficient data structure to handle the common case. The
# second is a dictionary mapping objects to lists of
# positions; it is used to handle the same number of objects
# for which we must keep multiple revisions.
self.reachable = fsIndex()
self.reach_ex = {}
# keep ltid for consistency checks during initial scan
self.ltid = z64
def isReachable(self, oid, pos):
"""Return 1 if revision of `oid` at `pos` is reachable."""
rpos = self.reachable.get(oid)
if rpos is None:
return 0
if rpos == pos:
return 1
return pos in self.reach_ex.get(oid, [])
def findReachable(self):
self.buildPackIndex()
self.findReachableAtPacktime([z64])
self.findReachableFromFuture()
# These mappings are no longer needed and may consume a lot
# of space.
del self.oid2verpos
del self.oid2curpos
def buildPackIndex(self):
pos = 4L
while pos < self.eof:
th = self._read_txn_header(pos)
if th.tid > self.packtime:
break
self.checkTxn(th, pos)
tpos = pos
end = pos + th.tlen
pos += th.headerlen()
while pos < end:
dh = self._read_data_header(pos)
self.checkData(th, tpos, dh, pos)
if dh.version:
self.oid2verpos[dh.oid] = pos
else:
self.oid2curpos[dh.oid] = pos
pos += dh.recordlen()
tlen = self._read_num(pos)
if tlen != th.tlen:
self.fail(pos, "redundant transaction length does not "
"match initial transaction length: %d != %d",
u64(s), th.tlen)
pos += 8
self.packpos = pos
def findReachableAtPacktime(self, roots):
"""Mark all objects reachable from the oids in roots as reachable."""
todo = list(roots)
while todo:
oid = todo.pop()
if self.reachable.has_key(oid):
continue
L = []
pos = self.oid2curpos.get(oid)
if pos is not None:
L.append(pos)
todo.extend(self.findrefs(pos))
pos = self.oid2verpos.get(oid)
if pos is not None:
L.append(pos)
todo.extend(self.findrefs(pos))
if not L:
continue
pos = L.pop()
self.reachable[oid] = pos
if L:
self.reach_ex[oid] = L
def findReachableFromFuture(self):
# In this pass, the roots are positions of object revisions.
# We add a pos to extra_roots when there is a backpointer to a
# revision that was not current at the packtime. The
# non-current revision could refer to objects that were
# otherwise unreachable at the packtime.
extra_roots = []
pos = self.packpos
while pos < self.eof:
th = self._read_txn_header(pos)
self.checkTxn(th, pos)
tpos = pos
end = pos + th.tlen
pos += th.headerlen()
while pos < end:
dh = self._read_data_header(pos)
self.checkData(th, tpos, dh, pos)
if dh.back and dh.back < self.packpos:
if self.reachable.has_key(dh.oid):
L = self.reach_ex.setdefault(dh.oid, [])
if dh.back not in L:
L.append(dh.back)
extra_roots.append(dh.back)
else:
self.reachable[dh.oid] = dh.back
if dh.version and dh.pnv:
if self.reachable.has_key(dh.oid):
L = self.reach_ex.setdefault(dh.oid, [])
if dh.pnv not in L:
L.append(dh.pnv)
extra_roots.append(dh.pnv)
else:
self.reachable[dh.oid] = dh.back
pos += dh.recordlen()
tlen = self._read_num(pos)
if tlen != th.tlen:
self.fail(pos, "redundant transaction length does not "
"match initial transaction length: %d != %d",
u64(s), th.tlen)
pos += 8
for pos in extra_roots:
refs = self.findrefs(pos)
self.findReachableAtPacktime(refs)
def findrefs(self, pos):
"""Return a list of oids referenced as of packtime."""
dh = self._read_data_header(pos)
# Chase backpointers until we get to the record with the refs
while dh.back:
dh = self._read_data_header(dh.back)
if dh.plen:
return referencesf(self._file.read(dh.plen))
else:
return []
class PackCopier(DataCopier):
# PackCopier has to cope with _file and _tfile being the
# same file. The copy() implementation is written assuming
# that they are different, so that using one object doesn't
# mess up the file pointer for the other object.
# PackCopier overrides _resolve_backpointer() and _restore_pnv()
# to guarantee that they keep the file pointer for _tfile in
# the right place.
def __init__(self, f, index, vindex, tindex, tvindex):
self._file = f
self._tfile = f
self._index = index
self._vindex = vindex
self._tindex = tindex
self._tvindex = tvindex
self._pos = None
def setTxnPos(self, pos):
self._pos = pos
def _resolve_backpointer(self, prev_txn, oid, data):
pos = self._tfile.tell()
try:
return DataCopier._resolve_backpointer(self, prev_txn, oid, data)
finally:
self._tfile.seek(pos)
def _restore_pnv(self, oid, prev, version, bp):
pos = self._tfile.tell()
try:
return DataCopier._restore_pnv(self, oid, prev, version, bp)
finally:
self._tfile.seek(pos)
class FileStoragePacker(FileStorageFormatter):
def __init__(self, path, stop, la, lr, cla, clr):
self._name = path
self._file = open(path, "rb")
self._stop = stop
self._packt = None
self.locked = 0
self._file.seek(0, 2)
self.file_end = self._file.tell()
self._file.seek(0)
self.gc = GC(self._file, self.file_end, self._stop)
# The packer needs to acquire the parent's commit lock
# during the copying stage, so the two sets of lock acquire
# and release methods are passed to the constructor.
self._lock_acquire = la
self._lock_release = lr
self._commit_lock_acquire = cla
self._commit_lock_release = clr
# The packer will use several indexes.
# index: oid -> pos
# vindex: version -> pos of XXX
# tindex: oid -> pos, for current txn
# tvindex: version -> pos of XXX, for current txn
# oid2tid: not used by the packer
self.index = fsIndex()
self.vindex = {}
self.tindex = {}
self.tvindex = {}
self.oid2tid = {}
self.toid2tid = {}
self.toid2tid_delete = {}
# Index for non-version data. This is a temporary structure
# to reduce I/O during packing
self.nvindex = fsIndex()
def pack(self):
# Pack copies all data reachable at the pack time or later.
#
# Copying occurs in two phases. In the first phase, txns
# before the pack time are copied if the contain any reachable
# data. In the second phase, all txns after the pack time
# are copied.
#
# Txn and data records contain pointers to previous records.
# Because these pointers are stored as file offsets, they
# must be updated when we copy data.
# XXX Need to add sanity checking to pack
self.gc.findReachable()
# Setup the destination file and copy the metadata.
# XXX rename from _tfile to something clearer
self._tfile = open(self._name + ".pack", "w+b")
self._file.seek(0)
self._tfile.write(self._file.read(self._metadata_size))
self._copier = PackCopier(self._tfile, self.index, self.vindex,
self.tindex, self.tvindex)
ipos, opos = self.copyToPacktime()
assert ipos == self.gc.packpos
if ipos == opos:
# pack didn't free any data. there's no point in continuing.
self._tfile.close()
os.remove(self._name + ".pack")
return None
self._commit_lock_acquire()
self.locked = 1
self._lock_acquire()
try:
self._file.seek(0, 2)
self.file_end = self._file.tell()
finally:
self._lock_release()
if ipos < self.file_end:
self.copyRest(ipos)
# OK, we've copied everything. Now we need to wrap things up.
pos = self._tfile.tell()
self._tfile.flush()
self._tfile.close()
self._file.close()
return pos
def copyToPacktime(self):
offset = 0L # the amount of space freed by packing
pos = self._metadata_size
new_pos = pos
while pos < self.gc.packpos:
th = self._read_txn_header(pos)
new_tpos, pos = self.copyDataRecords(pos, th)
if new_tpos:
new_pos = self._tfile.tell() + 8
tlen = new_pos - new_tpos - 8
# Update the transaction length
self._tfile.seek(new_tpos + 8)
self._tfile.write(p64(tlen))
self._tfile.seek(new_pos - 8)
self._tfile.write(p64(tlen))
tlen = self._read_num(pos)
if tlen != th.tlen:
self.fail(pos, "redundant transaction length does not "
"match initial transaction length: %d != %d",
u64(s), th.tlen)
pos += 8
return pos, new_pos
def fetchBackpointer(self, oid, back):
"""Return data and refs backpointer `back` to object `oid.
If `back` is 0 or ultimately resolves to 0, return None
and None. In this case, the transaction undoes the object
creation.
"""
if back == 0:
return None
data, tid = self._loadBackTxn(oid, back, 0)
return data
def copyDataRecords(self, pos, th):
"""Copy any current data records between pos and tend.
Returns position of txn header in output file and position
of next record in the input file.
If any data records are copied, also write txn header (th).
"""
copy = 0
new_tpos = 0L
tend = pos + th.tlen
pos += th.headerlen()
while pos < tend:
h = self._read_data_header(pos)
if not self.gc.isReachable(h.oid, pos):
pos += h.recordlen()
continue
pos += h.recordlen()
# If we are going to copy any data, we need to copy
# the transaction header. Note that we will need to
# patch up the transaction length when we are done.
if not copy:
th.status = "p"
s = th.asString()
new_tpos = self._tfile.tell()
self._tfile.write(s)
new_pos = new_tpos + len(s)
copy = 1
if h.plen:
data = self._file.read(h.plen)
else:
# If a current record has a backpointer, fetch
# refs and data from the backpointer. We need
# to write the data in the new record.
data = self.fetchBackpointer(h.oid, h.back)
self.writePackedDataRecord(h, data, new_tpos)
new_pos = self._tfile.tell()
return new_tpos, pos
def writePackedDataRecord(self, h, data, new_tpos):
# Update the header to reflect current information, then write
# it to the output file.
if data is None:
data = ""
h.prev = 0
h.back = 0
h.plen = len(data)
h.tloc = new_tpos
pos = self._tfile.tell()
if h.version:
h.pnv = self.index.get(h.oid, 0)
h.vprev = self.vindex.get(h.version, 0)
self.vindex[h.version] = pos
self.index[h.oid] = pos
if h.version:
self.vindex[h.version] = pos
self._tfile.write(h.asString())
self._tfile.write(data)
if not data:
# Packed records never have backpointers (?).
# If there is no data, write a z64 backpointer.
# This is a George Bailey event.
self._tfile.write(z64)
def copyRest(self, ipos):
# After the pack time, all data records are copied.
# Copy one txn at a time, using copy() for data.
# Release the commit lock every 20 copies
self._lock_counter = 0
try:
while 1:
ipos = self.copyOne(ipos)
except CorruptedDataError, err:
# The last call to copyOne() will raise
# CorruptedDataError, because it will attempt to read past
# the end of the file. Double-check that the exception
# occurred for this reason.
self._file.seek(0, 2)
endpos = self._file.tell()
if endpos != err.pos:
raise
def copyOne(self, ipos):
# The call below will raise CorruptedDataError at EOF.
th = self._read_txn_header(ipos)
self._lock_counter += 1
if self._lock_counter % 20 == 0:
self._commit_lock_release()
pos = self._tfile.tell()
self._copier.setTxnPos(pos)
self._tfile.write(th.asString())
tend = ipos + th.tlen
ipos += th.headerlen()
while ipos < tend:
h = self._read_data_header(ipos)
ipos += h.recordlen()
prev_txn = None
if h.plen:
data = self._file.read(h.plen)
else:
data = self.fetchBackpointer(h.oid, h.back)
if h.back:
prev_txn = self.getTxnFromData(h.oid, h.back)
self._copier.copy(h.oid, h.tid, data, h.version,
prev_txn, pos, self._tfile.tell())
tlen = self._tfile.tell() - pos
assert tlen == th.tlen
self._tfile.write(p64(tlen))
ipos += 8
self.index.update(self.tindex)
self.tindex.clear()
self.vindex.update(self.tvindex)
self.tvindex.clear()
if self._lock_counter % 20 == 0:
self._commit_lock_acquire()
return ipos
...@@ -21,7 +21,7 @@ It is meant to illustrate the simplest possible storage. ...@@ -21,7 +21,7 @@ It is meant to illustrate the simplest possible storage.
The Mapping storage uses a single data structure to map object ids to data. The Mapping storage uses a single data structure to map object ids to data.
""" """
__version__='$Revision: 1.10 $'[11:-2] __version__='$Revision: 1.11 $'[11:-2]
from ZODB import utils from ZODB import utils
from ZODB import BaseStorage from ZODB import BaseStorage
...@@ -58,6 +58,16 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -58,6 +58,16 @@ class MappingStorage(BaseStorage.BaseStorage):
finally: finally:
self._lock_release() self._lock_release()
def loadEx(self, oid, version):
self._lock_acquire()
try:
# Since this storage doesn't support versions, tid and
# serial will always be the same.
p = self._index[oid]
return p[8:], p[:8], "" # pickle, serial, tid
finally:
self._lock_release()
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError(self, transaction) raise POSException.StorageTransactionError(self, transaction)
...@@ -75,11 +85,10 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -75,11 +85,10 @@ class MappingStorage(BaseStorage.BaseStorage):
serials=(oserial, serial), serials=(oserial, serial),
data=data) data=data)
serial = self._serial self._tindex.append((oid, self._tid + data))
self._tindex.append((oid, serial+data))
finally: finally:
self._lock_release() self._lock_release()
return serial return self._tid
def _clear_temp(self): def _clear_temp(self):
self._tindex = [] self._tindex = []
...@@ -87,7 +96,7 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -87,7 +96,7 @@ class MappingStorage(BaseStorage.BaseStorage):
def _finish(self, tid, user, desc, ext): def _finish(self, tid, user, desc, ext):
for oid, p in self._tindex: for oid, p in self._tindex:
self._index[oid] = p self._index[oid] = p
self._ltid = self._serial self._ltid = self._tid
def lastTransaction(self): def lastTransaction(self):
return self._ltid return self._ltid
...@@ -95,6 +104,8 @@ class MappingStorage(BaseStorage.BaseStorage): ...@@ -95,6 +104,8 @@ class MappingStorage(BaseStorage.BaseStorage):
def pack(self, t, referencesf): def pack(self, t, referencesf):
self._lock_acquire() self._lock_acquire()
try: try:
if not self._index:
return
# Build an index of *only* those objects reachable from the root. # Build an index of *only* those objects reachable from the root.
rootl = ['\0\0\0\0\0\0\0\0'] rootl = ['\0\0\0\0\0\0\0\0']
pindex = {} pindex = {}
......
...@@ -85,7 +85,6 @@ class BasicStorage: ...@@ -85,7 +85,6 @@ class BasicStorage:
eq(value, MinPO(11)) eq(value, MinPO(11))
eq(revid, newrevid) eq(revid, newrevid)
## def checkNonVersionStore(self, oid=None, revid=None, version=None):
def checkNonVersionStore(self): def checkNonVersionStore(self):
revid = ZERO revid = ZERO
newrevid = self._dostore(revid=None) newrevid = self._dostore(revid=None)
......
...@@ -20,7 +20,7 @@ import tempfile ...@@ -20,7 +20,7 @@ import tempfile
import unittest import unittest
import ZODB, ZODB.FileStorage import ZODB, ZODB.FileStorage
from StorageTestBase import StorageTestBase, removefs from StorageTestBase import StorageTestBase
class FileStorageCorruptTests(StorageTestBase): class FileStorageCorruptTests(StorageTestBase):
...@@ -30,7 +30,7 @@ class FileStorageCorruptTests(StorageTestBase): ...@@ -30,7 +30,7 @@ class FileStorageCorruptTests(StorageTestBase):
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
removefs(self.path) self._storage.cleanup()
def _do_stores(self): def _do_stores(self):
oids = [] oids = []
......
...@@ -36,40 +36,40 @@ class HistoryStorage: ...@@ -36,40 +36,40 @@ class HistoryStorage:
h = self._storage.history(oid, size=1) h = self._storage.history(oid, size=1)
eq(len(h), 1) eq(len(h), 1)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
# Try to get 2 historical revisions # Try to get 2 historical revisions
h = self._storage.history(oid, size=2) h = self._storage.history(oid, size=2)
eq(len(h), 2) eq(len(h), 2)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
# Try to get all 3 historical revisions # Try to get all 3 historical revisions
h = self._storage.history(oid, size=3) h = self._storage.history(oid, size=3)
eq(len(h), 3) eq(len(h), 3)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[2] d = h[2]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
# There should be no more than 3 revisions # There should be no more than 3 revisions
h = self._storage.history(oid, size=4) h = self._storage.history(oid, size=4)
eq(len(h), 3) eq(len(h), 3)
d = h[0] d = h[0]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[2] d = h[2]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkVersionHistory(self): def checkVersionHistory(self):
...@@ -94,22 +94,22 @@ class HistoryStorage: ...@@ -94,22 +94,22 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 6) eq(len(h), 6)
d = h[0] d = h[0]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[1] d = h[1]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[4] d = h[4]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkHistoryAfterVersionCommit(self): def checkHistoryAfterVersionCommit(self):
...@@ -151,25 +151,25 @@ class HistoryStorage: ...@@ -151,25 +151,25 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 7) eq(len(h), 7)
d = h[0] d = h[0]
eq(d['serial'], revid7) eq(d['tid'], revid7)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[4] d = h[4]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[6] d = h[6]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
def checkHistoryAfterVersionAbort(self): def checkHistoryAfterVersionAbort(self):
...@@ -211,23 +211,23 @@ class HistoryStorage: ...@@ -211,23 +211,23 @@ class HistoryStorage:
h = self._storage.history(oid, version, 100) h = self._storage.history(oid, version, 100)
eq(len(h), 7) eq(len(h), 7)
d = h[0] d = h[0]
eq(d['serial'], revid7) eq(d['tid'], revid7)
eq(d['version'], '') eq(d['version'], '')
d = h[1] d = h[1]
eq(d['serial'], revid6) eq(d['tid'], revid6)
eq(d['version'], version) eq(d['version'], version)
d = h[2] d = h[2]
eq(d['serial'], revid5) eq(d['tid'], revid5)
eq(d['version'], version) eq(d['version'], version)
d = h[3] d = h[3]
eq(d['serial'], revid4) eq(d['tid'], revid4)
eq(d['version'], version) eq(d['version'], version)
d = h[4] d = h[4]
eq(d['serial'], revid3) eq(d['tid'], revid3)
eq(d['version'], '') eq(d['version'], '')
d = h[5] d = h[5]
eq(d['serial'], revid2) eq(d['tid'], revid2)
eq(d['version'], '') eq(d['version'], '')
d = h[6] d = h[6]
eq(d['serial'], revid1) eq(d['tid'], revid1)
eq(d['version'], '') eq(d['version'], '')
...@@ -33,7 +33,7 @@ class IteratorCompare: ...@@ -33,7 +33,7 @@ class IteratorCompare:
eq(reciter.tid, revid) eq(reciter.tid, revid)
for rec in reciter: for rec in reciter:
eq(rec.oid, oid) eq(rec.oid, oid)
eq(rec.serial, revid) eq(rec.tid, revid)
eq(rec.version, '') eq(rec.version, '')
eq(zodb_unpickle(rec.data), MinPO(val)) eq(zodb_unpickle(rec.data), MinPO(val))
val = val + 1 val = val + 1
...@@ -147,6 +147,20 @@ class IteratorStorage(IteratorCompare): ...@@ -147,6 +147,20 @@ class IteratorStorage(IteratorCompare):
finally: finally:
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
def checkLoadEx(self):
oid = self._storage.new_oid()
self._dostore(oid, data=42)
data, tid, ver = self._storage.loadEx(oid, "")
self.assertEqual(zodb_unpickle(data), MinPO(42))
match = False
for txn in self._storage.iterator():
for rec in txn:
if rec.oid == oid and rec.tid == tid:
self.assertEqual(txn.tid, tid)
match = True
if not match:
self.fail("Could not find transaction with matching id")
class ExtendedIteratorStorage(IteratorCompare): class ExtendedIteratorStorage(IteratorCompare):
...@@ -202,7 +216,7 @@ class IteratorDeepCompare: ...@@ -202,7 +216,7 @@ class IteratorDeepCompare:
eq(txn1._extension, txn2._extension) eq(txn1._extension, txn2._extension)
for rec1, rec2 in zip(txn1, txn2): for rec1, rec2 in zip(txn1, txn2):
eq(rec1.oid, rec2.oid) eq(rec1.oid, rec2.oid)
eq(rec1.serial, rec2.serial) eq(rec1.tid, rec2.tid)
eq(rec1.version, rec2.version) eq(rec1.version, rec2.version)
eq(rec1.data, rec2.data) eq(rec1.data, rec2.data)
# Make sure there are no more records left in rec1 and rec2, # Make sure there are no more records left in rec1 and rec2,
......
...@@ -154,9 +154,12 @@ class StorageClientThread(TestThread): ...@@ -154,9 +154,12 @@ class StorageClientThread(TestThread):
class ExtStorageClientThread(StorageClientThread): class ExtStorageClientThread(StorageClientThread):
def runtest(self): def runtest(self):
# pick some other storage ops to execute # pick some other storage ops to execute, depending in part
ops = [getattr(self, meth) for meth in dir(ExtStorageClientThread) # on the features provided by the storage.
if meth.startswith('do_')] names = ["do_load", "do_modifiedInVersion"]
if self.storage.supportsUndo():
names += ["do_loadSerial", "do_undoLog", "do_iterator"]
ops = [getattr(self, meth) for meth in names]
assert ops, "Didn't find an storage ops in %s" % self.storage assert ops, "Didn't find an storage ops in %s" % self.storage
# do a store to guarantee there's at least one oid in self.oids # do a store to guarantee there's at least one oid in self.oids
self.dostore(0) self.dostore(0)
......
...@@ -121,9 +121,6 @@ class PackableStorageBase: ...@@ -121,9 +121,6 @@ class PackableStorageBase:
return u.load() return u.load()
return loads return loads
class PackableStorage(PackableStorageBase):
def _initroot(self): def _initroot(self):
try: try:
self._storage.load(ZERO, '') self._storage.load(ZERO, '')
...@@ -141,6 +138,8 @@ class PackableStorage(PackableStorageBase): ...@@ -141,6 +138,8 @@ class PackableStorage(PackableStorageBase):
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
class PackableStorage(PackableStorageBase):
def checkPackEmptyStorage(self): def checkPackEmptyStorage(self):
self._storage.pack(time.time(), referencesf) self._storage.pack(time.time(), referencesf)
...@@ -152,6 +151,63 @@ class PackableStorage(PackableStorageBase): ...@@ -152,6 +151,63 @@ class PackableStorage(PackableStorageBase):
self._initroot() self._initroot()
self._storage.pack(time.time() - 10000, referencesf) self._storage.pack(time.time() - 10000, referencesf)
def _PackWhileWriting(self, pack_now=0):
# A storage should allow some reading and writing during
# a pack. This test attempts to exercise locking code
# in the storage to test that it is safe. It generates
# a lot of revisions, so that pack takes a long time.
db = DB(self._storage)
conn = db.open()
root = conn.root()
for i in range(10):
root[i] = MinPO(i)
get_transaction().commit()
snooze()
packt = time.time()
choices = range(10)
for dummy in choices:
for i in choices:
root[i].value = MinPO(i)
get_transaction().commit()
threads = [ClientThread(db, choices) for i in range(4)]
for t in threads:
t.start()
if pack_now:
db.pack(time.time())
else:
db.pack(packt)
for t in threads:
t.join(30)
for t in threads:
t.join(1)
self.assert_(not t.isAlive())
# Iterate over the storage to make sure it's sane, but not every
# storage supports iterators.
if not hasattr(self._storage, "iterator"):
return
iter = self._storage.iterator()
for txn in iter:
for data in txn:
pass
iter.close()
def checkPackWhileWriting(self):
self._PackWhileWriting(pack_now=0)
def checkPackNowWhileWriting(self):
self._PackWhileWriting(pack_now=1)
class PackableUndoStorage(PackableStorageBase):
def checkPackAllRevisions(self): def checkPackAllRevisions(self):
self._initroot() self._initroot()
eq = self.assertEqual eq = self.assertEqual
...@@ -381,61 +437,6 @@ class PackableStorage(PackableStorageBase): ...@@ -381,61 +437,6 @@ class PackableStorage(PackableStorageBase):
eq(root['obj'].value, 7) eq(root['obj'].value, 7)
def _PackWhileWriting(self, pack_now=0):
# A storage should allow some reading and writing during
# a pack. This test attempts to exercise locking code
# in the storage to test that it is safe. It generates
# a lot of revisions, so that pack takes a long time.
db = DB(self._storage)
conn = db.open()
root = conn.root()
for i in range(10):
root[i] = MinPO(i)
get_transaction().commit()
snooze()
packt = time.time()
choices = range(10)
for dummy in choices:
for i in choices:
root[i].value = MinPO(i)
get_transaction().commit()
threads = [ClientThread(db, choices) for i in range(4)]
for t in threads:
t.start()
if pack_now:
db.pack(time.time())
else:
db.pack(packt)
for t in threads:
t.join(30)
for t in threads:
t.join(1)
self.assert_(not t.isAlive())
# Iterate over the storage to make sure it's sane, but not every
# storage supports iterators.
if not hasattr(self._storage, "iterator"):
return
iter = self._storage.iterator()
for txn in iter:
for data in txn:
pass
iter.close()
def checkPackWhileWriting(self):
self._PackWhileWriting(pack_now=0)
def checkPackNowWhileWriting(self):
self._PackWhileWriting(pack_now=1)
def checkPackUndoLog(self): def checkPackUndoLog(self):
self._initroot() self._initroot()
# Create a `persistent' object # Create a `persistent' object
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# FOR A PARTICULAR PURPOSE. # FOR A PARTICULAR PURPOSE.
# #
############################################################################## ##############################################################################
from ZODB.POSException import ReadOnlyError from ZODB.POSException import ReadOnlyError, Unsupported
from ZODB.Transaction import Transaction from ZODB.Transaction import Transaction
class ReadOnlyStorage: class ReadOnlyStorage:
...@@ -37,8 +37,12 @@ class ReadOnlyStorage: ...@@ -37,8 +37,12 @@ class ReadOnlyStorage:
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
self.assertEqual(revid, self.oids[oid]) self.assertEqual(revid, self.oids[oid])
self.assert_(not self._storage.modifiedInVersion(oid)) self.assert_(not self._storage.modifiedInVersion(oid))
_data = self._storage.loadSerial(oid, revid) # Storages without revisions may not have loadSerial().
self.assertEqual(data, _data) try:
_data = self._storage.loadSerial(oid, revid)
self.assertEqual(data, _data)
except Unsupported:
pass
def checkWriteMethods(self): def checkWriteMethods(self):
self._make_readonly() self._make_readonly()
......
...@@ -54,7 +54,7 @@ class RecoveryStorage(IteratorDeepCompare): ...@@ -54,7 +54,7 @@ class RecoveryStorage(IteratorDeepCompare):
# Now abort the version and the creation # Now abort the version and the creation
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.abortVersion('one', t) tid, oids = self._storage.abortVersion('one', t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
self.assertEqual(oids, [oid]) self.assertEqual(oids, [oid])
...@@ -80,9 +80,9 @@ class RecoveryStorage(IteratorDeepCompare): ...@@ -80,9 +80,9 @@ class RecoveryStorage(IteratorDeepCompare):
data=MinPO(92)) data=MinPO(92))
revid_c = self._dostore(oid, revid=revid_b, version=version, revid_c = self._dostore(oid, revid=revid_b, version=version,
data=MinPO(93)) data=MinPO(93))
self._undo(self._storage.undoInfo()[0]['id'], oid) self._undo(self._storage.undoInfo()[0]['id'], [oid])
self._commitVersion(version, '') self._commitVersion(version, '')
self._undo(self._storage.undoInfo()[0]['id'], oid) self._undo(self._storage.undoInfo()[0]['id'], [oid])
# now copy the records to a new storage # now copy the records to a new storage
self._dst.copyTransactionsFrom(self._storage) self._dst.copyTransactionsFrom(self._storage)
...@@ -95,7 +95,7 @@ class RecoveryStorage(IteratorDeepCompare): ...@@ -95,7 +95,7 @@ class RecoveryStorage(IteratorDeepCompare):
self._abortVersion(version) self._abortVersion(version)
self.assert_(self._storage.versionEmpty(version)) self.assert_(self._storage.versionEmpty(version))
self._undo(self._storage.undoInfo()[0]['id'], oid) self._undo(self._storage.undoInfo()[0]['id'], [oid])
self.assert_(not self._storage.versionEmpty(version)) self.assert_(not self._storage.versionEmpty(version))
# check the data is what we expect it to be # check the data is what we expect it to be
...@@ -109,7 +109,7 @@ class RecoveryStorage(IteratorDeepCompare): ...@@ -109,7 +109,7 @@ class RecoveryStorage(IteratorDeepCompare):
self._storage = self._dst self._storage = self._dst
self._abortVersion(version) self._abortVersion(version)
self.assert_(self._storage.versionEmpty(version)) self.assert_(self._storage.versionEmpty(version))
self._undo(self._storage.undoInfo()[0]['id'], oid) self._undo(self._storage.undoInfo()[0]['id'], [oid])
self.assert_(not self._storage.versionEmpty(version)) self.assert_(not self._storage.versionEmpty(version))
# check the data is what we expect it to be # check the data is what we expect it to be
...@@ -149,7 +149,7 @@ class RecoveryStorage(IteratorDeepCompare): ...@@ -149,7 +149,7 @@ class RecoveryStorage(IteratorDeepCompare):
final = list(it)[-1] final = list(it)[-1]
self._dst.tpc_begin(final, final.tid, final.status) self._dst.tpc_begin(final, final.tid, final.status)
for r in final: for r in final:
self._dst.restore(r.oid, r.serial, r.data, r.version, r.data_txn, self._dst.restore(r.oid, r.tid, r.data, r.version, r.data_txn,
final) final)
it.close() it.close()
self._dst.tpc_vote(final) self._dst.tpc_vote(final)
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
"""Check loadSerial() on storages that support historical revisions.""" """Check loadSerial() on storages that support historical revisions."""
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle, zodb_pickle from ZODB.tests.StorageTestBase import zodb_unpickle, zodb_pickle, snooze
from ZODB.utils import p64, u64
ZERO = '\0'*8 ZERO = '\0'*8
...@@ -31,3 +32,107 @@ class RevisionStorage: ...@@ -31,3 +32,107 @@ class RevisionStorage:
for revid, value in revisions.items(): for revid, value in revisions.items():
data = self._storage.loadSerial(oid, revid) data = self._storage.loadSerial(oid, revid)
self.assertEqual(zodb_unpickle(data), value) self.assertEqual(zodb_unpickle(data), value)
def checkLoadBefore(self):
# Store 10 revisions of one object and then make sure that we
# can get all the non-current revisions back.
oid = self._storage.new_oid()
revs = []
revid = None
for i in range(10):
# We need to ensure that successive timestamps are at least
# two apart, so that a timestamp exists that's unambiguously
# between successive timestamps. Each call to snooze()
# guarantees that the next timestamp will be at least one
# larger (and probably much more than that) than the previous
# one.
snooze()
snooze()
revid = self._dostore(oid, revid, data=MinPO(i))
revs.append(self._storage.loadEx(oid, ""))
prev = u64(revs[0][1])
for i in range(1, 10):
tid = revs[i][1]
cur = u64(tid)
middle = prev + (cur - prev) // 2
assert prev < middle < cur # else the snooze() trick failed
prev = cur
t = self._storage.loadBefore(oid, p64(middle))
self.assert_(t is not None)
data, start, end = t
self.assertEqual(revs[i-1][0], data)
self.assertEqual(tid, end)
def checkLoadBeforeEdges(self):
# Check the edges cases for a non-current load.
oid = self._storage.new_oid()
self.assertRaises(KeyError, self._storage.loadBefore,
oid, p64(0))
revid1 = self._dostore(oid, data=MinPO(1))
self.assertEqual(self._storage.loadBefore(oid, p64(0)), None)
self.assertEqual(self._storage.loadBefore(oid, revid1), None)
cur = p64(u64(revid1) + 1)
data, start, end = self._storage.loadBefore(oid, cur)
self.assertEqual(zodb_unpickle(data), MinPO(1))
self.assertEqual(start, revid1)
self.assertEqual(end, None)
revid2 = self._dostore(oid, revid=revid1, data=MinPO(2))
data, start, end = self._storage.loadBefore(oid, cur)
self.assertEqual(zodb_unpickle(data), MinPO(1))
self.assertEqual(start, revid1)
self.assertEqual(end, revid2)
def checkLoadBeforeOld(self):
# Look for a very old revision. With the BaseStorage implementation
# this should require multple history() calls.
oid = self._storage.new_oid()
revs = []
revid = None
for i in range(50):
revid = self._dostore(oid, revid, data=MinPO(i))
revs.append(revid)
data, start, end = self._storage.loadBefore(oid, revs[12])
self.assertEqual(zodb_unpickle(data), MinPO(11))
self.assertEqual(start, revs[11])
self.assertEqual(end, revs[12])
# XXX Is it okay to assume everyone testing against RevisionStorage
# implements undo?
def checkLoadBeforeUndo(self):
# Do several transactions then undo them.
oid = self._storage.new_oid()
revid = None
for i in range(5):
revid = self._dostore(oid, revid, data=MinPO(i))
revs = []
for i in range(4):
info = self._storage.undoInfo()
tid = info[0]["id"]
# Always undo the most recent txn, so the value will
# alternate between 3 and 4.
self._undo(tid, [oid], note="undo %d" % i)
revs.append(self._storage.loadEx(oid, ""))
prev_tid = None
for i, (data, tid, ver) in enumerate(revs):
t = self._storage.loadBefore(oid, p64(u64(tid) + 1))
self.assertEqual(data, t[0])
self.assertEqual(tid, t[1])
if prev_tid:
self.assert_(prev_tid < t[1])
prev_tid = t[1]
if i < 3:
self.assertEqual(revs[i+1][1], t[2])
else:
self.assertEqual(None, t[2])
# XXX There are other edge cases to handle, including pack.
...@@ -19,9 +19,6 @@ method _dostore() which performs a complete store transaction for a ...@@ -19,9 +19,6 @@ method _dostore() which performs a complete store transaction for a
single object revision. single object revision.
""" """
import errno
import os
import string
import sys import sys
import time import time
import types import types
...@@ -94,8 +91,7 @@ def zodb_unpickle(data): ...@@ -94,8 +91,7 @@ def zodb_unpickle(data):
try: try:
klass = ns[klassname] klass = ns[klassname]
except KeyError: except KeyError:
sys.stderr.write("can't find %s in %s" % (klassname, print >> sys.stderr, "can't find %s in %r" % (klassname, ns)
repr(ns)))
inst = klass() inst = klass()
else: else:
raise ValueError, "expected class info: %s" % repr(klass_info) raise ValueError, "expected class info: %s" % repr(klass_info)
...@@ -140,16 +136,6 @@ def import_helper(name): ...@@ -140,16 +136,6 @@ def import_helper(name):
__import__(name) __import__(name)
return sys.modules[name] return sys.modules[name]
def removefs(base):
"""Remove all files created by FileStorage with path base."""
for ext in '', '.old', '.tmp', '.lock', '.index', '.pack':
path = base + ext
try:
os.remove(path)
except os.error, err:
if err[0] != errno.ENOENT:
raise
class StorageTestBase(unittest.TestCase): class StorageTestBase(unittest.TestCase):
...@@ -217,25 +203,26 @@ class StorageTestBase(unittest.TestCase): ...@@ -217,25 +203,26 @@ class StorageTestBase(unittest.TestCase):
# The following methods depend on optional storage features. # The following methods depend on optional storage features.
def _undo(self, tid, oid=None): def _undo(self, tid, expected_oids=None, note=None):
# Undo a tid that affects a single object (oid). # Undo a tid that affects a single object (oid).
# XXX This is very specialized # XXX This is very specialized
t = Transaction() t = Transaction()
t.note("undo") t.note(note or "undo")
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
if oid is not None: if expected_oids is not None:
self.assertEqual(len(oids), 1) self.assertEqual(len(oids), len(expected_oids), repr(oids))
self.assertEqual(oids[0], oid) for oid in expected_oids:
self.assert_(oid in oids)
return self._storage.lastTransaction() return self._storage.lastTransaction()
def _commitVersion(self, src, dst): def _commitVersion(self, src, dst):
t = Transaction() t = Transaction()
t.note("commit %r to %r" % (src, dst)) t.note("commit %r to %r" % (src, dst))
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.commitVersion(src, dst, t) tid, oids = self._storage.commitVersion(src, dst, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
return oids return oids
...@@ -244,7 +231,7 @@ class StorageTestBase(unittest.TestCase): ...@@ -244,7 +231,7 @@ class StorageTestBase(unittest.TestCase):
t = Transaction() t = Transaction()
t.note("abort %r" % ver) t.note("abort %r" % ver)
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.abortVersion(ver, t) tid, oids = self._storage.abortVersion(ver, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
return oids return oids
...@@ -115,36 +115,27 @@ class TransactionalUndoStorage: ...@@ -115,36 +115,27 @@ class TransactionalUndoStorage:
revid = self._dostore(oid, revid=revid, data=MinPO(25)) revid = self._dostore(oid, revid=revid, data=MinPO(25))
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id']
# Now start an undo transaction # Now start an undo transaction
oids = self.undo(tid, "undo1") self._undo(info[0]["id"], [oid], note="undo1")
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(24)) eq(zodb_unpickle(data), MinPO(24))
# Do another one # Do another one
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[2]['id'] self._undo(info[2]["id"], [oid], note="undo2")
oids = self.undo(tid, "undo2")
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(23)) eq(zodb_unpickle(data), MinPO(23))
# Try to undo the first record # Try to undo the first record
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[4]['id'] self._undo(info[4]["id"], [oid], note="undo3")
oids = self.undo(tid, "undo3")
eq(len(oids), 1)
eq(oids[0], oid)
# This should fail since we've undone the object's creation # This should fail since we've undone the object's creation
self.assertRaises(KeyError, self.assertRaises(KeyError,
self._storage.load, oid, '') self._storage.load, oid, '')
# And now let's try to redo the object's creation # And now let's try to redo the object's creation
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]["id"], [oid])
oids = self.undo(tid, "undo4")
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(23)) eq(zodb_unpickle(data), MinPO(23))
self._iterate() self._iterate()
...@@ -173,27 +164,14 @@ class TransactionalUndoStorage: ...@@ -173,27 +164,14 @@ class TransactionalUndoStorage:
revid = self._dostore(oid, revid=revid, data=MinPO(12)) revid = self._dostore(oid, revid=revid, data=MinPO(12))
# Undo the last transaction # Undo the last transaction
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]['id'], [oid])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(11)) eq(zodb_unpickle(data), MinPO(11))
# Now from here, we can either redo the last undo, or undo the object # Now from here, we can either redo the last undo, or undo the object
# creation. Let's undo the object creation. # creation. Let's undo the object creation.
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[2]['id'] self._undo(info[2]['id'], [oid])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
self.assertRaises(KeyError, self._storage.load, oid, '') self.assertRaises(KeyError, self._storage.load, oid, '')
self._iterate() self._iterate()
...@@ -204,27 +182,13 @@ class TransactionalUndoStorage: ...@@ -204,27 +182,13 @@ class TransactionalUndoStorage:
revid = self._dostore(oid, revid=revid, data=MinPO(12)) revid = self._dostore(oid, revid=revid, data=MinPO(12))
# Undo the last transaction # Undo the last transaction
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]['id'], [oid])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(11)) eq(zodb_unpickle(data), MinPO(11))
# Now from here, we can either redo the last undo, or undo the object # Now from here, we can either redo the last undo, or undo the object
# creation. Let's redo the last undo # creation. Let's redo the last undo
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]['id'], [oid])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1)
eq(oids[0], oid)
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(12)) eq(zodb_unpickle(data), MinPO(12))
self._iterate() self._iterate()
...@@ -266,17 +230,10 @@ class TransactionalUndoStorage: ...@@ -266,17 +230,10 @@ class TransactionalUndoStorage:
eq(zodb_unpickle(data), MinPO(32)) eq(zodb_unpickle(data), MinPO(32))
data, revid2 = self._storage.load(oid2, '') data, revid2 = self._storage.load(oid2, '')
eq(zodb_unpickle(data), MinPO(52)) eq(zodb_unpickle(data), MinPO(52))
# Now attempt to undo the transaction containing two objects # Now attempt to undo the transaction containing two objects
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]['id'], [oid1, oid2])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 2)
self.failUnless(oid1 in oids)
self.failUnless(oid2 in oids)
data, revid1 = self._storage.load(oid1, '') data, revid1 = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(31)) eq(zodb_unpickle(data), MinPO(31))
data, revid2 = self._storage.load(oid2, '') data, revid2 = self._storage.load(oid2, '')
...@@ -322,13 +279,11 @@ class TransactionalUndoStorage: ...@@ -322,13 +279,11 @@ class TransactionalUndoStorage:
tid1 = info[1]['id'] tid1 = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
oids1 = self._storage.transactionalUndo(tid1, t) tid, oids1 = self._storage.transactionalUndo(tid1, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
# We get the finalization stuff called an extra time: # We get the finalization stuff called an extra time:
## self._storage.tpc_vote(t)
## self._storage.tpc_finish(t)
eq(len(oids), 2) eq(len(oids), 2)
eq(len(oids1), 2) eq(len(oids1), 2)
unless(oid1 in oids) unless(oid1 in oids)
...@@ -337,17 +292,10 @@ class TransactionalUndoStorage: ...@@ -337,17 +292,10 @@ class TransactionalUndoStorage:
eq(zodb_unpickle(data), MinPO(30)) eq(zodb_unpickle(data), MinPO(30))
data, revid2 = self._storage.load(oid2, '') data, revid2 = self._storage.load(oid2, '')
eq(zodb_unpickle(data), MinPO(50)) eq(zodb_unpickle(data), MinPO(50))
# Now try to undo the one we just did to undo, whew # Now try to undo the one we just did to undo, whew
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]['id'], [oid1, oid2])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 2)
unless(oid1 in oids)
unless(oid2 in oids)
data, revid1 = self._storage.load(oid1, '') data, revid1 = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(32)) eq(zodb_unpickle(data), MinPO(32))
data, revid2 = self._storage.load(oid2, '') data, revid2 = self._storage.load(oid2, '')
...@@ -379,15 +327,7 @@ class TransactionalUndoStorage: ...@@ -379,15 +327,7 @@ class TransactionalUndoStorage:
eq(revid1, revid2) eq(revid1, revid2)
# Now attempt to undo the transaction containing two objects # Now attempt to undo the transaction containing two objects
info = self._storage.undoInfo() info = self._storage.undoInfo()
tid = info[0]['id'] self._undo(info[0]["id"], [oid1, oid2])
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 2)
self.failUnless(oid1 in oids)
self.failUnless(oid2 in oids)
data, revid1 = self._storage.load(oid1, '') data, revid1 = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(31)) eq(zodb_unpickle(data), MinPO(31))
data, revid2 = self._storage.load(oid2, '') data, revid2 = self._storage.load(oid2, '')
...@@ -413,7 +353,7 @@ class TransactionalUndoStorage: ...@@ -413,7 +353,7 @@ class TransactionalUndoStorage:
tid = info[1]['id'] tid = info[1]['id']
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
...@@ -506,7 +446,7 @@ class TransactionalUndoStorage: ...@@ -506,7 +446,7 @@ class TransactionalUndoStorage:
# And now attempt to undo the last transaction # And now attempt to undo the last transaction
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t) tid, oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
...@@ -736,7 +676,7 @@ class TransactionalUndoStorage: ...@@ -736,7 +676,7 @@ class TransactionalUndoStorage:
tid = p64(i + 1) tid = p64(i + 1)
eq(txn.tid, tid) eq(txn.tid, tid)
L1 = [(rec.oid, rec.serial, rec.data_txn) for rec in txn] L1 = [(rec.oid, rec.tid, rec.data_txn) for rec in txn]
L2 = [(oid, revid, None) for _tid, oid, revid in orig L2 = [(oid, revid, None) for _tid, oid, revid in orig
if _tid == tid] if _tid == tid]
......
...@@ -40,15 +40,6 @@ class TransactionalUndoVersionStorage: ...@@ -40,15 +40,6 @@ class TransactionalUndoVersionStorage:
pass # not expected pass # not expected
return self._dostore(*args, **kwargs) return self._dostore(*args, **kwargs)
def _undo(self, tid, oid):
t = Transaction()
self._storage.tpc_begin(t)
oids = self._storage.transactionalUndo(tid, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
self.assertEqual(len(oids), 1)
self.assertEqual(oids[0], oid)
def checkUndoInVersion(self): def checkUndoInVersion(self):
eq = self.assertEqual eq = self.assertEqual
unless = self.failUnless unless = self.failUnless
...@@ -68,21 +59,17 @@ class TransactionalUndoVersionStorage: ...@@ -68,21 +59,17 @@ class TransactionalUndoVersionStorage:
version=version) version=version)
info = self._storage.undoInfo() info = self._storage.undoInfo()
self._undo(info[0]['id'], oid) self._undo(info[0]['id'], [oid])
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(revid, revid_a) ## eq(revid, revid_a)
eq(zodb_unpickle(data), MinPO(91)) eq(zodb_unpickle(data), MinPO(91))
data, revid = self._storage.load(oid, version) data, revid = self._storage.load(oid, version)
unless(revid > revid_b and revid > revid_c) unless(revid > revid_b and revid > revid_c)
eq(zodb_unpickle(data), MinPO(92)) eq(zodb_unpickle(data), MinPO(92))
# Now commit the version... # Now commit the version...
t = Transaction() oids = self._commitVersion(version, "")
self._storage.tpc_begin(t)
oids = self._storage.commitVersion(version, '', t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
eq(oids[0], oid) eq(oids[0], oid)
...@@ -90,7 +77,7 @@ class TransactionalUndoVersionStorage: ...@@ -90,7 +77,7 @@ class TransactionalUndoVersionStorage:
# ...and undo the commit # ...and undo the commit
info = self._storage.undoInfo() info = self._storage.undoInfo()
self._undo(info[0]['id'], oid) self._undo(info[0]['id'], [oid])
check_objects(91, 92) check_objects(91, 92)
...@@ -102,7 +89,7 @@ class TransactionalUndoVersionStorage: ...@@ -102,7 +89,7 @@ class TransactionalUndoVersionStorage:
# Now undo the abort # Now undo the abort
info=self._storage.undoInfo() info=self._storage.undoInfo()
self._undo(info[0]['id'], oid) self._undo(info[0]['id'], [oid])
check_objects(91, 92) check_objects(91, 92)
...@@ -143,16 +130,24 @@ class TransactionalUndoVersionStorage: ...@@ -143,16 +130,24 @@ class TransactionalUndoVersionStorage:
self._storage.pack(pt, referencesf) self._storage.pack(pt, referencesf)
t = Transaction() self._undo(t_id, note="undo commit version")
t.description = 'undo commit version'
self._storage.tpc_begin(t)
self._storage.transactionalUndo(t_id, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
self.assertEqual(load_value(oid1), 0) self.assertEqual(load_value(oid1), 0)
self.assertEqual(load_value(oid1, version), 2) self.assertEqual(load_value(oid1, version), 2)
data, tid, ver = self._storage.loadEx(oid1, "")
# After undoing the version commit, the non-version data
# once again becomes the non-version data from 'create1'.
self.assertEqual(tid, self._storage.lastTransaction())
self.assertEqual(ver, "")
# The current version data comes from an undo record, which
# means that it gets data via the backpointer but tid from the
# current txn.
data, tid, ver = self._storage.loadEx(oid1, version)
self.assertEqual(ver, version)
self.assertEqual(tid, self._storage.lastTransaction())
def checkUndoAbortVersion(self): def checkUndoAbortVersion(self):
def load_value(oid, version=''): def load_value(oid, version=''):
data, revid = self._storage.load(oid, version) data, revid = self._storage.load(oid, version)
...@@ -175,12 +170,7 @@ class TransactionalUndoVersionStorage: ...@@ -175,12 +170,7 @@ class TransactionalUndoVersionStorage:
version=version, description='version2') version=version, description='version2')
self._x_dostore(description='create2') self._x_dostore(description='create2')
t = Transaction() self._abortVersion(version)
t.description = 'abort version'
self._storage.tpc_begin(t)
self._storage.abortVersion(version, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
info = self._storage.undoInfo() info = self._storage.undoInfo()
t_id = info[0]['id'] t_id = info[0]['id']
...@@ -189,12 +179,7 @@ class TransactionalUndoVersionStorage: ...@@ -189,12 +179,7 @@ class TransactionalUndoVersionStorage:
# after abort, we should see non-version data # after abort, we should see non-version data
self.assertEqual(load_value(oid1, version), 0) self.assertEqual(load_value(oid1, version), 0)
t = Transaction() self._undo(t_id, note="undo abort version")
t.description = 'undo abort version'
self._storage.tpc_begin(t)
self._storage.transactionalUndo(t_id, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
self.assertEqual(load_value(oid1), 0) self.assertEqual(load_value(oid1), 0)
# t undo will re-create the version # t undo will re-create the version
...@@ -205,12 +190,7 @@ class TransactionalUndoVersionStorage: ...@@ -205,12 +190,7 @@ class TransactionalUndoVersionStorage:
self._storage.pack(pt, referencesf) self._storage.pack(pt, referencesf)
t = Transaction() self._undo(t_id, note="undo undo")
t.description = 'undo undo'
self._storage.tpc_begin(t)
self._storage.transactionalUndo(t_id, t)
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
# undo of undo will put as back where we started # undo of undo will put as back where we started
self.assertEqual(load_value(oid1), 0) self.assertEqual(load_value(oid1), 0)
......
...@@ -16,10 +16,6 @@ ...@@ -16,10 +16,6 @@
Any storage that supports versions should be able to pass all these tests. Any storage that supports versions should be able to pass all these tests.
""" """
# XXX we should clean this code up to get rid of the #JF# comments.
# They were introduced when Jim reviewed the original version of the
# code. Barry and Jeremy didn't understand versions then.
import time import time
from ZODB import POSException from ZODB import POSException
...@@ -48,26 +44,33 @@ class VersionStorage: ...@@ -48,26 +44,33 @@ class VersionStorage:
revid1 = self._dostore(oid, data=MinPO(12)) revid1 = self._dostore(oid, data=MinPO(12))
revid2 = self._dostore(oid, revid=revid1, data=MinPO(13), revid2 = self._dostore(oid, revid=revid1, data=MinPO(13),
version="version") version="version")
data, tid, ver = self._storage.loadEx(oid, "version")
self.assertEqual(revid2, tid)
self.assertEqual(zodb_unpickle(data), MinPO(13))
oids = self._abortVersion("version") oids = self._abortVersion("version")
self.assertEqual([oid], oids) self.assertEqual([oid], oids)
data, revid3 = self._storage.load(oid, "") data, revid3 = self._storage.load(oid, "")
# use repr() to avoid getting binary data in a traceback on error # use repr() to avoid getting binary data in a traceback on error
self.assertEqual(`revid1`, `revid3`) self.assertNotEqual(revid1, revid3)
self.assertNotEqual(`revid2`, `revid3`) self.assertNotEqual(revid2, revid3)
data, tid, ver = self._storage.loadEx(oid, "")
self.assertEqual(revid3, tid)
self.assertEqual(zodb_unpickle(data), MinPO(12))
self.assertEqual(tid, self._storage.lastTransaction())
def checkVersionedStoreAndLoad(self): def checkVersionedStoreAndLoad(self):
eq = self.assertEqual eq = self.assertEqual
# Store a couple of non-version revisions of the object # Store a couple of non-version revisions of the object
oid = self._storage.new_oid() oid = self._storage.new_oid()
revid = self._dostore(oid, data=MinPO(11)) revid = self._dostore(oid, data=MinPO(11))
revid = self._dostore(oid, revid=revid, data=MinPO(12)) revid1 = self._dostore(oid, revid=revid, data=MinPO(12))
# And now store some new revisions in a version # And now store some new revisions in a version
version = 'test-version' version = 'test-version'
revid = self._dostore(oid, revid=revid, data=MinPO(13), revid = self._dostore(oid, revid=revid1, data=MinPO(13),
version=version) version=version)
revid = self._dostore(oid, revid=revid, data=MinPO(14), revid = self._dostore(oid, revid=revid, data=MinPO(14),
version=version) version=version)
revid = self._dostore(oid, revid=revid, data=MinPO(15), revid2 = self._dostore(oid, revid=revid, data=MinPO(15),
version=version) version=version)
# Now read back the object in both the non-version and version and # Now read back the object in both the non-version and version and
# make sure the values jive. # make sure the values jive.
...@@ -78,6 +81,20 @@ class VersionStorage: ...@@ -78,6 +81,20 @@ class VersionStorage:
if hasattr(self._storage, 'getSerial'): if hasattr(self._storage, 'getSerial'):
s = self._storage.getSerial(oid) s = self._storage.getSerial(oid)
eq(s, max(revid, vrevid)) eq(s, max(revid, vrevid))
data, tid, ver = self._storage.loadEx(oid, version)
eq(zodb_unpickle(data), MinPO(15))
eq(tid, revid2)
data, tid, ver = self._storage.loadEx(oid, "other version")
eq(zodb_unpickle(data), MinPO(12))
eq(tid, revid2)
# loadSerial returns non-version data
try:
data = self._storage.loadSerial(oid, revid)
eq(zodb_unpickle(data), MinPO(12))
data = self._storage.loadSerial(oid, revid2)
eq(zodb_unpickle(data), MinPO(12))
except POSException.Unsupported:
pass
def checkVersionedLoadErrors(self): def checkVersionedLoadErrors(self):
oid = self._storage.new_oid() oid = self._storage.new_oid()
...@@ -89,11 +106,6 @@ class VersionStorage: ...@@ -89,11 +106,6 @@ class VersionStorage:
self.assertRaises(KeyError, self.assertRaises(KeyError,
self._storage.load, self._storage.load,
self._storage.new_oid(), '') self._storage.new_oid(), '')
# Try to load a bogus version string
#JF# Nope, fall back to non-version
#JF# self.assertRaises(KeyError,
#JF# self._storage.load,
#JF# oid, 'bogus')
data, revid = self._storage.load(oid, 'bogus') data, revid = self._storage.load(oid, 'bogus')
self.assertEqual(zodb_unpickle(data), MinPO(11)) self.assertEqual(zodb_unpickle(data), MinPO(11))
...@@ -112,9 +124,6 @@ class VersionStorage: ...@@ -112,9 +124,6 @@ class VersionStorage:
def checkVersionEmpty(self): def checkVersionEmpty(self):
# Before we store anything, these versions ought to be empty # Before we store anything, these versions ought to be empty
version = 'test-version' version = 'test-version'
#JF# The empty string is not a valid version. I think that this should
#JF# be an error. Let's punt for now.
#JF# assert self._storage.versionEmpty('')
self.failUnless(self._storage.versionEmpty(version)) self.failUnless(self._storage.versionEmpty(version))
# Now store some objects # Now store some objects
oid = self._storage.new_oid() oid = self._storage.new_oid()
...@@ -125,10 +134,6 @@ class VersionStorage: ...@@ -125,10 +134,6 @@ class VersionStorage:
revid = self._dostore(oid, revid=revid, data=MinPO(14), revid = self._dostore(oid, revid=revid, data=MinPO(14),
version=version) version=version)
# The blank version should not be empty # The blank version should not be empty
#JF# The empty string is not a valid version. I think that this should
#JF# be an error. Let's punt for now.
#JF# assert not self._storage.versionEmpty('')
# Neither should 'test-version' # Neither should 'test-version'
self.failUnless(not self._storage.versionEmpty(version)) self.failUnless(not self._storage.versionEmpty(version))
# But this non-existant version should be empty # But this non-existant version should be empty
...@@ -190,6 +195,22 @@ class VersionStorage: ...@@ -190,6 +195,22 @@ class VersionStorage:
data, revid = self._storage.load(oid, '') data, revid = self._storage.load(oid, '')
eq(zodb_unpickle(data), MinPO(51)) eq(zodb_unpickle(data), MinPO(51))
def checkAbortVersionNonCurrent(self):
# Make sure the non-current serial number is correctly
# after a version is aborted.
oid, version = self._setup_version()
self._abortVersion(version)
data, tid, ver = self._storage.loadEx(oid, "")
# write a new revision of oid so that the aborted-version txn
# is not current
self._dostore(oid, revid=tid, data=MinPO(17))
ltid = self._storage.lastTransaction()
ncdata, ncstart, end = self._storage.loadBefore(oid, ltid)
self.assertEqual(data, ncdata)
self.assertEqual(tid, ncstart)
def checkAbortVersionErrors(self): def checkAbortVersionErrors(self):
eq = self.assertEqual eq = self.assertEqual
oid, version = self._setup_version() oid, version = self._setup_version()
...@@ -197,13 +218,6 @@ class VersionStorage: ...@@ -197,13 +218,6 @@ class VersionStorage:
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
#JF# The spec is silent on what happens if you abort or commit
#JF# a non-existent version. FileStorage consideres this a noop.
#JF# We can change the spec, but until we do ....
#JF# self.assertRaises(POSException.VersionError,
#JF# self._storage.abortVersion,
#JF# 'bogus', t)
# And try to abort the empty version # And try to abort the empty version
if (hasattr(self._storage, 'supportsTransactionalUndo') if (hasattr(self._storage, 'supportsTransactionalUndo')
and self._storage.supportsTransactionalUndo()): and self._storage.supportsTransactionalUndo()):
...@@ -213,7 +227,7 @@ class VersionStorage: ...@@ -213,7 +227,7 @@ class VersionStorage:
'', t) '', t)
# But now we really try to abort the version # But now we really try to abort the version
oids = self._storage.abortVersion(version, t) tid, oids = self._storage.abortVersion(version, t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
eq(len(oids), 1) eq(len(oids), 1)
...@@ -241,17 +255,17 @@ class VersionStorage: ...@@ -241,17 +255,17 @@ class VersionStorage:
def checkNewSerialOnCommitVersionToVersion(self): def checkNewSerialOnCommitVersionToVersion(self):
oid, version = self._setup_version() oid, version = self._setup_version()
data, vserial = self._storage.load(oid, version) data, vtid = self._storage.load(oid, version)
data, nserial = self._storage.load(oid, '') data, ntid = self._storage.load(oid, '')
version2 = 'test version 2' version2 = 'test version 2'
self._commitVersion(version, version2) self._commitVersion(version, version2)
data, serial = self._storage.load(oid, version2) data, tid = self._storage.load(oid, version2)
self.failUnless(serial != vserial and serial != nserial, self.failUnless(tid != vtid and tid != ntid,
"New serial, %r, should be different from the old " "New tid, %r, should be different from the old "
"version, %r, and non-version, %r, serials." "version, %r, and non-version, %r, tids."
% (serial, vserial, nserial)) % (tid, vtid, ntid))
def checkModifyAfterAbortVersion(self): def checkModifyAfterAbortVersion(self):
eq = self.assertEqual eq = self.assertEqual
...@@ -332,13 +346,8 @@ class VersionStorage: ...@@ -332,13 +346,8 @@ class VersionStorage:
data, revid = self._storage.load(oid1, '') data, revid = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(51)) eq(zodb_unpickle(data), MinPO(51))
#JF# Ditto
#JF# self.assertRaises(POSException.VersionError,
#JF# self._storage.load, oid1, version1)
data, revid = self._storage.load(oid1, '') data, revid = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(51)) eq(zodb_unpickle(data), MinPO(51))
#JF# self.assertRaises(POSException.VersionError,
#JF# self._storage.load, oid1, version2)
data, revid = self._storage.load(oid1, '') data, revid = self._storage.load(oid1, '')
eq(zodb_unpickle(data), MinPO(51)) eq(zodb_unpickle(data), MinPO(51))
...@@ -359,7 +368,6 @@ class VersionStorage: ...@@ -359,7 +368,6 @@ class VersionStorage:
data, revid = self._storage.load(oid2, version2) data, revid = self._storage.load(oid2, version2)
eq(zodb_unpickle(data), MinPO(54)) eq(zodb_unpickle(data), MinPO(54))
#JF# To do a test like you want, you have to add the data in a version
oid = self._storage.new_oid() oid = self._storage.new_oid()
revid = self._dostore(oid, revid=revid, data=MinPO(54), version='one') revid = self._dostore(oid, revid=revid, data=MinPO(54), version='one')
self.assertRaises(KeyError, self.assertRaises(KeyError,
...@@ -375,7 +383,7 @@ class VersionStorage: ...@@ -375,7 +383,7 @@ class VersionStorage:
# Now abort the version and the creation # Now abort the version and the creation
t = Transaction() t = Transaction()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
oids = self._storage.abortVersion('one', t) tid, oids = self._storage.abortVersion('one', t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
self._storage.tpc_finish(t) self._storage.tpc_finish(t)
self.assertEqual(oids, [oid]) self.assertEqual(oids, [oid])
......
...@@ -27,8 +27,13 @@ class ConfigTestBase(unittest.TestCase): ...@@ -27,8 +27,13 @@ class ConfigTestBase(unittest.TestCase):
def _opendb(self, s): def _opendb(self, s):
return ZODB.config.databaseFromString(s) return ZODB.config.databaseFromString(s)
def tearDown(self):
if getattr(self, "storage", None) is not None:
self.storage.cleanup()
def _test(self, s): def _test(self, s):
db = self._opendb(s) db = self._opendb(s)
self.storage = db._storage
# Do something with the database to make sure it works # Do something with the database to make sure it works
cn = db.open() cn = db.open()
rt = cn.root() rt = cn.root()
...@@ -56,7 +61,6 @@ class ZODBConfigTest(ConfigTestBase): ...@@ -56,7 +61,6 @@ class ZODBConfigTest(ConfigTestBase):
""") """)
def test_file_config1(self): def test_file_config1(self):
import ZODB.FileStorage
path = tempfile.mktemp() path = tempfile.mktemp()
self._test( self._test(
""" """
...@@ -66,10 +70,8 @@ class ZODBConfigTest(ConfigTestBase): ...@@ -66,10 +70,8 @@ class ZODBConfigTest(ConfigTestBase):
</filestorage> </filestorage>
</zodb> </zodb>
""" % path) """ % path)
ZODB.FileStorage.cleanup(path)
def test_file_config2(self): def test_file_config2(self):
import ZODB.FileStorage
path = tempfile.mktemp() path = tempfile.mktemp()
cfg = """ cfg = """
<zodb> <zodb>
...@@ -81,7 +83,6 @@ class ZODBConfigTest(ConfigTestBase): ...@@ -81,7 +83,6 @@ class ZODBConfigTest(ConfigTestBase):
</zodb> </zodb>
""" % path """ % path
self.assertRaises(ReadOnlyError, self._test, cfg) self.assertRaises(ReadOnlyError, self._test, cfg)
ZODB.FileStorage.cleanup(path)
def test_zeo_config(self): def test_zeo_config(self):
# We're looking for a port that doesn't exist so a connection attempt # We're looking for a port that doesn't exist so a connection attempt
...@@ -119,9 +120,6 @@ class BDBConfigTest(ConfigTestBase): ...@@ -119,9 +120,6 @@ class BDBConfigTest(ConfigTestBase):
if e.errno <> errno.EEXIST: if e.errno <> errno.EEXIST:
raise raise
def tearDown(self):
shutil.rmtree(self._path)
def test_bdbfull_simple(self): def test_bdbfull_simple(self):
cfg = """ cfg = """
<zodb> <zodb>
......
...@@ -35,6 +35,10 @@ class DemoStorageTests(StorageTestBase.StorageTestBase, ...@@ -35,6 +35,10 @@ class DemoStorageTests(StorageTestBase.StorageTestBase,
# have this limit, so we inhibit this test here. # have this limit, so we inhibit this test here.
pass pass
def checkAbortVersionNonCurrent(self):
# XXX Need to implement a real loadBefore for DemoStorage?
pass
def test_suite(): def test_suite():
suite = unittest.makeSuite(DemoStorageTests, 'check') suite = unittest.makeSuite(DemoStorageTests, 'check')
......
...@@ -26,16 +26,30 @@ from ZODB.tests import StorageTestBase, BasicStorage, \ ...@@ -26,16 +26,30 @@ from ZODB.tests import StorageTestBase, BasicStorage, \
Synchronization, ConflictResolution, HistoryStorage, \ Synchronization, ConflictResolution, HistoryStorage, \
IteratorStorage, Corruption, RevisionStorage, PersistentStorage, \ IteratorStorage, Corruption, RevisionStorage, PersistentStorage, \
MTStorage, ReadOnlyStorage, RecoveryStorage MTStorage, ReadOnlyStorage, RecoveryStorage
from ZODB.tests.StorageTestBase import MinPO, zodb_unpickle from ZODB.tests.StorageTestBase import MinPO, zodb_unpickle, zodb_pickle
class BaseFileStorageTests(StorageTestBase.StorageTestBase):
def open(self, **kwargs):
self._storage = ZODB.FileStorage.FileStorage('FileStorageTests.fs',
**kwargs)
def setUp(self):
self.open(create=1)
def tearDown(self):
self._storage.close()
self._storage.cleanup()
class FileStorageTests( class FileStorageTests(
StorageTestBase.StorageTestBase, BaseFileStorageTests,
BasicStorage.BasicStorage, BasicStorage.BasicStorage,
TransactionalUndoStorage.TransactionalUndoStorage, TransactionalUndoStorage.TransactionalUndoStorage,
RevisionStorage.RevisionStorage, RevisionStorage.RevisionStorage,
VersionStorage.VersionStorage, VersionStorage.VersionStorage,
TransactionalUndoVersionStorage.TransactionalUndoVersionStorage, TransactionalUndoVersionStorage.TransactionalUndoVersionStorage,
PackableStorage.PackableStorage, PackableStorage.PackableStorage,
PackableStorage.PackableUndoStorage,
Synchronization.SynchronizedStorage, Synchronization.SynchronizedStorage,
ConflictResolution.ConflictResolvingStorage, ConflictResolution.ConflictResolvingStorage,
ConflictResolution.ConflictResolvingTransUndoStorage, ConflictResolution.ConflictResolvingTransUndoStorage,
...@@ -47,17 +61,6 @@ class FileStorageTests( ...@@ -47,17 +61,6 @@ class FileStorageTests(
ReadOnlyStorage.ReadOnlyStorage ReadOnlyStorage.ReadOnlyStorage
): ):
def open(self, **kwargs):
self._storage = ZODB.FileStorage.FileStorage('FileStorageTests.fs',
**kwargs)
def setUp(self):
self.open(create=1)
def tearDown(self):
self._storage.close()
StorageTestBase.removefs("FileStorageTests.fs")
def checkLongMetadata(self): def checkLongMetadata(self):
s = "X" * 75000 s = "X" * 75000
try: try:
...@@ -175,28 +178,43 @@ class FileStorageRecoveryTest( ...@@ -175,28 +178,43 @@ class FileStorageRecoveryTest(
): ):
def setUp(self): def setUp(self):
StorageTestBase.removefs("Source.fs") self._storage = ZODB.FileStorage.FileStorage("Source.fs", create=True)
StorageTestBase.removefs("Dest.fs") self._dst = ZODB.FileStorage.FileStorage("Dest.fs", create=True)
self._storage = ZODB.FileStorage.FileStorage('Source.fs')
self._dst = ZODB.FileStorage.FileStorage('Dest.fs')
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
self._dst.close() self._dst.close()
StorageTestBase.removefs("Source.fs") self._storage.cleanup()
StorageTestBase.removefs("Dest.fs") self._dst.cleanup()
def new_dest(self): def new_dest(self):
StorageTestBase.removefs('Dest.fs')
return ZODB.FileStorage.FileStorage('Dest.fs') return ZODB.FileStorage.FileStorage('Dest.fs')
class SlowFileStorageTest(BaseFileStorageTests):
level = 2
def check10Kstores(self):
# The _get_cached_serial() method has a special case
# every 8000 calls. Make sure it gets minimal coverage.
oids = [[self._storage.new_oid(), None] for i in range(100)]
for i in range(100):
t = Transaction()
self._storage.tpc_begin(t)
for j in range(100):
o = MinPO(j)
oid, revid = oids[j]
serial = self._storage.store(oid, revid, zodb_pickle(o), "", t)
oids[j][1] = serial
self._storage.tpc_vote(t)
self._storage.tpc_finish(t)
def test_suite(): def test_suite():
suite = unittest.makeSuite(FileStorageTests, 'check') suite = unittest.TestSuite()
suite2 = unittest.makeSuite(Corruption.FileStorageCorruptTests, 'check') for klass in [FileStorageTests, Corruption.FileStorageCorruptTests,
suite3 = unittest.makeSuite(FileStorageRecoveryTest, 'check') FileStorageRecoveryTest, SlowFileStorageTest]:
suite.addTest(suite2) suite.addTest(unittest.makeSuite(klass, "check"))
suite.addTest(suite3)
return suite return suite
if __name__=='__main__': if __name__=='__main__':
......
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
import ZODB.MappingStorage import ZODB.MappingStorage
import os, unittest import os, unittest
from ZODB.tests import StorageTestBase, BasicStorage, Synchronization from ZODB.tests import StorageTestBase
from ZODB.tests \
import BasicStorage, MTStorage, Synchronization, PackableStorage
class MappingStorageTests(StorageTestBase.StorageTestBase, class MappingStorageTests(StorageTestBase.StorageTestBase,
BasicStorage.BasicStorage, BasicStorage.BasicStorage,
Synchronization.SynchronizedStorage, MTStorage.MTStorage,
): PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
):
def setUp(self): def setUp(self):
self._storage = ZODB.MappingStorage.MappingStorage() self._storage = ZODB.MappingStorage.MappingStorage()
......
...@@ -24,12 +24,9 @@ import StringIO ...@@ -24,12 +24,9 @@ import StringIO
import ZODB import ZODB
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
from ZODB.fsrecover import recover from ZODB.fsrecover import recover
from ZODB.tests.StorageTestBase import removefs
from persistent.mapping import PersistentMapping from persistent.mapping import PersistentMapping
from ZODB.fsdump import Dumper
class RecoverTest(unittest.TestCase): class RecoverTest(unittest.TestCase):
level = 2 level = 2
...@@ -47,8 +44,10 @@ class RecoverTest(unittest.TestCase): ...@@ -47,8 +44,10 @@ class RecoverTest(unittest.TestCase):
self.storage.close() self.storage.close()
if self.recovered is not None: if self.recovered is not None:
self.recovered.close() self.recovered.close()
removefs(self.path) self.storage.cleanup()
removefs(self.dest) temp = FileStorage(self.dest)
temp.close()
temp.cleanup()
def populate(self): def populate(self):
db = ZODB.DB(self.storage) db = ZODB.DB(self.storage)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
""" """
Revision information: Revision information:
$Id: testTransaction.py,v 1.16 2003/11/28 16:44:54 jim Exp $ $Id: testTransaction.py,v 1.17 2003/12/24 16:01:58 jeremy Exp $
""" """
""" """
......
...@@ -16,7 +16,6 @@ import unittest ...@@ -16,7 +16,6 @@ import unittest
import ZODB import ZODB
import ZODB.FileStorage import ZODB.FileStorage
from ZODB.POSException import ReadConflictError, ConflictError from ZODB.POSException import ReadConflictError, ConflictError
from ZODB.tests.StorageTestBase import removefs
from persistent import Persistent from persistent import Persistent
from persistent.mapping import PersistentMapping from persistent.mapping import PersistentMapping
...@@ -53,7 +52,7 @@ class ZODBTests(unittest.TestCase): ...@@ -53,7 +52,7 @@ class ZODBTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self._db.close() self._db.close()
removefs("ZODBTests.fs") self._storage.cleanup()
def checkExportImport(self, abort_it=0, dup_name='test_duplicate'): def checkExportImport(self, abort_it=0, dup_name='test_duplicate'):
self.populate() self.populate()
......
...@@ -36,5 +36,4 @@ class Prefix: ...@@ -36,5 +36,4 @@ class Prefix:
def __cmp__(self, o): def __cmp__(self, o):
l, v = self.value l, v = self.value
rval = cmp(o[:l], v) return cmp(o[:l], v)
return rval
#!python #!python
"""Print a text summary of the contents of a FileStorage.""" """Print a text summary of the contents of a FileStorage."""
from ZODB.fsdump import fsdump from ZODB.FileStorage.fsdump import fsdump
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment