Commit a47596a7 authored by Jim Fulton's avatar Jim Fulton

Collapsed the serializer reader and writer class hierarchies.

parent d5ce0b90
...@@ -37,7 +37,7 @@ from ZODB.POSException \ ...@@ -37,7 +37,7 @@ from ZODB.POSException \
import ConflictError, ReadConflictError, InvalidObjectReference, \ import ConflictError, ReadConflictError, InvalidObjectReference, \
ConnectionStateError ConnectionStateError
from ZODB.TmpStore import TmpStore from ZODB.TmpStore import TmpStore
from ZODB.serialize import ObjectWriter, ConnectionObjectReader, myhasattr from ZODB.serialize import ObjectWriter, ObjectReader, myhasattr
from ZODB.utils import u64, oid_repr, z64, positive_id, \ from ZODB.utils import u64, oid_repr, z64, positive_id, \
DEPRECATED_ARGUMENT, deprecated36 DEPRECATED_ARGUMENT, deprecated36
...@@ -791,8 +791,7 @@ class Connection(ExportImport, object): ...@@ -791,8 +791,7 @@ class Connection(ExportImport, object):
self._flush_invalidations() self._flush_invalidations()
if self._synch: if self._synch:
self._txn_mgr.registerSynch(self) self._txn_mgr.registerSynch(self)
self._reader = ConnectionObjectReader(self, self._cache, self._reader = ObjectReader(self, self._cache, self._db.classFactory)
self._db.classFactory)
# Multi-database support # Multi-database support
self.connections = {self._db.database_name: self} self.connections = {self._db.database_name: self}
......
...@@ -49,7 +49,7 @@ persistent reference (see below) is used. ...@@ -49,7 +49,7 @@ persistent reference (see below) is used.
It's unclear what "usually" means in the last paragraph. There are two It's unclear what "usually" means in the last paragraph. There are two
useful places to concentrate confusion about exactly which formats exist: useful places to concentrate confusion about exactly which formats exist:
- BaseObjectReader.getClassName() below returns a dotted "module.class" - ObjectReader.getClassName() below returns a dotted "module.class"
string, via actually loading a pickle. This requires that the string, via actually loading a pickle. This requires that the
implementation of application objects be available. implementation of application objects be available.
...@@ -114,7 +114,7 @@ def myhasattr(obj, name, _marker=object()): ...@@ -114,7 +114,7 @@ def myhasattr(obj, name, _marker=object()):
return getattr(obj, name, _marker) is not _marker return getattr(obj, name, _marker) is not _marker
class BaseObjectWriter: class ObjectWriter:
"""Serializes objects for storage in the database. """Serializes objects for storage in the database.
The ObjectWriter creates object pickles in the ZODB format. It The ObjectWriter creates object pickles in the ZODB format. It
...@@ -122,12 +122,16 @@ class BaseObjectWriter: ...@@ -122,12 +122,16 @@ class BaseObjectWriter:
object. object.
""" """
def __init__(self, jar=None): _jar = None
def __init__(self, obj=None):
self._file = cStringIO.StringIO() self._file = cStringIO.StringIO()
self._p = cPickle.Pickler(self._file, 1) self._p = cPickle.Pickler(self._file, 1)
self._stack = []
self._p.persistent_id = self.persistent_id self._p.persistent_id = self.persistent_id
if jar is not None: self._stack = []
if obj is not None:
self._stack.append(obj)
jar = obj._p_jar
assert myhasattr(jar, "new_oid") assert myhasattr(jar, "new_oid")
self._jar = jar self._jar = jar
...@@ -139,7 +143,9 @@ class BaseObjectWriter: ...@@ -139,7 +143,9 @@ class BaseObjectWriter:
... def new_oid(self): ... def new_oid(self):
... return 42 ... return 42
>>> jar = DummyJar() >>> jar = DummyJar()
>>> writer = BaseObjectWriter(jar) >>> class O:
... _p_jar = jar
>>> writer = ObjectWriter(O)
Normally, object references include the oid and a cached Normally, object references include the oid and a cached
reference to the class. Having the class available allows reference to the class. Having the class available allows
...@@ -304,12 +310,6 @@ class BaseObjectWriter: ...@@ -304,12 +310,6 @@ class BaseObjectWriter:
self._file.truncate() self._file.truncate()
return self._file.getvalue() return self._file.getvalue()
class ObjectWriter(BaseObjectWriter):
def __init__(self, obj):
BaseObjectWriter.__init__(self, obj._p_jar)
self._stack.append(obj)
def __iter__(self): def __iter__(self):
return NewObjectIterator(self._stack) return NewObjectIterator(self._stack)
...@@ -331,22 +331,80 @@ class NewObjectIterator: ...@@ -331,22 +331,80 @@ class NewObjectIterator:
else: else:
raise StopIteration raise StopIteration
class BaseObjectReader: class ObjectReader:
def _persistent_load(self, oid): def __init__(self, conn=None, cache=None, factory=None):
# subclasses must define _persistent_load(). self._conn = conn
raise NotImplementedError self._cache = cache
self._factory = factory
def _get_class(self, module, name): def _get_class(self, module, name):
# subclasses must define _get_class() return self._factory(self._conn, module, name)
raise NotImplementedError
def _get_unpickler(self, pickle): def _get_unpickler(self, pickle):
file = cStringIO.StringIO(pickle) file = cStringIO.StringIO(pickle)
unpickler = cPickle.Unpickler(file) unpickler = cPickle.Unpickler(file)
unpickler.persistent_load = self._persistent_load unpickler.persistent_load = self._persistent_load
factory = self._factory
conn = self._conn
def find_global(modulename, name):
return factory(conn, modulename, name)
unpickler.find_global = find_global
return unpickler return unpickler
def _persistent_load(self, oid):
if isinstance(oid, tuple):
# Quick instance reference. We know all we need to know
# to create the instance w/o hitting the db, so go for it!
oid, klass = oid
obj = self._cache.get(oid, None)
if obj is not None:
return obj
if isinstance(klass, tuple):
klass = self._get_class(*klass)
if issubclass(klass, Broken):
# We got a broken class. We might need to make it
# PersistentBroken
if not issubclass(klass, broken.PersistentBroken):
klass = broken.persistentBroken(klass)
try:
obj = klass.__new__(klass)
except TypeError:
# Couldn't create the instance. Maybe there's more
# current data in the object's actual record!
return self._conn.get(oid)
# TODO: should be done by connection
obj._p_oid = oid
obj._p_jar = self._conn
# When an object is created, it is put in the UPTODATE
# state. We must explicitly deactivate it to turn it into
# a ghost.
obj._p_changed = None
self._cache[oid] = obj
return obj
elif isinstance(oid, list):
# see weakref.py
[oid] = oid
obj = WeakRef.__new__(WeakRef)
obj.oid = oid
obj.dm = self._conn
return obj
obj = self._cache.get(oid, None)
if obj is not None:
return obj
return self._conn.get(oid)
def _new_object(self, klass, args): def _new_object(self, klass, args):
if not args and not myhasattr(klass, "__getnewargs__"): if not args and not myhasattr(klass, "__getnewargs__"):
obj = klass.__new__(klass) obj = klass.__new__(klass)
...@@ -407,97 +465,6 @@ class BaseObjectReader: ...@@ -407,97 +465,6 @@ class BaseObjectReader:
state = self.getState(pickle) state = self.getState(pickle)
obj.__setstate__(state) obj.__setstate__(state)
class ExternalReference(object):
pass
class SimpleObjectReader(BaseObjectReader):
"""Can be used to inspect a single object pickle.
It returns an ExternalReference() object for other persistent
objects. It can't instantiate the object.
"""
ext_ref = ExternalReference()
def _persistent_load(self, oid):
return self.ext_ref
def _get_class(self, module, name):
return None
class ConnectionObjectReader(BaseObjectReader):
def __init__(self, conn, cache, factory):
self._conn = conn
self._cache = cache
self._factory = factory
def _get_class(self, module, name):
return self._factory(self._conn, module, name)
def _get_unpickler(self, pickle):
unpickler = BaseObjectReader._get_unpickler(self, pickle)
factory = self._factory
conn = self._conn
def find_global(modulename, name):
return factory(conn, modulename, name)
unpickler.find_global = find_global
return unpickler
def _persistent_load(self, oid):
if isinstance(oid, tuple):
# Quick instance reference. We know all we need to know
# to create the instance w/o hitting the db, so go for it!
oid, klass = oid
obj = self._cache.get(oid, None)
if obj is not None:
return obj
if isinstance(klass, tuple):
klass = self._get_class(*klass)
if issubclass(klass, Broken):
# We got a broken class. We might need to make it
# PersistentBroken
if not issubclass(klass, broken.PersistentBroken):
klass = broken.persistentBroken(klass)
try:
obj = klass.__new__(klass)
except TypeError:
# Couldn't create the instance. Maybe there's more
# current data in the object's actual record!
return self._conn.get(oid)
# TODO: should be done by connection
obj._p_oid = oid
obj._p_jar = self._conn
# When an object is created, it is put in the UPTODATE
# state. We must explicitly deactivate it to turn it into
# a ghost.
obj._p_changed = None
self._cache[oid] = obj
return obj
elif isinstance(oid, list):
# see weakref.py
[oid] = oid
obj = WeakRef.__new__(WeakRef)
obj.oid = oid
obj.dm = self._conn
return obj
obj = self._cache.get(oid, None)
if obj is not None:
return obj
return self._conn.get(oid)
def referencesf(p, rootl=None): def referencesf(p, rootl=None):
if rootl is None: if rootl is None:
......
...@@ -39,6 +39,9 @@ def make_pickle(ob): ...@@ -39,6 +39,9 @@ def make_pickle(ob):
return sio.getvalue() return sio.getvalue()
def test_factory(conn, module_name, name):
return globals()[name]
class SerializerTestCase(unittest.TestCase): class SerializerTestCase(unittest.TestCase):
# old format: (module, name), None # old format: (module, name), None
...@@ -58,7 +61,7 @@ class SerializerTestCase(unittest.TestCase): ...@@ -58,7 +61,7 @@ class SerializerTestCase(unittest.TestCase):
(ClassWithNewargs, (1,))) (ClassWithNewargs, (1,)))
def test_getClassName(self): def test_getClassName(self):
r = serialize.BaseObjectReader() r = serialize.ObjectReader(factory=test_factory)
eq = self.assertEqual eq = self.assertEqual
eq(r.getClassName(self.old_style_with_newargs), eq(r.getClassName(self.old_style_with_newargs),
__name__ + ".ClassWithNewargs") __name__ + ".ClassWithNewargs")
...@@ -73,14 +76,14 @@ class SerializerTestCase(unittest.TestCase): ...@@ -73,14 +76,14 @@ class SerializerTestCase(unittest.TestCase):
# Use a TestObjectReader since we need _get_class() to be # Use a TestObjectReader since we need _get_class() to be
# implemented; otherwise this is just a BaseObjectReader. # implemented; otherwise this is just a BaseObjectReader.
class TestObjectReader(serialize.BaseObjectReader): class TestObjectReader(serialize.ObjectReader):
# A production object reader would optimize this, but we # A production object reader would optimize this, but we
# don't need to in a test # don't need to in a test
def _get_class(self, module, name): def _get_class(self, module, name):
__import__(module) __import__(module)
return getattr(sys.modules[module], name) return getattr(sys.modules[module], name)
r = TestObjectReader() r = TestObjectReader(factory=test_factory)
g = r.getGhost(self.old_style_with_newargs) g = r.getGhost(self.old_style_with_newargs)
self.assert_(isinstance(g, ClassWithNewargs)) self.assert_(isinstance(g, ClassWithNewargs))
self.assertEqual(g, 1) self.assertEqual(g, 1)
......
...@@ -46,11 +46,11 @@ class TestUtils(unittest.TestCase): ...@@ -46,11 +46,11 @@ class TestUtils(unittest.TestCase):
self.assertEquals(U64("\000\000\000\001\000\000\000\000"), 1L<<32) self.assertEquals(U64("\000\000\000\001\000\000\000\000"), 1L<<32)
def checkPersistentIdHandlesDescriptor(self): def checkPersistentIdHandlesDescriptor(self):
from ZODB.serialize import BaseObjectWriter from ZODB.serialize import ObjectWriter
class P(Persistent): class P(Persistent):
pass pass
writer = BaseObjectWriter(None) writer = ObjectWriter(None)
self.assertEqual(writer.persistent_id(P), None) self.assertEqual(writer.persistent_id(P), None)
# It's hard to know where to put this test. We're checking that the # It's hard to know where to put this test. We're checking that the
...@@ -59,13 +59,13 @@ class TestUtils(unittest.TestCase): ...@@ -59,13 +59,13 @@ class TestUtils(unittest.TestCase):
# the pickle (and so also trying to import application module and # the pickle (and so also trying to import application module and
# class objects, which isn't a good idea on a ZEO server when avoidable). # class objects, which isn't a good idea on a ZEO server when avoidable).
def checkConflictErrorDoesntImport(self): def checkConflictErrorDoesntImport(self):
from ZODB.serialize import BaseObjectWriter from ZODB.serialize import ObjectWriter
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
import cPickle as pickle import cPickle as pickle
obj = MinPO() obj = MinPO()
data = BaseObjectWriter().serialize(obj) data = ObjectWriter().serialize(obj)
# The pickle contains a GLOBAL ('c') opcode resolving to MinPO's # The pickle contains a GLOBAL ('c') opcode resolving to MinPO's
# module and class. # module and class.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment