#!/usr/bin/python

import atexit, errno, imp, os, pprint, random, re, socket, shlex, shutil
import signal, string, subprocess, sys, threading, time, urlparse, xmlrpclib

SVN_UP_REV=re.compile(r'^(?:At|Updated to) revision (\d+).$')
SVN_CHANGED_REV=re.compile(r'^Last Changed Rev.*:\s*(\d+)', re.MULTILINE)

def killallIfParentDies():
  os.setsid()
  atexit.register(lambda: os.kill(0, 9))
  from ctypes import cdll
  libc = cdll.LoadLibrary('libc.so.6')
  def PR_SET_PDEATHSIG(sig):
    if libc.prctl(1, sig):
      raise OSError
  PR_SET_PDEATHSIG(signal.SIGINT)

_format_command_search = re.compile("[[\\s $({?*\\`#~';<>&|]").search
_format_command_escape = lambda s: "'%s'" % r"'\''".join(s.split("'"))
def format_command(*args, **kw):
  cmdline = []
  for k, v in sorted(kw.items()):
    if _format_command_search(v):
      v = _format_command_escape(v)
    cmdline.append('%s=%s' % (k, v))
  for v in args:
    if _format_command_search(v):
      v = _format_command_escape(v)
    cmdline.append(v)
  return ' '.join(cmdline)

def subprocess_capture(p, quiet=False):
  def readerthread(input, output, buffer):
    while True:
      data = input.readline()
      if not data:
        break
      output(data)
      buffer.append(data)
  if p.stdout:
    stdout = []
    output = quiet and (lambda data: None) or sys.stdout.write
    stdout_thread = threading.Thread(target=readerthread,
                                     args=(p.stdout, output, stdout))
    stdout_thread.setDaemon(True)
    stdout_thread.start()
  if p.stderr:
    stderr = []
    stderr_thread = threading.Thread(target=readerthread,
                                     args=(p.stderr, sys.stderr.write, stderr))
    stderr_thread.setDaemon(True)
    stderr_thread.start()
  if p.stdout:
    stdout_thread.join()
  if p.stderr:
    stderr_thread.join()
  p.wait()
  return (p.stdout and ''.join(stdout),
          p.stderr and ''.join(stderr))


class Persistent(object):
  """Very simple persistent data storage for optimization purpose

  This tool should become a standalone daemon communicating only with an ERP5
  instance. But for the moment, it only execute 1 test suite and exists,
  and test suite classes may want some information from previous runs.
  """

  def __init__(self, filename):
    self._filename = filename

  def __getattr__(self, attr):
    if attr == '_db':
      try:
        db = file(self._filename, 'r+')
      except IOError, e:
        if e.errno != errno.ENOENT:
          raise
        db = file(self._filename, 'w+')
      else:
        try:
          self.__dict__.update(eval(db.read()))
        except StandardError:
          pass
      self._db = db
      return db
    self._db
    return super(Persistent, self).__getattribute__(attr)

  def sync(self):
    self._db.seek(0)
    db = dict(x for x in self.__dict__.iteritems() if x[0][:1] != '_')
    pprint.pprint(db, self._db)
    self._db.truncate()


class SubprocessError(EnvironmentError):
  def __init__(self, status_dict):
    self.status_dict = status_dict
  def __getattr__(self, name):
    return self.status_dict[name]
  def __str__(self):
    return 'Error %i' % self.status_code


class Updater(object):

  _git_cache = {}
  realtime_output = True
  stdin = file(os.devnull)

  def __init__(self, revision=None):
    self.revision = revision
    self._path_list = []

  def deletePycFiles(self, path):
    """Delete *.pyc files so that deleted/moved files can not be imported"""
    for path, dir_list, file_list in os.walk(path):
      for file in file_list:
        if file[-4:] in ('.pyc', '.pyo'):
          # allow several processes clean the same folder at the same time
          try:
            os.remove(os.path.join(path, file))
          except OSError, e:
            if e.errno != errno.ENOENT:
              raise

  def spawn(self, *args, **kw):
    quiet = kw.pop('quiet', False)
    env = kw and dict(os.environ, **kw) or None
    command = format_command(*args, **kw)
    print '\n$ ' + command
    sys.stdout.flush()
    p = subprocess.Popen(args, stdin=self.stdin, stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE, env=env)
    if self.realtime_output:
      stdout, stderr = subprocess_capture(p, quiet)
    else:
      stdout, stderr = p.communicate()
      if not quiet:
        sys.stdout.write(stdout)
      sys.stderr.write(stderr)
    result = dict(status_code=p.returncode, command=command,
                  stdout=stdout, stderr=stderr)
    if p.returncode:
      raise SubprocessError(result)
    return result

  def _git(self, *args, **kw):
    return self.spawn('git', *args, **kw)['stdout'].strip()

  def _git_find_rev(self, ref):
    try:
      return self._git_cache[ref]
    except KeyError:
      if os.path.exists('.git/svn'):
        r = self._git('svn', 'find-rev', ref)
        assert r
        self._git_cache[ref[0] != 'r' and 'r%u' % int(r) or r] = ref
      else:
        r = self._git('rev-list', '--topo-order', '--count', ref), ref
      self._git_cache[ref] = r
      return r

  def getRevision(self, *path_list):
    if not path_list:
      path_list = self._path_list
    if os.path.isdir('.git'):
      h = self._git('log', '-1', '--format=%H', '--', *path_list)
      return self._git_find_rev(h)
    if os.path.isdir('.svn'):
      stdout = self.spawn('svn', 'info', *path_list)['stdout']
      return str(max(map(int, SVN_CHANGED_REV.findall(stdout))))
    raise NotImplementedError

  def checkout(self, *path_list):
    if not path_list:
      path_list = '.',
    revision = self.revision
    if os.path.isdir('.git'):
      # edit .git/info/sparse-checkout if you want sparse checkout
      if revision:
        if type(revision) is str:
          h = self._git_find_rev('r' + revision)
        else:
          h = revision[1]
        if h != self._git('rev-parse', 'HEAD'):
          self.deletePycFiles('.')
          self._git('reset', '--merge', h)
      else:
        self.deletePycFiles('.')
        if os.path.exists('.git/svn'):
          self._git('svn', 'rebase')
        else:
          self._git('pull', '--ff-only')
        self.revision = self._git_find_rev(self._git('rev-parse', 'HEAD'))
    elif os.path.isdir('.svn'):
      # following code allows sparse checkout
      def svn_mkdirs(path):
        path = os.path.dirname(path)
        if path and not os.path.isdir(path):
          svn_mkdirs(path)
          self.spawn(*(args + ['--depth=empty', path]))
      for path in path_list:
        args = ['svn', 'up', '--force', '--non-interactive']
        if revision:
          args.append('-r%s' % revision)
        svn_mkdirs(path)
        args += '--set-depth=infinity', path
        self.deletePycFiles(path)
        try:
          status_dict = self.spawn(*args)
        except SubprocessError, e:
          if 'cleanup' not in e.stderr:
            raise
          self.spawn('svn', 'cleanup', path)
          status_dict = self.spawn(*args)
        if not revision:
          self.revision = revision = SVN_UP_REV.findall(
            status_dict['stdout'].splitlines()[-1])[0]
    else:
      raise NotImplementedError
    self._path_list += path_list


class TestSuite(Updater):

  mysql_db_count = 1
  allow_restart = False

  def __init__(self, max_instance_count, **kw):
    self.__dict__.update(kw)
    self._path_list = ['tests']
    pool = threading.Semaphore(max_instance_count)
    self.acquire = pool.acquire
    self.release = pool.release
    self._instance = threading.local()
    self._pool = max_instance_count == 1 and [None] or \
                 range(1, max_instance_count + 1)
    self._ready = set()
    self.running = {}
    if max_instance_count != 1:
      self.realtime_output = False
    elif os.isatty(1):
      self.realtime_output = True
    self.persistent = Persistent('run_test_suite-%s.tmp'
                                 % self.__class__.__name__)

  instance = property(lambda self: self._instance.id)

  def start(self, test, on_stop=None):
    assert test not in self.running
    self.running[test] = instance = self._pool.pop(0)
    def run():
      self._instance.id = instance
      if instance not in self._ready:
        self._ready.add(instance)
        self.setup()
      status_dict = self.run(test)
      if on_stop is not None:
        on_stop(status_dict)
      self._pool.append(self.running.pop(test))
      self.release()
    thread = threading.Thread(target=run)
    thread.setDaemon(True)
    thread.start()

  def update(self):
    self.checkout() # by default, update everything

  def setup(self):
    pass

  def run(self, test):
    raise NotImplementedError

  def getTestList(self):
    raise NotImplementedError


class ERP5TypeTestSuite(TestSuite):

  RUN_RE = re.compile(
    r'Ran (?P<all_tests>\d+) tests? in (?P<seconds>\d+\.\d+)s',
    re.DOTALL)

  STATUS_RE = re.compile(r"""
    (OK|FAILED)\s+\(
      (failures=(?P<failures>\d+),?\s*)?
      (errors=(?P<errors>\d+),?\s*)?
      (skipped=(?P<skips>\d+),?\s*)?
      (expected\s+failures=(?P<expected_failures>\d+),?\s*)?
      (unexpected\s+successes=(?P<unexpected_successes>\d+),?\s*)?
    \)
    """, re.DOTALL | re.VERBOSE)

  def setup(self):
    instance_home = self.instance and 'unit_test.%u' % self.instance \
                                   or 'unit_test'
    tests = os.path.join(instance_home, 'tests')
    if os.path.exists(tests):
      shutil.rmtree(instance_home + '.previous', True)
      shutil.move(tests, instance_home + '.previous')

  def run(self, test):
    return self.runUnitTest(test)

  def runUnitTest(self, *args, **kw):
    if self.instance:
      args = ('--instance_home=unit_test.%u' % self.instance,) + args
    mysql_db_list = [string.Template(x).substitute(I=self.instance or '1')
                      for x in self.mysql_db_list]
    if len(mysql_db_list) > 1:
      kw['extra_sql_connection_string_list'] = ','.join(mysql_db_list[1:])
    try:
      runUnitTest = os.environ.get('RUN_UNIT_TEST',
                                   'Products/ERP5Type/tests/runUnitTest.py')
      args = tuple(shlex.split(runUnitTest)) \
           + ('--verbose', '--erp5_sql_connection_string=' + mysql_db_list[0]) \
           + args
      status_dict = self.spawn(*args, **kw)
    except SubprocessError, e:
      status_dict = e.status_dict
    test_log = status_dict['stderr']
    search = self.RUN_RE.search(test_log)
    if search:
      groupdict = search.groupdict()
      status_dict.update(duration=float(groupdict['seconds']),
                         test_count=int(groupdict['all_tests']))
    search = self.STATUS_RE.search(test_log)
    if search:
      groupdict = search.groupdict()
      status_dict.update(error_count=int(groupdict['errors'] or 0),
                         failure_count=int(groupdict['failures'] or 0),
                         skip_count=int(groupdict['skips'] or 0)
                                   +int(groupdict['expected_failures'] or 0)
                                   +int(groupdict['unexpected_successes'] or 0))
    return status_dict


#class LoadSaveExample(ERP5TypeTestSuite):
#  def getTestList(self):
#    return [test_path.split(os.sep)[-1][:-3]
#            for test_path in glob.glob('tests/test*.py')]
#
#  def setup(self):
#    TestSuite.setup(self)
#    return self.runUnitTest(self, '--save', 'testFoo')
#
#  def run(self, test):
#    return self.runUnitTest(self, '--load', test)


sys.modules['test_suite'] = module = imp.new_module('test_suite')
for var in SubprocessError, TestSuite, ERP5TypeTestSuite:
  setattr(module, var.__name__, var)


class DummyTaskDistributionTool(object):

  def __init__(self):
    self.lock = threading.Lock()

  def createTestResult(self, name, revision, test_name_list, allow_restart):
    self.test_name_list = list(test_name_list)
    return None, revision

  def startUnitTest(self, test_result_path, exclude_list=()):
    self.lock.acquire()
    try:
      for i, test in enumerate(self.test_name_list):
        if test not in exclude_list:
          del self.test_name_list[i]
          return None, test
    finally:
      self.lock.release()

  def stopUnitTest(self, test_path, status_dict):
    pass


def safeRpcCall(function, *args):
  retry = 64
  while True:
    try:
      return function(*args)
    except (socket.error, xmlrpclib.ProtocolError), e:
      print >>sys.stderr, e
      pprint.pprint(args, file(function._Method__name, 'w'))
      time.sleep(retry)
      retry += retry >> 1

def getOptionParser():
  from optparse import OptionParser
  parser = OptionParser(
    usage="%prog [options] [SUITE_CLASS@]<SUITE_NAME>[=<MAX_INSTANCES>]")
  _ = parser.add_option
  _("--master", help="URL of ERP5 instance, used as master node")
  _("--mysql_db_list", help="comma-separated list of connection strings")
  return parser

def main():
  os.environ['LC_ALL'] = 'C'

  parser = getOptionParser()
  options, args = parser.parse_args()
  try:
    name, = args
    if '=' in name:
      name, max_instance_count = name.rsplit('=', 1)
      max_instance_count = int(max_instance_count)
    else:
      max_instance_count = 1
    if '@' in name:
      suite_class_name, name = name.split('@', 1)
    else:
      suite_class_name = name
  except ValueError:
    parser.error("invalid arguments")
  db_list = options.mysql_db_list
  if db_list:
    db_list = db_list.split(',')
    multi = max_instance_count != 1
    try:
      for db in db_list:
        if db == string.Template(db).substitute(I=1) and multi:
          raise KeyError
    except KeyError:
      parser.error("invalid value for --mysql_db_list")
  else:
    db_list = (max_instance_count == 1 and 'test test' or 'test$I test'),

  def makeSuite(revision=None):
    updater = Updater(revision)
    if portal_url:
      updater.checkout('tests')
    for k in sys.modules.keys():
      if k == 'tests' or k.startswith('tests.'):
        del sys.modules[k]
    module_name, class_name = ('tests.' + suite_class_name).rsplit('.', 1)
    try:
      suite_class = getattr(__import__(module_name, None, None, [class_name]),
                            class_name)
    except (AttributeError, ImportError):
      parser.error("unknown test suite")
    if len(db_list) < suite_class.mysql_db_count:
      parser.error("%r suite needs %u DB (only %u given)" %
        (name, suite_class.mysql_db_count, len(db_list)))
    suite = suite_class(revision=updater.revision,
                        max_instance_count=max_instance_count,
                        mysql_db_list=db_list[:suite_class.mysql_db_count])
    if portal_url:
      suite.update()
    return suite

  portal_url = options.master
  if portal_url:
    if portal_url[-1] != '/':
      portal_url += '/'
    portal = xmlrpclib.ServerProxy(portal_url, allow_none=1)
    master = portal.portal_task_distribution
    assert master.getProtocolRevision() == 1
  else:
    master = DummyTaskDistributionTool()

  suite = makeSuite()
  revision = suite.getRevision()
  test_result = safeRpcCall(master.createTestResult,
    name, revision, suite.getTestList(), suite.allow_restart)
  if test_result:
    test_result_path, test_revision = test_result
    if portal_url: # for buildbot
      url_parts = list(urlparse.urlparse(portal_url + test_result_path))
      url_parts[1] = url_parts[1].split('@')[-1]
      print 'ERP5_TEST_URL %s OK' % urlparse.urlunparse(url_parts)
    while suite.acquire():
      test = safeRpcCall(master.startUnitTest, test_result_path,
                         suite.running.keys())
      if test:
        if revision != test_revision:
          suite = makeSuite(test_revision)
          revision = test_revision
          suite.acquire()
        suite.start(test[1], lambda status_dict, __test_path=test[0]:
          safeRpcCall(master.stopUnitTest, __test_path, status_dict))
      elif not suite.running:
        break
      # We are finishing the suite. Let's disable idle nodes.


if __name__ == '__main__':
  sys.path[0] = ''
  if not os.isatty(0):
    killallIfParentDies()
  sys.exit(main())