PageRenderTime 61ms CodeModel.GetById 25ms RepoModel.GetById 0ms app.codeStats 0ms

/test/test_subclass.py

https://bitbucket.org/svenx/mpi4py
Python | 303 lines | 215 code | 80 blank | 8 comment | 22 complexity | 00a44b48e47e9a8eab919c8b432f5e26 MD5 | raw file
  1. from mpi4py import MPI
  2. import mpiunittest as unittest
  3. import sys
  4. # ---
  5. class MyBaseComm(object):
  6. def free(self):
  7. if self != MPI.COMM_NULL:
  8. MPI.Comm.Free(self)
  9. class BaseTestBaseComm(object):
  10. def setUp(self):
  11. self.comm = self.CommType(self.COMM_BASE)
  12. def testSubType(self):
  13. self.assertTrue(type(self.comm) not in [
  14. MPI.Comm,
  15. MPI.Intracomm,
  16. MPI.Cartcomm,
  17. MPI.Graphcomm,
  18. MPI.Distgraphcomm,
  19. MPI.Intercomm])
  20. self.assertTrue(isinstance(self.comm, self.CommType))
  21. def testCloneFree(self):
  22. if self.COMM_BASE != MPI.COMM_NULL:
  23. comm = self.comm.Clone()
  24. else:
  25. comm = self.CommType()
  26. self.assertTrue(isinstance(comm, MPI.Comm))
  27. self.assertTrue(isinstance(comm, self.CommType))
  28. comm.free()
  29. def tearDown(self):
  30. self.comm.free()
  31. # ---
  32. class MyComm(MPI.Comm, MyBaseComm):
  33. def __new__(cls, comm=None):
  34. if comm is not None:
  35. if comm != MPI.COMM_NULL:
  36. comm = comm.Clone()
  37. return super(MyComm, cls).__new__(cls, comm)
  38. class BaseTestMyComm(BaseTestBaseComm):
  39. CommType = MyComm
  40. class TestMyCommNULL(BaseTestMyComm, unittest.TestCase):
  41. COMM_BASE = MPI.COMM_NULL
  42. class TestMyCommSELF(BaseTestMyComm, unittest.TestCase):
  43. COMM_BASE = MPI.COMM_SELF
  44. class TestMyCommWORLD(BaseTestMyComm, unittest.TestCase):
  45. COMM_BASE = MPI.COMM_WORLD
  46. # ---
  47. class MyIntracomm(MPI.Intracomm, MyBaseComm):
  48. def __new__(cls, comm=None):
  49. if comm is not None:
  50. if comm != MPI.COMM_NULL:
  51. comm = comm.Dup()
  52. return super(MyIntracomm, cls).__new__(cls, comm)
  53. class BaseTestMyIntracomm(BaseTestBaseComm):
  54. CommType = MyIntracomm
  55. class TestMyIntracommNULL(BaseTestMyIntracomm, unittest.TestCase):
  56. COMM_BASE = MPI.COMM_NULL
  57. class TestMyIntracommSELF(BaseTestMyIntracomm, unittest.TestCase):
  58. COMM_BASE = MPI.COMM_SELF
  59. class TestMyIntracommWORLD(BaseTestMyIntracomm, unittest.TestCase):
  60. COMM_BASE = MPI.COMM_WORLD
  61. # ---
  62. class MyCartcomm(MPI.Cartcomm, MyBaseComm):
  63. def __new__(cls, comm=None):
  64. if comm is not None:
  65. if comm != MPI.COMM_NULL:
  66. dims = [comm.size]
  67. comm = comm.Create_cart(dims)
  68. return super(MyCartcomm, cls).__new__(cls, comm)
  69. class BaseTestMyCartcomm(BaseTestBaseComm):
  70. CommType = MyCartcomm
  71. class TestMyCartcommNULL(BaseTestMyCartcomm, unittest.TestCase):
  72. COMM_BASE = MPI.COMM_NULL
  73. class TestMyCartcommSELF(BaseTestMyCartcomm, unittest.TestCase):
  74. COMM_BASE = MPI.COMM_SELF
  75. class TestMyCartcommWORLD(BaseTestMyCartcomm, unittest.TestCase):
  76. COMM_BASE = MPI.COMM_WORLD
  77. # ---
  78. class MyGraphcomm(MPI.Graphcomm, MyBaseComm):
  79. def __new__(cls, comm=None):
  80. if comm is not None:
  81. if comm != MPI.COMM_NULL:
  82. index = list(range(0, comm.size+1))
  83. edges = list(range(0, comm.size))
  84. comm = comm.Create_graph(index, edges)
  85. return super(MyGraphcomm, cls).__new__(cls, comm)
  86. class BaseTestMyGraphcomm(BaseTestBaseComm):
  87. CommType = MyGraphcomm
  88. class TestMyGraphcommNULL(BaseTestMyGraphcomm, unittest.TestCase):
  89. COMM_BASE = MPI.COMM_NULL
  90. class TestMyGraphcommSELF(BaseTestMyGraphcomm, unittest.TestCase):
  91. COMM_BASE = MPI.COMM_SELF
  92. class TestMyGraphcommWORLD(BaseTestMyGraphcomm, unittest.TestCase):
  93. COMM_BASE = MPI.COMM_WORLD
  94. # ---
  95. class MyRequest(MPI.Request):
  96. def __new__(cls, request=None):
  97. return super(MyRequest, cls).__new__(cls, request)
  98. def test(self):
  99. return super(type(self), self).Test()
  100. def wait(self):
  101. return super(type(self), self).Wait()
  102. class MyPrequest(MPI.Prequest):
  103. def __new__(cls, request=None):
  104. return super(MyPrequest, cls).__new__(cls, request)
  105. def test(self):
  106. return super(type(self), self).Test()
  107. def wait(self):
  108. return super(type(self), self).Wait()
  109. def start(self):
  110. return super(type(self), self).Start()
  111. class MyGrequest(MPI.Grequest):
  112. def __new__(cls, request=None):
  113. return super(MyGrequest, cls).__new__(cls, request)
  114. def test(self):
  115. return super(type(self), self).Test()
  116. def wait(self):
  117. return super(type(self), self).Wait()
  118. class BaseTestMyRequest(object):
  119. def setUp(self):
  120. self.req = self.MyRequestType(MPI.REQUEST_NULL)
  121. def testSubType(self):
  122. self.assertTrue(type(self.req) is not self.MPIRequestType)
  123. self.assertTrue(isinstance(self.req, self.MPIRequestType))
  124. self.assertTrue(isinstance(self.req, self.MyRequestType))
  125. self.req.test()
  126. class TestMyRequest(BaseTestMyRequest, unittest.TestCase):
  127. MPIRequestType = MPI.Request
  128. MyRequestType = MyRequest
  129. class TestMyPrequest(BaseTestMyRequest, unittest.TestCase):
  130. MPIRequestType = MPI.Prequest
  131. MyRequestType = MyPrequest
  132. class TestMyGrequest(BaseTestMyRequest, unittest.TestCase):
  133. MPIRequestType = MPI.Grequest
  134. MyRequestType = MyGrequest
  135. class TestMyRequest2(TestMyRequest):
  136. def setUp(self):
  137. req = MPI.COMM_SELF.Isend(
  138. [MPI.BOTTOM, 0, MPI.BYTE],
  139. dest=MPI.PROC_NULL, tag=0)
  140. self.req = MyRequest(req)
  141. class TestMyPrequest2(TestMyPrequest):
  142. def setUp(self):
  143. req = MPI.COMM_SELF.Send_init(
  144. [MPI.BOTTOM, 0, MPI.BYTE],
  145. dest=MPI.PROC_NULL, tag=0)
  146. self.req = MyPrequest(req)
  147. def tearDown(self):
  148. self.req.Free()
  149. def testStart(self):
  150. for i in range(5):
  151. self.req.start()
  152. self.req.test()
  153. self.req.start()
  154. self.req.wait()
  155. # ---
  156. class MyWin(MPI.Win):
  157. def __new__(cls, win=None):
  158. return MPI.Win.__new__(cls, win)
  159. def free(self):
  160. if self != MPI.WIN_NULL:
  161. MPI.Win.Free(self)
  162. class BaseTestMyWin(object):
  163. def setUp(self):
  164. w = MPI.Win.Create(MPI.BOTTOM)
  165. self.win = MyWin(w)
  166. def tearDown(self):
  167. self.win.free()
  168. def testSubType(self):
  169. self.assertTrue(type(self.win) is not MPI.Win)
  170. self.assertTrue(isinstance(self.win, MPI.Win))
  171. self.assertTrue(isinstance(self.win, MyWin))
  172. def testFree(self):
  173. self.assertTrue(self.win)
  174. self.win.free()
  175. self.assertFalse(self.win)
  176. class TestMyWin(BaseTestMyWin, unittest.TestCase):
  177. pass
  178. SpectrumMPI = MPI.get_vendor()[0] == 'Spectrum MPI'
  179. try:
  180. if SpectrumMPI: raise NotImplementedError
  181. MPI.Win.Create(MPI.BOTTOM).Free()
  182. except NotImplementedError:
  183. unittest.disable(BaseTestMyWin, 'mpi-win')
  184. # ---
  185. import os, tempfile
  186. class MyFile(MPI.File):
  187. def __new__(cls, file=None):
  188. return MPI.File.__new__(cls, file)
  189. def close(self):
  190. if self != MPI.FILE_NULL:
  191. MPI.File.Close(self)
  192. class BaseTestMyFile(object):
  193. def openfile(self):
  194. fd, fname = tempfile.mkstemp(prefix='mpi4py')
  195. os.close(fd)
  196. amode = MPI.MODE_RDWR | MPI.MODE_CREATE | MPI.MODE_DELETE_ON_CLOSE
  197. try:
  198. self.file = MPI.File.Open(MPI.COMM_SELF, fname, amode, MPI.INFO_NULL)
  199. return self.file
  200. except Exception:
  201. os.remove(fname)
  202. raise
  203. def setUp(self):
  204. f = self.openfile()
  205. self.file = MyFile(f)
  206. def tearDown(self):
  207. self.file.close()
  208. def testSubType(self):
  209. self.assertTrue(type(self.file) is not MPI.File)
  210. self.assertTrue(isinstance(self.file, MPI.File))
  211. self.assertTrue(isinstance(self.file, MyFile))
  212. def testFree(self):
  213. self.assertTrue(self.file)
  214. self.file.close()
  215. self.assertFalse(self.file)
  216. class TestMyFile(BaseTestMyFile, unittest.TestCase):
  217. pass
  218. try:
  219. BaseTestMyFile().openfile().Close()
  220. except NotImplementedError:
  221. unittest.disable(BaseTestMyFile, 'mpi-file')
  222. if __name__ == '__main__':
  223. unittest.main()