PageRenderTime 49ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/trainit.py

https://gitlab.com/smartllvlab/cluster-fsl
Python | 194 lines | 137 code | 31 blank | 26 comment | 13 complexity | a2760a73629c8dc439c81e379a501b6c MD5 | raw file
  1. import torch
  2. import argparse
  3. import pandas as pd
  4. import os
  5. import json
  6. import pprint
  7. from src import utils as ut
  8. from src import datasets, models
  9. from src.models import backbones
  10. from torch.utils.data import DataLoader
  11. from haven import haven_utils as hu
  12. from haven import haven_chk as hc
  13. def trainval(exp_dict, savedir_base, datadir, reset=False,
  14. num_workers=0, title=None,
  15. ckpt=None):
  16. # bookkeeping
  17. # ---------------
  18. # get experiment directory
  19. # exp_id = hu.hash_dict(exp_dict) + '-' + exp_dict['name'].rpartition('.')[0]
  20. exp_id = title
  21. savedir = os.path.join(savedir_base, exp_id)
  22. os.makedirs(savedir, exist_ok=True)
  23. ut.setup_logger(os.path.join(savedir, 'train_log.txt'))
  24. if reset:
  25. # delete and backup experiment
  26. hc.delete_experiment(savedir, backup_flag=True)
  27. # create folder and save the experiment dictionary
  28. os.makedirs(savedir, exist_ok=True)
  29. hu.save_json(os.path.join(savedir, 'exp_dict.json'), exp_dict)
  30. pprint.pprint(exp_dict)
  31. print('Experiment saved in %s' % savedir)
  32. # load datasets
  33. # ==========================
  34. train_set = datasets.get_dataset(dataset_name=exp_dict["dataset_train"],
  35. data_root=os.path.join(datadir, exp_dict["dataset_train_root"]),
  36. split="train",
  37. transform=exp_dict["transform_train"],
  38. classes=exp_dict["classes_train"],
  39. support_size=exp_dict["support_size_train"],
  40. query_size=exp_dict["query_size_train"],
  41. n_iters=exp_dict["train_iters"],
  42. unlabeled_size=exp_dict["unlabeled_size_train"])
  43. val_set = datasets.get_dataset(dataset_name=exp_dict["dataset_val"],
  44. data_root=os.path.join(datadir, exp_dict["dataset_val_root"]),
  45. split="val",
  46. transform=exp_dict["transform_val"],
  47. classes=exp_dict["classes_val"],
  48. support_size=exp_dict["support_size_val"],
  49. query_size=exp_dict["query_size_val"],
  50. n_iters=exp_dict.get("val_iters", None),
  51. unlabeled_size=exp_dict["unlabeled_size_val"])
  52. test_set = datasets.get_dataset(dataset_name=exp_dict["dataset_test"],
  53. data_root=os.path.join(datadir, exp_dict["dataset_test_root"]),
  54. split="test",
  55. transform=exp_dict["transform_val"],
  56. classes=exp_dict["classes_test"],
  57. support_size=exp_dict["support_size_test"],
  58. query_size=exp_dict["query_size_test"],
  59. n_iters=exp_dict["test_iters"],
  60. unlabeled_size=exp_dict["unlabeled_size_test"])
  61. # get dataloaders
  62. # ==========================
  63. train_loader = torch.utils.data.DataLoader(
  64. train_set,
  65. batch_size=exp_dict["batch_size"],
  66. shuffle=True,
  67. num_workers=num_workers,
  68. collate_fn=ut.get_collate(exp_dict["collate_fn"]),
  69. drop_last=True)
  70. val_loader = torch.utils.data.DataLoader(
  71. val_set,
  72. batch_size=1,
  73. shuffle=False,
  74. num_workers=num_workers,
  75. collate_fn=lambda x: x,
  76. drop_last=True)
  77. test_loader = torch.utils.data.DataLoader(
  78. test_set,
  79. batch_size=1,
  80. shuffle=False,
  81. num_workers=num_workers,
  82. collate_fn=lambda x: x,
  83. drop_last=True)
  84. # create model and trainer
  85. # ==========================
  86. # Create model, opt, wrapper
  87. backbone = backbones.get_backbone(backbone_name=exp_dict['model']["backbone"], exp_dict=exp_dict)
  88. model = models.get_model(model_name=exp_dict["model"]['name'], backbone=backbone,
  89. n_classes=exp_dict["n_classes"],
  90. exp_dict=exp_dict,
  91. pretrained_weights_dir=None,
  92. savedir_base=savedir_base)
  93. if ckpt is not None:
  94. print('=> Model from `{}` loaded'.format(ckpt))
  95. a, b = model.model.load_state_dict(torch.load(ckpt, map_location='cpu')['model'], strict=False)
  96. if a:
  97. print('Missing keys:', a)
  98. if b:
  99. print('Unexpected keys:', b)
  100. # Checkpoint
  101. # -----------
  102. checkpoint_path = os.path.join(savedir, 'checkpoint.pth')
  103. score_list_path = os.path.join(savedir, 'score_list.pkl')
  104. if os.path.exists(score_list_path):
  105. # resume experiment
  106. model.load_state_dict(hu.torch_load(checkpoint_path))
  107. score_list = hu.load_pkl(score_list_path)
  108. s_epoch = score_list[-1]['epoch'] + 1
  109. else:
  110. # restart experiment
  111. score_list = []
  112. s_epoch = 0
  113. # Run training and validation
  114. for epoch in range(s_epoch, exp_dict["max_epoch"]):
  115. score_dict = {"epoch": epoch}
  116. score_dict.update(model.get_lr())
  117. # train
  118. score_dict.update(model.train_on_loader(train_loader))
  119. # validate
  120. score_dict.update(model.val_on_loader(val_loader))
  121. # score_dict.update(model.test_on_loader(test_loader))
  122. # Add score_dict to score_list
  123. score_list += [score_dict]
  124. # Report
  125. score_df = pd.DataFrame(score_list)
  126. print(score_df.tail())
  127. # Save checkpoint
  128. hu.save_pkl(score_list_path, score_list)
  129. hu.torch_save(checkpoint_path, model.get_state_dict())
  130. print("Saved: %s" % savedir)
  131. if "accuracy" in exp_dict["target_loss"]:
  132. is_best = score_dict[exp_dict["target_loss"]] >= score_df[exp_dict["target_loss"]][:-1].max()
  133. else:
  134. is_best = score_dict[exp_dict["target_loss"]] <= score_df[exp_dict["target_loss"]][:-1].min()
  135. # Save best checkpoint
  136. if is_best:
  137. hu.save_pkl(os.path.join(savedir, "score_list_best.pkl"), score_list)
  138. hu.torch_save(os.path.join(savedir, "checkpoint_best.pth"), model.get_state_dict())
  139. print("Saved Best: %s" % savedir)
  140. # Check for end of training conditions
  141. if model.is_end_of_training():
  142. return
  143. import os
  144. if __name__ == '__main__':
  145. parser = argparse.ArgumentParser()
  146. parser.add_argument('cfg', type=str, help='json config path')
  147. parser.add_argument('--ckpt', type=str, default=None,
  148. help='model checkpoint you wanna resume from')
  149. parser.add_argument('-sb', '--savedir_base', required=True)
  150. parser.add_argument('-d', '--datadir', default='data/')
  151. parser.add_argument('-nw', '--num_workers', default=2, type=int)
  152. parser.add_argument('-t', '--title', default=None, type=str)
  153. args = parser.parse_args()
  154. os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
  155. with open(args.cfg) as f:
  156. cfg = json.load(f)
  157. trainval(exp_dict=cfg,
  158. savedir_base=args.savedir_base,
  159. reset=False,
  160. datadir=args.datadir,
  161. num_workers=args.num_workers,
  162. title = args.title,
  163. ckpt=args.ckpt)