Commit e9e2e77d authored by Bram Schoenmakers's avatar Bram Schoenmakers

Add a class for a directed graph + tests.

parent ae73424d
""" Contains the class for a directed graph. """
class DirectedGraph(object):
"""
Represents a simple directed graph, used for tracking todo
dependencies. The nodes are very simple: just integers.
"""
def __init__(self):
self._edges = {}
def add_node(self, p_id):
""" Adds a node to the graph. """
if not self.has_node(p_id):
self._edges[p_id] = set()
def has_node(self, p_id):
""" Returns true iff the graph has the given node. """
return p_id in self._edges
def add_edge(self, p_from, p_to):
"""
Adds an edge to the graph. The nodes will be added if they don't
exist.
"""
if not self.has_node(p_from):
self.add_node(p_from)
if not self.has_node(p_to):
self.add_node(p_to)
self._edges[p_from].add(p_to)
def has_path(self, p_from, p_to):
"""
Returns true iff there is a path from the first node to the second.
"""
return p_to in self.reachable_nodes(p_from)
def incoming_neighbors(self, p_id, p_recursive=False):
"""
Returns a set of the direct neighbors that can reach the given
node.
"""
return self.reachable_nodes_reverse(p_id, p_recursive)
def outgoing_neighbors(self, p_id, p_recursive=False):
"""
Returns the set of the direct neighbors that the given node can
reach.
"""
return self.reachable_nodes(p_id, p_recursive)
def reachable_nodes(self, p_id, p_recursive=True, p_reverse=False):
"""
Returns the set of all neighbors that the given node can reach.
If recursive, it will also return the neighbor's neighbors, etc.
If reverse, the arrows are reversed and then the reachable neighbors
are located.
"""
stack = [p_id]
visited = set()
result = set()
while len(stack):
current = stack.pop()
if current in visited or current not in self._edges:
continue
visited.add(current)
if p_reverse:
parents = [node for node, neighbors in self._edges.iteritems() \
if current in neighbors]
stack = stack + parents
result = result.union(parents)
else:
stack = stack + list(self._edges[current])
result = result.union(self._edges[current])
if not p_recursive:
break
return result
def reachable_nodes_reverse(self, p_id, p_recursive=True):
""" Find neighbors in the inverse graph. """
return self.reachable_nodes(p_id, p_recursive, True)
def remove_node(self, p_id, remove_unconnected_nodes=True):
""" Removes a node from the graph. """
if self.has_node(p_id):
for neighbor in self.incoming_neighbors(p_id):
self._edges[neighbor].remove(p_id)
neighbors = set()
if remove_unconnected_nodes:
neighbors = self.outgoing_neighbors(p_id)
del self._edges[p_id]
for neighbor in neighbors:
if self.is_isolated(neighbor):
self.remove_node(neighbor)
def is_isolated(self, p_id):
"""
Returns True iff the given node has no incoming or outgoing edges.
"""
return len(self.incoming_neighbors(p_id)) == 0 \
and len(self.outgoing_neighbors(p_id)) == 0
def has_edge(self, p_from, p_to):
""" Returns True when the graph has the given edge. """
return p_from in self._edges and p_to in self._edges[p_from]
def remove_edge(self, p_from, p_to, remove_unconnected_nodes=True):
"""
Removes an edge from the graph.
When remove_unconnected_nodes is True, then the nodes are also removed
if they become isolated.
"""
if self.has_edge(p_from, p_to):
self._edges[p_from].remove(p_to)
if remove_unconnected_nodes:
if self.is_isolated(p_from):
self.remove_node(p_from)
if self.is_isolated(p_to):
self.remove_node(p_to)
import unittest
import Graph
class GraphTest(unittest.TestCase):
def setUp(self):
self.graph = Graph.DirectedGraph()
self.graph.add_edge(1, 2)
self.graph.add_edge(2, 4)
self.graph.add_edge(4, 3)
self.graph.add_edge(4, 6)
self.graph.add_edge(6, 2)
self.graph.add_edge(1, 3)
self.graph.add_edge(3, 5)
# 1
# / \
# v v
# />2 />3
# / | / |
# / v / v
# 6 <- 4 5
def test_has_nodes(self):
for i in range(1, 7):
self.assertTrue(self.graph.has_node(i))
def test_incoming_neighbors1(self):
self.assertEquals(self.graph.incoming_neighbors(1), set())
def test_incoming_neighbors2(self):
self.assertEquals(self.graph.incoming_neighbors(2), set([1, 6]))
def test_incoming_neighbors3(self):
self.assertEquals(self.graph.incoming_neighbors(1, True), set())
def test_incoming_neighbors4(self):
self.assertEquals(self.graph.incoming_neighbors(5, True), set([1, 2, 3, 4, 6]))
def test_outgoing_neighbors1(self):
self.assertEquals(self.graph.outgoing_neighbors(1), set([2, 3]))
def test_outgoing_neighbors2(self):
self.assertEquals(self.graph.outgoing_neighbors(2), set([4]))
def test_outgoing_neighbors3(self):
self.assertEquals(self.graph.outgoing_neighbors(1, True), set([2, 3, 4, 5, 6]))
def test_outgoing_neighbors4(self):
self.assertEquals(self.graph.outgoing_neighbors(3), set([5]))
def test_outgoing_neighbors5(self):
self.assertEquals(self.graph.outgoing_neighbors(5), set([]))
def test_remove_edge1(self):
self.graph.remove_edge(1, 2)
self.assertFalse(self.graph.has_path(1, 4))
self.assertTrue(self.graph.has_path(2, 4))
def test_remove_edge2(self):
self.graph.remove_edge(3, 5, True)
self.assertFalse(self.graph.has_path(1, 5))
self.assertFalse(self.graph.has_node(5))
def test_remove_edge3(self):
self.graph.remove_edge(3, 5, False)
self.assertFalse(self.graph.has_path(1, 5))
self.assertTrue(self.graph.has_node(5))
def test_remove_edge4(self):
""" Remove non-existing edge. """
self.graph.remove_edge(4, 5)
def test_remove_edge5(self):
self.graph.remove_edge(3, 5, True)
self.assertFalse(self.graph.has_path(1, 5))
self.assertFalse(self.graph.has_node(5))
def test_remove_edge6(self):
self.graph.remove_edge(1, 3, True)
self.assertTrue(self.graph.has_path(1, 5))
def test_remove_node1(self):
self.graph.remove_node(2)
self.assertTrue(self.graph.has_node(1))
self.assertTrue(self.graph.has_node(4))
self.assertTrue(self.graph.has_node(6))
self.assertFalse(self.graph.has_node(2))
self.assertFalse(self.graph.has_edge(2, 4))
self.assertFalse(self.graph.has_edge(1, 2))
def test_remove_node2(self):
self.graph.remove_node(3, True)
self.assertFalse(self.graph.has_node(5))
self.assertFalse(self.graph.has_edge(1, 3))
self.assertFalse(self.graph.has_edge(3, 5))
self.assertFalse(self.graph.has_path(1, 5))
def test_remove_node3(self):
self.graph.remove_node(3, False)
self.assertTrue(self.graph.has_node(5))
self.assertFalse(self.graph.has_edge(1, 3))
self.assertFalse(self.graph.has_edge(3, 5))
self.assertFalse(self.graph.has_path(1, 5))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment