/train.py

https://github.com/zekunhao1995/DualSDF · Python · 152 lines · 108 code · 29 blank · 15 comment · 19 complexity · 8c36185b080f13c39d4f627b752b96bd MD5 · raw file

  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import argparse
  7. import time
  8. import yaml
  9. from tensorboardX import SummaryWriter
  10. from shutil import copy2
  11. import sys
  12. import json
  13. import importlib
  14. def get_args():
  15. # command line args
  16. parser = argparse.ArgumentParser(
  17. description='DualSDF Training')
  18. parser.add_argument('config', type=str,
  19. help='The configuration file.')
  20. # Resume:
  21. parser.add_argument('--resume', default=False, action='store_true')
  22. parser.add_argument('--pretrained', default=None, type=str,
  23. help='pretrained model checkpoint')
  24. # For easy debugging:
  25. parser.add_argument('--test_run', default=False, action='store_true')
  26. parser.add_argument('--special', default=None, type=str,
  27. help='Run special tasks')
  28. args = parser.parse_args()
  29. def dict2namespace(config):
  30. namespace = argparse.Namespace()
  31. for key, value in config.items():
  32. if isinstance(value, dict):
  33. new_value = dict2namespace(value)
  34. else:
  35. new_value = value
  36. setattr(namespace, key, new_value)
  37. return namespace
  38. # parse config file
  39. with open(args.config, 'r') as f:
  40. config = yaml.load(f)
  41. config = dict2namespace(config)
  42. # Create log_name
  43. log_prefix = ''
  44. if args.test_run:
  45. log_prefix = 'tmp_'
  46. if args.special is not None:
  47. log_prefix = log_prefix + 'special_{}_'.format(args.special)
  48. cfg_file_name = os.path.splitext(os.path.basename(args.config))[0]
  49. run_time = time.strftime('%Y-%b-%d-%H-%M-%S')
  50. # Currently save dir and log_dir are the same
  51. config.log_name = "logs/{}{}_{}".format(log_prefix, cfg_file_name, run_time)
  52. config.save_dir = "logs/{}{}_{}/checkpoints".format(log_prefix, cfg_file_name, run_time)
  53. config.log_dir = "logs/{}{}_{}".format(log_prefix, cfg_file_name, run_time)
  54. os.makedirs(os.path.join(config.log_dir, 'config'))
  55. os.makedirs(config.save_dir)
  56. copy2(args.config, os.path.join(config.log_dir, 'config'))
  57. with open(os.path.join(config.log_dir, 'config', 'argv.json'), 'w') as f:
  58. json.dump(sys.argv, f)
  59. return args, config
  60. def main(args, cfg):
  61. torch.backends.cudnn.benchmark = True
  62. writer = SummaryWriter(logdir=cfg.log_name)
  63. device = torch.device('cuda:0')
  64. # Load experimental settings
  65. data_lib = importlib.import_module(cfg.data.type)
  66. loaders = data_lib.get_data_loaders(cfg.data)
  67. train_loader = loaders['train_loader']
  68. train_shape_ids = loaders['train_shape_ids']
  69. cfg.train_shape_ids = train_shape_ids
  70. test_loader = loaders['test_loader']
  71. test_shape_ids = loaders['test_shape_ids']
  72. cfg.test_shape_ids = test_shape_ids
  73. trainer_lib = importlib.import_module(cfg.trainer.type)
  74. trainer = trainer_lib.Trainer(cfg, args, device)
  75. # Prepare for training
  76. start_epoch = 0
  77. trainer.prep_train()
  78. if args.resume:
  79. if args.pretrained is not None:
  80. start_epoch = trainer.resume(args.pretrained)
  81. else:
  82. start_epoch = trainer.resume(cfg.resume.dir)
  83. if args.special is not None:
  84. special_fun = getattr(trainer, args.special)
  85. special_fun(test_loader=test_loader, writer=writer)
  86. exit()
  87. # Main training loop
  88. prev_time = time.time()
  89. print("[Train] Start epoch: %d End epoch: %d" % (start_epoch, cfg.trainer.epochs))
  90. step_cnt = 0
  91. for epoch in range(start_epoch, cfg.trainer.epochs):
  92. trainer.epoch_start(epoch)
  93. # train for one epoch
  94. for bidx, data in enumerate(train_loader):
  95. step_cnt = bidx + len(train_loader) * epoch + 1
  96. logs_info = trainer.step(data)
  97. # Print info
  98. current_time = time.time()
  99. elapsed_time = current_time - prev_time
  100. prev_time = time.time()
  101. print('Epoch: {}; Iter: {}; Time: {:0.5f};'.format(epoch, bidx, elapsed_time))
  102. # Log
  103. if step_cnt % int(cfg.viz.log_interval) == 0:
  104. if writer is not None:
  105. for k, v in logs_info.items():
  106. writer.add_scalar(k, v, step_cnt)
  107. # Save checkpoints
  108. if (epoch + 1) % int(cfg.viz.save_interval) == 0:
  109. trainer.save(epoch=epoch, step=step_cnt)
  110. trainer.epoch_end(epoch, writer=writer)
  111. writer.flush()
  112. # always save last epoch
  113. if (epoch + 1) % int(cfg.viz.save_interval) != 0:
  114. trainer.save(epoch=epoch, step=step_cnt)
  115. writer.close()
  116. if __name__ == "__main__":
  117. # command line args
  118. args, cfg = get_args()
  119. print("Arguments:")
  120. print(args)
  121. print("Configuration:")
  122. print(cfg)
  123. main(args, cfg)