PageRenderTime 47ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/brewery/tests/test_node_stream.py

https://bitbucket.org/Stiivi/brewery/
Python | 221 lines | 162 code | 46 blank | 13 comment | 4 complexity | 3ef08cd5e1db8e8f43607244249e8727 MD5 | raw file
Possible License(s): LGPL-3.0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import brewery
  4. import brewery.ds as ds
  5. import unittest
  6. import logging
  7. import time
  8. import StringIO
  9. from brewery.streams import *
  10. from brewery.nodes import *
  11. logging.basicConfig(level=logging.WARN)
  12. class StreamBuildingTestCase(unittest.TestCase):
  13. def setUp(self):
  14. # Stream we have here:
  15. #
  16. # source ---+---> csv_target
  17. # |
  18. # +---> sample ----> html_target
  19. self.stream = Stream()
  20. self.node1 = Node()
  21. self.node1.description = "source"
  22. self.stream.add(self.node1, "source")
  23. self.node2 = Node()
  24. self.node2.description = "csv_target"
  25. self.stream.add(self.node2, "csv_target")
  26. self.node4 = Node()
  27. self.node4.description = "html_target"
  28. self.stream.add(self.node4, "html_target")
  29. self.node3 = Node()
  30. self.node3.description = "sample"
  31. self.stream.add(self.node3, "sample")
  32. self.stream.connect("source", "sample")
  33. self.stream.connect("source", "csv_target")
  34. self.stream.connect("sample", "html_target")
  35. def test_connections(self):
  36. self.assertEqual(4, len(self.stream.nodes))
  37. self.assertEqual(3, len(self.stream.connections))
  38. self.assertRaises(KeyError, self.stream.connect, "sample", "unknown")
  39. node = Node()
  40. self.assertRaises(KeyError, self.stream.add, node, "sample")
  41. self.stream.remove("sample")
  42. self.assertEqual(3, len(self.stream.nodes))
  43. self.assertEqual(1, len(self.stream.connections))
  44. def test_node_sort(self):
  45. sorted_nodes = self.stream.sorted_nodes()
  46. nodes = [self.node1, self.node3, self.node2, self.node4]
  47. self.assertEqual(self.node1, sorted_nodes[0])
  48. self.assertEqual(self.node4, sorted_nodes[-1])
  49. self.stream.connect("html_target", "source")
  50. self.assertRaises(Exception, self.stream.sorted_nodes)
  51. def test_update(self):
  52. stream_desc = {
  53. "nodes": {
  54. "source": {"type": "row_list_source"},
  55. "target": {"type": "record_list_target"},
  56. "aggtarget": {"type": "record_list_target"},
  57. "sample": {"type": "sample"},
  58. "map": {"type": "field_map"},
  59. "aggregate": {"type": "aggregate", "keys": ["str"] }
  60. },
  61. "connections": [
  62. ("source", "sample"),
  63. ("sample", "map"),
  64. ("map", "target"),
  65. ("source", "aggregate"),
  66. ("aggregate", "aggtarget")
  67. ]
  68. }
  69. stream = Stream()
  70. stream.update(stream_desc)
  71. self.assertTrue(isinstance(stream.node("source"), Node))
  72. self.assertTrue(isinstance(stream.node("aggregate"), AggregateNode))
  73. node = stream.node("aggregate")
  74. self.assertEqual(["str"], node.keys)
  75. class FailNode(Node):
  76. __node_info__ = {
  77. "attributes": [ {"name":"message"} ]
  78. }
  79. def __init__(self):
  80. self.message = "This is fail node and it failed as expected"
  81. def run(self):
  82. logging.debug("intentionally failing a node")
  83. raise Exception(self.message)
  84. class SlowSourceNode(Node):
  85. @property
  86. def output_fields(self):
  87. return brewery.fieldlist(["i"])
  88. def run(self):
  89. for cycle in range(0,10):
  90. for i in range(0, 1000):
  91. self.put([i])
  92. time.sleep(0.05)
  93. class StreamInitializationTestCase(unittest.TestCase):
  94. def setUp(self):
  95. # Stream we have here:
  96. #
  97. # source ---+---> aggregate ----> aggtarget
  98. # |
  99. # +---> sample ----> map ----> target
  100. self.fields = brewery.fieldlist(["a", "b", "c", "str"])
  101. self.src_list = [[1,2,3,"a"], [4,5,6,"b"], [7,8,9,"a"]]
  102. self.target_list = []
  103. self.aggtarget_list = []
  104. nodes = {
  105. "source": RowListSourceNode(self.src_list, self.fields),
  106. "target": RecordListTargetNode(self.target_list),
  107. "aggtarget": RecordListTargetNode(self.aggtarget_list),
  108. "sample": SampleNode("sample"),
  109. "map": FieldMapNode(drop_fields = ["c"]),
  110. "aggregate": AggregateNode(keys = ["str"])
  111. }
  112. connections = [
  113. ("source", "sample"),
  114. ("sample", "map"),
  115. ("map", "target"),
  116. ("source", "aggregate"),
  117. ("aggregate", "aggtarget")
  118. ]
  119. self.stream = Stream(nodes, connections)
  120. def test_initialization(self):
  121. self.stream._initialize()
  122. target = self.stream.node("map")
  123. names = target.output_fields.names()
  124. self.assertEqual(['a', 'b', 'str'], names)
  125. agg = self.stream.node("aggregate")
  126. names = agg.output_fields.names()
  127. self.assertEqual(['str', 'record_count'], names)
  128. def test_run(self):
  129. self.stream.run()
  130. target = self.stream.node("target")
  131. data = target.list
  132. expected = [{'a': 1, 'b': 2, 'str': 'a'},
  133. {'a': 4, 'b': 5, 'str': 'b'},
  134. {'a': 7, 'b': 8, 'str': 'a'}]
  135. self.assertEqual(expected, data)
  136. target = self.stream.node("aggtarget")
  137. data = target.list
  138. expected = [{'record_count': 2, 'str': 'a'}, {'record_count': 1, 'str': 'b'}]
  139. self.assertEqual(expected, data)
  140. def test_run_removed(self):
  141. self.stream.remove("aggregate")
  142. self.stream.remove("aggtarget")
  143. self.stream.run()
  144. def test_fail_run(self):
  145. nodes = {
  146. "source": RowListSourceNode(self.src_list, self.fields),
  147. "fail": FailNode(),
  148. "target": RecordListTargetNode(self.target_list)
  149. }
  150. connections = [
  151. ("source", "fail"),
  152. ("fail", "target")
  153. ]
  154. stream = Stream(nodes, connections)
  155. self.assertRaisesRegexp(StreamRuntimeError, "This is fail node", stream.run)
  156. nodes["fail"].message = u"Unicode message: čučoriedka ľúbivo ťukala"
  157. try:
  158. stream.run()
  159. except StreamRuntimeError, e:
  160. handle = StringIO.StringIO()
  161. # This should not raise an exception
  162. e.print_exception(handle)
  163. handle.close()
  164. def test_fail_with_slow_source(self):
  165. nodes = {
  166. "source": SlowSourceNode(),
  167. "fail": FailNode(),
  168. "target": RecordListTargetNode(self.target_list)
  169. }
  170. connections = [
  171. ("source", "fail"),
  172. ("fail", "target")
  173. ]
  174. stream = Stream(nodes, connections)
  175. self.assertRaises(StreamRuntimeError, stream.run)