import logging

from neo.protocol import UP_TO_DATE_STATE, OUT_OF_DATE_STATE, FEEDING_STATE, \

class Cell(object):
    """This class represents a cell in a partition table."""

    def __init__(self, node, state = UP_TO_DATE_STATE):
        self.node = node
        self.state = state

    def getState(self):
        return self.state

    def setState(self, state):
        self.state = state

    def getNode(self):
        return self.node

    def getNodeState(self):
        """This is a short hand."""
        return self.node.getState()

    def getUUID(self):
        return self.node.getUUID()

class PartitionTable(object):
    """This class manages a partition table."""

    def __init__(self, num_partitions, num_replicas): = num_partitions = num_replicas
        self.num_filled_rows = 0
        self.partition_list = [None] * num_partitions
        self.count_dict = {}

    def clear(self):
        """Forget an existing partition table."""
        self.num_filled_rows = 0
        self.partition_list = [None] *

    def make(self, node_list):
        """Make a new partition table from scratch."""
        # First, filter the list of nodes.
        node_list = [n for n in node_list \
                if n.getState() == RUNNING_STATE and n.getUUID() is not None]
        if len(node_list) == 0:
            # Impossible.
            raise RuntimeError, \
                    'cannot make a partition table with an empty storage node list'

        # Take it into account that the number of storage nodes may be less than the
        # number of replicas.
        repeats = min(, len(node_list))
        index = 0
        for offset in xrange(
            row = []
            for i in xrange(repeats):
                node = node_list[index]
                self.count_dict.setdefault(node, 0) += 1
                index += 1
                if index == len(node_list):
                    index = 0
            self.partition_list[offset] = row

        self.num_filled_rows =

    def setCell(self, offset, node, state):
        if node.getState() in (BROKEN_STATE, DOWN_STATE):

        row = self.partition_list[offset]
        if row is None:
            # Create a new row.
            row = [Cell(node, state)]
            if state != FEEDING_STATE:
                self.count_dict.setdefault(node, 0) += 1
            self.partition_list[offset] = row

            self.num_filled_rows += 1
            # XXX this can be slow, but it is necessary to remove a duplicate,
            # if any.
            for cell in row:
                if cell.getNode() == node:
                    if state != FEEDING_STATE:
                        self.count_dict.setdefault(node, 0) -= 1
            row.append(Cell(node, state))
            if state != FEEDING_STATE:
                self.count_dict.setdefault(node, 0) += 1

    def filled(self):
        return self.num_filled_rows ==

    def hasOffset(self, offset):
        return self.partition_list[offset] is not None

    def operational(self):
        if not self.filled():
            return False

        # FIXME it is better to optimize this code, as this could be extremely
        # slow. The possible fix is to have a handler to notify a change on
        # a node state, and record which rows are ready.
        for row in self.partition_list:
            for cell in row:
                if cell.getState() in (UP_TO_DATE_STATE, FEEDING_STATE) \
                        and cell.getNodeState() == RUNNING_STATE:
                return False

        return True

    def findLeastUsedNode(self, excluded_node_list = ()):
        min_count = + 1
        min_node = None
        for node, count in self.count_dict.iteritems():
            if min_count > count \
                    and node not in excluded_node_list \
                    and node.getState() == RUNNING_STATE:
                min_node = node
                min_count = count
        return min_node

    def dropNode(self, node):
        cell_list = []
        uuid = node.getUUID()
        for offset, row in enumerate(self.partition_list):
            if row is not None:
                for cell in row:
                    if cell.getNode() == node:
                        cell_list.append((offset, uuid, DISCARDED_STATE))
                        node = self.findLeastUsedNode()
                        if node is not None:
                            row.append(Cell(node, OUT_OF_DATE_STATE))
                            cell_list.append((offset, node.getUUID(), OUT_OF_DATE_STATE))

        del self.count_dict[node]
        return cell_list

    def getRow(self, offset):
        row = self.partition_list[offset]
        if row is None:
            return ()
        return [(cell.getUUID(), cell.getState()) for cell in row]

    def tweak(self):
        """Test if nodes are distributed uniformly. Otherwise, correct the partition
        changed_cell_list = []

        for offset, row in enumerate(self.partition_list):
            removed_cell_list = []
            feeding_cell = None
            out_of_date_cell_present = False
            out_of_date_cell_list = []
            up_to_date_cell_list = []
            for cell in row:
                if cell.getNodeState() == BROKEN_STATE:
                    # Remove a broken cell.
                elif cell.getState() == FEEDING_STATE:
                    if feeding_cell is None:
                        feeding_cell = cell
                        # Remove an excessive feeding cell.
                elif cell.getState() == OUT_OF_DATE_STATE:

            # If all cells are up-to-date, a feeding cell is not required.
            if len(out_of_date_cell_list) == 0 and feeding_cell is not None:

            ideal_num =
            while len(out_of_date_cell_list) + len(up_to_date_cell_list) > ideal_num:
                # This row contains too many cells.
                if len(up_to_date_cell_list) > 1:
                    # There are multiple up-to-date cells, so choose whatever
                    # used too much.
                    cell_list = out_of_date_cell_list + up_to_date_cell_list
                    # Drop an out-of-date cell.
                    cell_list = out_of_date_cell_list

                max_count = 0
                chosen_cell = None
                for cell in out_of_date_cell_list + up_to_date_cell_list:
                    count = self.count_dict[cell.getNode()]
                    if max_count < count:
                        max_count = count
                        chosen_cell = cell
                ideal_num -= 1

            # Now remove cells really.
            for cell in removed_cell_list:
                if cell.getState() != FEEDING_STATE:
                    self.count_dict[cell.getNode()] -= 1
                changed_cell_list.append((offset, cell.getUUID(), DISCARDED_STATE))

        # Add cells, if a row contains less than the number of replicas.
        for offset, row in enumerate(self.partition_list):
            num_cells = 0
            for cell in row:
                if cell.getState() != FEEDING_STATE:
                    num_cells += 1
            while num_cells <
                node = self.findLeastUsedNode([cell.getNode() for cell in row])
                if node is None:
                row.append(Cell(node, OUT_OF_DATE_STATE))
                changed_cell_list.append((offset, node.getUUID(), OUT_OF_DATE_STATE))
                self.count_dict[node] += 1
                num_cells += 1

        # FIXME still not enough. It is necessary to check if it is possible
        # to reduce differences between frequently used nodes and rarely used
        # nodes by replacing cells.

        return changed_cell_list