/test/test_spawn.py

https://code.google.com/p/mpi4py/ · Python · 160 lines · 141 code · 17 blank · 2 comment · 22 complexity · d149456a022ac81e4d418b630d44bcdf MD5 · raw file

  1. import sys, os, mpi4py
  2. from mpi4py import MPI
  3. import mpiunittest as unittest
  4. MPI4PYPATH = os.path.abspath(os.path.dirname(mpi4py.__path__[0]))
  5. CHILDSCRIPT = os.path.abspath(
  6. os.path.join(os.path.dirname(__file__), 'spawn_child.py')
  7. )
  8. class BaseTestSpawn(object):
  9. COMM = MPI.COMM_NULL
  10. COMMAND = sys.executable
  11. ARGS = [CHILDSCRIPT, MPI4PYPATH]
  12. MAXPROCS = 1
  13. INFO = MPI.INFO_NULL
  14. ROOT = 0
  15. def testCommSpawn(self):
  16. child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
  17. info=self.INFO, root=self.ROOT)
  18. local_size = child.Get_size()
  19. remote_size = child.Get_remote_size()
  20. child.Barrier()
  21. child.Disconnect()
  22. self.assertEqual(local_size, self.COMM.Get_size())
  23. self.assertEqual(remote_size, self.MAXPROCS)
  24. def testReturnedErrcodes(self):
  25. errcodes = []
  26. child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
  27. info=self.INFO, root=self.ROOT,
  28. errcodes=errcodes)
  29. child.Barrier()
  30. child.Disconnect()
  31. rank = self.COMM.Get_rank()
  32. self.assertEqual(len(errcodes), self.MAXPROCS)
  33. for errcode in errcodes:
  34. self.assertEqual(errcode, MPI.SUCCESS)
  35. def testArgsOnlyAtRoot(self):
  36. self.COMM.Barrier()
  37. rank = self.COMM.Get_rank()
  38. if rank == self.ROOT:
  39. child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
  40. info=self.INFO, root=self.ROOT)
  41. else:
  42. child = self.COMM.Spawn(None, None, -1,
  43. info=None, root=self.ROOT)
  44. child.Barrier()
  45. child.Disconnect()
  46. self.COMM.Barrier()
  47. def testCommSpawnMultiple(self):
  48. COMMAND = [self.COMMAND] * 3
  49. ARGS = [self.ARGS] * len(COMMAND)
  50. MAXPROCS = [self.MAXPROCS] * len(COMMAND)
  51. INFO = [self.INFO] * len(COMMAND)
  52. child = self.COMM.Spawn_multiple(
  53. COMMAND, ARGS, MAXPROCS,
  54. info=INFO, root=self.ROOT)
  55. local_size = child.Get_size()
  56. remote_size = child.Get_remote_size()
  57. child.Barrier()
  58. child.Disconnect()
  59. self.assertEqual(local_size, self.COMM.Get_size())
  60. self.assertEqual(remote_size, sum(MAXPROCS))
  61. def testReturnedErrcodesMultiple(self):
  62. COMMAND = [self.COMMAND]*3
  63. ARGS = [self.ARGS]*len(COMMAND)
  64. MAXPROCS = range(1, len(COMMAND)+1)
  65. INFO = MPI.INFO_NULL
  66. errcodelist = []
  67. child = self.COMM.Spawn_multiple(
  68. COMMAND, ARGS, MAXPROCS,
  69. info=INFO, root=self.ROOT,
  70. errcodes=errcodelist)
  71. child.Barrier()
  72. child.Disconnect()
  73. rank = self.COMM.Get_rank()
  74. self.assertEqual(len(errcodelist), len(COMMAND))
  75. for i, errcodes in enumerate(errcodelist):
  76. self.assertEqual(len(errcodes), MAXPROCS[i])
  77. for errcode in errcodes:
  78. self.assertEqual(errcode, MPI.SUCCESS)
  79. def testArgsOnlyAtRootMultiple(self):
  80. self.COMM.Barrier()
  81. rank = self.COMM.Get_rank()
  82. if rank == self.ROOT:
  83. COMMAND = [self.COMMAND] * 3
  84. ARGS = [self.ARGS] * len(COMMAND)
  85. MAXPROCS = range(2, len(COMMAND)+2)
  86. INFO = [MPI.INFO_NULL] * len(COMMAND)
  87. child = self.COMM.Spawn_multiple(
  88. COMMAND, ARGS, MAXPROCS,
  89. info=INFO, root=self.ROOT)
  90. else:
  91. child = self.COMM.Spawn_multiple(
  92. None, None, -1,
  93. info=None, root=self.ROOT)
  94. child.Barrier()
  95. child.Disconnect()
  96. self.COMM.Barrier()
  97. class TestSpawnSelf(BaseTestSpawn, unittest.TestCase):
  98. COMM = MPI.COMM_SELF
  99. class TestSpawnWorld(BaseTestSpawn, unittest.TestCase):
  100. COMM = MPI.COMM_WORLD
  101. class TestSpawnSelfMany(BaseTestSpawn, unittest.TestCase):
  102. COMM = MPI.COMM_SELF
  103. MAXPROCS = MPI.COMM_WORLD.Get_size()
  104. class TestSpawnWorldMany(BaseTestSpawn, unittest.TestCase):
  105. COMM = MPI.COMM_WORLD
  106. MAXPROCS = MPI.COMM_WORLD.Get_size()
  107. _SKIP_TEST = False
  108. _name, _version = MPI.get_vendor()
  109. if _name == 'Open MPI':
  110. if _version < (1, 5, 0):
  111. _SKIP_TEST = True
  112. elif _version < (1, 4, 0):
  113. _SKIP_TEST = True
  114. if 'win' in sys.platform:
  115. _SKIP_TEST = True
  116. elif _name == 'MPICH2':
  117. if _version < (1, 0, 6):
  118. _SKIP_TEST = True
  119. if 'win' in sys.platform:
  120. _SKIP_TEST = True
  121. elif _name == 'Microsoft MPI':
  122. _SKIP_TEST = True
  123. elif _name == 'HP MPI':
  124. _SKIP_TEST = True
  125. elif MPI.Get_version() < (2, 0):
  126. _SKIP_TEST = True
  127. if _SKIP_TEST:
  128. del BaseTestSpawn
  129. del TestSpawnSelf
  130. del TestSpawnWorld
  131. del TestSpawnSelfMany
  132. del TestSpawnWorldMany
  133. elif _name == 'MPICH2':
  134. if _version > (1, 2):
  135. # Up to mpich2-1.3.1 when running under Hydra process manager,
  136. # spawn fails for the singleton init case
  137. if MPI.COMM_WORLD.Get_attr(MPI.APPNUM) is None:
  138. del TestSpawnSelf
  139. del TestSpawnWorld
  140. del TestSpawnSelfMany
  141. del TestSpawnWorldMany
  142. if __name__ == '__main__':
  143. unittest.main()