/supervisely_lib/nn/hosted/inference_single_image.py

https://github.com/supervisely/supervisely
Python | 106 lines | 79 code | 22 blank | 5 comment | 7 complexity | 3eff760a1b853a37cbe69f7ba5c85b04 MD5 | raw file
  1. # coding: utf-8
  2. from copy import deepcopy
  3. import json
  4. from supervisely_lib import logger
  5. from supervisely_lib.annotation.obj_class_collection import ObjClassCollection
  6. from supervisely_lib.annotation.tag_meta_collection import TagMetaCollection
  7. from supervisely_lib.imaging.image import read as sly_image_read
  8. from supervisely_lib.io.json import load_json_file
  9. from supervisely_lib.nn.hosted.constants import MODEL, SETTINGS, INPUT_SIZE, HEIGHT, WIDTH
  10. from supervisely_lib.nn.hosted.legacy.inference_config import maybe_convert_from_v1_inference_task_config, \
  11. maybe_convert_from_deploy_task_config
  12. from supervisely_lib.nn.config import update_recursively
  13. from supervisely_lib.project.project_meta import ProjectMeta
  14. from supervisely_lib.task.paths import TaskPaths
  15. from supervisely_lib.task.progress import Progress
  16. from supervisely_lib.worker_api.interfaces import SingleImageInferenceInterface
  17. GPU_DEVICE = 'gpu_device'
  18. class SingleImageInferenceBase(SingleImageInferenceInterface):
  19. def __init__(self, task_model_config=None, _load_model_weights=True):
  20. logger.info('Starting base single image inference applier init.')
  21. task_model_config = self._load_task_model_config() if task_model_config is None else deepcopy(task_model_config)
  22. self._config = update_recursively(self.get_default_config(), task_model_config)
  23. # Only validate after merging task config with the defaults.
  24. self._validate_model_config(self._config)
  25. self._load_train_config()
  26. if _load_model_weights:
  27. self._construct_and_fill_model()
  28. logger.info('Base single image inference applier init done.')
  29. def _construct_and_fill_model(self):
  30. progress_dummy = Progress('Building model:', 1)
  31. progress_dummy.iter_done_report()
  32. def _validate_model_config(self, config):
  33. pass
  34. def inference(self, image, ann):
  35. raise NotImplementedError()
  36. def inference_image_file(self, image_file, ann):
  37. image = sly_image_read(image_file)
  38. return self.inference(image, ann)
  39. @staticmethod
  40. def get_default_config():
  41. return {}
  42. @property
  43. def class_title_to_idx_key(self):
  44. return 'class_title_to_idx'
  45. @property
  46. def train_classes_key(self):
  47. return 'classes'
  48. @property
  49. def model_out_meta(self):
  50. return self._model_out_meta
  51. def get_out_meta(self):
  52. return self._model_out_meta
  53. def _model_out_tags(self):
  54. return TagMetaCollection() # Empty by default
  55. def _load_raw_model_config_json(self):
  56. try:
  57. with open(TaskPaths.MODEL_CONFIG_PATH) as fin:
  58. self.train_config = json.load(fin)
  59. except FileNotFoundError:
  60. raise RuntimeError('Unable to run inference, config from training was not found.')
  61. @staticmethod
  62. def _load_task_model_config():
  63. raw_task_config = load_json_file(TaskPaths.TASK_CONFIG_PATH)
  64. raw_task_config = maybe_convert_from_deploy_task_config(raw_task_config)
  65. task_config = maybe_convert_from_v1_inference_task_config(raw_task_config)
  66. return task_config[MODEL]
  67. def _load_train_config(self):
  68. self._load_raw_model_config_json()
  69. self.class_title_to_idx = self.train_config[self.class_title_to_idx_key]
  70. logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx})
  71. train_classes = ObjClassCollection.from_json(self.train_config[self.train_classes_key])
  72. logger.info('Read model out classes', extra={'classes': train_classes.to_json()})
  73. # TODO: Factor out meta constructing from _load_train_config method.
  74. self._model_out_meta = ProjectMeta(obj_classes=train_classes, tag_metas=self._model_out_tags())
  75. # Make a separate [index] --> [class] map that excludes the 'special' classes that should not be in the`
  76. # final output.
  77. self.out_class_mapping = {idx: train_classes.get(title) for title, idx in self.class_title_to_idx.items() if
  78. train_classes.has_key(title)}
  79. def _determine_model_input_size(self):
  80. src_size = self.train_config[SETTINGS][INPUT_SIZE]
  81. self.input_size = (src_size[HEIGHT], src_size[WIDTH])
  82. logger.info('Model input size is read (for auto-rescale).', extra={INPUT_SIZE: {
  83. WIDTH: self.input_size[1], HEIGHT: self.input_size[0]
  84. }})