/test/checkpoint_test.py

https://github.com/cybertronai/transformer-xl
Python | 442 lines | 308 code | 106 blank | 28 comment | 1 complexity | b09932778c9174433c1a491219ac3a0a MD5 | raw file
  1. import os
  2. import sys
  3. module_path = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(module_path+'/..')
  5. sys.path.append(module_path+'/../utils')
  6. import train
  7. import util
  8. import copy
  9. import os
  10. import pytest
  11. import torch
  12. import globals as g # global state current run, shared between modules
  13. simple_args_str = "--local --data=data --batch_size=1 --verbose_log_steps=0 --n_layer=1 --d_model=10 --d_inner=2 " \
  14. "--max_tokens=4 --tgt_len=1 --scheduler=constant --log_interval=20"
  15. simple_args = train.parse_args(simple_args_str.strip().split())
  16. def test_checkpoint():
  17. g.args = copy.deepcopy(simple_args)
  18. train.logging_setup()
  19. train.data_setup()
  20. losses1 = train.main_loop()
  21. # run halfway and save checkpoint
  22. g.args.max_tokens = 2
  23. g.args.save_state_fn = '/tmp/state.pt'
  24. train.data_setup() # reset iterators
  25. losses2 = train.main_loop()
  26. train.save_state(g.state, g.args.save_state_fn)
  27. # restore from checkpoint and continue to the end
  28. g.args.max_tokens = 4
  29. g.args.save_state_fn = None
  30. g.args.load_state_fn = '/tmp/state.pt'
  31. train.data_setup() # reset iterators
  32. losses3 = train.main_loop()
  33. util.assert_close(losses3[0], losses1[len(losses2)])
  34. util.assert_close(losses3[-1], losses1[-1])
  35. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  36. def test_checkpoint_dropout():
  37. g.args = copy.deepcopy(simple_args)
  38. g.args.dropout = 0.5
  39. train.logging_setup()
  40. train.data_setup()
  41. losses1 = train.main_loop()
  42. # run halfway and save checkpoint
  43. g.args.max_tokens = 2
  44. g.args.save_state_fn = '/tmp/state.pt'
  45. train.data_setup() # reset iterators
  46. losses2 = train.main_loop()
  47. train.save_state(g.state, g.args.save_state_fn)
  48. # restore from checkpoint and continue to the end
  49. g.args.max_tokens = 4
  50. g.args.save_state_fn = None
  51. g.args.load_state_fn = '/tmp/state.pt'
  52. train.data_setup() # reset iterators
  53. losses3 = train.main_loop()
  54. util.assert_close(losses3[0], losses1[len(losses2)])
  55. util.assert_close(losses3[-1], losses1[-1])
  56. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  57. def test_checkpoint_lamb():
  58. g.args = copy.deepcopy(simple_args)
  59. g.args.dropout = 0.5
  60. train.logging_setup()
  61. train.data_setup()
  62. g.args.max_tokens = 4
  63. g.args.optim = 'lamb'
  64. losses1 = train.main_loop()
  65. # run halfway and save checkpoint
  66. g.args.max_tokens = 2
  67. g.args.save_state_fn = '/tmp/state.pt'
  68. train.data_setup() # reset iterators
  69. losses2 = train.main_loop()
  70. train.save_state(g.state, g.args.save_state_fn)
  71. # restore from checkpoint and continue to the end
  72. g.args.max_tokens = 4
  73. g.args.save_state_fn = None
  74. g.args.load_state_fn = '/tmp/state.pt'
  75. train.data_setup() # reset iterators
  76. losses3 = train.main_loop()
  77. util.assert_close(losses3[0], losses1[len(losses2)])
  78. util.assert_close(losses3[-1], losses1[-1])
  79. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  80. def test_checkpoint_fp16_lamb():
  81. g.args = copy.deepcopy(simple_args)
  82. g.args.dropout = 0.5
  83. train.logging_setup()
  84. train.data_setup()
  85. g.args.max_tokens = 40
  86. g.args.optim = 'lamb'
  87. losses1 = train.main_loop()
  88. # run halfway and save checkpoint
  89. g.args.max_tokens = 20
  90. g.args.save_state_fn = '/tmp/state.pt'
  91. train.data_setup() # reset iterators
  92. losses2 = train.main_loop()
  93. train.save_state(g.state, g.args.save_state_fn)
  94. # restore from checkpoint and continue to the end
  95. g.args.max_tokens = 40
  96. g.args.save_state_fn = None
  97. g.args.load_state_fn = '/tmp/state.pt'
  98. train.data_setup() # reset iterators
  99. losses3 = train.main_loop()
  100. util.assert_close(losses3[0], losses1[len(losses2)])
  101. util.assert_close(losses3[-1], losses1[-1])
  102. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  103. def test_checkpoint_wiki_small():
  104. os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
  105. os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
  106. os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
  107. os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
  108. os.system('rm data/wikiextracted/cache.pt.bpe')
  109. g.args = copy.deepcopy(simple_args)
  110. g.args.test = 'yes'
  111. g.args.data = 'data/wikiextracted'
  112. g.args.dataset = 'wiki'
  113. g.args.dropatt = 0.1
  114. g.args.dropout = 0.1
  115. g.args.bpe = True # wiki requires BPE
  116. g.args.optim = 'lamb'
  117. train.logging_setup()
  118. train.data_setup()
  119. g.args.max_tokens = 4
  120. losses1 = train.main_loop()
  121. print(losses1)
  122. # run halfway and save checkpoint
  123. g.args.max_tokens = 2
  124. g.args.save_state_fn = '/tmp/state.pt'
  125. train.data_setup() # reset iterators
  126. losses2 = train.main_loop()
  127. train.save_state(g.state, g.args.save_state_fn)
  128. # restore from checkpoint and continue to the end
  129. g.args.max_tokens = 4
  130. g.args.save_state_fn = None
  131. g.args.load_state_fn = '/tmp/state.pt'
  132. train.data_setup() # reset iterators
  133. losses3 = train.main_loop()
  134. util.assert_close(losses3[0], losses1[len(losses2)])
  135. util.assert_close(losses3[-1], losses1[-1])
  136. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  137. def test_checkpoint_wiki():
  138. os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
  139. os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
  140. os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
  141. os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
  142. os.system('rm data/wikiextracted/cache.pt.bpe')
  143. g.args = copy.deepcopy(simple_args)
  144. g.args.test = 'yes'
  145. g.args.batch_size = 2
  146. g.args.tgt_len = 2
  147. g.args.data = 'data/wikiextracted'
  148. g.args.dataset = 'wiki'
  149. g.args.dropatt = 0.1
  150. g.args.dropout = 0.1
  151. g.args.bpe = True # wiki requires BPE
  152. g.args.optim = 'lamb'
  153. train.logging_setup()
  154. train.data_setup()
  155. # 36 words total, 8 in first file
  156. g.args.max_tokens = 30
  157. losses1 = train.main_loop()
  158. print(losses1)
  159. # run halfway and save checkpoint
  160. g.args.max_tokens = 10
  161. g.args.save_state_fn = '/tmp/state.pt'
  162. train.data_setup() # reset iterators
  163. losses2 = train.main_loop()
  164. train.save_state(g.state, g.args.save_state_fn)
  165. # restore from checkpoint and continue to the end
  166. g.args.max_tokens = 30
  167. g.args.save_state_fn = None
  168. g.args.load_state_fn = '/tmp/state.pt'
  169. train.data_setup() # reset iterators
  170. losses3 = train.main_loop()
  171. util.assert_close(losses3[0], losses1[len(losses2)])
  172. util.assert_close(losses3[-1], losses1[-1])
  173. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  174. def test_checkpoint_wiki_multiepoch():
  175. os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
  176. os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
  177. os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
  178. os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
  179. os.system('rm data/wikiextracted/cache.pt.bpe')
  180. g.args = copy.deepcopy(simple_args)
  181. g.args.test = 'yes'
  182. g.args.batch_size = 2
  183. g.args.tgt_len = 2
  184. g.args.data = 'data/wikiextracted'
  185. g.args.dataset = 'wiki'
  186. g.args.dropatt = 0.1
  187. g.args.dropout = 0.1
  188. g.args.bpe = True # wiki requires BPE
  189. g.args.optim = 'lamb'
  190. train.logging_setup()
  191. train.data_setup()
  192. # 36 words total, 8 in first file
  193. g.args.max_tokens = 50
  194. losses1 = train.main_loop()
  195. print(losses1)
  196. # run halfway and save checkpoint
  197. g.args.max_tokens = 40
  198. g.args.save_state_fn = '/tmp/state.pt'
  199. train.data_setup() # reset iterators
  200. losses2 = train.main_loop()
  201. train.save_state(g.state, g.args.save_state_fn)
  202. # restore from checkpoint and continue to the end
  203. g.args.max_tokens = 50
  204. g.args.save_state_fn = None
  205. g.args.load_state_fn = '/tmp/state.pt'
  206. train.data_setup() # reset iterators
  207. losses3 = train.main_loop()
  208. util.assert_close(losses3[0], losses1[len(losses2)])
  209. util.assert_close(losses3[-1], losses1[-1])
  210. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  211. def test_checkpoint_git():
  212. os.system('rm data/git/git_1.txt.tokenized')
  213. os.system('rm data/git/git_2.txt.tokenized')
  214. os.system('rm data/git/git_3.txt.tokenized')
  215. os.system('rm data/git/git_4.txt.tokenized')
  216. g.args = copy.deepcopy(simple_args)
  217. g.args.test = 'yes'
  218. g.args.batch_size = 2
  219. g.args.tgt_len = 2
  220. g.args.data = 'data/git'
  221. g.args.dataset = 'git'
  222. g.args.dropatt = 0.1
  223. g.args.dropout = 0.1
  224. g.args.bpe = True # git requires BPE
  225. g.args.optim = 'lamb'
  226. train.logging_setup()
  227. train.data_setup()
  228. # 49 words total, 9 in first file
  229. g.args.max_tokens = 30
  230. losses1 = train.main_loop()
  231. print(losses1)
  232. # run halfway and save checkpoint
  233. g.args.max_tokens = 12
  234. g.args.save_state_fn = '/tmp/state.pt'
  235. train.data_setup() # reset iterators
  236. losses2 = train.main_loop()
  237. train.save_state(g.state, g.args.save_state_fn)
  238. # restore from checkpoint and continue to the end
  239. g.args.max_tokens = 30
  240. g.args.save_state_fn = None
  241. g.args.load_state_fn = '/tmp/state.pt'
  242. train.data_setup() # reset iterators
  243. losses3 = train.main_loop()
  244. util.assert_close(losses3[0], losses1[len(losses2)])
  245. util.assert_close(losses3[-1], losses1[-1])
  246. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  247. def test_checkpoint_git_multiepoch():
  248. os.system('rm data/git/git_1.txt.tokenized')
  249. os.system('rm data/git/git_2.txt.tokenized')
  250. os.system('rm data/git/git_3.txt.tokenized')
  251. os.system('rm data/git/git_4.txt.tokenized')
  252. g.args = copy.deepcopy(simple_args)
  253. g.args.test = 'yes'
  254. g.args.batch_size = 2
  255. g.args.tgt_len = 2
  256. g.args.data = 'data/git'
  257. g.args.dataset = 'git'
  258. g.args.dropatt = 0.1
  259. g.args.dropout = 0.1
  260. g.args.bpe = True # git requires BPE
  261. g.args.optim = 'lamb'
  262. train.logging_setup()
  263. train.data_setup()
  264. # 49 words total, 9 in first file
  265. g.args.max_tokens = 70
  266. losses1 = train.main_loop()
  267. print(losses1)
  268. # run halfway and save checkpoint
  269. g.args.max_tokens = 55
  270. g.args.save_state_fn = '/tmp/state.pt'
  271. train.data_setup() # reset iterators
  272. losses2 = train.main_loop()
  273. train.save_state(g.state, g.args.save_state_fn)
  274. # restore from checkpoint and continue to the end
  275. g.args.max_tokens = 70
  276. g.args.save_state_fn = None
  277. g.args.load_state_fn = '/tmp/state.pt'
  278. train.data_setup() # reset iterators
  279. losses3 = train.main_loop()
  280. util.assert_close(losses3[0], losses1[len(losses2)])
  281. util.assert_close(losses3[-1], losses1[-1])
  282. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  283. @pytest.mark.skipif(not torch.cuda.is_available(), reason="fp16 tests require GPU")
  284. def test_checkpoint_fp16():
  285. g.args = copy.deepcopy(simple_args)
  286. g.args.fp16 = True
  287. train.logging_setup()
  288. train.data_setup()
  289. losses1 = train.main_loop()
  290. # run halfway and save checkpoint
  291. g.args.max_tokens = 2
  292. g.args.save_state_fn = '/tmp/state.pt'
  293. train.data_setup() # reset iterators
  294. losses2 = train.main_loop()
  295. train.save_state(g.state, g.args.save_state_fn)
  296. # restore from checkpoint and continue to the end
  297. g.args.max_tokens = 4
  298. g.args.save_state_fn = None
  299. g.args.load_state_fn = '/tmp/state.pt'
  300. # data_setup() # reset iterators
  301. losses3 = train.main_loop()
  302. print(losses1)
  303. print(losses2)
  304. print(losses3)
  305. util.assert_close(losses3[0], losses1[len(losses2)])
  306. util.assert_close(losses3[-1], losses1[-1])
  307. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  308. @pytest.mark.skipif(not torch.cuda.is_available(), reason="fp16 tests require GPU")
  309. def test_checkpoint_fp16_dropout():
  310. g.args = copy.deepcopy(simple_args)
  311. g.args.fp16 = True
  312. g.args.droput = 0.0
  313. train.logging_setup()
  314. train.data_setup()
  315. g.args.max_tokens = 40
  316. losses1 = train.main_loop()
  317. # run halfway and save checkpoint
  318. g.args.max_tokens = 20
  319. g.args.save_state_fn = '/tmp/state.pt'
  320. train.data_setup() # reset iterators
  321. losses2 = train.main_loop()
  322. train.save_state(g.state, g.args.save_state_fn)
  323. # restore from checkpoint and continue to the end
  324. g.args.max_tokens = 40
  325. g.args.save_state_fn = None
  326. g.args.load_state_fn = '/tmp/state.pt'
  327. # data_setup() # reset iterators
  328. losses3 = train.main_loop()
  329. print(losses1)
  330. print(losses2)
  331. print(losses3)
  332. util.assert_close(losses3[0], losses1[len(losses2)])
  333. util.assert_close(losses3[-1], losses1[-1])
  334. g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
  335. if __name__ == '__main__':
  336. test_checkpoint()