/blaze/thirdparty/onnx/onnx-1.2.2/onnx/backend/base.py

https://github.com/alibaba/x-deeplearning · Python · 110 lines · 75 code · 18 blank · 17 comment · 5 complexity · 5a1ff5092163f653551c3555720a34c1 MD5 · raw file

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. from collections import namedtuple
  6. from typing import Text, Sequence, Any, Type, Tuple, NewType, Optional, Dict
  7. import six
  8. import numpy # type: ignore
  9. import onnx.checker
  10. import onnx.onnx_cpp2py_export.checker as c_checker
  11. from onnx import ModelProto, NodeProto, IR_VERSION
  12. class DeviceType(object):
  13. _Type = NewType('_Type', int)
  14. CPU = _Type(0) # type: _Type
  15. CUDA = _Type(1) # type: _Type
  16. class Device(object):
  17. '''
  18. Describes device type and device id
  19. syntax: device_type:device_id(optional)
  20. example: 'CPU', 'CUDA', 'CUDA:1'
  21. '''
  22. def __init__(self, device): # type: (Text) -> None
  23. options = device.split(':')
  24. self.type = getattr(DeviceType, options[0])
  25. self.device_id = 0
  26. if len(options) > 1:
  27. self.device_id = int(options[1])
  28. def namedtupledict(typename, field_names, *args, **kwargs): # type: (Text, Sequence[Text], *Any, **Any) -> Type[Tuple[Any, ...]]
  29. field_names_map = {n: i for i, n in enumerate(field_names)}
  30. # Some output names are invalid python identifier, e.g. "0"
  31. kwargs.setdefault(str('rename'), True)
  32. data = namedtuple(typename, field_names, *args, **kwargs) # type: ignore
  33. def getitem(self, key): # type: (Any, Any) -> Any
  34. if isinstance(key, six.string_types):
  35. key = field_names_map[key]
  36. return super(type(self), self).__getitem__(key) # type: ignore
  37. data.__getitem__ = getitem
  38. return data
  39. class BackendRep(object):
  40. def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
  41. pass
  42. class Backend(object):
  43. @classmethod
  44. def prepare(cls,
  45. model, # type: ModelProto
  46. device='CPU', # type: Text
  47. **kwargs # type: Any
  48. ): # type: (...) -> Optional[BackendRep]
  49. # TODO Remove Optional from return type
  50. onnx.checker.check_model(model)
  51. return None
  52. @classmethod
  53. def run_model(cls,
  54. model, # type: ModelProto
  55. inputs, # type: Any
  56. device='CPU', # type: Text
  57. **kwargs # type: Any
  58. ): # type: (...) -> Tuple[Any, ...]
  59. backend = cls.prepare(model, device, **kwargs)
  60. assert backend is not None
  61. return backend.run(inputs)
  62. @classmethod
  63. def run_node(cls,
  64. node, # type: NodeProto
  65. inputs, # type: Any
  66. device='CPU', # type: Text
  67. outputs_info=None, # type: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]]
  68. **kwargs # type: Dict[Text, Any]
  69. ): # type: (...) -> Optional[Tuple[Any, ...]]
  70. '''Simple run one operator and return the results.
  71. Args:
  72. outputs_info: a list of tuples, which contains the element type and
  73. shape of each output. First element of the tuple is the dtype, and
  74. the second element is the shape. More use case can be found in
  75. https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
  76. '''
  77. # TODO Remove Optional from return type
  78. if 'opset_version' in kwargs:
  79. special_context = c_checker.CheckerContext()
  80. special_context.ir_version = IR_VERSION
  81. special_context.opset_imports = {'': kwargs['opset_version']} # type: ignore
  82. onnx.checker.check_node(node, special_context)
  83. else:
  84. onnx.checker.check_node(node)
  85. return None
  86. @classmethod
  87. def supports_device(cls, device): # type: (Text) -> bool
  88. """
  89. Checks whether the backend is compiled with particular device support.
  90. In particular it's used in the testing suite.
  91. """
  92. return True