__init__.py 41.1 KB
Newer Older
1
#
Julien Muchembled's avatar
Julien Muchembled committed
2
# Copyright (C) 2011-2019  Nexedi SA
3 4 5 6 7 8 9 10 11 12 13 14
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

17 18
# XXX: Consider using ClusterStates.STOPPING to stop clusters

19 20
import os, random, select, socket, sys, tempfile
import thread, threading, time, traceback, weakref
21
from collections import deque
22
from contextlib import contextmanager
23
from itertools import count
24
from functools import partial, wraps
25
from zlib import decompress
26 27 28
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
29
from neo.admin.handler import MasterEventHandler
30
from neo.client import Storage
31
from neo.lib import logging
32
from neo.lib.connection import BaseConnection, \
33
    ClientConnection, Connection, ConnectionClosed, ListeningConnection
34
from neo.lib.connector import SocketConnector, ConnectorException
35
from neo.lib.handler import EventHandler
36
from neo.lib.locking import SimpleQueue
37 38
from neo.lib.protocol import uuid_str, \
    ClusterStates, Enum, NodeStates, NodeTypes, Packets
39
from neo.lib.util import cached_property, parseMasterList, p64
40
from neo.master.recovery import  RecoveryManager
41 42 43
from .. import (getTempDirectory, setupMySQLdb,
    ImporterConfigParser, NeoTestBase, Patch,
    ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER)
44 45 46

BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
47
TIC_LOOP = xrange(1000)
48 49


50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
class LockLock(object):
    """Double lock used as synchronisation point between 2 threads

    Used to wait that a slave thread has reached a specific location, and to
    keep it suspended there. It resumes on __exit__
    """

    def __init__(self):
        self._l = threading.Lock(), threading.Lock()

    def __call__(self):
        """Define synchronisation point for both threads"""
        if self._owner == thread.get_ident():
            self._l[0].acquire()
        else:
            self._l[0].release()
            self._l[1].acquire()

    def __enter__(self):
        self._owner = thread.get_ident()
        for l in self._l:
            l.acquire(0)
        return self

    def __exit__(self, t, v, tb):
        try:
            self._l[1].release()
        except thread.error:
            pass


81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
class FairLock(deque):
    """Same as a threading.Lock except that waiting threads are queued, so that
    the first one waiting for the lock is the first to get it. This is useful
    when several concurrent threads fight for the same resource in loop:
    the owner could give too little time for other to get a chance to acquire,
    blocking them for a long time with bad luck.
    """

    def __enter__(self, _allocate_lock=threading.Lock):
        me = _allocate_lock()
        me.acquire()
        self.append(me)
        other = self[0]
        while me is not other:
            with other:
                other = self[0]

    def __exit__(self, t, v, tb):
        self.popleft().release()


102
class Serialized(object):
103
    """
104 105
    "Threaded" tests run all nodes in the same process as the test itself,
    and threads are scheduled by this class, which mainly provides 2 features:
106 107 108 109 110 111 112 113 114 115 116 117 118
    - more determinism, by minimizing the number of active threads, and
      switching them in a round-robin;
    - tic() method to wait only the necessary time for the cluster to be idle.

    The basic concept is that each thread has a lock that always gets acquired
    by itself. The following pattern is used to yield the processor to the next
    thread:
        release(); acquire()
    It should be noted that this is not atomic, i.e. all other threads
    sometimes complete before a thread tries to acquire its lock: in order that
    the previous thread does not fail by releasing an un-acquired lock,
    we actually use Semaphores instead of Locks.

119
    The epoll object of each node is hooked so that thread switching happens
120 121
    before polling for network activity. An extra epoll object is used to
    detect which node has a readable epoll object.
122
    """
123
    check_timeout = False
124
    _disabled = False
125

126 127
    @classmethod
    def init(cls):
128 129
        if cls._disabled:
            return
130 131 132 133 134
        cls._busy = set()
        cls._busy_cond = threading.Condition(threading.Lock())
        cls._epoll = select.epoll()
        cls._pdb = None
        cls._sched_lock = threading.Semaphore(0)
135 136
        cls._tic_lock = FairLock()
        cls._fd_dict = {}
137

138
    @classmethod
139 140 141 142
    def idle(cls, owner):
        with cls._busy_cond:
            cls._busy.discard(owner)
            cls._busy_cond.notify_all()
143

144
    @classmethod
145
    def stop(cls):
146 147
        if cls._disabled:
            return
148 149 150
        assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen"
            " when a test fails, in which case you can see the real exception"
            " by disabling this one." % cls._fd_dict)
151 152
        del(cls._busy, cls._busy_cond, cls._epoll, cls._fd_dict,
            cls._pdb, cls._sched_lock, cls._tic_lock)
153

154
    @classmethod
155 156
    def _sort_key(cls, fd_event):
        return -cls._fd_dict[fd_event[0]]._last
157

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    @classmethod
    @contextmanager
    def until(cls, patched=None, **patch):
        if cls._disabled:
            if patched is None:
                yield int
            else:
                l = threading.Lock()
                l.acquire()
                (name, patch), = patch.iteritems()
                def release():
                    p.revert()
                    l.release()
                with Patch(patched, **{name: lambda *args, **kw:
                        patch(release, *args, **kw)}) as p:
                    yield l.acquire
        else:
            yield cls.tic

177
    @classmethod
178 179 180 181 182 183 184 185 186 187 188 189 190 191
    @contextmanager
    def pdb(cls):
        try:
            cls._pdb = sys._getframe(2).f_trace.im_self
            cls._pdb.set_continue()
        except AttributeError:
            pass
        yield
        p = cls._pdb
        if p is not None:
            cls._pdb = None
            t = threading.currentThread()
            p.stdout.write(getattr(t, 'node_name', t.name))
            p.set_trace(sys._getframe(3))
192 193

    @classmethod
194 195 196 197 198 199
    def tic(cls, step=-1, check_timeout=(), quiet=False,
            # BUG: We overuse epoll as a way to know if there are pending
            #      network messages. Sometimes, and this is more visible with
            #      a single-core CPU, other threads are still busy and haven't
            #      sent anything yet on the network. This causes tic() to
            #      return prematurely. Passing a non-zero value is a hack.
200 201
            #      We also increase SocketConnector.SOMAXCONN in tests so that
            #      a connection attempt is never delayed inside the kernel.
202
            timeout=0):
203 204 205 206
        if cls._disabled:
            if timeout:
                time.sleep(timeout)
            return
207 208
        # If you're in a pdb here, 'n' switches to another thread
        # (the following lines are not supposed to be debugged into)
209
        with cls._tic_lock, cls.pdb():
210 211 212 213 214 215 216
            if not quiet:
                f = sys._getframe(1)
                try:
                    logging.info('tic (%s:%u) ...',
                        f.f_code.co_filename, f.f_lineno)
                finally:
                    del f
217
            if cls._busy:
218 219 220
                with cls._busy_cond:
                    while cls._busy:
                        cls._busy_cond.wait()
221 222 223 224 225
            for app in check_timeout:
                app.em.epoll.check_timeout = True
                app.em.wakeup()
                del app
            while step:
226
                event_list = cls._epoll.poll(timeout)
227 228 229 230 231 232 233 234 235 236 237 238
                if not event_list:
                    break
                step -= 1
                event_list.sort(key=cls._sort_key)
                next_lock = cls._sched_lock
                for fd, event in event_list:
                    self = cls._fd_dict[fd]
                    self._release_next = next_lock.release
                    next_lock = self._lock
                del self
                next_lock.release()
                cls._sched_lock.acquire()
239 240

    def __init__(self, app, busy=True):
241 242
        if self._disabled:
            return
243 244
        self._epoll = app.em.epoll
        app.em.epoll = self
245 246 247
        # XXX: It may have been initialized before the SimpleQueue is patched.
        thread_container = getattr(app, '_thread_container', None)
        thread_container is None or thread_container.__init__()
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
        if busy:
            self._busy.add(self) # block tic until app waits for polling

    def __getattr__(self, attr):
        if attr in ('close', 'modify', 'register', 'unregister'):
            return getattr(self._epoll, attr)
        return self.__getattribute__(attr)

    def poll(self, timeout):
        if self.check_timeout:
            assert timeout >= 0, (self, timeout)
            del self.check_timeout
        elif timeout:
            with self.pdb(): # same as in tic()
                release = self._release_next
                self._release_next = None
                release()
                self._lock.acquire()
                self._last = time.time()
        return self._epoll.poll(timeout)

    def _release_next(self):
        self._last = time.time()
        self._lock = threading.Semaphore(0)
        fd = self._epoll.fileno()
        cls = self.__class__
        cls._fd_dict[fd] = self
        cls._epoll.register(fd)
        cls.idle(self)
277

278 279
    def exit(self):
        fd = self._epoll.fileno()
280
        cls = self.__class__
281 282
        if cls._fd_dict.pop(fd, None) is None:
            cls.idle(self)
283
        else:
284 285 286 287 288 289 290
            cls._epoll.unregister(fd)
            self._release_next()

class TestSerialized(Serialized):

    def __init__(*args):
        Serialized.__init__(busy=False, *args)
291

292
    def poll(self, timeout):
293
        if timeout:
294
            for x in TIC_LOOP:
295 296 297
                r = self._epoll.poll(0)
                if r:
                    return r
298
                Serialized.tic(step=1, timeout=.001)
299
            ConnectionFilter.log()
300
            raise Exception("tic is looping forever")
301
        return self._epoll.poll(timeout)
302

303

304 305
class Node(object):

306
    def getConnectionList(self, *peers):
307
        addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
308
        addr_set = {addr(c.connector) for peer in peers
309
            for c in peer.em.connection_dict.itervalues()
310
            if isinstance(c, Connection)}
311
        addr_set.discard(None)
312
        return (c for c in self.em.connection_dict.itervalues()
313
            if isinstance(c, Connection) and addr(c.connector) in addr_set)
314 315 316

    def filterConnection(self, *peers):
        return ConnectionFilter(self.getConnectionList(*peers))
317 318

class ServerNode(Node):
319

320 321
    _server_class_dict = {}

322 323
    class __metaclass__(type):
        def __init__(cls, name, bases, d):
324
            if Node not in bases and threading.Thread not in cls.__mro__:
325
                cls.__bases__ = bases + (threading.Thread,)
326 327 328 329 330 331
                cls.node_type = getattr(NodeTypes, name[:-11].upper())
                cls._node_list = []
                cls._virtual_ip = socket.inet_ntop(ADDRESS_TYPE,
                    LOCAL_IP[:-1] + chr(2 + len(cls._server_class_dict)))
                cls._server_class_dict[cls._virtual_ip] = cls

332 333 334 335 336
    @staticmethod
    def resetPorts():
        for cls in ServerNode._server_class_dict.itervalues():
            del cls._node_list[:]

337 338 339 340 341 342 343 344 345 346 347 348 349
    @classmethod
    def newAddress(cls):
        address = cls._virtual_ip, len(cls._node_list)
        cls._node_list.append(None)
        return address

    @classmethod
    def resolv(cls, address):
        try:
            cls = cls._server_class_dict[address[0]]
        except KeyError:
            return address
        return cls._node_list[address[1]].getListeningAddress()
350

351
    def __init__(self, cluster=None, address=None, **kw):
352 353
        if not address:
            address = self.newAddress()
354
        if cluster is None:
355
            master_nodes = ()
356
            name = kw.get('name', 'test')
357
        else:
358
            master_nodes = cluster.master_nodes
359
            name = kw.get('name', cluster.name)
360
        port = address[1]
361 362
        if address is not BIND:
            self._node_list[port] = weakref.proxy(self)
363 364 365
        self._init_args = init_args = kw.copy()
        init_args['cluster'] = cluster
        init_args['address'] = address
366
        threading.Thread.__init__(self)
367
        self.daemon = True
368
        self.node_name = '%s_%u' % (self.node_type, port)
369 370 371
        kw.update(cluster=name, bind=address,
            masters=master_nodes and parseMasterList(master_nodes))
        super(ServerNode, self).__init__(kw)
372

373
    def getVirtualAddress(self):
374
        return self._init_args['address']
375

376
    def resetNode(self, **kw):
377
        assert not self.is_alive()
378
        init_args = self._init_args
379
        init_args['reset'] = False
380 381
        assert set(kw).issubset(init_args), (kw, init_args)
        init_args.update(kw)
382
        self.close()
383
        self.__init__(**init_args)
384 385

    def start(self):
386
        Serialized(self)
387 388 389 390 391 392 393
        threading.Thread.start(self)

    def run(self):
        try:
            super(ServerNode, self).run()
        finally:
            self._afterRun()
394
            logging.debug('stopping %r', self)
395 396
            if isinstance(self.em.epoll, Serialized):
                self.em.epoll.exit()
397 398 399 400

    def _afterRun(self):
        try:
            self.listening_conn.close()
401
            self.listening_conn = None
402 403 404 405 406 407 408
        except AttributeError:
            pass

    def getListeningAddress(self):
        try:
            return self.listening_conn.getAddress()
        except AttributeError:
409
            raise ConnectorException
410

411
    def stop(self):
412
        self.em.wakeup(thread.exit)
413

414 415 416 417 418 419 420 421
class AdminApplication(ServerNode, neo.admin.app.Application):
    pass

class MasterApplication(ServerNode, neo.master.app.Application):
    pass

class StorageApplication(ServerNode, neo.storage.app.Application):

422 423
    dm = type('', (), {'close': lambda self: None})()

424 425 426 427
    def _afterRun(self):
        super(StorageApplication, self)._afterRun()
        try:
            self.dm.close()
428
            del self.dm
429 430
        except StandardError: # AttributeError & ProgrammingError
            pass
431 432
        if self.master_conn:
            self.master_conn.close()
433

434
    def getAdapter(self):
435
        return self._init_args['adapter']
436

437 438
    def getDataLockInfo(self):
        dm = self.dm
439 440
        index = tuple(dm.query("SELECT id, hash, compression FROM data"))
        assert set(dm._uncommitted_data).issubset(x[0] for x in index)
441
        get = dm._uncommitted_data.get
442 443 444 445 446
        return {(str(h), c & 0x7f): get(i, 0) for i, h, c in index}

    def sqlCount(self, table):
        (r,), = self.dm.query("SELECT COUNT(*) FROM " + table)
        return r
447

448
class ClientApplication(Node, neo.client.app.Application):
449

450 451
    max_reconnection_to_master = 10

452 453
    def __init__(self, master_nodes, name, **kw):
        super(ClientApplication, self).__init__(master_nodes, name, **kw)
454
        self.poll_thread.node_name = name
455 456
        # Smaller cache to speed up tests that checks behaviour when it's too
        # small. See also NEOCluster.cache_size
457
        self._cache.max_size //= 1024
458

459
    def _run(self):
460
        try:
461
            super(ClientApplication, self)._run()
462
        finally:
463 464
            if isinstance(self.em.epoll, Serialized):
                self.em.epoll.exit()
465 466 467 468

    def start(self):
        isinstance(self.em.epoll, Serialized) or Serialized(self)
        super(ClientApplication, self).start()
469

470
    def getConnectionList(self, *peers):
471 472 473 474 475
        for peer in peers:
            if isinstance(peer, MasterApplication):
                conn = self._getMasterConnection()
            else:
                assert isinstance(peer, StorageApplication)
476
                conn = self.getStorageConnection(self.nm.getByUUID(peer.uuid))
477
            yield conn
478

479
    def extraCellSortKey(self, key):
480
        return Patch(self, getCellSortKey=lambda orig, cell:
481
            (orig(cell, lambda: key(cell)), random.random()))
482

483 484 485 486 487 488 489
    def closeAllStorageConnections(self):
        for node in self.nm.getStorageList():
            conn = node._connection # XXX
            if conn is not None:
                conn.setReconnectionNoDelay()
                conn.close()

490 491
class NeoCTL(neo.neoctl.app.NeoCTL):

492 493
    def __init__(self, *args, **kw):
        super(NeoCTL, self).__init__(*args, **kw)
494
        TestSerialized(self)
495 496


497
class LoggerThreadName(str):
498

499 500
    def __new__(cls, default='TEST'):
        return str.__new__(cls, default)
501

502
    def __getattribute__(self, attr):
503 504
        return getattr(str(self), attr)

505 506 507
    def __hash__(self):
        return id(self)

508 509
    def __str__(self):
        try:
510
            return threading.currentThread().node_name
511
        except AttributeError:
512
            return str.__str__(self)
513

514 515 516

class ConnectionFilter(object):

517
    filtered_count = 0
518
    filter_list = []
519
    filter_queue = weakref.WeakKeyDictionary() # XXX: see the end of __new__
520
    lock = threading.RLock()
521 522 523 524 525
    _addPacket = Connection._addPacket

    @contextmanager
    def __new__(cls, conn_list=()):
        self = object.__new__(cls)
526
        self.filter_dict = {}
527 528 529 530 531 532 533 534
        self.conn_list = frozenset(conn_list)
        if not cls.filter_list:
            def _addPacket(conn, packet):
                with cls.lock:
                    try:
                        queue = cls.filter_queue[conn]
                    except KeyError:
                        for self in cls.filter_list:
535
                            if self._test(conn, packet):
536 537 538 539 540
                                self.filtered_count += 1
                                break
                        else:
                            return cls._addPacket(conn, packet)
                        cls.filter_queue[conn] = queue = deque()
541 542 543 544
                    p = packet.__class__
                    logging.debug("queued %s#0x%04x for %s",
                                  p.__name__, packet.getId(), conn)
                    p = packet.__new__(p)
545 546 547 548 549 550 551 552 553 554
                    p.__dict__.update(packet.__dict__)
                    queue.append(p)
            Connection._addPacket = _addPacket
        try:
            cls.filter_list.append(self)
            yield self
        finally:
            del cls.filter_list[-1:]
            if not cls.filter_list:
                Connection._addPacket = cls._addPacket.im_func
555 556 557 558 559
            # Retry even in case of exception, at least to avoid leaks in
            # filter_queue. Sometimes, WeakKeyDictionary only does the job
            # only an explicit call to gc.collect.
            with cls.lock:
                cls._retry()
560

561
    def _test(self, conn, packet):
562 563 564 565 566
        if not self.conn_list or conn in self.conn_list:
            for filter in self.filter_dict:
                if filter(conn, packet):
                    return True
        return False
567

568 569 570 571 572
    @classmethod
    def retry(cls):
        with cls.lock:
            cls._retry()

573 574 575
    @classmethod
    def _retry(cls):
        for conn, queue in cls.filter_queue.items():
576 577
            while queue:
                packet = queue.popleft()
578
                for self in cls.filter_list:
579
                    if self._test(conn, packet):
580 581 582
                        queue.appendleft(packet)
                        break
                else:
583
                    if conn.isClosed():
584 585 586 587 588 589
                        queue.clear()
                    else:
                        # Use the thread that created the packet to reinject it,
                        # to avoid a race condition on Connector.queued.
                        conn.em.wakeup(lambda conn=conn, packet=packet:
                            conn.isClosed() or cls._addPacket(conn, packet))
590 591
                    continue
                break
592 593
            else:
                del cls.filter_queue[conn]
594

595 596 597 598 599 600 601 602 603 604 605 606 607 608
    @classmethod
    def log(cls):
        try:
            if cls.filter_queue:
                logging.info('%s:', cls.__name__)
                for conn, queue in cls.filter_queue.iteritems():
                    app = NEOThreadedTest.getConnectionApp(conn)
                    logging.info('  %s %s:', uuid_str(app.uuid), conn)
                    for p in queue:
                        logging.info('    #0x%04x %s',
                                     p.getId(), p.__class__.__name__)
        except Exception:
            logging.exception('')

609
    def add(self, filter, *patches):
610
        with self.lock:
611
            self.filter_dict[filter] = patches
612 613
            for p in patches:
                p.apply()
614 615

    def remove(self, *filters):
616
        with self.lock:
617
            for filter in filters:
618 619
                for p in self.filter_dict.pop(filter):
                    p.revert()
620 621
            self._retry()

622 623 624 625 626 627
    def discard(self, *filters):
        try:
            self.remove(*filters)
        except KeyError:
            pass

628 629 630
    def __contains__(self, filter):
        return filter in self.filter_dict

631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
    def byPacket(self, packet_type, *args):
        patches = []
        other = []
        for x in args:
            (patches if isinstance(x, Patch) else other).append(x)
        def delay(conn, packet):
            return isinstance(packet, packet_type) and False not in (
                callback(conn) for callback in other)
        self.add(delay, *patches)
        return delay

    def __getattr__(self, attr):
        if attr.startswith('delay'):
            return partial(self.byPacket, getattr(Packets, attr[5:]))
        return self.__getattribute__(attr)

647 648
class NEOCluster(object):

Julien Muchembled's avatar
Julien Muchembled committed
649 650
    SSL = None

651 652
    def __init__(orig, self): # temporary definition for SimpleQueue patch
        orig(self)
653 654
        if Serialized._disabled:
            return
655 656 657
        lock = self._lock
        def _lock(blocking=True):
            if blocking:
658
                logging.info('<SimpleQueue>._lock.acquire()')
659
                for i in TIC_LOOP:
660 661
                    if lock(False):
                        return True
662
                    Serialized.tic(step=1, quiet=True, timeout=.001)
663
                ConnectionFilter.log()
664
                raise Exception("tic is looping forever")
665 666 667 668 669 670
            return lock(False)
        self._lock = _lock
    _patches = (
        Patch(BaseConnection, getTimeout=lambda orig, self: None),
        Patch(SimpleQueue, __init__=__init__),
        Patch(SocketConnector, CONNECT_LIMIT=0),
671
        Patch(SocketConnector, SOMAXCONN=128), # see Serialized.tic comment
672 673 674
        Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)),
        Patch(SocketConnector, _connect = lambda orig, self, addr:
            orig(self, ServerNode.resolv(addr))))
675 676
    _patch_count = 0
    _resource_dict = weakref.WeakValueDictionary()
677

678 679 680 681 682 683
    def _allocate(self, resource, new):
        result = resource, new()
        while result in self._resource_dict:
            result = resource, new()
        self._resource_dict[result] = self
        return result[1]
684

685 686 687
    @staticmethod
    def _patch():
        cls = NEOCluster
688 689 690
        cls._patch_count += 1
        if cls._patch_count > 1:
            return
691 692
        for patch in cls._patches:
            patch.apply()
693
        Serialized.init()
694

695
    @staticmethod
696
    def _unpatch():
697
        cls = NEOCluster
698 699 700 701
        assert cls._patch_count > 0
        cls._patch_count -= 1
        if cls._patch_count:
            return
702 703
        for patch in cls._patches:
            patch.revert()
704
        Serialized.stop()
705

706 707
    started = False

708 709
    def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
                       adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
710
                       storage_count=None, db_list=None, clear_databases=True,
711
                       db_user=DB_USER, db_password='', compress=True,
712 713
                       importer=None, autostart=None, dedup=False, name=None):
        self.name = name or 'neo_%s' % self._allocate('name',
714
            lambda: random.randint(0, 100))
715
        self.compress = compress
716
        self.num_partitions = partitions
717 718 719
        master_list = [MasterApplication.newAddress()
                       for _ in xrange(master_count)]
        self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
720
        kw = dict(replicas=replicas, adapter=adapter,
721
            partitions=partitions, reset=clear_databases, dedup=dedup)
722
        kw['cluster'] = weak_self = weakref.proxy(self)
723
        kw['ssl'] = self.SSL
724
        if upstream is not None:
Vincent Pelletier's avatar
Vincent Pelletier committed
725
            self.upstream = weakref.proxy(upstream)
726 727 728
            kw.update(upstream_cluster=upstream.name,
                upstream_masters=parseMasterList(upstream.master_nodes))
        self.master_list = [MasterApplication(autostart=autostart,
729
                                              address=x, **kw)
730
                            for x in master_list]
731 732 733
        if db_list is None:
            if storage_count is None:
                storage_count = replicas + 1
734 735 736
            index = count().next
            db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
                       for _ in xrange(storage_count)]
737 738
        if adapter == 'MySQL':
            setupMySQLdb(db_list, db_user, db_password, clear_databases)
739
            db = '%s:%s@%%s%s' % (db_user, db_password, DB_SOCKET)
740 741
        elif adapter == 'SQLite':
            db = os.path.join(getTempDirectory(), '%s.sqlite')
742 743
        else:
            assert False, adapter
744
        if importer:
745
            cfg = ImporterConfigParser(adapter, **importer)
746 747 748 749
            cfg.set("neo", "database", db % tuple(db_list))
            db = os.path.join(getTempDirectory(), '%s.conf')
            with open(db % tuple(db_list), "w") as f:
                cfg.write(f)
750 751 752
            kw["adapter"] = "Importer"
        kw['wait'] = 0
        self.storage_list = [StorageApplication(database=db % x, **kw)
753 754
                             for x in db_list]
        self.admin_list = [AdminApplication(**kw)]
755

756 757 758 759
    def __repr__(self):
        return "<%s(%s) at 0x%x>" % (self.__class__.__name__,
                                     self.name, id(self))

760 761 762 763 764 765 766 767 768 769 770 771 772 773 774
    # A few shortcuts that work when there's only 1 master/storage/admin
    @property
    def master(self):
        master, = self.master_list
        return master
    @property
    def storage(self):
        storage, = self.storage_list
        return storage
    @property
    def admin(self):
        admin, = self.admin_list
        return admin
    ###

775 776 777 778 779 780 781 782 783
    # More handy shortcuts for tests
    @property
    def backup_tid(self):
        return self.neoctl.getRecovery()[1]

    @property
    def last_tid(self):
        return self.primary_master.getLastTransaction()

784 785 786 787
    @property
    def primary_master(self):
        master, = [master for master in self.master_list if master.primary]
        return master
788 789 790

    @property
    def cache_size(self):
791
        return self.client._cache.max_size
792
    ###
793

794 795 796 797 798
    def __enter__(self):
        return self

    def __exit__(self, t, v, tb):
        self.stop(None)
799

800
    def start(self, storage_list=None, master_list=None, recovering=False):
801
        self.started = True
802
        self._patch()
803
        self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
804 805
        if master_list is None:
            master_list = self.master_list
806 807
        if storage_list is None:
            storage_list = self.storage_list
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
        def answerPartitionTable(release, orig, *args):
            orig(*args)
            release()
        def dispatch(release, orig, handler, *args):
            orig(handler, *args)
            node_list = handler.app.nm.getStorageList(only_identified=True)
            if len(node_list) == len(storage_list) and not any(
                    node.getConnection().isPending() for node in node_list):
                release()
        expected_state = (ClusterStates.RECOVERING,) if recovering else (
            ClusterStates.RUNNING, ClusterStates.BACKINGUP)
        def notifyClusterInformation(release, orig, handler, conn, state):
            orig(handler, conn, state)
            if state in expected_state:
                release()
        with Serialized.until(MasterEventHandler,
                answerPartitionTable=answerPartitionTable) as tic1, \
             Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \
             Serialized.until(MasterEventHandler,
                notifyClusterInformation=notifyClusterInformation) as tic3:
            for node in master_list:
                node.start()
            for node in self.admin_list:
                node.start()
            tic1()
            for node in storage_list:
                node.start()
            tic2()
            if not recovering:
                self.startCluster()
                tic3()
839 840 841 842 843
        self.checkStarted(expected_state, storage_list)

    def checkStarted(self, expected_state, storage_list=None):
        if isinstance(expected_state, Enum.Item):
            expected_state = expected_state,
844
        state = self.neoctl.getClusterState()
845 846 847 848 849 850
        assert state in expected_state, state
        expected_state = (NodeStates.PENDING
            if state == ClusterStates.RECOVERING
            else NodeStates.RUNNING)
        for node in self.storage_list if storage_list is None else storage_list:
            state = self.getNodeState(node)
851
            assert state == expected_state, (repr(node), state)
852

853
    def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw):
854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880
        if self.started:
            del self.started
            logging.debug("stopping %s", self)
            client = self.__dict__.get("client")
            client is None or self.__dict__.pop("db", client).close()
            node_list = self.admin_list + self.storage_list + self.master_list
            for node in node_list:
                node.stop()
            try:
                node_list.append(client.poll_thread)
            except AttributeError: # client is None or thread is already stopped
                pass
            self.join(node_list)
            self.neoctl.close()
            del self.neoctl
            logging.debug("stopped %s", self)
            self._unpatch()
        if clear_database is None:
            try:
                for node_type in 'admin', 'storage', 'master':
                    for node in getattr(self, node_type + '_list'):
                        node.close()
            except:
                __print_exc()
                raise
        else:
            for node_type in 'master', 'storage', 'admin':
881
                reset_kw = kw.copy()
882
                if node_type == 'storage':
883
                    reset_kw['reset'] = clear_database
884
                for node in getattr(self, node_type + '_list'):
885
                    node.resetNode(**reset_kw)
886

887
    def _newClient(self):
Julien Muchembled's avatar
Julien Muchembled committed
888 889 890
        return ClientApplication(name=self.name, master_nodes=self.master_nodes,
                                 compress=self.compress, ssl=self.SSL)

891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
    @contextmanager
    def newClient(self, with_db=False):
        x = self._newClient()
        try:
            t = x.poll_thread
            closed = []
            if with_db:
                x = ZODB.DB(storage=self.getZODBStorage(client=x))
            else:
                # XXX: Do nothing if finally if the caller already closed it.
                x.close = lambda: closed.append(x.__class__.close(x))
            yield x
        finally:
            closed or x.close()
            self.join((t,))

907 908
    @cached_property
    def client(self):
909
        client = self._newClient()
910 911 912 913 914 915 916 917 918 919 920 921
        # Make sure client won't be reused after it was closed.
        def close():
            client = self.client
            del self.client, client.close
            client.close()
        client.close = close
        return client

    @cached_property
    def db(self):
        return ZODB.DB(storage=self.getZODBStorage())

922
    def startCluster(self):
923 924 925
        try:
            self.neoctl.startCluster()
        except RuntimeError:
926
            Serialized.tic()
927
            if self.neoctl.getClusterState() not in (
928
                      ClusterStates.BACKINGUP,
929 930 931 932 933
                      ClusterStates.RUNNING,
                      ClusterStates.VERIFYING,
                  ):
                raise

934 935
    def enableStorageList(self, storage_list):
        self.neoctl.enableStorageList([x.uuid for x in storage_list])
936
        Serialized.tic()
937
        for node in storage_list:
938 939
            state = self.getNodeState(node)
            assert state == NodeStates.RUNNING, state
940

941 942 943
    def join(self, thread_list, timeout=5):
        timeout += time.time()
        while thread_list:
944 945 946
            # Map with repr before that threads become unprintable.
            assert time.time() < timeout, map(repr, thread_list)
            Serialized.tic(timeout=.001)
947 948
            thread_list = [t for t in thread_list if t.is_alive()]

949 950 951 952 953 954
    def getNodeState(self, node):
        uuid = node.uuid
        for node in self.neoctl.getNodeList(node.node_type):
            if node[2] == uuid:
                return node[3]

Julien Muchembled's avatar
Julien Muchembled committed
955
    def getOutdatedCells(self):
956 957 958 959 960 961
        # Ask the admin instead of the primary master to check that it is
        # notified of every change.
        return [(i, cell.getUUID())
            for i, row in enumerate(self.admin.pt.partition_list)
            for cell in row
            if not cell.isReadable()]
962 963

    def getZODBStorage(self, **kw):
964 965
        kw['_app'] = kw.pop('client', self.client)
        return Storage.Storage(None, self.name, **kw)
966

967
    def importZODB(self, dummy_zodb=None, random=random):
968 969
        if dummy_zodb is None:
            from ..stat_zodb import PROD1
970
            dummy_zodb = PROD1(random)
971
        as_storage = dummy_zodb.as_storage
972 973
        return lambda count: self.getZODBStorage().copyTransactionsFrom(
            as_storage(count))
974

975 976 977 978 979 980 981 982
    def populate(self, transaction_list, tid=lambda i: p64(i+1),
                                         oid=lambda i: p64(i+1)):
        storage = self.getZODBStorage()
        tid_dict = {}
        for i, oid_list in enumerate(transaction_list):
            txn = transaction.Transaction()
            storage.tpc_begin(txn, tid(i))
            for o in oid_list:
983
                storage.store(oid(o), tid_dict.get(o), repr((i, o)), '', txn)
984 985 986 987 988
            storage.tpc_vote(txn)
            i = storage.tpc_finish(txn)
            for o in oid_list:
                tid_dict[o] = i

989
    def getTransaction(self, db=None):
990
        txn = transaction.TransactionManager()
991
        return txn, (self.db if db is None else db).open(txn)
992

993 994 995 996 997 998 999
    def moduloTID(self, partition):
        """Force generation of TIDs that will be stored in given partition"""
        partition = p64(partition)
        master = self.primary_master
        return Patch(master.tm, _nextTID=lambda orig, *args:
            orig(*args) if args else orig(partition, master.pt.getPartitions()))

1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
    def sortStorageList(self):
        """Sort storages so that storage_list[i] has partition i for all i"""
        pt = [{x.getUUID() for x in x}
            for x in self.primary_master.pt.partition_list]
        r = []
        x = [iter(pt[0])]
        try:
            while 1:
                try:
                    r.append(next(x[-1]))
                except StopIteration:
                    del r[-1], x[-1]
                else:
                    x.append(iter(pt[len(r)].difference(r)))
        except IndexError:
            assert len(r) == len(self.storage_list)
        x = {x.uuid: x for x in self.storage_list}
        self.storage_list[:] = (x[r] for r in r)
        return self.storage_list
1019

1020
class NEOThreadedTest(NeoTestBase):
1021

1022 1023
    __run_count = {}

1024
    def setupLog(self):
1025 1026 1027 1028 1029
        test_id = self.id()
        i = self.__run_count.get(test_id, 0)
        self.__run_count[test_id] = 1 + i
        if i:
            test_id += '-%s' % i
1030
        logging._nid_dict.clear()
1031
        logging.setup(os.path.join(getTempDirectory(), test_id + '.log'))
1032
        return LoggerThreadName()
1033

1034 1035
    def _tearDown(self, success):
        super(NEOThreadedTest, self)._tearDown(success)
1036
        ServerNode.resetPorts()
1037
        if success and logging._max_size is not None:
1038 1039 1040
            with logging as db:
                db.execute("UPDATE packet SET body=NULL")
                db.execute("VACUUM")
1041

1042 1043
    tic = Serialized.tic

1044
    @contextmanager
1045
    def getLoopbackConnection(self):
1046
        app = MasterApplication(address=BIND,
1047
            ssl=NEOCluster.SSL, replicas=0, partitions=1)
1048 1049 1050 1051 1052 1053 1054
        try:
            handler = EventHandler(app)
            app.listening_conn = ListeningConnection(app, handler, app.server)
            yield ClientConnection(app, handler, app.nm.createMaster(
                address=app.listening_conn.getAddress(), uuid=app.uuid))
        finally:
            app.close()
1055

1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
    def getUnpickler(self, conn):
        reader = conn._reader
        def unpickler(data, compression=False):
            if compression:
                data = decompress(data)
            obj = reader.getGhost(data)
            reader.setGhostState(obj, data)
            return obj
        return unpickler

1066
    class newPausedThread(threading.Thread):
1067 1068 1069 1070

        def __init__(self, func, *args, **kw):
            threading.Thread.__init__(self)
            self.__target = func, args, kw
1071
            self.daemon = True
1072 1073 1074

        def run(self):
            try:
1075
                self.__result = apply(*self.__target)
1076 1077
            except:
                self.__exc_info = sys.exc_info()
1078 1079
                if self.__exc_info[0] is NEOThreadedTest.failureException:
                    traceback.print_exception(*self.__exc_info)
1080 1081 1082

        def join(self, timeout=None):
            threading.Thread.join(self, timeout)
1083 1084 1085 1086 1087 1088 1089
            if not self.is_alive():
                try:
                    return self.__result
                except AttributeError:
                    etype, value, tb = self.__exc_info
                    del self.__exc_info
                    raise etype, value, tb
1090

1091 1092 1093 1094 1095 1096
    class newThread(newPausedThread):

        def __init__(self, *args, **kw):
            NEOThreadedTest.newPausedThread.__init__(self, *args, **kw)
            self.start()

1097 1098 1099 1100
    def commitWithStorageFailure(self, client, txn):
        with Patch(client, _getFinalTID=lambda *_: None):
            self.assertRaises(ConnectionClosed, txn.commit)

1101 1102 1103 1104 1105 1106
    def assertPartitionTable(self, cluster, expected, pt_node=None,
                                   sort_by_nid=False):
        if sort_by_nid:
            index = lambda x: x
        else:
            index = [x.uuid for x in cluster.storage_list].index
1107 1108 1109
        super(NEOThreadedTest, self).assertPartitionTable(
            (pt_node or cluster.admin).pt, expected,
            lambda x: index(x.getUUID()))
1110

1111 1112
    @staticmethod
    def noConnection(jar, storage):
1113 1114 1115
        return Patch(jar.db().storage.app,
            getStorageConnection=lambda orig, node:
                None if node.getUUID() == storage.uuid else orig(node))
1116

1117 1118 1119 1120
    @staticmethod
    def getConnectionApp(conn):
        return getattr(conn.getHandler(), 'app', None)

1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158
    @staticmethod
    def readCurrent(ob):
        ob._p_activate()
        ob._p_jar.readCurrent(ob)


class ThreadId(list):

    def __call__(self):
        try:
            return self.index(thread.get_ident())
        except ValueError:
            i = len(self)
            self.append(thread.get_ident())
            return i


@apply
class RandomConflictDict(dict):
    # One must not depend on how Python iterates over dict keys, because this
    # is implementation-defined behaviour. This patch makes sure of that when
    # resolving conflicts.

    def __new__(cls):
        from neo.client.transactions import Transaction
        def __init__(orig, self, *args):
            orig(self, *args)
            assert self.conflict_dict == {}
            self.conflict_dict = dict.__new__(cls)
        return Patch(Transaction, __init__=__init__)

    def popitem(self):
        try:
            k = random.choice(list(self))
        except IndexError:
            raise KeyError
        return k, self.pop(k)

1159 1160 1161 1162 1163

def predictable_random(seed=None):
    # Because we have 2 running threads when client works, we can't
    # patch neo.client.pool (and cluster should have 1 storage).
    from neo.master import backup_app
1164
    from neo.master.handlers import administration
1165 1166 1167 1168
    from neo.storage import replicator
    def decorator(wrapped):
        def wrapper(*args, **kw):
            s = repr(time.time()) if seed is None else seed
1169
            logging.info("using seed %r", s)
1170 1171
            r = random.Random(s)
            try:
1172 1173
                administration.random = backup_app.random = replicator.random \
                    = r
1174 1175
                return wrapped(*args, **kw)
            finally:
1176 1177
                administration.random = backup_app.random = replicator.random \
                    = random
1178 1179
        return wraps(wrapped)(wrapper)
    return decorator
1180

1181
def with_cluster(serialized=True, start_cluster=True, **cluster_kw):
1182 1183
    def decorator(wrapped):
        def wrapper(self, *args, **kw):
1184 1185 1186 1187 1188 1189 1190 1191
            try:
                Serialized._disabled = not serialized
                with NEOCluster(**cluster_kw) as cluster:
                    if start_cluster:
                        cluster.start()
                    return wrapped(self, cluster, *args, **kw)
            finally:
                Serialized._disabled = False
1192 1193
        return wraps(wrapped)(wrapper)
    return decorator