/objectdetection/trial_adapter.py

https://github.com/dataloop-ai/ZazuML
Python | 106 lines | 83 code | 20 blank | 3 comment | 14 complexity | 51237f7c6686a6dd8812938e0474dcee MD5 | raw file
  1. import os
  2. # sys.path.insert(1, os.path.dirname(__file__))
  3. from .model_trainer import ModelTrainer
  4. from predictor.pred_utils import detect, detect_single_image
  5. from copy import deepcopy
  6. import random
  7. import time
  8. import hashlib
  9. import torch
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. def generate_trial_id():
  13. s = str(time.time()) + str(random.randint(1, 1e7))
  14. return hashlib.sha256(s.encode('utf-8')).hexdigest()[:32]
  15. class TrialAdapter(ModelTrainer):
  16. def _unpack_trial_checkpoint(self, trial_checkpoint):
  17. self.hp_values = trial_checkpoint['hp_values'] if 'hp_values' in trial_checkpoint else {}
  18. self.hp_values['hyperparameter_tuner/initial_epoch'] = trial_checkpoint[
  19. 'epoch'] if 'model_fn' and 'epoch' in trial_checkpoint else self.hp_values[
  20. 'hyperparameter_tuner/initial_epoch']
  21. self.annotation_type = trial_checkpoint['annotation_type']
  22. self.model_fn = trial_checkpoint['model_fn']
  23. trial_checkpoint['training_configs'].update(self.hp_values)
  24. self.configs = trial_checkpoint['training_configs']
  25. new_trial_id = self.configs['hyperparameter_tuner/new_trial_id']
  26. past_trial_id = self.configs[
  27. 'hyperparameter_tuner/past_trial_id'] if 'hyperparameter_tuner/past_trial_id' in self.configs else None
  28. self.data_path = trial_checkpoint['home_path']
  29. checkpoint = None
  30. if 'model' in trial_checkpoint:
  31. checkpoint = deepcopy(trial_checkpoint)
  32. for x in ['model_fn', 'training_configs', 'data', 'hp_values', 'epoch']:
  33. checkpoint.pop(x)
  34. # return checkpoint with just
  35. return new_trial_id, past_trial_id, checkpoint
  36. def load(self, checkpoint_path='checkpoint.pt'):
  37. # the only necessary keys for load are ['model_specs']
  38. trial_checkpoint = torch.load(checkpoint_path)
  39. new_trial_id, past_trial_id, checkpoint = self._unpack_trial_checkpoint(trial_checkpoint)
  40. super().load(self.data_path, new_trial_id, past_trial_id, checkpoint)
  41. def train(self):
  42. super().preprocess(augment_policy=self.configs['augment_policy'],
  43. dataset=self.annotation_type,
  44. resize=self.configs['input_size'],
  45. batch=self.configs['batch'])
  46. super().build(model=self.model_fn,
  47. depth=self.configs['depth'],
  48. learning_rate=self.configs['learning_rate'],
  49. ratios=self.configs['anchor_ratios'],
  50. scales=self.configs['anchor_scales'])
  51. super().train(epochs=self.configs['hyperparameter_tuner/epochs'],
  52. init_epoch=self.configs['hyperparameter_tuner/initial_epoch'])
  53. def get_checkpoint_metadata(self):
  54. logger.info('getting best checkpoint')
  55. checkpoint = super().get_best_checkpoint()
  56. logging.info('got best checkpoint')
  57. checkpoint['hp_values'] = self.hp_values
  58. checkpoint['model_fn'] = self.model_fn
  59. checkpoint['training_configs'] = self.configs
  60. checkpoint['data_path'] = self.data_path
  61. checkpoint['annotation_type'] = self.annotation_type
  62. checkpoint['checkpoint_path'] = self.save_best_checkpoint_path
  63. checkpoint.pop('model')
  64. logging.info('checkpoint keys: ' + str(checkpoint.keys()))
  65. return checkpoint
  66. @property
  67. def checkpoint_path(self):
  68. return super().save_best_checkpoint_path
  69. def load_inference(self, checkpoint_path):
  70. if torch.cuda.is_available():
  71. logger.info('cuda available')
  72. self.inference_checkpoint = torch.load(checkpoint_path)
  73. else:
  74. logger.info('run on cpu')
  75. self.inference_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
  76. return self.inference_checkpoint
  77. def predict_single_image(self, image_path, checkpoint_path='checkpoint.pt'):
  78. if hasattr(self, 'inference_checkpoint'):
  79. return detect_single_image(self.inference_checkpoint, image_path)
  80. else:
  81. self.load_inference(checkpoint_path)
  82. return detect_single_image(self.inference_checkpoint, image_path)
  83. def predict_items(self, items, checkpoint_path, with_upload=True, model_name='object_detection'):
  84. for item in items:
  85. dirname = self.predict_item(item, checkpoint_path, with_upload, model_name)
  86. return dirname