/training_utils/train_utils.py

https://github.com/TropComplique/image-classification-caltech-256
Python | 192 lines | 135 code | 51 blank | 6 comment | 16 complexity | 882793f9e54c7587d4c3a64946b6b7db MD5 | raw file
  1. import numpy as np
  2. from torch.autograd import Variable
  3. import torch.nn.functional as F
  4. import time
  5. import copy
  6. from tqdm import tqdm
  7. from torch.optim.lr_scheduler import ReduceLROnPlateau
  8. # compute accuracy with pytorch
  9. def accuracy(true, pred, top_k=(1,)):
  10. max_k = max(top_k)
  11. batch_size = true.size(0)
  12. _, pred = pred.topk(max_k, 1)
  13. pred = pred.t()
  14. correct = pred.eq(true.view(1, -1).expand_as(pred))
  15. result = []
  16. for k in top_k:
  17. correct_k = correct[:k].view(-1).float().sum(0)
  18. result.append(correct_k.div_(batch_size).data[0])
  19. return result
  20. def optimization_step(model, criterion, optimizer, x_batch, y_batch):
  21. x_batch, y_batch = Variable(x_batch.cuda()), Variable(y_batch.cuda(async=True))
  22. logits = model(x_batch)
  23. # compute logloss
  24. loss = criterion(logits, y_batch)
  25. batch_loss = loss.data[0]
  26. # compute accuracies
  27. pred = F.softmax(logits)
  28. batch_accuracy, batch_top5_accuracy = accuracy(y_batch, pred, top_k=(1, 5))
  29. optimizer.zero_grad()
  30. loss.backward()
  31. optimizer.step()
  32. return batch_loss, batch_accuracy, batch_top5_accuracy
  33. def evaluate(model, criterion, val_iterator, n_batches):
  34. loss = 0.0
  35. acc = 0.0 # accuracy
  36. top5_accuracy = 0.0
  37. total_samples = 0
  38. for j, (x_batch, y_batch) in enumerate(val_iterator):
  39. x_batch = Variable(x_batch.cuda(), volatile=True)
  40. y_batch = Variable(y_batch.cuda(async=True), volatile=True)
  41. n_batch_samples = y_batch.size()[0]
  42. logits = model(x_batch)
  43. # compute logloss
  44. batch_loss = criterion(logits, y_batch).data[0]
  45. # compute accuracies
  46. pred = F.softmax(logits)
  47. batch_accuracy, batch_top5_accuracy = accuracy(y_batch, pred, top_k=(1, 5))
  48. loss += batch_loss*n_batch_samples
  49. acc += batch_accuracy*n_batch_samples
  50. top5_accuracy += batch_top5_accuracy*n_batch_samples
  51. total_samples += n_batch_samples
  52. if j >= n_batches:
  53. break
  54. return loss/total_samples, acc/total_samples, top5_accuracy/total_samples
  55. def train(model, criterion, optimizer,
  56. train_iterator, n_epochs, n_batches,
  57. val_iterator, validation_step, n_validation_batches,
  58. saving_step, lr_scheduler=None):
  59. all_losses = []
  60. all_models = []
  61. is_reduce_on_plateau = isinstance(lr_scheduler, ReduceLROnPlateau)
  62. running_loss = 0.0
  63. running_accuracy = 0.0
  64. running_top5_accuracy = 0.0
  65. start = time.time()
  66. model.train()
  67. for epoch in range(0, n_epochs):
  68. for step, (x_batch, y_batch) in enumerate(train_iterator, 1 + epoch*n_batches):
  69. if lr_scheduler is not None and not is_reduce_on_plateau:
  70. optimizer = lr_scheduler(optimizer, step)
  71. batch_loss, batch_accuracy, batch_top5_accuracy = optimization_step(
  72. model, criterion, optimizer, x_batch, y_batch
  73. )
  74. running_loss += batch_loss
  75. running_accuracy += batch_accuracy
  76. running_top5_accuracy += batch_top5_accuracy
  77. if step % validation_step == 0:
  78. model.eval()
  79. test_loss, test_accuracy, test_top5_accuracy = evaluate(
  80. model, criterion, val_iterator, n_validation_batches
  81. )
  82. end = time.time()
  83. print('{0:.2f} {1:.3f} {2:.3f} {3:.3f} {4:.3f} {5:.3f} {6:.3f} {7:.3f}'.format(
  84. step/n_batches, running_loss/validation_step, test_loss,
  85. running_accuracy/validation_step, test_accuracy,
  86. running_top5_accuracy/validation_step, test_top5_accuracy,
  87. end - start
  88. ))
  89. all_losses += [(
  90. step/n_batches,
  91. running_loss/validation_step, test_loss,
  92. running_accuracy/validation_step, test_accuracy,
  93. running_top5_accuracy/validation_step, test_top5_accuracy
  94. )]
  95. if is_reduce_on_plateau:
  96. lr_scheduler.step(test_accuracy)
  97. running_loss = 0.0
  98. running_accuracy = 0.0
  99. running_top5_accuracy = 0.0
  100. start = time.time()
  101. model.train()
  102. if saving_step is not None and step % saving_step == 0:
  103. print('saving')
  104. model.cpu()
  105. clone = copy.deepcopy(model)
  106. all_models += [clone.state_dict()]
  107. model.cuda()
  108. return all_losses, all_models
  109. def predict(model, val_iterator_no_shuffle, return_erroneous=False):
  110. val_predictions = []
  111. val_true_targets = []
  112. if return_erroneous:
  113. erroneous_samples = []
  114. erroneous_targets = []
  115. erroneous_predictions = []
  116. model.eval()
  117. for x_batch, y_batch in tqdm(val_iterator_no_shuffle):
  118. x_batch = Variable(x_batch.cuda(), volatile=True)
  119. y_batch = Variable(y_batch.cuda(), volatile=True)
  120. logits = model(x_batch)
  121. # compute probabilities
  122. probs = F.softmax(logits)
  123. if return_erroneous:
  124. _, argmax = probs.max(1)
  125. hits = argmax.eq(y_batch).data
  126. miss = 1 - hits
  127. if miss.nonzero().numel() != 0:
  128. erroneous_samples += [x_batch[miss.nonzero()[:, 0]].cpu().data.numpy()]
  129. erroneous_targets += [y_batch[miss.nonzero()[:, 0]].cpu().data.numpy()]
  130. erroneous_predictions += [probs[miss.nonzero()[:, 0]].cpu().data.numpy()]
  131. val_predictions += [probs.cpu().data.numpy()]
  132. val_true_targets += [y_batch.cpu().data.numpy()]
  133. val_predictions = np.concatenate(val_predictions, axis=0)
  134. val_true_targets = np.concatenate(val_true_targets, axis=0)
  135. if return_erroneous:
  136. erroneous_samples = np.concatenate(erroneous_samples, axis=0)
  137. erroneous_targets = np.concatenate(erroneous_targets, axis=0)
  138. erroneous_predictions = np.concatenate(erroneous_predictions, axis=0)
  139. return val_predictions, val_true_targets,\
  140. erroneous_samples, erroneous_targets, erroneous_predictions
  141. return val_predictions, val_true_targets