/test/test_spawn.py
https://code.google.com/p/mpi4py/ · Python · 160 lines · 141 code · 17 blank · 2 comment · 22 complexity · d149456a022ac81e4d418b630d44bcdf MD5 · raw file
- import sys, os, mpi4py
- from mpi4py import MPI
- import mpiunittest as unittest
- MPI4PYPATH = os.path.abspath(os.path.dirname(mpi4py.__path__[0]))
- CHILDSCRIPT = os.path.abspath(
- os.path.join(os.path.dirname(__file__), 'spawn_child.py')
- )
- class BaseTestSpawn(object):
- COMM = MPI.COMM_NULL
- COMMAND = sys.executable
- ARGS = [CHILDSCRIPT, MPI4PYPATH]
- MAXPROCS = 1
- INFO = MPI.INFO_NULL
- ROOT = 0
- def testCommSpawn(self):
- child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
- info=self.INFO, root=self.ROOT)
- local_size = child.Get_size()
- remote_size = child.Get_remote_size()
- child.Barrier()
- child.Disconnect()
- self.assertEqual(local_size, self.COMM.Get_size())
- self.assertEqual(remote_size, self.MAXPROCS)
- def testReturnedErrcodes(self):
- errcodes = []
- child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
- info=self.INFO, root=self.ROOT,
- errcodes=errcodes)
- child.Barrier()
- child.Disconnect()
- rank = self.COMM.Get_rank()
- self.assertEqual(len(errcodes), self.MAXPROCS)
- for errcode in errcodes:
- self.assertEqual(errcode, MPI.SUCCESS)
- def testArgsOnlyAtRoot(self):
- self.COMM.Barrier()
- rank = self.COMM.Get_rank()
- if rank == self.ROOT:
- child = self.COMM.Spawn(self.COMMAND, self.ARGS, self.MAXPROCS,
- info=self.INFO, root=self.ROOT)
- else:
- child = self.COMM.Spawn(None, None, -1,
- info=None, root=self.ROOT)
- child.Barrier()
- child.Disconnect()
- self.COMM.Barrier()
- def testCommSpawnMultiple(self):
- COMMAND = [self.COMMAND] * 3
- ARGS = [self.ARGS] * len(COMMAND)
- MAXPROCS = [self.MAXPROCS] * len(COMMAND)
- INFO = [self.INFO] * len(COMMAND)
- child = self.COMM.Spawn_multiple(
- COMMAND, ARGS, MAXPROCS,
- info=INFO, root=self.ROOT)
- local_size = child.Get_size()
- remote_size = child.Get_remote_size()
- child.Barrier()
- child.Disconnect()
- self.assertEqual(local_size, self.COMM.Get_size())
- self.assertEqual(remote_size, sum(MAXPROCS))
- def testReturnedErrcodesMultiple(self):
- COMMAND = [self.COMMAND]*3
- ARGS = [self.ARGS]*len(COMMAND)
- MAXPROCS = range(1, len(COMMAND)+1)
- INFO = MPI.INFO_NULL
- errcodelist = []
- child = self.COMM.Spawn_multiple(
- COMMAND, ARGS, MAXPROCS,
- info=INFO, root=self.ROOT,
- errcodes=errcodelist)
- child.Barrier()
- child.Disconnect()
- rank = self.COMM.Get_rank()
- self.assertEqual(len(errcodelist), len(COMMAND))
- for i, errcodes in enumerate(errcodelist):
- self.assertEqual(len(errcodes), MAXPROCS[i])
- for errcode in errcodes:
- self.assertEqual(errcode, MPI.SUCCESS)
- def testArgsOnlyAtRootMultiple(self):
- self.COMM.Barrier()
- rank = self.COMM.Get_rank()
- if rank == self.ROOT:
- COMMAND = [self.COMMAND] * 3
- ARGS = [self.ARGS] * len(COMMAND)
- MAXPROCS = range(2, len(COMMAND)+2)
- INFO = [MPI.INFO_NULL] * len(COMMAND)
- child = self.COMM.Spawn_multiple(
- COMMAND, ARGS, MAXPROCS,
- info=INFO, root=self.ROOT)
- else:
- child = self.COMM.Spawn_multiple(
- None, None, -1,
- info=None, root=self.ROOT)
- child.Barrier()
- child.Disconnect()
- self.COMM.Barrier()
- class TestSpawnSelf(BaseTestSpawn, unittest.TestCase):
- COMM = MPI.COMM_SELF
- class TestSpawnWorld(BaseTestSpawn, unittest.TestCase):
- COMM = MPI.COMM_WORLD
- class TestSpawnSelfMany(BaseTestSpawn, unittest.TestCase):
- COMM = MPI.COMM_SELF
- MAXPROCS = MPI.COMM_WORLD.Get_size()
- class TestSpawnWorldMany(BaseTestSpawn, unittest.TestCase):
- COMM = MPI.COMM_WORLD
- MAXPROCS = MPI.COMM_WORLD.Get_size()
- _SKIP_TEST = False
- _name, _version = MPI.get_vendor()
- if _name == 'Open MPI':
- if _version < (1, 5, 0):
- _SKIP_TEST = True
- elif _version < (1, 4, 0):
- _SKIP_TEST = True
- if 'win' in sys.platform:
- _SKIP_TEST = True
- elif _name == 'MPICH2':
- if _version < (1, 0, 6):
- _SKIP_TEST = True
- if 'win' in sys.platform:
- _SKIP_TEST = True
- elif _name == 'Microsoft MPI':
- _SKIP_TEST = True
- elif _name == 'HP MPI':
- _SKIP_TEST = True
- elif MPI.Get_version() < (2, 0):
- _SKIP_TEST = True
- if _SKIP_TEST:
- del BaseTestSpawn
- del TestSpawnSelf
- del TestSpawnWorld
- del TestSpawnSelfMany
- del TestSpawnWorldMany
- elif _name == 'MPICH2':
- if _version > (1, 2):
- # Up to mpich2-1.3.1 when running under Hydra process manager,
- # spawn fails for the singleton init case
- if MPI.COMM_WORLD.Get_attr(MPI.APPNUM) is None:
- del TestSpawnSelf
- del TestSpawnWorld
- del TestSpawnSelfMany
- del TestSpawnWorldMany
- if __name__ == '__main__':
- unittest.main()