/basicsr/models/video_gan_model.py

https://github.com/xinntao/EDVR
Python | 142 lines | 106 code | 22 blank | 14 comment | 18 complexity | 7fc9e8f0066864cbc866b20db2bc3c7f MD5 | raw file
  1. import importlib
  2. import torch
  3. from collections import OrderedDict
  4. from copy import deepcopy
  5. from basicsr.models.archs import define_network
  6. from basicsr.models.video_base_model import VideoBaseModel
  7. loss_module = importlib.import_module('basicsr.models.losses')
  8. class VideoGANModel(VideoBaseModel):
  9. """Video GAN model."""
  10. def init_training_settings(self):
  11. train_opt = self.opt['train']
  12. # define network net_d
  13. self.net_d = define_network(deepcopy(self.opt['network_d']))
  14. self.net_d = self.model_to_device(self.net_d)
  15. self.print_network(self.net_d)
  16. # load pretrained models
  17. load_path = self.opt['path'].get('pretrain_model_d', None)
  18. if load_path is not None:
  19. self.load_network(self.net_d, load_path,
  20. self.opt['path']['strict_load'])
  21. self.net_g.train()
  22. self.net_d.train()
  23. # define losses
  24. if train_opt.get('pixel_opt'):
  25. pixel_type = train_opt['pixel_opt'].pop('type')
  26. cri_pix_cls = getattr(loss_module, pixel_type)
  27. self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
  28. self.device)
  29. else:
  30. self.cri_pix = None
  31. if train_opt.get('perceptual_opt'):
  32. percep_type = train_opt['perceptual_opt'].pop('type')
  33. cri_perceptual_cls = getattr(loss_module, percep_type)
  34. self.cri_perceptual = cri_perceptual_cls(
  35. **train_opt['perceptual_opt']).to(self.device)
  36. else:
  37. self.cri_perceptual = None
  38. if train_opt.get('gan_opt'):
  39. gan_type = train_opt['gan_opt'].pop('type')
  40. cri_gan_cls = getattr(loss_module, gan_type)
  41. self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device)
  42. self.net_d_iters = train_opt.get('net_d_iters', 1)
  43. self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
  44. # set up optimizers and schedulers
  45. self.setup_optimizers()
  46. self.setup_schedulers()
  47. def setup_optimizers(self):
  48. train_opt = self.opt['train']
  49. # optimizer g
  50. optim_type = train_opt['optim_g'].pop('type')
  51. if optim_type == 'Adam':
  52. self.optimizer_g = torch.optim.Adam(self.net_g.parameters(),
  53. **train_opt['optim_g'])
  54. else:
  55. raise NotImplementedError(
  56. f'optimizer {optim_type} is not supperted yet.')
  57. self.optimizers.append(self.optimizer_g)
  58. # optimizer d
  59. optim_type = train_opt['optim_d'].pop('type')
  60. if optim_type == 'Adam':
  61. self.optimizer_d = torch.optim.Adam(self.net_d.parameters(),
  62. **train_opt['optim_d'])
  63. else:
  64. raise NotImplementedError(
  65. f'optimizer {optim_type} is not supperted yet.')
  66. self.optimizers.append(self.optimizer_d)
  67. def optimize_parameters(self, current_iter):
  68. # optimize net_g
  69. for p in self.net_d.parameters():
  70. p.requires_grad = False
  71. self.optimizer_g.zero_grad()
  72. self.output = self.net_g(self.lq)
  73. l_g_total = 0
  74. loss_dict = OrderedDict()
  75. if (current_iter % self.net_d_iters == 0
  76. and current_iter > self.net_d_init_iters):
  77. # pixel loss
  78. if self.cri_pix:
  79. l_g_pix = self.cri_pix(self.output, self.gt)
  80. l_g_total += l_g_pix
  81. loss_dict['l_g_pix'] = l_g_pix
  82. # perceptual loss
  83. if self.cri_perceptual:
  84. l_g_percep, l_g_style = self.cri_perceptual(
  85. self.output, self.gt)
  86. if l_g_percep is not None:
  87. l_g_total += l_g_percep
  88. loss_dict['l_g_percep'] = l_g_percep
  89. if l_g_style is not None:
  90. l_g_total += l_g_style
  91. loss_dict['l_g_style'] = l_g_style
  92. # gan loss
  93. fake_g_pred = self.net_d(self.output)
  94. l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
  95. l_g_total += l_g_gan
  96. loss_dict['l_g_gan'] = l_g_gan
  97. l_g_total.backward()
  98. self.optimizer_g.step()
  99. # optimize net_d
  100. for p in self.net_d.parameters():
  101. p.requires_grad = True
  102. self.optimizer_d.zero_grad()
  103. # real
  104. real_d_pred = self.net_d(self.gt)
  105. l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
  106. loss_dict['l_d_real'] = l_d_real
  107. loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
  108. l_d_real.backward()
  109. # fake
  110. fake_d_pred = self.net_d(self.output.detach())
  111. l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
  112. loss_dict['l_d_fake'] = l_d_fake
  113. loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
  114. l_d_fake.backward()
  115. self.optimizer_d.step()
  116. self.log_dict = self.reduce_loss_dict(loss_dict)
  117. def save(self, epoch, current_iter):
  118. self.save_network(self.net_g, 'net_g', current_iter)
  119. self.save_network(self.net_d, 'net_d', current_iter)
  120. self.save_training_state(epoch, current_iter)