Commit a9ebb196 authored by Jim Fulton's avatar Jim Fulton

Provide shorter code path for loads, which are most common operation.

Simplified and optimized marshalling code.
parent a5ad6b94
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
import ZEO.ServerStub import ZEO.ServerStub
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.zrpc.marshal import Marshaller from ZEO.zrpc.marshal import encode
from ZEO.tests import forker from ZEO.tests import forker
from ZODB.DB import DB from ZODB.DB import DB
...@@ -473,7 +473,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -473,7 +473,7 @@ class ConnectionTests(CommonSetupTearDown):
class Hack: class Hack:
pass pass
msg = Marshaller().encode(1, 0, "foo", (Hack(),)) msg = encode(1, 0, "foo", (Hack(),))
self._bad_message(msg) self._bad_message(msg)
del Hack del Hack
......
...@@ -15,12 +15,12 @@ import asyncore ...@@ -15,12 +15,12 @@ import asyncore
import sys import sys
import threading import threading
import logging import logging
import ZEO.zrpc.marshal
import ZEO.zrpc.trigger import ZEO.zrpc.trigger
from ZEO.zrpc import smac from ZEO.zrpc import smac
from ZEO.zrpc.error import ZRPCError, DisconnectedError from ZEO.zrpc.error import ZRPCError, DisconnectedError
from ZEO.zrpc.marshal import Marshaller, ServerMarshaller
from ZEO.zrpc.log import short_repr, log from ZEO.zrpc.log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException import ZODB.POSException
...@@ -282,7 +282,10 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -282,7 +282,10 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# our peer. # our peer.
def __init__(self, sock, addr, obj, tag, map=None): def __init__(self, sock, addr, obj, tag, map=None):
self.obj = None self.obj = None
self.marshal = Marshaller() self.decode = ZEO.zrpc.marshal.decode
self.encode = ZEO.zrpc.marshal.encode
self.fast_encode = ZEO.zrpc.marshal.fast_encode
self.closed = False self.closed = False
self.peer_protocol_version = None # set in recv_handshake() self.peer_protocol_version = None # set in recv_handshake()
...@@ -408,13 +411,34 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -408,13 +411,34 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# will raise an exception. The exception will ultimately # will raise an exception. The exception will ultimately
# result in asycnore calling handle_error(), which will # result in asycnore calling handle_error(), which will
# close the connection. # close the connection.
msgid, async, name, args = self.marshal.decode(message) msgid, async, name, args = self.decode(message)
if debug_zrpc: if debug_zrpc:
self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name, self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
short_repr(args)), short_repr(args)),
level=TRACE) level=TRACE)
if name == REPLY:
if name == 'loadEx':
# Special case and inline the heck out of load case:
try:
ret = self.obj.loadEx(*args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception, msg:
if not isinstance(msg, self.unlogged_exception_types):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
self.return_error(msgid, *sys.exc_info()[:2])
else:
try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll()
except:
# Fall back to normal version for better error handling
self.send_reply(msgid, ret)
elif name == REPLY:
assert not async assert not async
self.handle_reply(msgid, args) self.handle_reply(msgid, args)
else: else:
...@@ -488,14 +512,14 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -488,14 +512,14 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# it's acceptable -- we really do want to catch every exception # it's acceptable -- we really do want to catch every exception
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.marshal.encode(msgid, 0, REPLY, (err_type, err_value)) msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above except: # see above
try: try:
r = short_repr(err_value) r = short_repr(err_value)
except: except:
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r) err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg) self.message_output(msg)
self.poll() self.poll()
...@@ -522,7 +546,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -522,7 +546,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
if debug_zrpc: if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, async, method), self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
level=TRACE) level=TRACE)
buf = self.marshal.encode(msgid, async, method, args) buf = self.encode(msgid, async, method, args)
self.message_output(buf) self.message_output(buf)
return msgid return msgid
...@@ -555,7 +579,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -555,7 +579,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
The calls will not be interleaved with other calls from the same The calls will not be interleaved with other calls from the same
client. client.
""" """
self.message_output(self.marshal.encode(0, 1, method, args) self.message_output(self.encode(0, 1, method, args)
for method, args in iterator) for method, args in iterator)
def handle_reply(self, msgid, ret): def handle_reply(self, msgid, ret):
...@@ -568,6 +592,8 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -568,6 +592,8 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.trigger.pull_trigger() self.trigger.pull_trigger()
# import cProfile, time
class ManagedServerConnection(Connection): class ManagedServerConnection(Connection):
"""Server-side Connection subclass.""" """Server-side Connection subclass."""
...@@ -578,7 +604,9 @@ class ManagedServerConnection(Connection): ...@@ -578,7 +604,9 @@ class ManagedServerConnection(Connection):
self.mgr = mgr self.mgr = mgr
map = {} map = {}
Connection.__init__(self, sock, addr, obj, 'S', map=map) Connection.__init__(self, sock, addr, obj, 'S', map=map)
self.marshal = ServerMarshaller()
self.decode = ZEO.zrpc.marshal.server_decode
self.trigger = ZEO.zrpc.trigger.trigger(map) self.trigger = ZEO.zrpc.trigger.trigger(map)
self.call_from_thread = self.trigger.pull_trigger self.call_from_thread = self.trigger.pull_trigger
...@@ -586,6 +614,15 @@ class ManagedServerConnection(Connection): ...@@ -586,6 +614,15 @@ class ManagedServerConnection(Connection):
t.setDaemon(True) t.setDaemon(True)
t.start() t.start()
# self.profile = cProfile.Profile()
# def message_input(self, message):
# self.profile.enable()
# try:
# Connection.message_input(self, message)
# finally:
# self.profile.disable()
def handshake(self): def handshake(self):
# Send the server's preferred protocol to the client. # Send the server's preferred protocol to the client.
self.message_output(self.current_protocol) self.message_output(self.current_protocol)
...@@ -597,6 +634,7 @@ class ManagedServerConnection(Connection): ...@@ -597,6 +634,7 @@ class ManagedServerConnection(Connection):
def close(self): def close(self):
self.obj.notifyDisconnected() self.obj.notifyDisconnected()
Connection.close(self) Connection.close(self)
# self.profile.dump_stats(str(time.time())+'.stats')
def send_reply(self, msgid, ret, immediately=True): def send_reply(self, msgid, ret, immediately=True):
# encode() can pass on a wide variety of exceptions from cPickle. # encode() can pass on a wide variety of exceptions from cPickle.
...@@ -604,14 +642,14 @@ class ManagedServerConnection(Connection): ...@@ -604,14 +642,14 @@ class ManagedServerConnection(Connection):
# it's acceptable -- we really do want to catch every exception # it's acceptable -- we really do want to catch every exception
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.marshal.encode(msgid, 0, REPLY, ret) msg = self.encode(msgid, 0, REPLY, ret)
except: # see above except: # see above
try: try:
r = short_repr(ret) r = short_repr(ret)
except: except:
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r) err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg) self.message_output(msg)
if immediately: if immediately:
self.poll() self.poll()
......
...@@ -11,60 +11,65 @@ ...@@ -11,60 +11,65 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
import cPickle from cPickle import Unpickler, Pickler
from cStringIO import StringIO from cStringIO import StringIO
import logging import logging
from ZEO.zrpc.error import ZRPCError from ZEO.zrpc.error import ZRPCError
from ZEO.zrpc.log import log, short_repr from ZEO.zrpc.log import log, short_repr
class Marshaller: def encode(*args): # args: (msgid, flags, name, args)
"""Marshal requests and replies to second across network""" # (We used to have a global pickler, but that's not thread-safe. :-( )
def encode(self, msgid, flags, name, args): # It's not thread safe if, in the couse of pickling, we call the
"""Returns an encoded message""" # Python interpeter, which releases the GIL.
# (We used to have a global pickler, but that's not thread-safe. :-( )
# Note that args may contain very large binary pickles already; for # Note that args may contain very large binary pickles already; for
# this reason, it's important to use proto 1 (or higher) pickles here # this reason, it's important to use proto 1 (or higher) pickles here
# too. For a long time, this used proto 0 pickles, and that can # too. For a long time, this used proto 0 pickles, and that can
# bloat our pickle to 4x the size (due to high-bit and control bytes # bloat our pickle to 4x the size (due to high-bit and control bytes
# being represented by \xij escapes in proto 0). # being represented by \xij escapes in proto 0).
# Undocumented: cPickle.Pickler accepts a lone protocol argument; # Undocumented: cPickle.Pickler accepts a lone protocol argument;
# pickle.py does not. # pickle.py does not.
pickler = cPickle.Pickler(1) pickler = Pickler(1)
pickler.fast = 1 pickler.fast = 1
return pickler.dump(args, 1)
# Undocumented: pickler.dump(), for a cPickle.Pickler, takes
# an optional boolean argument. When true, it returns the pickle;
# when false or unspecified, it returns the pickler object itself. @apply
# pickle.py does none of this. def fast_encode():
return pickler.dump((msgid, flags, name, args), 1) # Only use in cases where you *know* the data contains only basic
# Python objects
def decode(self, msg): pickler = Pickler(1)
"""Decodes msg and returns its parts""" pickler.fast = 1
unpickler = cPickle.Unpickler(StringIO(msg)) dump = pickler.dump
unpickler.find_global = find_global def fast_encode(*args):
return dump(args, 1)
try: return fast_encode
return unpickler.load() # msgid, flags, name, args
except: def decode(msg):
log("can't decode message: %s" % short_repr(msg), """Decodes msg and returns its parts"""
level=logging.ERROR) unpickler = Unpickler(StringIO(msg))
raise unpickler.find_global = find_global
class ServerMarshaller(Marshaller): try:
return unpickler.load() # msgid, flags, name, args
def decode(self, msg): except:
"""Decodes msg and returns its parts""" log("can't decode message: %s" % short_repr(msg),
unpickler = cPickle.Unpickler(StringIO(msg)) level=logging.ERROR)
unpickler.find_global = server_find_global raise
try: def server_decode(msg):
return unpickler.load() # msgid, flags, name, args """Decodes msg and returns its parts"""
except: unpickler = Unpickler(StringIO(msg))
log("can't decode message: %s" % short_repr(msg), unpickler.find_global = server_find_global
level=logging.ERROR)
raise try:
return unpickler.load() # msgid, flags, name, args
except:
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment