/test/data_join/test_raw_data_manifest_manager.py

https://github.com/bytedance/fedlearner · Python · 291 lines · 261 code · 16 blank · 14 comment · 32 complexity · 3c2876c4ec56fb3d9a53d59a5d31717e MD5 · raw file

  1. # Copyright 2020 The FedLearner Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # coding: utf-8
  15. import os
  16. import unittest
  17. from tensorflow.compat.v1 import gfile
  18. from google.protobuf import text_format, timestamp_pb2
  19. from fedlearner.common import db_client
  20. from fedlearner.common import data_join_service_pb2 as dj_pb
  21. from fedlearner.common import common_pb2 as common_pb
  22. from fedlearner.data_join import raw_data_manifest_manager, common
  23. class TestRawDataManifestManager(unittest.TestCase):
  24. def _raw_data_manifest_manager(self, cli):
  25. partition_num = 4
  26. rank_id = 2
  27. data_source = common_pb.DataSource()
  28. data_source.data_source_meta.name = "milestone-x"
  29. data_source.data_source_meta.partition_num = partition_num
  30. data_source.role = common_pb.FLRole.Leader
  31. cli.delete_prefix(common.data_source_kvstore_base_dir(data_source.data_source_meta.name))
  32. manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
  33. cli, data_source)
  34. manifest_map = manifest_manager.list_all_manifest()
  35. for i in range(partition_num):
  36. self.assertTrue(i in manifest_map)
  37. self.assertEqual(
  38. manifest_map[i].sync_example_id_rep.state,
  39. dj_pb.SyncExampleIdState.UnSynced
  40. )
  41. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  42. self.assertEqual(
  43. manifest_map[i].join_example_rep.state,
  44. dj_pb.JoinExampleState.UnJoined
  45. )
  46. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  47. self.assertFalse(manifest_map[i].finished)
  48. manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
  49. self.assertNotEqual(manifest, None)
  50. partition_id = manifest.partition_id
  51. manifest_map = manifest_manager.list_all_manifest()
  52. for i in range(partition_num):
  53. self.assertTrue(i in manifest_map)
  54. if i != partition_id:
  55. self.assertEqual(
  56. manifest_map[i].sync_example_id_rep.state,
  57. dj_pb.SyncExampleIdState.UnSynced
  58. )
  59. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  60. self.assertEqual(
  61. manifest_map[i].join_example_rep.state,
  62. dj_pb.JoinExampleState.UnJoined
  63. )
  64. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  65. else:
  66. self.assertEqual(
  67. manifest_map[i].sync_example_id_rep.state,
  68. dj_pb.SyncExampleIdState.Syncing
  69. )
  70. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
  71. self.assertEqual(
  72. manifest_map[i].join_example_rep.state,
  73. dj_pb.JoinExampleState.UnJoined
  74. )
  75. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  76. self.assertFalse(manifest_map[i].finished)
  77. partition_id2 = 3 - partition_id
  78. rank_id2 = 100
  79. manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
  80. manifest_map = manifest_manager.list_all_manifest()
  81. for i in range(partition_num):
  82. self.assertTrue(i in manifest_map)
  83. if i == partition_id:
  84. self.assertEqual(
  85. manifest_map[i].sync_example_id_rep.state,
  86. dj_pb.SyncExampleIdState.Syncing
  87. )
  88. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
  89. else:
  90. self.assertEqual(
  91. manifest_map[i].sync_example_id_rep.state,
  92. dj_pb.SyncExampleIdState.UnSynced
  93. )
  94. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  95. if i == partition_id2:
  96. self.assertEqual(
  97. manifest_map[i].join_example_rep.state,
  98. dj_pb.JoinExampleState.Joining
  99. )
  100. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
  101. else:
  102. self.assertEqual(
  103. manifest_map[i].join_example_rep.state,
  104. dj_pb.JoinExampleState.UnJoined
  105. )
  106. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  107. self.assertFalse(manifest_map[i].finished)
  108. self.assertRaises(Exception, manifest_manager.finish_join_example,
  109. rank_id, partition_id)
  110. self.assertRaises(Exception, manifest_manager.finish_join_example,
  111. rank_id2, partition_id2)
  112. self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
  113. -rank_id, partition_id)
  114. self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
  115. rank_id2, partition_id2)
  116. rank_id3 = 0
  117. manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
  118. manifest_map = manifest_manager.list_all_manifest()
  119. for i in range(partition_num):
  120. self.assertTrue(i in manifest_map)
  121. if i == partition_id:
  122. self.assertEqual(
  123. manifest_map[i].sync_example_id_rep.state,
  124. dj_pb.SyncExampleIdState.Syncing
  125. )
  126. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
  127. else:
  128. self.assertEqual(
  129. manifest_map[i].sync_example_id_rep.state,
  130. dj_pb.SyncExampleIdState.UnSynced
  131. )
  132. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  133. if i == partition_id:
  134. self.assertEqual(
  135. manifest_map[i].join_example_rep.state,
  136. dj_pb.JoinExampleState.Joining
  137. )
  138. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
  139. elif i == partition_id2:
  140. self.assertEqual(
  141. manifest_map[i].join_example_rep.state,
  142. dj_pb.JoinExampleState.Joining
  143. )
  144. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
  145. else:
  146. self.assertEqual(
  147. manifest_map[i].join_example_rep.state,
  148. dj_pb.JoinExampleState.UnJoined
  149. )
  150. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  151. self.assertFalse(manifest_map[i].finished)
  152. self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
  153. rank_id, partition_id)
  154. raw_data_metas = [dj_pb.RawDataMeta(file_path='a',
  155. timestamp=timestamp_pb2.Timestamp(seconds=3)),
  156. dj_pb.RawDataMeta(file_path='a',
  157. timestamp=timestamp_pb2.Timestamp(seconds=3)),
  158. dj_pb.RawDataMeta(file_path='c',
  159. timestamp=timestamp_pb2.Timestamp(seconds=1))]
  160. self.assertRaises(Exception, manifest_manager.add_raw_data,
  161. partition_id, raw_data_metas, False)
  162. manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
  163. latest_ts = manifest_manager.get_raw_date_latest_timestamp(partition_id)
  164. self.assertEqual(latest_ts.seconds, 3)
  165. self.assertEqual(latest_ts.nanos, 0)
  166. manifest = manifest_manager.get_manifest(partition_id)
  167. self.assertEqual(manifest.next_process_index, 2)
  168. raw_data_metas = [dj_pb.RawDataMeta(file_path='a',
  169. timestamp=timestamp_pb2.Timestamp(seconds=3)),
  170. dj_pb.RawDataMeta(file_path='a',
  171. timestamp=timestamp_pb2.Timestamp(seconds=3)),
  172. dj_pb.RawDataMeta(file_path='b',
  173. timestamp=timestamp_pb2.Timestamp(seconds=2)),
  174. dj_pb.RawDataMeta(file_path='c',
  175. timestamp=timestamp_pb2.Timestamp(seconds=1)),
  176. dj_pb.RawDataMeta(file_path='d',
  177. timestamp=timestamp_pb2.Timestamp(seconds=4))]
  178. manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
  179. latest_ts = manifest_manager.get_raw_date_latest_timestamp(partition_id)
  180. self.assertEqual(latest_ts.seconds, 4)
  181. self.assertEqual(latest_ts.nanos, 0)
  182. manifest_map = manifest_manager.list_all_manifest()
  183. for i in range(partition_num):
  184. self.assertTrue(i in manifest_map)
  185. if i == partition_id:
  186. self.assertEqual(manifest_map[i].next_process_index, 4)
  187. else:
  188. self.assertEqual(manifest_map[i].next_process_index, 0)
  189. manifest_manager.finish_raw_data(partition_id)
  190. manifest_manager.finish_raw_data(partition_id)
  191. self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200)
  192. manifest_manager.finish_sync_example_id(rank_id, partition_id)
  193. manifest_manager.finish_sync_example_id(rank_id, partition_id)
  194. manifest_map = manifest_manager.list_all_manifest()
  195. for i in range(data_source.data_source_meta.partition_num):
  196. self.assertTrue(i in manifest_map)
  197. if i == partition_id:
  198. self.assertEqual(
  199. manifest_map[i].sync_example_id_rep.state,
  200. dj_pb.SyncExampleIdState.Synced
  201. )
  202. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
  203. self.assertTrue(manifest_map[i].finished)
  204. else:
  205. self.assertEqual(
  206. manifest_map[i].sync_example_id_rep.state,
  207. dj_pb.SyncExampleIdState.UnSynced
  208. )
  209. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  210. if i == partition_id:
  211. self.assertEqual(
  212. manifest_map[i].join_example_rep.state,
  213. dj_pb.JoinExampleState.Joining
  214. )
  215. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
  216. elif i == partition_id2:
  217. self.assertEqual(
  218. manifest_map[i].join_example_rep.state,
  219. dj_pb.JoinExampleState.Joining
  220. )
  221. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
  222. else:
  223. self.assertEqual(
  224. manifest_map[i].join_example_rep.state,
  225. dj_pb.JoinExampleState.UnJoined
  226. )
  227. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  228. manifest_manager.finish_join_example(rank_id3, partition_id)
  229. manifest_manager.finish_join_example(rank_id3, partition_id)
  230. manifest_map = manifest_manager.list_all_manifest()
  231. for i in range(data_source.data_source_meta.partition_num):
  232. self.assertTrue(i in manifest_map)
  233. if i == partition_id:
  234. self.assertEqual(
  235. manifest_map[i].sync_example_id_rep.state,
  236. dj_pb.SyncExampleIdState.Synced
  237. )
  238. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
  239. else:
  240. self.assertEqual(
  241. manifest_map[i].sync_example_id_rep.state,
  242. dj_pb.SyncExampleIdState.UnSynced
  243. )
  244. self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
  245. if i == partition_id:
  246. self.assertEqual(
  247. manifest_map[i].join_example_rep.state,
  248. dj_pb.JoinExampleState.Joined
  249. )
  250. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
  251. elif i == partition_id2:
  252. self.assertEqual(
  253. manifest_map[i].join_example_rep.state,
  254. dj_pb.JoinExampleState.Joining
  255. )
  256. self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
  257. else:
  258. self.assertEqual(
  259. manifest_map[i].join_example_rep.state,
  260. dj_pb.JoinExampleState.UnJoined
  261. )
  262. self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
  263. cli.destroy_client_pool()
  264. def test_raw_data_manifest_manager_with_db(self):
  265. cli = db_client.DBClient('etcd', True)
  266. self._raw_data_manifest_manager(cli)
  267. def test_raw_data_manifest_manager_with_nfs(self):
  268. root_dir = "test_fedlearner"
  269. os.environ["STORAGE_ROOT_PATH"] = root_dir
  270. cli = db_client.DBClient('nfs', True)
  271. self._raw_data_manifest_manager(cli)
  272. if gfile.Exists(root_dir):
  273. gfile.DeleteRecursively(root_dir)
  274. if __name__ == '__main__':
  275. unittest.main()