Commit 48c4bba4 authored by Shane Hathaway's avatar Shane Hathaway

Changed object import so it happens in a subtransaction. This cleans up

the undo log.
parent eb0cfff2
...@@ -84,8 +84,8 @@ ...@@ -84,8 +84,8 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.49 2001/04/02 14:54:54 chrism Exp $""" $Id: Connection.py,v 1.50 2001/04/14 23:16:44 shane Exp $"""
__version__='$Revision: 1.49 $'[11:-2] __version__='$Revision: 1.50 $'[11:-2]
from cPickleCache import PickleCache from cPickleCache import PickleCache
from POSException import ConflictError, ExportError from POSException import ConflictError, ExportError
...@@ -260,9 +260,20 @@ class Connection(ExportImport.ExportImport): ...@@ -260,9 +260,20 @@ class Connection(ExportImport.ExportImport):
# Return the connection to the pool. # Return the connection to the pool.
db._closeConnection(self) db._closeConnection(self)
__onCommitActions=()
def onCommitAction(self, method_name, *args, **kw):
self.__onCommitActions = self.__onCommitActions + (
(method_name, args, kw),)
get_transaction().register(self)
def commit(self, object, transaction, _type=type, _st=type('')): def commit(self, object, transaction, _type=type, _st=type('')):
if object is self: if object is self:
return # we registered ourself # We registered ourself. Execute a commit action, if any.
if self.__onCommitActions:
method_name, args, kw = self.__onCommitActions[0]
self.__onCommitActions = self.__onCommitActions[1:]
apply(getattr(self, method_name), (transaction,) + args, kw)
return
oid=object._p_oid oid=object._p_oid
invalid=self._invalid invalid=self._invalid
if oid is None or object._p_jar is not self: if oid is None or object._p_jar is not self:
...@@ -413,6 +424,7 @@ class Connection(ExportImport.ExportImport): ...@@ -413,6 +424,7 @@ class Connection(ExportImport.ExportImport):
def commit_sub(self, t, def commit_sub(self, t,
_type=type, _st=type(''), _None=None): _type=type, _st=type(''), _None=None):
"""Commit all work done in subtransactions"""
tmp=self._tmp tmp=self._tmp
if tmp is _None: return if tmp is _None: return
src=self._storage src=self._storage
...@@ -597,6 +609,7 @@ class Connection(ExportImport.ExportImport): ...@@ -597,6 +609,7 @@ class Connection(ExportImport.ExportImport):
raise raise
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
self.__onCommitActions = ()
self._storage.tpc_abort(transaction) self._storage.tpc_abort(transaction)
cache=self._cache cache=self._cache
cache.invalidate(self._invalidated) cache.invalidate(self._invalidated)
...@@ -622,6 +635,7 @@ class Connection(ExportImport.ExportImport): ...@@ -622,6 +635,7 @@ class Connection(ExportImport.ExportImport):
def tpc_vote(self, transaction, def tpc_vote(self, transaction,
_type=type, _st=type('')): _type=type, _st=type('')):
self.__onCommitActions = ()
try: vote=self._storage.tpc_vote try: vote=self._storage.tpc_vote
except: return except: return
s=vote(transaction) s=vote(transaction)
......
...@@ -142,22 +142,35 @@ class ExportImport: ...@@ -142,22 +142,35 @@ class ExportImport:
return customImporters[magic](self, file, clue) return customImporters[magic](self, file, clue)
raise POSException.ExportError, 'Invalid export header' raise POSException.ExportError, 'Invalid export header'
t=get_transaction().sub() t = get_transaction()
t.note('import into %s from %s' % (self.db().getName(), file_name))
if clue: t.note(clue) if clue: t.note(clue)
storage=self._storage return_oid_list = []
new_oid=storage.new_oid self.onCommitAction('_importDuringCommit', file, return_oid_list)
oids={} t.commit(1)
wrote_oid=oids.has_key # Return the root imported object.
new_oid=storage.new_oid if return_oid_list:
store=storage.store return self[return_oid_list[0]]
else:
return None
def _importDuringCommit(self, transaction, file, return_oid_list):
'''
Invoked by the transaction manager mid commit.
Appends one item, the OID of the first object created,
to return_oid_list.
'''
oids = {}
storage = self._storage
new_oid = storage.new_oid
store = storage.store
read = file.read
def persistent_load(ooid, def persistent_load(ooid,
Ghost=Ghost, StringType=StringType, Ghost=Ghost, StringType=StringType,
atoi=string.atoi, TupleType=type(()), atoi=string.atoi, TupleType=type(()),
oids=oids, wrote_oid=wrote_oid, new_oid=new_oid): oids=oids, wrote_oid=oids.has_key,
new_oid=storage.new_oid):
"Remap a persistent id to a new ID and create a ghost for it." "Remap a persistent id to a new ID and create a ghost for it."
...@@ -174,50 +187,41 @@ class ExportImport: ...@@ -174,50 +187,41 @@ class ExportImport:
Ghost.oid=oid Ghost.oid=oid
return Ghost return Ghost
version=self._version version = self._version
return_oid=None
while 1:
storage.tpc_begin(t) h=read(16)
try: if h==export_end_marker: break
while 1: if len(h) != 16:
h=read(16) raise POSException.ExportError, 'Truncated export file'
if h==export_end_marker: break l=u64(h[8:16])
if len(h) != 16: p=read(l)
raise POSException.ExportError, 'Truncated export file' if len(p) != l:
l=u64(h[8:16]) raise POSException.ExportError, 'Truncated export file'
p=read(l)
if len(p) != l: ooid=h[:8]
raise POSException.ExportError, 'Truncated export file' if oids:
oid=oids[ooid]
ooid=h[:8] if type(oid) is TupleType: oid=oid[0]
if oids: else:
oid=oids[ooid] oids[ooid] = oid = storage.new_oid()
if type(oid) is TupleType: oid=oid[0] return_oid_list.append(oid)
else:
oids[ooid]=return_oid=oid=new_oid() pfile=StringIO(p)
unpickler=Unpickler(pfile)
pfile=StringIO(p) unpickler.persistent_load=persistent_load
unpickler=Unpickler(pfile)
unpickler.persistent_load=persistent_load newp=StringIO()
pickler=Pickler(newp,1)
newp=StringIO() pickler.persistent_id=persistent_id
pickler=Pickler(newp,1)
pickler.persistent_id=persistent_id pickler.dump(unpickler.load())
pickler.dump(unpickler.load())
pickler.dump(unpickler.load()) p=newp.getvalue()
pickler.dump(unpickler.load()) plen=len(p)
p=newp.getvalue()
plen=len(p) store(oid, None, p, version, transaction)
store(oid, None, p, version, t)
except:
storage.tpc_abort(t)
raise
else:
storage.tpc_vote(t)
storage.tpc_finish(t)
if return_oid is not None: return self[return_oid]
StringType=type('') StringType=type('')
...@@ -233,6 +237,6 @@ export_end_marker='\377'*16 ...@@ -233,6 +237,6 @@ export_end_marker='\377'*16
class Ghost: pass class Ghost: pass
def persistent_id(object, Ghost=Ghost): def persistent_id(object, Ghost=Ghost):
if hasattr(object, '__class__') and object.__class__ is Ghost: if getattr(object, '__class__', None) is Ghost:
return object.oid return object.oid
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment