Commit 245a8580 authored by Jim Fulton's avatar Jim Fulton

Refactored the zrpc implementation to:

- Most server methods now return data to clients more quickly by writing to
  client sockets immediately, rather than waiting for the asyncore
  select loop to get around to it.

- More clearly define client and server responsibilities. Machinery
  needed for just clients or just servers has been moved to the
  corresponding connection subclasses.

- Degeneralized "flags" argument to many methods. There's just one
  async flag.
parent 693d13fd
...@@ -1340,10 +1340,10 @@ class ClientStub: ...@@ -1340,10 +1340,10 @@ class ClientStub:
self.rpc.callAsyncNoPoll('invalidateTransaction', tid, args) self.rpc.callAsyncNoPoll('invalidateTransaction', tid, args)
def serialnos(self, arg): def serialnos(self, arg):
self.rpc.callAsync('serialnos', arg) self.rpc.callAsyncNoPoll('serialnos', arg)
def info(self, arg): def info(self, arg):
self.rpc.callAsync('info', arg) self.rpc.callAsyncNoPoll('info', arg)
def storeBlob(self, oid, serial, blobfilename): def storeBlob(self, oid, serial, blobfilename):
......
...@@ -56,3 +56,5 @@ class Connection: ...@@ -56,3 +56,5 @@ class Connection:
def callAsync(self, meth, *args): def callAsync(self, meth, *args):
print self.name, 'callAsync', meth, repr(args) print self.name, 'callAsync', meth, repr(args)
callAsyncNoPoll = callAsync
...@@ -78,9 +78,10 @@ will conflict. It will be blocked at the vote call. ...@@ -78,9 +78,10 @@ will conflict. It will be blocked at the vote call.
>>> zs2.storeBlobEnd(oid, serial, data, '1') >>> zs2.storeBlobEnd(oid, serial, data, '1')
>>> delay = zs2.vote('1') >>> delay = zs2.vote('1')
>>> def send_reply(id, reply): >>> class Sender:
... def send_reply(self, id, reply):
... print 'reply', id, reply ... print 'reply', id, reply
>>> delay.set_sender(1, send_reply, None) >>> delay.set_sender(1, Sender())
>>> logger = logging.getLogger('ZEO') >>> logger = logging.getLogger('ZEO')
>>> handler = logging.StreamHandler(sys.stdout) >>> handler = logging.StreamHandler(sys.stdout)
......
...@@ -30,7 +30,6 @@ from ZODB.loglevels import BLATHER, TRACE ...@@ -30,7 +30,6 @@ from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException import ZODB.POSException
REPLY = ".reply" # message name used for replies REPLY = ".reply" # message name used for replies
ASYNC = 1
exception_type_type = type(Exception) exception_type_type = type(Exception)
...@@ -180,34 +179,33 @@ class Delay: ...@@ -180,34 +179,33 @@ class Delay:
the mainloop from sending a response. the mainloop from sending a response.
""" """
def set_sender(self, msgid, send_reply, return_error): def set_sender(self, msgid, conn):
self.msgid = msgid self.msgid = msgid
self.send_reply = send_reply self.conn = conn
self.return_error = return_error
def reply(self, obj): def reply(self, obj):
self.send_reply(self.msgid, obj) self.conn.send_reply(self.msgid, obj)
def error(self, exc_info): def error(self, exc_info):
log("Error raised in delayed method", logging.ERROR, exc_info=True) log("Error raised in delayed method", logging.ERROR, exc_info=True)
self.return_error(self.msgid, 0, *exc_info[:2]) self.conn.return_error(self.msgid, *exc_info[:2])
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
self.ready = threading.Event() self.ready = threading.Event()
def set_sender(self, msgid, send_reply, return_error): def set_sender(self, *args):
Delay.set_sender(self, msgid, send_reply, return_error) Delay.set_sender(self, *args)
self.ready.set() self.ready.set()
def reply(self, obj): def reply(self, obj):
self.ready.wait() self.ready.wait()
Delay.reply(self, obj) self.conn.call_from_thread(self.conn.send_reply, self.msgid, obj)
def error(self, exc_info): def error(self, exc_info):
self.ready.wait() self.ready.wait()
Delay.error(self, exc_info) self.conn.call_from_thread(Delay.error, self, exc_info)
# PROTOCOL NEGOTIATION # PROTOCOL NEGOTIATION
# #
...@@ -304,9 +302,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -304,9 +302,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
client for that particular call. client for that particular call.
The protocol also supports asynchronous calls. The client does The protocol also supports asynchronous calls. The client does
not wait for a return value for an asynchronous call. The only not wait for a return value for an asynchronous call.
defined flag is ASYNC. If a method call message has the ASYNC
flag set, the server will raise an exception.
If a method call raises an Exception, the exception is propagated If a method call raises an Exception, the exception is propagated
back to the client via the REPLY message. The client side will back to the client via the REPLY message. The client side will
...@@ -428,15 +424,6 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -428,15 +424,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# The singleton dict is a socket map containing only this object. # The singleton dict is a socket map containing only this object.
self._singleton = {self._fileno: self} self._singleton = {self._fileno: self}
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
# waiting_for_reply is used internally to indicate whether # waiting_for_reply is used internally to indicate whether
# a call is in progress. setting a session key is deferred # a call is in progress. setting a session key is deferred
# until after the call returns. # until after the call returns.
...@@ -488,9 +475,6 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -488,9 +475,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.closed = True self.closed = True
self.__super_close() self.__super_close()
self.trigger.pull_trigger() self.trigger.pull_trigger()
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
def register_object(self, obj): def register_object(self, obj):
"""Register obj as the true object to invoke methods on.""" """Register obj as the true object to invoke methods on."""
...@@ -537,29 +521,19 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -537,29 +521,19 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# will raise an exception. The exception will ultimately # will raise an exception. The exception will ultimately
# result in asycnore calling handle_error(), which will # result in asycnore calling handle_error(), which will
# close the connection. # close the connection.
msgid, flags, name, args = self.marshal.decode(message) msgid, async, name, args = self.marshal.decode(message)
if debug_zrpc: if debug_zrpc:
self.log("recv msg: %s, %s, %s, %s" % (msgid, flags, name, self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
short_repr(args)), short_repr(args)),
level=TRACE) level=TRACE)
if name == REPLY: if name == REPLY:
self.handle_reply(msgid, flags, args) assert not async
self.handle_reply(msgid, args)
else: else:
self.handle_request(msgid, flags, name, args) self.handle_request(msgid, async, name, args)
def handle_reply(self, msgid, flags, args): def handle_request(self, msgid, async, name, args):
if debug_zrpc:
self.log("recv reply: %s, %s, %s"
% (msgid, flags, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = flags, args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def handle_request(self, msgid, flags, name, args):
obj = self.obj obj = self.obj
if name.startswith('_') or not hasattr(obj, name): if name.startswith('_') or not hasattr(obj, name):
...@@ -590,9 +564,14 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -590,9 +564,14 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.log("%s() raised exception: %s" % (name, msg), self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True) logging.ERROR, exc_info=True)
error = sys.exc_info()[:2] error = sys.exc_info()[:2]
return self.return_error(msgid, flags, *error) if async:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
else:
self.return_error(msgid, *error)
return
if flags & ASYNC: if async:
if ret is not None: if ret is not None:
raise ZRPCError("async method %s returned value %s" % raise ZRPCError("async method %s returned value %s" %
(name, short_repr(ret))) (name, short_repr(ret)))
...@@ -601,43 +580,19 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -601,43 +580,19 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.log("%s returns %s" % (name, short_repr(ret)), self.log("%s returns %s" % (name, short_repr(ret)),
logging.DEBUG) logging.DEBUG)
if isinstance(ret, Delay): if isinstance(ret, Delay):
ret.set_sender(msgid, self.send_reply, self.return_error) ret.set_sender(msgid, self)
else: else:
self.send_reply(msgid, ret) self.send_reply(msgid, ret, not self.delay_sesskey)
if self.delay_sesskey: if self.delay_sesskey:
self.__super_setSessionKey(self.delay_sesskey) self.__super_setSessionKey(self.delay_sesskey)
self.delay_sesskey = None self.delay_sesskey = None
def handle_error(self): def return_error(self, msgid, err_type, err_value):
if sys.exc_info()[0] == SystemExit: # Note that, ideally, this should be defined soley for
raise sys.exc_info() # servers, but a test arranges to get it called on
self.log("Error caught in asyncore", # a client. Too much trouble to fix it now. :/
level=logging.ERROR, exc_info=True)
self.close()
def send_reply(self, msgid, ret):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.marshal.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
self.poll()
def return_error(self, msgid, flags, err_type, err_value):
if flags & ASYNC:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
return
if not isinstance(err_value, Exception): if not isinstance(err_value, Exception):
err_value = err_type, err_value err_value = err_type, err_value
...@@ -657,79 +612,37 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -657,79 +612,37 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.message_output(msg) self.message_output(msg)
self.poll() self.poll()
def handle_error(self):
if sys.exc_info()[0] == SystemExit:
raise sys.exc_info()
self.log("Error caught in asyncore",
level=logging.ERROR, exc_info=True)
self.close()
def setSessionKey(self, key): def setSessionKey(self, key):
if self.waiting_for_reply: if self.waiting_for_reply:
self.delay_sesskey = key self.delay_sesskey = key
else: else:
self.__super_setSessionKey(key) self.__super_setSessionKey(key)
# The next two public methods (call and callAsync) are used by def send_call(self, method, args, async=False):
# clients to invoke methods on remote objects
def __new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def __call_message(self, method, args, flags):
# compute a message and return it
msgid = self.__new_msgid()
if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
level=TRACE)
return self.marshal.encode(msgid, flags, method, args)
def send_call(self, method, args, flags):
# send a message and return its msgid # send a message and return its msgid
msgid = self.__new_msgid() if async:
msgid = 0
else:
msgid = self._new_msgid()
if debug_zrpc: if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method), self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
level=TRACE) level=TRACE)
buf = self.marshal.encode(msgid, flags, method, args) buf = self.marshal.encode(msgid, async, method, args)
self.message_output(buf) self.message_output(buf)
return msgid return msgid
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args, 0)
r_flags, r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args, 0)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_flags, r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def callAsync(self, method, *args): def callAsync(self, method, *args):
if self.closed: if self.closed:
raise DisconnectedError() raise DisconnectedError()
self.send_call(method, args, ASYNC) self.send_call(method, args, 1)
self.poll() self.poll()
def callAsyncNoPoll(self, method, *args): def callAsyncNoPoll(self, method, *args):
...@@ -738,7 +651,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -738,7 +651,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# allowing any client to sneak in a load request. # allowing any client to sneak in a load request.
if self.closed: if self.closed:
raise DisconnectedError() raise DisconnectedError()
self.send_call(method, args, ASYNC) self.send_call(method, args, 1)
def callAsyncIterator(self, iterator): def callAsyncIterator(self, iterator):
"""Queue a sequence of calls using an iterator """Queue a sequence of calls using an iterator
...@@ -746,46 +659,11 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -746,46 +659,11 @@ class Connection(smac.SizedMessageAsyncConnection, object):
The calls will not be interleaved with other calls from the same The calls will not be interleaved with other calls from the same
client. client.
""" """
self.message_output(self.__outputIterator(iterator)) self.message_output(self.marshal.encode(0, 1, method, args)
for method, args in iterator)
def __outputIterator(self, iterator):
for method, args in iterator:
yield self.__call_message(method, args, ASYNC)
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
# Delay used when we call asyncore.poll() directly.
# Start with a 1 msec delay, double until 1 sec.
delay = 0.001
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid)
if reply is not None:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
def flush(self): def handle_reply(self, msgid, ret):
"""Invoke poll() until the output buffer is empty.""" assert msgid == -1 and ret is None
if debug_zrpc:
self.log("flush")
while self.writable():
self.poll()
def poll(self): def poll(self):
"""Invoke asyncore mainloop to get pending message out.""" """Invoke asyncore mainloop to get pending message out."""
...@@ -794,7 +672,6 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -794,7 +672,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.trigger.pull_trigger() self.trigger.pull_trigger()
class ManagedServerConnection(Connection): class ManagedServerConnection(Connection):
"""Server-side Connection subclass.""" """Server-side Connection subclass."""
...@@ -803,6 +680,7 @@ class ManagedServerConnection(Connection): ...@@ -803,6 +680,7 @@ class ManagedServerConnection(Connection):
# Servers use a shared server trigger that uses the asyncore socket map # Servers use a shared server trigger that uses the asyncore socket map
trigger = trigger() trigger = trigger()
call_from_thread = trigger.pull_trigger
def __init__(self, sock, addr, obj, mgr): def __init__(self, sock, addr, obj, mgr):
self.mgr = mgr self.mgr = mgr
...@@ -821,13 +699,33 @@ class ManagedServerConnection(Connection): ...@@ -821,13 +699,33 @@ class ManagedServerConnection(Connection):
self.obj.notifyDisconnected() self.obj.notifyDisconnected()
Connection.close(self) Connection.close(self)
def send_reply(self, msgid, ret, immediately=True):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.marshal.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
if immediately:
self.poll()
poll = smac.SizedMessageAsyncConnection.handle_write
class ManagedClientConnection(Connection): class ManagedClientConnection(Connection):
"""Client-side Connection subclass.""" """Client-side Connection subclass."""
__super_init = Connection.__init__ __super_init = Connection.__init__
__super_close = Connection.close
base_message_output = Connection.message_output base_message_output = Connection.message_output
trigger = client_trigger trigger = client_trigger
call_from_thread = trigger.pull_trigger
def __init__(self, sock, addr, mgr): def __init__(self, sock, addr, mgr):
self.mgr = mgr self.mgr = mgr
...@@ -846,9 +744,24 @@ class ManagedClientConnection(Connection): ...@@ -846,9 +744,24 @@ class ManagedClientConnection(Connection):
self.queue_output = True self.queue_output = True
self.queued_messages = [] self.queued_messages = []
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
self.__super_init(sock, addr, None, tag='C', map=client_map) self.__super_init(sock, addr, None, tag='C', map=client_map)
client_trigger.pull_trigger() client_trigger.pull_trigger()
def close(self):
Connection.close(self)
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
# Our message_ouput() queues messages until recv_handshake() gets the # Our message_ouput() queues messages until recv_handshake() gets the
# protocol handshake from the server. # protocol handshake from the server.
def message_output(self, message): def message_output(self, message):
...@@ -890,3 +803,88 @@ class ManagedClientConnection(Connection): ...@@ -890,3 +803,88 @@ class ManagedClientConnection(Connection):
self.queue_output = False self.queue_output = False
finally: finally:
self.output_lock.release() self.output_lock.release()
def _new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
# Delay used when we call asyncore.poll() directly.
# Start with a 1 msec delay, double until 1 sec.
delay = 0.001
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid, self)
if reply is not self:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def handle_reply(self, msgid, args):
if debug_zrpc:
self.log("recv reply: %s, %s"
% (msgid, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def send_reply(self, msgid, ret):
# Whimper. Used to send heartbeat
assert msgid == -1 and ret is None
self.message_output('(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment