PageRenderTime 364ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/test/test_subclass.py

https://code.google.com/p/mpi4py/
Python | 307 lines | 220 code | 78 blank | 9 comment | 23 complexity | 6fb919c3b1fa73a7c5392a298e5902fc MD5 | raw file
Possible License(s): BSD-3-Clause
  1. from mpi4py import MPI
  2. import mpiunittest as unittest
  3. # ---
  4. class MyBaseComm(object):
  5. def free(self):
  6. if self != MPI.COMM_NULL:
  7. super(type(self), self).Free()
  8. class BaseTestBaseComm(object):
  9. def setUp(self):
  10. self.comm = self.CommType(self.COMM_BASE)
  11. def testSubType(self):
  12. self.assertTrue(type(self.comm) not in [
  13. MPI.Comm,
  14. MPI.Intracomm,
  15. MPI.Cartcomm,
  16. MPI.Graphcomm,
  17. MPI.Distgraphcomm,
  18. MPI.Intercomm])
  19. self.assertTrue(isinstance(self.comm, self.CommType))
  20. def testCloneFree(self):
  21. if self.COMM_BASE != MPI.COMM_NULL:
  22. comm = self.comm.Clone()
  23. else:
  24. comm = self.CommType()
  25. self.assertTrue(isinstance(comm, MPI.Comm))
  26. self.assertTrue(isinstance(comm, self.CommType))
  27. comm.free()
  28. def tearDown(self):
  29. self.comm.free()
  30. # ---
  31. class MyComm(MPI.Comm, MyBaseComm):
  32. def __new__(cls, comm=None):
  33. if comm is not None:
  34. if comm != MPI.COMM_NULL:
  35. comm = comm.Clone()
  36. return super(MyComm, cls).__new__(cls, comm)
  37. class BaseTestMyComm(BaseTestBaseComm):
  38. CommType = MyComm
  39. class TestMyCommNULL(BaseTestMyComm, unittest.TestCase):
  40. COMM_BASE = MPI.COMM_NULL
  41. class TestMyCommSELF(BaseTestMyComm, unittest.TestCase):
  42. COMM_BASE = MPI.COMM_SELF
  43. class TestMyCommWORLD(BaseTestMyComm, unittest.TestCase):
  44. COMM_BASE = MPI.COMM_WORLD
  45. # ---
  46. class MyIntracomm(MPI.Intracomm, MyBaseComm):
  47. def __new__(cls, comm=None):
  48. if comm is not None:
  49. if comm != MPI.COMM_NULL:
  50. comm = comm.Dup()
  51. return super(MyIntracomm, cls).__new__(cls, comm)
  52. class BaseTestMyIntracomm(BaseTestBaseComm):
  53. CommType = MyIntracomm
  54. class TestMyIntracommNULL(BaseTestMyIntracomm, unittest.TestCase):
  55. COMM_BASE = MPI.COMM_NULL
  56. class TestMyIntracommSELF(BaseTestMyIntracomm, unittest.TestCase):
  57. COMM_BASE = MPI.COMM_SELF
  58. class TestMyIntracommWORLD(BaseTestMyIntracomm, unittest.TestCase):
  59. COMM_BASE = MPI.COMM_WORLD
  60. # ---
  61. class MyCartcomm(MPI.Cartcomm, MyBaseComm):
  62. def __new__(cls, comm=None):
  63. if comm is not None:
  64. if comm != MPI.COMM_NULL:
  65. dims = [comm.size]
  66. comm = comm.Create_cart(dims)
  67. return super(MyCartcomm, cls).__new__(cls, comm)
  68. class BaseTestMyCartcomm(BaseTestBaseComm):
  69. CommType = MyCartcomm
  70. class TestMyCartcommNULL(BaseTestMyCartcomm, unittest.TestCase):
  71. COMM_BASE = MPI.COMM_NULL
  72. class TestMyCartcommSELF(BaseTestMyCartcomm, unittest.TestCase):
  73. COMM_BASE = MPI.COMM_SELF
  74. class TestMyCartcommWORLD(BaseTestMyCartcomm, unittest.TestCase):
  75. COMM_BASE = MPI.COMM_WORLD
  76. # ---
  77. class MyGraphcomm(MPI.Graphcomm, MyBaseComm):
  78. def __new__(cls, comm=None):
  79. if comm is not None:
  80. if comm != MPI.COMM_NULL:
  81. index = list(range(0, comm.size+1))
  82. edges = list(range(0, comm.size))
  83. comm = comm.Create_graph(index, edges)
  84. return super(MyGraphcomm, cls).__new__(cls, comm)
  85. class BaseTestMyGraphcomm(BaseTestBaseComm):
  86. CommType = MyGraphcomm
  87. class TestMyGraphcommNULL(BaseTestMyGraphcomm, unittest.TestCase):
  88. COMM_BASE = MPI.COMM_NULL
  89. class TestMyGraphcommSELF(BaseTestMyGraphcomm, unittest.TestCase):
  90. COMM_BASE = MPI.COMM_SELF
  91. class TestMyGraphcommWORLD(BaseTestMyGraphcomm, unittest.TestCase):
  92. COMM_BASE = MPI.COMM_WORLD
  93. # ---
  94. class MyRequest(MPI.Request):
  95. def __new__(cls, request=None):
  96. return super(MyRequest, cls).__new__(cls, request)
  97. def test(self):
  98. return super(type(self), self).Test()
  99. def wait(self):
  100. return super(type(self), self).Wait()
  101. class MyPrequest(MPI.Prequest):
  102. def __new__(cls, request=None):
  103. return super(MyPrequest, cls).__new__(cls, request)
  104. def test(self):
  105. return super(type(self), self).Test()
  106. def wait(self):
  107. return super(type(self), self).Wait()
  108. def start(self):
  109. return super(type(self), self).Start()
  110. class MyGrequest(MPI.Grequest):
  111. def __new__(cls, request=None):
  112. return super(MyGrequest, cls).__new__(cls, request)
  113. def test(self):
  114. return super(type(self), self).Test()
  115. def wait(self):
  116. return super(type(self), self).Wait()
  117. class BaseTestMyRequest(object):
  118. def setUp(self):
  119. self.req = self.MyRequestType(MPI.REQUEST_NULL)
  120. def testSubType(self):
  121. self.assertTrue(type(self.req) is not self.MPIRequestType)
  122. self.assertTrue(isinstance(self.req, self.MPIRequestType))
  123. self.assertTrue(isinstance(self.req, self.MyRequestType))
  124. self.req.test()
  125. class TestMyRequest(BaseTestMyRequest, unittest.TestCase):
  126. MPIRequestType = MPI.Request
  127. MyRequestType = MyRequest
  128. class TestMyPrequest(BaseTestMyRequest, unittest.TestCase):
  129. MPIRequestType = MPI.Prequest
  130. MyRequestType = MyPrequest
  131. class TestMyGrequest(BaseTestMyRequest, unittest.TestCase):
  132. MPIRequestType = MPI.Grequest
  133. MyRequestType = MyGrequest
  134. class TestMyRequest2(TestMyRequest):
  135. def setUp(self):
  136. req = MPI.COMM_SELF.Isend(
  137. [MPI.BOTTOM, 0, MPI.BYTE],
  138. dest=MPI.PROC_NULL, tag=0)
  139. self.req = MyRequest(req)
  140. class TestMyPrequest2(TestMyPrequest):
  141. def setUp(self):
  142. req = MPI.COMM_SELF.Send_init(
  143. [MPI.BOTTOM, 0, MPI.BYTE],
  144. dest=MPI.PROC_NULL, tag=0)
  145. self.req = MyPrequest(req)
  146. def tearDown(self):
  147. self.req.Free()
  148. def testStart(self):
  149. for i in range(5):
  150. self.req.start()
  151. self.req.test()
  152. self.req.start()
  153. self.req.wait()
  154. # ---
  155. class MyWin(MPI.Win):
  156. def __new__(cls, win=None):
  157. return MPI.Win.__new__(cls, win)
  158. def free(self):
  159. MPI.Win.Free(self)
  160. class BaseTestMyWin(object):
  161. def setUp(self):
  162. w = MPI.Win.Create(MPI.BOTTOM)
  163. self.win = MyWin(w)
  164. def tearDown(self):
  165. if self.win:
  166. self.win.Free()
  167. def testSubType(self):
  168. self.assertTrue(type(self.win) is not MPI.Win)
  169. self.assertTrue(isinstance(self.win, MPI.Win))
  170. self.assertTrue(isinstance(self.win, MyWin))
  171. def testFree(self):
  172. self.assertTrue(self.win)
  173. self.win.free()
  174. self.assertFalse(self.win)
  175. class TestMyWin(BaseTestMyWin, unittest.TestCase):
  176. pass
  177. try:
  178. w = MPI.Win.Create(MPI.BOTTOM).Free()
  179. except NotImplementedError:
  180. del TestMyWin
  181. # ---
  182. import os, tempfile
  183. class MyFile(MPI.File):
  184. def __new__(cls, file=None):
  185. return MPI.File.__new__(cls, file)
  186. def close(self):
  187. MPI.File.Close(self)
  188. class BaseTestMyFile(object):
  189. def setUp(self):
  190. fd, fname = tempfile.mkstemp(prefix='mpi4py')
  191. os.close(fd)
  192. amode = MPI.MODE_RDWR | MPI.MODE_CREATE | MPI.MODE_DELETE_ON_CLOSE
  193. try:
  194. f = MPI.File.Open(MPI.COMM_SELF,
  195. fname, amode,
  196. MPI.INFO_NULL)
  197. self.file = MyFile(f)
  198. except Exception:
  199. os.remove(fname)
  200. raise
  201. def tearDown(self):
  202. if self.file:
  203. self.file.Close()
  204. def testSubType(self):
  205. self.assertTrue(type(self.file) is not MPI.File)
  206. self.assertTrue(isinstance(self.file, MPI.File))
  207. self.assertTrue(isinstance(self.file, MyFile))
  208. def testFree(self):
  209. self.assertTrue(self.file)
  210. self.file.close()
  211. self.assertFalse(self.file)
  212. class TestMyFile(BaseTestMyFile, unittest.TestCase):
  213. pass
  214. import sys
  215. _name, _version = MPI.get_vendor()
  216. if _name == 'Open MPI':
  217. if 'win' in sys.platform:
  218. def _dummy(*args): raise NotImplementedError
  219. BaseTestMyFile.setUp = _dummy
  220. try:
  221. dummy = BaseTestMyFile()
  222. dummy.setUp()
  223. dummy.tearDown()
  224. del dummy
  225. except NotImplementedError:
  226. del TestMyFile
  227. # ---
  228. if __name__ == '__main__':
  229. unittest.main()