PageRenderTime 63ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/tests/test_cluster.py

https://gitlab.com/africanresearchcloud/hpc
Python | 391 lines | 350 code | 19 blank | 22 comment | 4 complexity | efb0b54f8f750175379d855c44a8a0c1 MD5 | raw file
  1. #! /usr/bin/env python
  2. #
  3. # Copyright (C) 2013, 2016 S3IT, University of Zurich
  4. #
  5. # This program is free software: you can redistribute it and/or modify
  6. # it under the terms of the GNU General Public License as published by
  7. # the Free Software Foundation, either version 3 of the License, or
  8. # (at your option) any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU General Public License
  16. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. #
  18. from __future__ import absolute_import
  19. __author__ = 'Nicolas Baer <nicolas.baer@uzh.ch>'
  20. import json
  21. import os
  22. import shutil
  23. import tempfile
  24. import unittest
  25. from mock import Mock, MagicMock, patch
  26. from elasticluster.conf import Configurator
  27. from elasticluster.cluster import Cluster, Node
  28. from elasticluster.exceptions import ClusterError
  29. from elasticluster.providers.ec2_boto import BotoCloudProvider
  30. from elasticluster.repository import PickleRepository
  31. from _helpers.config import Configuration
  32. class TestCluster(unittest.TestCase):
  33. def setUp(self):
  34. self.storage_path = tempfile.mkdtemp()
  35. f, path = tempfile.mkstemp()
  36. self.path = path
  37. def tearDown(self):
  38. shutil.rmtree(self.storage_path)
  39. os.unlink(self.path)
  40. def get_cluster(self, cloud_provider=None, config=None, nodes=None):
  41. if not cloud_provider:
  42. cloud_provider = BotoCloudProvider("https://hobbes.gc3.uzh.ch/",
  43. "nova", "a-key", "s-key")
  44. if not config:
  45. config = Configuration().get_config(self.path)
  46. setup = Mock()
  47. configurator = Configurator(config)
  48. conf_login = configurator.cluster_conf['mycluster']['login']
  49. repository = PickleRepository(self.storage_path)
  50. cluster = Cluster(name="mycluster",
  51. cloud_provider=cloud_provider,
  52. setup_provider=setup,
  53. repository=repository,
  54. user_key_name=conf_login['user_key_name'],
  55. user_key_public=conf_login['user_key_public'],
  56. user_key_private=conf_login['user_key_private'],
  57. )
  58. if not nodes:
  59. nodes = {"compute": 2, "frontend": 1}
  60. for kind, num in nodes.iteritems():
  61. conf_kind = configurator.cluster_conf['mycluster']['nodes'][kind]
  62. cluster.add_nodes(kind, num, conf_kind['image_id'],
  63. conf_login['image_user'],
  64. conf_kind['flavor'],
  65. conf_kind['security_group'])
  66. return cluster
  67. def test_add_node(self):
  68. """
  69. Add node
  70. """
  71. cluster = self.get_cluster()
  72. # without name
  73. size = len(cluster.nodes['compute'])
  74. cluster.add_node("compute", 'image_id', 'image_user', 'flavor',
  75. 'security_group')
  76. self.assertEqual(size + 1, len(cluster.nodes['compute']))
  77. new_node = cluster.nodes['compute'][2]
  78. self.assertEqual(new_node.name, 'compute003')
  79. self.assertEqual(new_node.kind, 'compute')
  80. # with custom name
  81. name = "test-node"
  82. size = len(cluster.nodes['compute'])
  83. cluster.add_node("compute", 'image_id', 'image_user', 'flavor',
  84. 'security_group', image_userdata="", name=name)
  85. self.assertEqual(size + 1, len(cluster.nodes['compute']))
  86. self.assertEqual(cluster.nodes['compute'][3].name, name)
  87. def test_remove_node(self):
  88. """
  89. Remove node
  90. """
  91. cluster = self.get_cluster()
  92. size = len(cluster.nodes['compute'])
  93. cluster.remove_node(cluster.nodes['compute'][1])
  94. self.assertEqual(size - 1, len(cluster.nodes['compute']))
  95. def test_start(self):
  96. """
  97. Start cluster
  98. """
  99. cloud_provider = MagicMock()
  100. cloud_provider.start_instance.return_value = u'test-id'
  101. cloud_provider.get_ips.return_value = ['127.0.0.1']
  102. cloud_provider.is_instance_running.return_value = True
  103. cluster = self.get_cluster(cloud_provider=cloud_provider)
  104. cluster.repository = MagicMock()
  105. cluster.repository.storage_path = '/foo/bar'
  106. ssh_mock = MagicMock()
  107. with patch('paramiko.SSHClient') as ssh_mock:
  108. cluster.start()
  109. cluster.repository.save_or_update.assert_called_with(cluster)
  110. for node in cluster.get_all_nodes():
  111. assert node.instance_id == u'test-id'
  112. assert node.ips == ['127.0.0.1']
  113. def test_check_cluster_size(self):
  114. nodes = {"compute": 3, "frontend": 1}
  115. nodes_min = {"compute": 1, "frontend": 3}
  116. cluster = self.get_cluster(nodes=nodes)
  117. cluster._check_cluster_size(nodes_min)
  118. self.assertEqual(len(cluster.nodes["frontend"]), 3)
  119. self.assertTrue(len(cluster.nodes["compute"]) >= 1)
  120. # not satisfiable cluster setup
  121. nodes = {"compute": 3, "frontend": 1}
  122. nodes_min = {"compute": 5, "frontend": 3}
  123. cluster = self.get_cluster(nodes=nodes)
  124. self.failUnlessRaises(ClusterError, cluster._check_cluster_size,
  125. min_nodes=nodes_min)
  126. def test_get_all_nodes(self):
  127. """
  128. Get all nodes
  129. """
  130. cluster = self.get_cluster()
  131. self.assertEqual(len(cluster.get_all_nodes()), 3)
  132. def test_stop(self):
  133. cloud_provider = MagicMock()
  134. cloud_provider.start_instance.return_value = u'test-id'
  135. cloud_provider.get_ips.return_value = ('127.0.0.1', '127.0.0.1')
  136. states = [True, True, True, True, True, False, False, False, False,
  137. False]
  138. def is_running(instance_id):
  139. return states.pop()
  140. cloud_provider.is_instance_running.side_effect = is_running
  141. cluster = self.get_cluster(cloud_provider=cloud_provider)
  142. for node in cluster.get_all_nodes():
  143. node.instance_id = u'test-id'
  144. cluster.repository = MagicMock()
  145. cluster.repository.storage_path = '/foo/bar'
  146. cluster.stop()
  147. cloud_provider.stop_instance.assert_called_with(u'test-id')
  148. cluster.repository.delete.assert_called_once_with(cluster)
  149. def test_get_frontend_node(self):
  150. """
  151. Get frontend node
  152. """
  153. config = Configuration().get_config(self.path)
  154. ssh_to = "frontend"
  155. config["mycluster"]["cluster"]["ssh_to"] = ssh_to
  156. cluster = self.get_cluster(config=config)
  157. cluster.ssh_to = ssh_to
  158. frontend = cluster.get_frontend_node()
  159. self.assertEqual(cluster.nodes['frontend'][0], frontend)
  160. def test_setup(self):
  161. """
  162. Setup the nodes of a cluster
  163. """
  164. setup_provider = MagicMock()
  165. setup_provider.setup_cluster.return_value = True
  166. cluster = self.get_cluster()
  167. cluster._setup_provider = setup_provider
  168. cluster.setup()
  169. setup_provider.setup_cluster.assert_called_once_with(cluster, tuple())
  170. def test_update(self):
  171. storage = MagicMock()
  172. cloud_provider = MagicMock()
  173. ip = '127.0.0.1'
  174. cloud_provider.get_ips.return_value = (ip, ip)
  175. cluster = self.get_cluster(cloud_provider=cloud_provider)
  176. cluster.repository = storage
  177. cluster.update()
  178. for node in cluster.get_all_nodes():
  179. self.assertEqual(node.ips[0], ip)
  180. def test_dict_mixin(self):
  181. """Check that the node class can be seen as dictionary"""
  182. config = Configuration().get_config(self.path)
  183. ssh_to = "frontend"
  184. config["mycluster"]["cluster"]["ssh_to"] = ssh_to
  185. cluster = self.get_cluster(config=config)
  186. cluster.ssh_to = ssh_to
  187. frontend = cluster.get_frontend_node()
  188. dcluster = dict(cluster)
  189. self.assertEqual(dcluster['ssh_to'], ssh_to)
  190. self.assertEqual(dcluster['nodes'].keys(), cluster.nodes.keys())
  191. self.failUnlessRaises(KeyError, lambda x: x['_cloud_provider'], dcluster)
  192. self.assertEqual(cluster['_cloud_provider'], cluster._cloud_provider)
  193. class TestNode(unittest.TestCase):
  194. def setUp(self):
  195. f, path = tempfile.mkstemp()
  196. self.path = path
  197. self.cluster_name = "cluster"
  198. self.name = "test"
  199. self.node_kind = "frontend"
  200. self.user_key_public = self.path
  201. self.user_key_private = self.path
  202. self.user_key_name = "key"
  203. self.image_user = "gc3-user"
  204. self.security_group = "security"
  205. self.image = "ami-000000"
  206. self.flavor = "m1.tiny"
  207. self.image_userdata = None
  208. def tearDown(self):
  209. os.unlink(self.path)
  210. def get_node(self):
  211. cloud_provider = MagicMock()
  212. node = Node(self.name, self.cluster_name, self.node_kind,
  213. cloud_provider, self.user_key_public,
  214. self.user_key_private, self.user_key_name,
  215. self.image_user, self.security_group,
  216. self.image, self.flavor,
  217. self.image_userdata)
  218. return node
  219. def test_start(self):
  220. """
  221. Start node
  222. """
  223. node = self.get_node()
  224. instance_id = "test-id"
  225. cloud_provider = node._cloud_provider
  226. cloud_provider.start_instance.return_value = instance_id
  227. node.start()
  228. node_name = "%s-%s" % (self.cluster_name, node.name)
  229. cloud_provider.start_instance.assert_called_once_with(
  230. self.user_key_name, self.user_key_public, self.user_key_private,
  231. self.security_group, self.flavor, self.image,
  232. self.image_userdata, username=self.image_user, node_name=node_name)
  233. self.assertEqual(node.instance_id, instance_id)
  234. def test_stop(self):
  235. """
  236. Stop Node
  237. """
  238. node = self.get_node()
  239. instance_id = "test-id"
  240. node.instance_id = instance_id
  241. node.stop()
  242. cloud_provider = node._cloud_provider
  243. cloud_provider.stop_instance.assert_called_once_with(instance_id)
  244. def test_is_alive(self):
  245. """
  246. Node is alive
  247. """
  248. # check without having any knowlegde of the node (e.g. instance id)
  249. node = self.get_node()
  250. self.assertFalse(node.is_alive())
  251. # check with knowledge and cloud provider and mock ip update
  252. instance_id = "test-id"
  253. node.instance_id = instance_id
  254. provider = node._cloud_provider
  255. provider.is_instance_running.return_value = True
  256. provider.get_ips.return_value = ['127.0.0.1', '127.0.0.1']
  257. node.is_alive()
  258. provider.is_instance_running.assert_called_once_with(instance_id)
  259. def test_connect(self):
  260. """
  261. Connect to node
  262. """
  263. node = self.get_node()
  264. # check without any ips set on the host
  265. self.assertEqual(node.connect(), None)
  266. # check with mocking the ssh connection
  267. ssh_mock = MagicMock()
  268. with patch('elasticluster.cluster.paramiko.SSHClient') as ssh_mock:
  269. ssh_mock.connect.return_value = True
  270. node.connect()
  271. def test_update_ips(self):
  272. """
  273. Update node ip address
  274. """
  275. # check without any ip addresses set
  276. node = self.get_node()
  277. instance_id = "test-id"
  278. node.instance_id = instance_id
  279. provider = node._cloud_provider
  280. ips = ['127.0.0.1', '127.0.0.2']
  281. node.ips = ips
  282. provider.get_ips.return_value = ips
  283. node.update_ips()
  284. self.assertEqual(node.ips, ips)
  285. provider.get_ips.assert_called_once_with(instance_id)
  286. def test_dict_mixin(self):
  287. """Check that the node class can be seen as dictionary"""
  288. node = self.get_node()
  289. # Setup node with dummy values
  290. instance_id = "test-id"
  291. node.instance_id = instance_id
  292. ips = ['127.0.0.1', '127.0.0.2']
  293. node.ips = ips
  294. dnode = dict(node)
  295. self.assertEqual(node['instance_id'], instance_id)
  296. self.assertEqual(node['ips'], ips)
  297. self.failUnlessRaises(KeyError, lambda x: x['_cloud_provider'], dnode)
  298. self.assertEqual(node['_cloud_provider'], node._cloud_provider)
  299. if __name__ == "__main__":
  300. pytest.main(['-v', __file__])