Commit 9aed0252 authored by Guillaume Hervier's avatar Guillaume Hervier

Implement threading for diffs checking on restore

parent b0371fcc
......@@ -149,6 +149,7 @@ def init_connection(remote_cmd):
stdin, stdout = os.popen2(remote_cmd)
conn_number = len(Globals.connections)
conn = connection.PipeConnection(stdout, stdin, conn_number)
conn.Client()
check_connection_version(conn, remote_cmd)
Log("Registering connection %d" % conn_number, 7)
......
......@@ -22,7 +22,8 @@
from __future__ import generators
import types, os, tempfile, cPickle, shutil, traceback, \
socket, sys, gzip
socket, sys, gzip, threading
from pool import Pool
# The following EA and ACL modules may be used if available
try: import xattr
except ImportError: pass
......@@ -115,6 +116,8 @@ class LowLevelPipeConnection(Connection):
"""inpipe is a file-type open for reading, outpipe for writing"""
self.inpipe = inpipe
self.outpipe = outpipe
self.write_lock = threading.RLock()
self.read_lock = threading.RLock()
def __str__(self):
"""Return string version
......@@ -128,6 +131,8 @@ class LowLevelPipeConnection(Connection):
def _put(self, obj, req_num):
"""Put an object into the pipe (will send raw if string)"""
self.write_lock.acquire()
log.Log.conn("sending", obj, req_num)
if type(obj) is types.StringType: self._putbuf(obj, req_num)
elif isinstance(obj, connection.Connection):self._putconn(obj, req_num)
......@@ -140,6 +145,8 @@ class LowLevelPipeConnection(Connection):
elif hasattr(obj, "next"): self._putiter(obj, req_num)
else: self._putobj(obj, req_num)
self.write_lock.release()
def _putobj(self, obj, req_num):
"""Send a generic python obj down the outpipe"""
self._write("o", cPickle.dumps(obj, 1), req_num)
......@@ -229,7 +236,11 @@ class LowLevelPipeConnection(Connection):
def _get(self):
"""Read an object from the pipe and return (req_num, value)"""
self.read_lock.acquire()
header_string = self.inpipe.read(9)
if len(header_string) == 0:
raise ConnectionQuit('EOF')
if not len(header_string) == 9:
raise ConnectionReadError("Truncated header string (problem "
"probably originated remotely)")
......@@ -251,6 +262,8 @@ class LowLevelPipeConnection(Connection):
assert format_string == "c", header_string
result = Globals.connection_dict[int(data)]
log.Log.conn("received", result, req_num)
self.read_lock.release()
return (req_num, result)
def _getrorpath(self, raw_rorpath_buf):
......@@ -276,6 +289,53 @@ class LowLevelPipeConnection(Connection):
self.inpipe.close()
class RequestNumberRegistry(object):
def __init__(self):
self._lock = threading.RLock()
self._next = 0
self._entries = {}
def get(self):
with self._lock:
if self._next >= 256:
return None
req_num = self._next
self.insert(req_num)
return req_num
def insert(self, req_num):
with self._lock:
if req_num in self._entries:
# Vacant slot
self._next = self._entries[req_num]
else:
self._next += 1
def remove(self, req_num):
with self._lock:
self._entries[req_num] = self._next
self._next = req_num
class AsyncRequest(object):
def __init__(self, req_num):
self.req_num = req_num
self.value = None
self.completed = threading.Event()
def set(self, value):
self.value = value
self.completed.set()
def get(self):
while not self.completed.is_set():
self.completed.wait()
return self.value
class PipeConnection(LowLevelPipeConnection):
"""Provide server and client functions for a Pipe Connection
......@@ -287,6 +347,17 @@ class PipeConnection(LowLevelPipeConnection):
client makes the first request, and the server listens first.
"""
DISCARDED_RESULTS_FUNCTIONS = [
'log.Log.log_to_file',
'log.Log.close_logfile_allconn',
'rpath.setdata_local',
'Globals.set',
]
RUN_ON_MAIN_THREAD = [
'robust.install_signal_handlers',
]
def __init__(self, inpipe, outpipe, conn_number = 0):
"""Init PipeConnection
......@@ -298,45 +369,46 @@ class PipeConnection(LowLevelPipeConnection):
"""
LowLevelPipeConnection.__init__(self, inpipe, outpipe)
self.conn_number = conn_number
self.unused_request_numbers = {}
for i in range(256): self.unused_request_numbers[i] = None
self.request_numbers = RequestNumberRegistry()
self.requests = {}
self.pool = Pool(processes=4,
max_taskqueue_size=16)
self._read_thread = None
def __str__(self): return "PipeConnection %d" % self.conn_number
def get_response(self, desired_req_num):
"""Read from pipe, responding to requests until req_num.
Sometimes after a request is sent, the other side will make
another request before responding to the original one. In
that case, respond to the request. But return once the right
response is given.
"""
while 1:
try: req_num, object = self._get()
def read_messages(self):
while True:
try:
req_num, obj = self._get()
except ConnectionQuit:
self._put("quitting", self.get_new_req_num())
self._close()
return
if req_num == desired_req_num: return object
else:
assert isinstance(object, ConnectionRequest)
self.answer_request(object, req_num)
def answer_request(self, request, req_num):
"""Put the object requested by request down the pipe"""
del self.unused_request_numbers[req_num]
argument_list = []
for i in range(request.num_args):
arg_req_num, arg = self._get()
assert arg_req_num == req_num
argument_list.append(arg)
break
if isinstance(obj, ConnectionRequest):
args = []
for _ in range(obj.num_args):
arg_req_num, arg = self._get()
assert arg_req_num == req_num
args.append(arg)
if Globals.server and obj.function_string in self.RUN_ON_MAIN_THREAD:
self.answer_request(obj, args, req_num)
else:
self.pool.apply(self.answer_request, obj, args, req_num)
elif req_num in self.requests:
req = self.requests.pop(req_num)
req.set(obj)
self.request_numbers.remove(req_num)
def answer_request(self, request, args, req_num):
try:
Security.vet_request(request, argument_list)
result = apply(eval(request.function_string), argument_list)
Security.vet_request(request, args)
result = apply(eval(request.function_string), args)
except: result = self.extract_exception()
self._put(result, req_num)
self.unused_request_numbers[req_num] = None
if request.function_string not in self.DISCARDED_RESULTS_FUNCTIONS:
self._put(result, req_num)
def extract_exception(self):
"""Return active exception"""
......@@ -348,12 +420,24 @@ class PipeConnection(LowLevelPipeConnection):
"".join(traceback.format_tb(sys.exc_info()[2]))), 5)
return sys.exc_info()[1]
def Client(self):
self._read_thread = read_thread = threading.Thread(target=self.read_messages)
read_thread.daemon = True
read_thread.start()
def Server(self):
"""Start server's read eval return loop"""
Globals.server = 1
Globals.connections.append(self)
log.Log("Starting server", 6)
self.get_response(-1)
# self.get_response(-1)
self.read_messages()
def new_request(self):
req_num = self.get_new_req_num()
req = AsyncRequest(req_num)
self.requests[req_num] = req
return req
def reval(self, function_string, *args):
"""Execute command on remote side
......@@ -363,11 +447,20 @@ class PipeConnection(LowLevelPipeConnection):
function.
"""
req_num = self.get_new_req_num()
self._put(ConnectionRequest(function_string, len(args)), req_num)
for arg in args: self._put(arg, req_num)
result = self.get_response(req_num)
self.unused_request_numbers[req_num] = None
req = self.new_request()
self.write_lock.acquire()
self._put(ConnectionRequest(function_string, len(args)), req.req_num)
for arg in args: self._put(arg, req.req_num)
self.write_lock.release()
if function_string in self.DISCARDED_RESULTS_FUNCTIONS:
result = None
del self.requests[req.req_num]
self.request_numbers.remove(req.req_num)
else:
result = req.get()
if isinstance(result, Exception): raise result
elif isinstance(result, SystemExit): raise result
elif isinstance(result, KeyboardInterrupt): raise result
......@@ -375,18 +468,19 @@ class PipeConnection(LowLevelPipeConnection):
def get_new_req_num(self):
"""Allot a new request number and return it"""
if not self.unused_request_numbers:
req_num = self.request_numbers.get()
if req_num is None:
raise ConnectionError("Exhaused possible connection numbers")
req_num = self.unused_request_numbers.keys()[0]
del self.unused_request_numbers[req_num]
return req_num
def quit(self):
"""Close the associated pipes and tell server side to quit"""
assert not Globals.server
self._putquit()
self._get()
self._close()
if self._read_thread is not None:
self._read_thread.join()
self.pool.stop()
self.pool.join()
def __getattr__(self, name):
"""Intercept attributes to allow for . invocation"""
......
......@@ -127,7 +127,7 @@ class IterVirtualFile(UnwrapFile):
return_val = self.buffer[:real_len]
self.buffer = self.buffer[real_len:]
return return_val
def addtobuffer(self):
"""Read a chunk from the file and add it to the buffer"""
assert self.iwf.currently_in_file
......@@ -335,6 +335,8 @@ class MiscIterToFile(FileWrappingIter):
elif currentobj is iterfile.MiscIterFlushRepeat:
self.add_misc(currentobj)
return None
elif isinstance(currentobj, rpath.RPath):
self.addrpath(currentobj)
elif isinstance(currentobj, rpath.RORPath):
self.addrorp(currentobj)
else: self.add_misc(currentobj)
......@@ -358,7 +360,19 @@ class MiscIterToFile(FileWrappingIter):
self.array_buf.fromstring("r")
self.array_buf.fromstring(C.long2str(long(len(pickle))))
self.array_buf.fromstring(pickle)
def addrpath(self, rp):
if rp.file:
data = (rp.conn.conn_number, rp.base, rp.index, rp.data, 1)
self.next_in_line = rp.file
else:
data = (rp.conn.conn_number, rp.base, rp.index, rp.data, 0)
self.rorps_in_buffer += 1
pickle = cPickle.dumps(data, 1)
self.array_buf.fromstring("R")
self.array_buf.fromstring(C.long2str(long(len(pickle))))
self.array_buf.fromstring(pickle)
def addfinal(self):
"""Signal the end of the iterator to the other end"""
self.array_buf.fromstring("z")
......@@ -383,9 +397,19 @@ class FileToMiscIter(IterWrappingFile):
while not type: type, data = self._get()
if type == "z": raise StopIteration
elif type == "r": return self.get_rorp(data)
elif type == "R": return self.get_rp(data)
elif type == "o": return data
else: raise IterFileException("Bad file type %s" % (type,))
def get_rp(self, pickled_tuple):
conn_number, base, index, data_dict, num_files = pickled_tuple
rp = rpath.RPath(Globals.connection_dict[conn_number],
base, index, data_dict)
if num_files:
assert num_files == 1, "Only one file accepted right now"
rp.setfile(self.get_file())
return rp
def get_rorp(self, pickled_tuple):
"""Return rorp that data represents"""
index, data_dict, num_files = pickled_tuple
......@@ -419,7 +443,7 @@ class FileToMiscIter(IterWrappingFile):
type, length = self.buf[0], C.str2long(self.buf[1:8])
data = self.buf[8:8+length]
self.buf = self.buf[8+length:]
if type in "oerh": return type, cPickle.loads(data)
if type in "oerRh": return type, cPickle.loads(data)
else: return type, data
......
......@@ -135,6 +135,7 @@ class Logger:
if verbosity <= 2 or Globals.server: termfp = sys.stderr
else: termfp = sys.stdout
termfp.write(self.format(message, self.term_verbosity))
termfp.flush()
def conn(self, direction, result, req_num):
"""Log some data on the connection
......
# vim: set nolist noet ts=4:
# Copyright 2002, 2003, 2004, 2005 Ben Escoto
#
# This file is part of rdiff-backup.
#
# rdiff-backup is free software; you can redistribute it and/or modify
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# rdiff-backup is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with rdiff-backup; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
from collections import namedtuple
from multiprocessing import cpu_count
import Queue
import itertools
import threading
Job = namedtuple('Job', ['func', 'iterable', 'outqueue'])
Task = namedtuple('Task', ['func', 'args', 'index', 'outqueue'])
Result = namedtuple('Result', ['index', 'value'])
RUNNING = 0
STOPPED = 1
def worker(taskqueue):
while True:
task = taskqueue.get(True)
if task is None:
taskqueue.task_done()
break
if task.func is None:
# It means this task was the last of an iterable job
result = None
else:
value = task.func(*task.args)
result = Result(task.index, value)
task.outqueue.put(result, block=True)
taskqueue.task_done()
class Pool(object):
def __init__(self, processes=None,
max_taskqueue_size=0, max_jobqueue_size=0):
if processes is None:
processes = cpu_count()
self.processes = processes
self.state = STOPPED
# Init queues
self.taskqueue = Queue.Queue(maxsize=max_taskqueue_size)
self.jobqueue = Queue.Queue(maxsize=max_jobqueue_size)
# Init workers
self.workers = []
self.start_workers()
# Init task handler thread
self._job_handler_thread = self._start_handler_thread(self._job_handler,
self.jobqueue,
self.taskqueue)
def start_workers(self):
while len(self.workers) < self.processes:
w = self._start_handler_thread(worker, self.taskqueue)
self.workers.append(w)
for w in self.workers:
if not w.is_alive():
w.start()
def _start_handler_thread(self, func, *args):
thread = threading.Thread(target=func, args=args)
thread.daemon = True
thread.start()
return thread
def create_job(self, func, iterable, max_outqueue_size=0):
outqueue = Queue.Queue(maxsize=max_outqueue_size)
job = Job(func, iterable, outqueue)
self.jobqueue.put(job)
return job
def imap(self, func, iterable, max_outqueue_size=0):
iterable = itertools.imap(None, iterable)
job = self.create_job(func, iterable,
max_outqueue_size=max_outqueue_size)
return IMapIterator(job.outqueue)
def apply(self, func, *args):
job = self.create_job(func, [args])
return AsyncResult(job.outqueue)
def stop(self):
self.jobqueue.put(None, block=True)
def join(self, timeout=None):
self.stop()
self._job_handler_thread.join(timeout=timeout)
for w in self.workers:
w.join(timeout=timeout)
def _job_handler(self, jobqueue, taskqueue):
while True:
job = jobqueue.get(True)
if job is None:
for w in self.workers:
taskqueue.put(None)
break
for (index, args) in enumerate(job.iterable):
task = Task(job.func, args, index, job.outqueue)
taskqueue.put(task, block=True)
taskqueue.put(Task(None, None, None, job.outqueue), block=True)
jobqueue.task_done()
class IMapIterator(object):
def __init__(self, outqueue):
self.outqueue = outqueue
self.results = {}
self.index = 0
def __iter__(self):
return self
def next(self):
while True:
if self.index in self.results:
result = self.results.pop(self.index)
else:
result = self.outqueue.get(True)
if result is None:
raise StopIteration()
if result.index != self.index:
self.results[result.index] = result
continue
self.index += 1
return result.value
class AsyncResult(object):
def __init__(self, outqueue):
self.outqueue = outqueue
self.completed = False
self.value = None
def wait(self):
if self.completed:
return
self.value = self.outqueue.get(True)
self.completed = True
def get(self):
self.wait()
return self.value
......@@ -21,7 +21,8 @@
"""Read increment files and restore to original"""
from __future__ import generators
import tempfile, os, cStringIO
from pool import Pool as ThreadPool
import tempfile, os, cStringIO, itertools
import static, rorpiter, FilenameMapping, connection
class RestoreError(Exception): pass
......@@ -31,15 +32,43 @@ def Restore(mirror_rp, inc_rpath, target, restore_to_time):
MirrorS = mirror_rp.conn.restore.MirrorStruct
TargetS = target.conn.restore.TargetStruct
pool = ThreadPool(max_taskqueue_size=8)
MirrorS.set_mirror_and_rest_times(restore_to_time)
MirrorS.initialize_rf_cache(mirror_rp, inc_rpath)
# we run this locally to retrieve RPath instead of RORPath objects
# target_iter = TargetS.get_initial_iter(target)
target_iter = selection.Select(target).set_iter()
diff_iter = MirrorS.get_diffs(target_iter)
target_iter = TargetS.get_initial_iter(target)
# target_iter = selection.Select(target).set_iter()
mir_iter = MirrorS.subtract_indicies(MirrorS.mirror_base.index,
MirrorS.get_mirror_rorp_iter())
collated = rorpiter.Collate2Iters(mir_iter, target_iter)
diff_iter = pool.imap(get_diff, collated, max_outqueue_size=8)
diff_iter = itertools.ifilter(lambda diff: diff is not None, diff_iter)
TargetS.patch(target, diff_iter)
pool.stop()
pool.join()
MirrorS.close_rf_cache()
def get_diff(args):
mir_rorp, target_rorp = args
if Globals.preserve_hardlinks and mir_rorp:
Hardlink.add_rorp(mir_rorp, target_rorp)
diff = None
if not (target_rorp and mir_rorp and target_rorp == mir_rorp and
(not Globals.preserve_hardlinks or
Hardlink.rorp_eq(mir_rorp, target_rorp))):
diff = MirrorStruct.get_diff(mir_rorp, target_rorp)
if Globals.preserve_hardlinks and mir_rorp:
Hardlink.del_rorp(mir_rorp)
return diff
def get_inclist(inc_rpath):
"""Returns increments with given base"""
dirname, basename = inc_rpath.dirsplit()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment