/odo/core.py

https://github.com/blaze/odo · Python · 200 lines · 127 code · 44 blank · 29 comment · 30 complexity · 61da1d7a0bc5a06a4d38e721814dd966 MD5 · raw file

  1. from __future__ import absolute_import, division, print_function
  2. from collections import namedtuple, Iterator
  3. from contextlib import contextmanager
  4. from warnings import warn
  5. from datashape import discover
  6. import networkx as nx
  7. import numpy as np
  8. from toolz import concatv
  9. from .compatibility import map, adjacency
  10. from .utils import expand_tuples, ignoring
  11. ooc_types = set() # Out-of-Core types
  12. class FailedConversionWarning(UserWarning):
  13. def __init__(self, src, dest, exc):
  14. self.src = src
  15. self.dest = dest
  16. self.exc = exc
  17. def __str__(self):
  18. return 'Failed on %s -> %s. Working around\nError message:\n%s' % (
  19. self.src.__name__, self.dest.__name__, self.exc,
  20. )
  21. class IterProxy(object):
  22. """An proxy to another iterator to support swapping the underlying stream
  23. mid-iteration.
  24. Parameters
  25. ----------
  26. it : iterable
  27. The iterable to proxy.
  28. Attributes
  29. ----------
  30. it : iterable
  31. The iterable being proxied. This can be reassigned to change the
  32. underlying stream.
  33. """
  34. def __init__(self, it):
  35. self._it = iter(it)
  36. def __next__(self):
  37. return next(self.it)
  38. next = __next__ # py2 compat
  39. def __iter__(self):
  40. return self
  41. @property
  42. def it(self):
  43. return self._it
  44. @it.setter
  45. def it(self, value):
  46. self._it = iter(value)
  47. class NetworkDispatcher(object):
  48. def __init__(self, name):
  49. self.name = name
  50. self.graph = nx.DiGraph()
  51. def register(self, a, b, cost=1.0):
  52. sigs = expand_tuples([a, b])
  53. def _(func):
  54. for a, b in sigs:
  55. self.graph.add_edge(b, a, cost=cost, func=func)
  56. return func
  57. return _
  58. def path(self, *args, **kwargs):
  59. return path(self.graph, *args, **kwargs)
  60. def __call__(self, *args, **kwargs):
  61. return _transform(self.graph, *args, **kwargs)
  62. def _transform(graph, target, source, excluded_edges=None, ooc_types=ooc_types,
  63. **kwargs):
  64. """ Transform source to target type using graph of transformations """
  65. # take a copy so we can mutate without affecting the input
  66. excluded_edges = (excluded_edges.copy()
  67. if excluded_edges is not None else
  68. set())
  69. with ignoring(NotImplementedError):
  70. if 'dshape' not in kwargs or kwargs['dshape'] is None:
  71. kwargs['dshape'] = discover(source)
  72. pth = path(graph, type(source), target,
  73. excluded_edges=excluded_edges,
  74. ooc_types=ooc_types)
  75. x = source
  76. path_proxy = IterProxy(pth)
  77. for convert_from, convert_to, f, cost in path_proxy:
  78. try:
  79. x = f(x, excluded_edges=excluded_edges, **kwargs)
  80. except NotImplementedError as e:
  81. if kwargs.get('raise_on_errors'):
  82. raise
  83. warn(FailedConversionWarning(convert_from, convert_to, e))
  84. # exclude the broken edge
  85. excluded_edges |= {(convert_from, convert_to)}
  86. # compute the path from `source` to `target` excluding
  87. # the edge that broke
  88. fresh_path = list(path(graph, type(source), target,
  89. excluded_edges=excluded_edges,
  90. ooc_types=ooc_types))
  91. fresh_path_cost = path_cost(fresh_path)
  92. # compute the path from the current `convert_from` type
  93. # to the `target`
  94. try:
  95. greedy_path = list(path(graph, convert_from, target,
  96. excluded_edges=excluded_edges,
  97. ooc_types=ooc_types))
  98. except nx.exception.NetworkXNoPath:
  99. greedy_path_cost = np.inf
  100. else:
  101. greedy_path_cost = path_cost(greedy_path)
  102. if fresh_path_cost < greedy_path_cost:
  103. # it is faster to start over from `source` with a new path
  104. x = source
  105. pth = fresh_path
  106. else:
  107. # it is faster to work around our broken edge from our
  108. # current location
  109. pth = greedy_path
  110. path_proxy.it = pth
  111. return x
  112. PathPart = namedtuple('PathPart', 'convert_from convert_to func cost')
  113. _virtual_superclasses = (Iterator,)
  114. def path(graph, source, target, excluded_edges=None, ooc_types=ooc_types):
  115. """ Path of functions between two types """
  116. if not isinstance(source, type):
  117. source = type(source)
  118. if not isinstance(target, type):
  119. target = type(target)
  120. for cls in concatv(source.mro(), _virtual_superclasses):
  121. if cls in graph:
  122. source = cls
  123. break
  124. # If both source and target are Out-Of-Core types then restrict ourselves
  125. # to the graph of out-of-core types
  126. if ooc_types:
  127. oocs = tuple(ooc_types)
  128. if issubclass(source, oocs) and issubclass(target, oocs):
  129. graph = graph.subgraph([n for n in graph.nodes()
  130. if issubclass(n, oocs)])
  131. with without_edges(graph, excluded_edges) as g:
  132. pth = nx.shortest_path(g, source=source, target=target, weight='cost')
  133. edge = adjacency(graph)
  134. def path_part(src, tgt):
  135. node = edge[src][tgt]
  136. return PathPart(src, tgt, node['func'], node['cost'])
  137. return map(path_part, pth, pth[1:])
  138. def path_cost(path):
  139. """Calculate the total cost of a path.
  140. """
  141. return sum(p.cost for p in path)
  142. @contextmanager
  143. def without_edges(g, edges):
  144. edges = edges or []
  145. held = dict()
  146. _g_edge = adjacency(g)
  147. for a, b in edges:
  148. held[(a, b)] = _g_edge[a][b]
  149. g.remove_edge(a, b)
  150. try:
  151. yield g
  152. finally:
  153. for (a, b), kwargs in held.items():
  154. g.add_edge(a, b, **kwargs)