from graphs import MutableGraph, Directed, Undirected

class DictOfDicts(MutableGraph):
    """Graphs, represented using a dict-of-dicts representation.

    For example, a graph with:
     - nodes {0, 1, 2, 3, 4, 5}
     - edges {(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 2), (4, 5)}
    and all edge labels equal to 1 is represented as
    {0: {1: 1, 2: 1},
     1: {2: 1, 3: 1},
     2: {3: 1},
     3: {2: 1},
     4: {5: 1},
     5: {}}

    This representation can be accessed via the `data' attribute.

    For undirected dictionaries, edges are represented in both of nodes'
    adjacency mappings.
    """

    def __init__(self, data=None):
        if data is None:
            data = {}
        self.nodes = self.NodesView(data)
        self.edges = self.EdgesView(data)
        self.data = data

    def neighbors(self, node):
        return self.data[node]

    def incoming(self, node):
        return {n: adj[node] for n, adj in self.data.items() if node in adj}

    class NodesView(MutableGraph.NodesView):

        def __init__(self, data):
            self.data = data

        def add(self, node):
            self.data.setdefault(node, {})

        def discard(self, node):
            data = self.data
            if node in data:
                del data[node]
                for neighbordict in data.values():
                    if node in neighbordict:
                        del neighbordict[node]

        def __len__(self):
            return len(self.data)

        def __iter__(self):
            return iter(self.data)

        def __contains__(self, node):
                return node in self.data

        @classmethod
        def _from_iterable(cls, it):
            # Needed by the Set abstract base class
            return set(it)
            # Is this correct? The other option is
            # return cls({node: {} for node in it})

        # Optimization for speed
        def __le__(self, other):
            if isinstance(other, DictOfDicts):
                return self.data.keys() <= other.data.keys()
            return super().__le__(other)

        # Fixes an error in early Python 2.6 and 3.0 implementations:
        # MutableMapping implementations would raise an error when
        # trying to modify the dictionary while iterating on them.
        def __iand__(self, c):
            for node in  self - c:
                self.discard(node)
            return self

    class EdgesView(MutableGraph.EdgesView):

        def __init__(self, data):
            self.data = data

        def __contains__(self, edge):
            n1, n2 = edge
            return n1 in self.data and n2 in self.data[n1]

        def __getitem__(self, edge):
            n1, n2 = edge
            return self.data[n1][n2]

class Directed_DictOfDicts(Directed, DictOfDicts):

    class EdgesView(DictOfDicts.EdgesView):
        
        def __len__(self):
            return sum(map(len, self.data.values()))

        def __setitem__(self, edge, value):
            n1, n2 = edge
            data = self.data
            data.setdefault(n2, {})
            data.setdefault(n1, {})[n2] = value

        def __delitem__(self, edge):
            n1, n2 = edge
            del self.data[n1][n2]

        def __iter__(self):
            return ((n1, n2) for n1, ndict in self.data.items() for n2 in ndict)

class Undirected_DictOfDicts(Undirected, DictOfDicts):

    class NodesView(DictOfDicts.NodesView):

        # Optimization for speed
        def discard(self, node):
            data = self.data
            if node in data:
                for n in data[node]:
                    del data[n][node]
                del data[node]
    
    class EdgesView(DictOfDicts.EdgesView):
        
        def __len__(self):
            data = self.data
            selfloops = sum(node in ndict for node, ndict in data.items())
            totedges = sum(map(len, data.values()))
            return (totedges + selfloops) // 2

        def __setitem__(self, edge, value):
            n1, n2 = edge
            data = self.data
            data.setdefault(n1, {})[n2] = value
            data.setdefault(n2, {})[n1] = value

        def __delitem__(self, edge):
            n1, n2 = edge
            del self.data[n1][n2]
            del self.data[n2][n1]

        def __iter__(self):
            seen = set()
            for n1, ndict in self.data.items():
                for n2 in ndict:
                    if n2 not in seen:
                        yield frozenset((n1, n2))
                seen.add(n1)

if __name__ == '__main__':
    ug = Undirected_DictOfDicts()
    dg = Directed_DictOfDicts()
    assert type(dg.edges) == Directed_DictOfDicts.EdgesView
    ug.edges[1, 2] = 1
    ug.edges.add(2, 3)
    assert ug.edges[3, 2] == 1
    ug.edges.discard(3,2)
    assert 2, 3 not in ug
    dg.edges[1, 2] = 1
    assert dg.data == {1: {2: 1}, 2: {}}
    assert (2, 1) not in dg.edges
    assert (2, 1) in ug.edges
    emptygraph = Undirected_DictOfDicts()
    assert emptygraph != ug
    assert emptygraph.nodes <= ug.nodes
    ug.nodes &= {1, 5}
    assert ug.nodes == {1}
    ug.nodes.clear()
    assert emptygraph.nodes <= ug.nodes
    assert emptygraph == ug

