Commit a47596a7 authored by Jim Fulton's avatar Jim Fulton

Collapsed the serializer reader and writer class hierarchies.

parent d5ce0b90
......@@ -37,7 +37,7 @@ from ZODB.POSException \
import ConflictError, ReadConflictError, InvalidObjectReference, \
ConnectionStateError
from ZODB.TmpStore import TmpStore
from ZODB.serialize import ObjectWriter, ConnectionObjectReader, myhasattr
from ZODB.serialize import ObjectWriter, ObjectReader, myhasattr
from ZODB.utils import u64, oid_repr, z64, positive_id, \
DEPRECATED_ARGUMENT, deprecated36
......@@ -791,8 +791,7 @@ class Connection(ExportImport, object):
self._flush_invalidations()
if self._synch:
self._txn_mgr.registerSynch(self)
self._reader = ConnectionObjectReader(self, self._cache,
self._db.classFactory)
self._reader = ObjectReader(self, self._cache, self._db.classFactory)
# Multi-database support
self.connections = {self._db.database_name: self}
......
......@@ -49,7 +49,7 @@ persistent reference (see below) is used.
It's unclear what "usually" means in the last paragraph. There are two
useful places to concentrate confusion about exactly which formats exist:
- BaseObjectReader.getClassName() below returns a dotted "module.class"
- ObjectReader.getClassName() below returns a dotted "module.class"
string, via actually loading a pickle. This requires that the
implementation of application objects be available.
......@@ -114,7 +114,7 @@ def myhasattr(obj, name, _marker=object()):
return getattr(obj, name, _marker) is not _marker
class BaseObjectWriter:
class ObjectWriter:
"""Serializes objects for storage in the database.
The ObjectWriter creates object pickles in the ZODB format. It
......@@ -122,14 +122,18 @@ class BaseObjectWriter:
object.
"""
def __init__(self, jar=None):
_jar = None
def __init__(self, obj=None):
self._file = cStringIO.StringIO()
self._p = cPickle.Pickler(self._file, 1)
self._stack = []
self._p.persistent_id = self.persistent_id
if jar is not None:
self._stack = []
if obj is not None:
self._stack.append(obj)
jar = obj._p_jar
assert myhasattr(jar, "new_oid")
self._jar = jar
self._jar = jar
def persistent_id(self, obj):
"""Return the persistent id for obj.
......@@ -139,7 +143,9 @@ class BaseObjectWriter:
... def new_oid(self):
... return 42
>>> jar = DummyJar()
>>> writer = BaseObjectWriter(jar)
>>> class O:
... _p_jar = jar
>>> writer = ObjectWriter(O)
Normally, object references include the oid and a cached
reference to the class. Having the class available allows
......@@ -304,12 +310,6 @@ class BaseObjectWriter:
self._file.truncate()
return self._file.getvalue()
class ObjectWriter(BaseObjectWriter):
def __init__(self, obj):
BaseObjectWriter.__init__(self, obj._p_jar)
self._stack.append(obj)
def __iter__(self):
return NewObjectIterator(self._stack)
......@@ -331,22 +331,80 @@ class NewObjectIterator:
else:
raise StopIteration
class BaseObjectReader:
class ObjectReader:
def _persistent_load(self, oid):
# subclasses must define _persistent_load().
raise NotImplementedError
def __init__(self, conn=None, cache=None, factory=None):
self._conn = conn
self._cache = cache
self._factory = factory
def _get_class(self, module, name):
# subclasses must define _get_class()
raise NotImplementedError
return self._factory(self._conn, module, name)
def _get_unpickler(self, pickle):
file = cStringIO.StringIO(pickle)
unpickler = cPickle.Unpickler(file)
unpickler.persistent_load = self._persistent_load
factory = self._factory
conn = self._conn
def find_global(modulename, name):
return factory(conn, modulename, name)
unpickler.find_global = find_global
return unpickler
def _persistent_load(self, oid):
if isinstance(oid, tuple):
# Quick instance reference. We know all we need to know
# to create the instance w/o hitting the db, so go for it!
oid, klass = oid
obj = self._cache.get(oid, None)
if obj is not None:
return obj
if isinstance(klass, tuple):
klass = self._get_class(*klass)
if issubclass(klass, Broken):
# We got a broken class. We might need to make it
# PersistentBroken
if not issubclass(klass, broken.PersistentBroken):
klass = broken.persistentBroken(klass)
try:
obj = klass.__new__(klass)
except TypeError:
# Couldn't create the instance. Maybe there's more
# current data in the object's actual record!
return self._conn.get(oid)
# TODO: should be done by connection
obj._p_oid = oid
obj._p_jar = self._conn
# When an object is created, it is put in the UPTODATE
# state. We must explicitly deactivate it to turn it into
# a ghost.
obj._p_changed = None
self._cache[oid] = obj
return obj
elif isinstance(oid, list):
# see weakref.py
[oid] = oid
obj = WeakRef.__new__(WeakRef)
obj.oid = oid
obj.dm = self._conn
return obj
obj = self._cache.get(oid, None)
if obj is not None:
return obj
return self._conn.get(oid)
def _new_object(self, klass, args):
if not args and not myhasattr(klass, "__getnewargs__"):
obj = klass.__new__(klass)
......@@ -407,97 +465,6 @@ class BaseObjectReader:
state = self.getState(pickle)
obj.__setstate__(state)
class ExternalReference(object):
pass
class SimpleObjectReader(BaseObjectReader):
"""Can be used to inspect a single object pickle.
It returns an ExternalReference() object for other persistent
objects. It can't instantiate the object.
"""
ext_ref = ExternalReference()
def _persistent_load(self, oid):
return self.ext_ref
def _get_class(self, module, name):
return None
class ConnectionObjectReader(BaseObjectReader):
def __init__(self, conn, cache, factory):
self._conn = conn
self._cache = cache
self._factory = factory
def _get_class(self, module, name):
return self._factory(self._conn, module, name)
def _get_unpickler(self, pickle):
unpickler = BaseObjectReader._get_unpickler(self, pickle)
factory = self._factory
conn = self._conn
def find_global(modulename, name):
return factory(conn, modulename, name)
unpickler.find_global = find_global
return unpickler
def _persistent_load(self, oid):
if isinstance(oid, tuple):
# Quick instance reference. We know all we need to know
# to create the instance w/o hitting the db, so go for it!
oid, klass = oid
obj = self._cache.get(oid, None)
if obj is not None:
return obj
if isinstance(klass, tuple):
klass = self._get_class(*klass)
if issubclass(klass, Broken):
# We got a broken class. We might need to make it
# PersistentBroken
if not issubclass(klass, broken.PersistentBroken):
klass = broken.persistentBroken(klass)
try:
obj = klass.__new__(klass)
except TypeError:
# Couldn't create the instance. Maybe there's more
# current data in the object's actual record!
return self._conn.get(oid)
# TODO: should be done by connection
obj._p_oid = oid
obj._p_jar = self._conn
# When an object is created, it is put in the UPTODATE
# state. We must explicitly deactivate it to turn it into
# a ghost.
obj._p_changed = None
self._cache[oid] = obj
return obj
elif isinstance(oid, list):
# see weakref.py
[oid] = oid
obj = WeakRef.__new__(WeakRef)
obj.oid = oid
obj.dm = self._conn
return obj
obj = self._cache.get(oid, None)
if obj is not None:
return obj
return self._conn.get(oid)
def referencesf(p, rootl=None):
if rootl is None:
......
......@@ -39,6 +39,9 @@ def make_pickle(ob):
return sio.getvalue()
def test_factory(conn, module_name, name):
return globals()[name]
class SerializerTestCase(unittest.TestCase):
# old format: (module, name), None
......@@ -58,7 +61,7 @@ class SerializerTestCase(unittest.TestCase):
(ClassWithNewargs, (1,)))
def test_getClassName(self):
r = serialize.BaseObjectReader()
r = serialize.ObjectReader(factory=test_factory)
eq = self.assertEqual
eq(r.getClassName(self.old_style_with_newargs),
__name__ + ".ClassWithNewargs")
......@@ -73,14 +76,14 @@ class SerializerTestCase(unittest.TestCase):
# Use a TestObjectReader since we need _get_class() to be
# implemented; otherwise this is just a BaseObjectReader.
class TestObjectReader(serialize.BaseObjectReader):
class TestObjectReader(serialize.ObjectReader):
# A production object reader would optimize this, but we
# don't need to in a test
def _get_class(self, module, name):
__import__(module)
return getattr(sys.modules[module], name)
r = TestObjectReader()
r = TestObjectReader(factory=test_factory)
g = r.getGhost(self.old_style_with_newargs)
self.assert_(isinstance(g, ClassWithNewargs))
self.assertEqual(g, 1)
......
......@@ -46,11 +46,11 @@ class TestUtils(unittest.TestCase):
self.assertEquals(U64("\000\000\000\001\000\000\000\000"), 1L<<32)
def checkPersistentIdHandlesDescriptor(self):
from ZODB.serialize import BaseObjectWriter
from ZODB.serialize import ObjectWriter
class P(Persistent):
pass
writer = BaseObjectWriter(None)
writer = ObjectWriter(None)
self.assertEqual(writer.persistent_id(P), None)
# It's hard to know where to put this test. We're checking that the
......@@ -59,13 +59,13 @@ class TestUtils(unittest.TestCase):
# the pickle (and so also trying to import application module and
# class objects, which isn't a good idea on a ZEO server when avoidable).
def checkConflictErrorDoesntImport(self):
from ZODB.serialize import BaseObjectWriter
from ZODB.serialize import ObjectWriter
from ZODB.POSException import ConflictError
from ZODB.tests.MinPO import MinPO
import cPickle as pickle
obj = MinPO()
data = BaseObjectWriter().serialize(obj)
data = ObjectWriter().serialize(obj)
# The pickle contains a GLOBAL ('c') opcode resolving to MinPO's
# module and class.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment