/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py

https://github.com/PaddlePaddle/Paddle · Python · 192 lines · 140 code · 39 blank · 13 comment · 6 complexity · 8c7a1e9b47a892e231bade6d42f7a618 MD5 · raw file

  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. import unittest
  17. import tempfile
  18. from test_dist_fleet_base import TestFleetBase
  19. class TestDistMnistSync2x2(TestFleetBase):
  20. def _setup_config(self):
  21. self._mode = "sync"
  22. self._reader = "pyreader"
  23. def check_with_place(self,
  24. model_file,
  25. delta=1e-3,
  26. check_error_log=False,
  27. need_envs={}):
  28. required_envs = {
  29. "PATH": os.getenv("PATH", ""),
  30. "PYTHONPATH": os.getenv("PYTHONPATH", ""),
  31. "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
  32. "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
  33. "http_proxy": "",
  34. "CPU_NUM": "2"
  35. }
  36. required_envs.update(need_envs)
  37. if check_error_log:
  38. required_envs["GLOG_v"] = "3"
  39. required_envs["GLOG_logtostderr"] = "1"
  40. tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
  41. def test_dist_train(self):
  42. self.check_with_place(
  43. "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
  44. class TestDistMnistAuto2x2(TestFleetBase):
  45. def _setup_config(self):
  46. self._mode = "auto"
  47. self._reader = "pyreader"
  48. def check_with_place(self,
  49. model_file,
  50. delta=1e-3,
  51. check_error_log=False,
  52. need_envs={}):
  53. required_envs = {
  54. "PATH": os.getenv("PATH", ""),
  55. "PYTHONPATH": os.getenv("PYTHONPATH", ""),
  56. "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
  57. "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
  58. "http_proxy": "",
  59. "CPU_NUM": "2"
  60. }
  61. required_envs.update(need_envs)
  62. if check_error_log:
  63. required_envs["GLOG_v"] = "3"
  64. required_envs["GLOG_logtostderr"] = "1"
  65. tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
  66. def test_dist_train(self):
  67. self.check_with_place(
  68. "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
  69. class TestDistMnistAsync2x2(TestFleetBase):
  70. def _setup_config(self):
  71. self._mode = "async"
  72. self._reader = "pyreader"
  73. def check_with_place(self,
  74. model_file,
  75. delta=1e-3,
  76. check_error_log=False,
  77. need_envs={}):
  78. required_envs = {
  79. "PATH": os.getenv("PATH", ""),
  80. "PYTHONPATH": os.getenv("PYTHONPATH", ""),
  81. "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
  82. "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
  83. "http_proxy": "",
  84. "CPU_NUM": "2"
  85. }
  86. required_envs.update(need_envs)
  87. if check_error_log:
  88. required_envs["GLOG_v"] = "3"
  89. required_envs["GLOG_logtostderr"] = "1"
  90. tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
  91. def test_dist_train(self):
  92. self.check_with_place(
  93. "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
  94. @unittest.skip(reason="Skip unstable ut, reader need to be rewrite")
  95. class TestDistMnistAsyncDataset2x2(TestFleetBase):
  96. def _setup_config(self):
  97. self._mode = "async"
  98. self._reader = "dataset"
  99. def check_with_place(self,
  100. model_file,
  101. delta=1e-3,
  102. check_error_log=False,
  103. need_envs={}):
  104. required_envs = {
  105. "PATH": os.getenv("PATH", ""),
  106. "PYTHONPATH": os.getenv("PYTHONPATH", ""),
  107. "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
  108. "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
  109. "http_proxy": "",
  110. "SAVE_MODEL": "1",
  111. "dump_param": "concat_0.tmp_0",
  112. "dump_fields": "dnn-fc-3.tmp_0,dnn-fc-3.tmp_0@GRAD",
  113. "dump_fields_path": tempfile.mkdtemp(),
  114. "Debug": "1"
  115. }
  116. required_envs.update(need_envs)
  117. if check_error_log:
  118. required_envs["GLOG_v"] = "3"
  119. required_envs["GLOG_logtostderr"] = "1"
  120. tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
  121. def test_dist_train(self):
  122. self.check_with_place(
  123. "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
  124. class TestDistCtrHalfAsync2x2(TestFleetBase):
  125. def _setup_config(self):
  126. self._mode = "async"
  127. self._reader = "pyreader"
  128. def check_with_place(self,
  129. model_file,
  130. delta=1e-3,
  131. check_error_log=False,
  132. need_envs={}):
  133. required_envs = {
  134. "PATH": os.getenv("PATH", ""),
  135. "PYTHONPATH": os.getenv("PYTHONPATH", ""),
  136. "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
  137. "FLAGS_rpc_deadline": "30000", # 5sec to fail fast
  138. "http_proxy": "",
  139. "FLAGS_communicator_send_queue_size": "2",
  140. "FLAGS_communicator_max_merge_var_num": "2",
  141. "CPU_NUM": "2",
  142. "SAVE_MODEL": "0"
  143. }
  144. required_envs.update(need_envs)
  145. if check_error_log:
  146. required_envs["GLOG_v"] = "3"
  147. required_envs["GLOG_logtostderr"] = "1"
  148. tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
  149. def test_dist_train(self):
  150. self.check_with_place(
  151. "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
  152. if __name__ == "__main__":
  153. unittest.main()