/test/checkpoint_test.py
Python | 442 lines | 308 code | 106 blank | 28 comment | 1 complexity | b09932778c9174433c1a491219ac3a0a MD5 | raw file
- import os
- import sys
- module_path = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(module_path+'/..')
- sys.path.append(module_path+'/../utils')
- import train
- import util
- import copy
- import os
- import pytest
- import torch
- import globals as g # global state current run, shared between modules
- simple_args_str = "--local --data=data --batch_size=1 --verbose_log_steps=0 --n_layer=1 --d_model=10 --d_inner=2 " \
- "--max_tokens=4 --tgt_len=1 --scheduler=constant --log_interval=20"
- simple_args = train.parse_args(simple_args_str.strip().split())
- def test_checkpoint():
- g.args = copy.deepcopy(simple_args)
- train.logging_setup()
- train.data_setup()
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 2
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 4
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_dropout():
- g.args = copy.deepcopy(simple_args)
- g.args.dropout = 0.5
- train.logging_setup()
- train.data_setup()
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 2
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 4
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_lamb():
- g.args = copy.deepcopy(simple_args)
- g.args.dropout = 0.5
- train.logging_setup()
- train.data_setup()
- g.args.max_tokens = 4
- g.args.optim = 'lamb'
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 2
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 4
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_fp16_lamb():
- g.args = copy.deepcopy(simple_args)
- g.args.dropout = 0.5
- train.logging_setup()
- train.data_setup()
- g.args.max_tokens = 40
- g.args.optim = 'lamb'
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 20
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 40
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_wiki_small():
- os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
- os.system('rm data/wikiextracted/cache.pt.bpe')
- g.args = copy.deepcopy(simple_args)
- g.args.test = 'yes'
- g.args.data = 'data/wikiextracted'
- g.args.dataset = 'wiki'
- g.args.dropatt = 0.1
- g.args.dropout = 0.1
- g.args.bpe = True # wiki requires BPE
- g.args.optim = 'lamb'
- train.logging_setup()
- train.data_setup()
- g.args.max_tokens = 4
- losses1 = train.main_loop()
- print(losses1)
- # run halfway and save checkpoint
- g.args.max_tokens = 2
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 4
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_wiki():
- os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
- os.system('rm data/wikiextracted/cache.pt.bpe')
- g.args = copy.deepcopy(simple_args)
- g.args.test = 'yes'
- g.args.batch_size = 2
- g.args.tgt_len = 2
- g.args.data = 'data/wikiextracted'
- g.args.dataset = 'wiki'
- g.args.dropatt = 0.1
- g.args.dropout = 0.1
- g.args.bpe = True # wiki requires BPE
- g.args.optim = 'lamb'
- train.logging_setup()
- train.data_setup()
- # 36 words total, 8 in first file
- g.args.max_tokens = 30
- losses1 = train.main_loop()
- print(losses1)
- # run halfway and save checkpoint
- g.args.max_tokens = 10
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 30
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_wiki_multiepoch():
- os.system('rm data/wikiextracted/AA/wiki_00.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_01.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_02.txt.tokenized')
- os.system('rm data/wikiextracted/AA/wiki_03.txt.tokenized')
- os.system('rm data/wikiextracted/cache.pt.bpe')
- g.args = copy.deepcopy(simple_args)
- g.args.test = 'yes'
- g.args.batch_size = 2
- g.args.tgt_len = 2
- g.args.data = 'data/wikiextracted'
- g.args.dataset = 'wiki'
- g.args.dropatt = 0.1
- g.args.dropout = 0.1
- g.args.bpe = True # wiki requires BPE
- g.args.optim = 'lamb'
- train.logging_setup()
- train.data_setup()
- # 36 words total, 8 in first file
- g.args.max_tokens = 50
- losses1 = train.main_loop()
- print(losses1)
- # run halfway and save checkpoint
- g.args.max_tokens = 40
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 50
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_git():
- os.system('rm data/git/git_1.txt.tokenized')
- os.system('rm data/git/git_2.txt.tokenized')
- os.system('rm data/git/git_3.txt.tokenized')
- os.system('rm data/git/git_4.txt.tokenized')
- g.args = copy.deepcopy(simple_args)
- g.args.test = 'yes'
- g.args.batch_size = 2
- g.args.tgt_len = 2
- g.args.data = 'data/git'
- g.args.dataset = 'git'
- g.args.dropatt = 0.1
- g.args.dropout = 0.1
- g.args.bpe = True # git requires BPE
- g.args.optim = 'lamb'
- train.logging_setup()
- train.data_setup()
- # 49 words total, 9 in first file
- g.args.max_tokens = 30
- losses1 = train.main_loop()
- print(losses1)
- # run halfway and save checkpoint
- g.args.max_tokens = 12
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 30
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- def test_checkpoint_git_multiepoch():
- os.system('rm data/git/git_1.txt.tokenized')
- os.system('rm data/git/git_2.txt.tokenized')
- os.system('rm data/git/git_3.txt.tokenized')
- os.system('rm data/git/git_4.txt.tokenized')
- g.args = copy.deepcopy(simple_args)
- g.args.test = 'yes'
- g.args.batch_size = 2
- g.args.tgt_len = 2
- g.args.data = 'data/git'
- g.args.dataset = 'git'
- g.args.dropatt = 0.1
- g.args.dropout = 0.1
- g.args.bpe = True # git requires BPE
- g.args.optim = 'lamb'
- train.logging_setup()
- train.data_setup()
- # 49 words total, 9 in first file
- g.args.max_tokens = 70
- losses1 = train.main_loop()
- print(losses1)
- # run halfway and save checkpoint
- g.args.max_tokens = 55
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 70
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses3 = train.main_loop()
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- @pytest.mark.skipif(not torch.cuda.is_available(), reason="fp16 tests require GPU")
- def test_checkpoint_fp16():
- g.args = copy.deepcopy(simple_args)
- g.args.fp16 = True
- train.logging_setup()
- train.data_setup()
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 2
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 4
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- # data_setup() # reset iterators
- losses3 = train.main_loop()
- print(losses1)
- print(losses2)
- print(losses3)
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- @pytest.mark.skipif(not torch.cuda.is_available(), reason="fp16 tests require GPU")
- def test_checkpoint_fp16_dropout():
- g.args = copy.deepcopy(simple_args)
- g.args.fp16 = True
- g.args.droput = 0.0
- train.logging_setup()
- train.data_setup()
- g.args.max_tokens = 40
- losses1 = train.main_loop()
- # run halfway and save checkpoint
- g.args.max_tokens = 20
- g.args.save_state_fn = '/tmp/state.pt'
- train.data_setup() # reset iterators
- losses2 = train.main_loop()
- train.save_state(g.state, g.args.save_state_fn)
- # restore from checkpoint and continue to the end
- g.args.max_tokens = 40
- g.args.save_state_fn = None
- g.args.load_state_fn = '/tmp/state.pt'
- # data_setup() # reset iterators
- losses3 = train.main_loop()
- print(losses1)
- print(losses2)
- print(losses3)
- util.assert_close(losses3[0], losses1[len(losses2)])
- util.assert_close(losses3[-1], losses1[-1])
- g.logger.info(f"Discrepancy was {(losses3[-1] - losses1[-1]) / losses1[-1]}")
- if __name__ == '__main__':
- test_checkpoint()