/models/crowdhuman/input.py

https://github.com/TuSimple/simpledet
Python | 194 lines | 124 code | 35 blank | 35 comment | 13 complexity | 49886d3237047a862d3a3ef36b34b70e MD5 | raw file
  1. from __future__ import division
  2. from __future__ import print_function
  3. import numpy as np
  4. import copy
  5. from core.detection_input import AnchorTarget2D
  6. from operator_py.cython.bbox import bbox_overlaps_cython
  7. from operator_py.cython.bbox_self import bbox_selfoverlaps_cython
  8. class PyramidAnchorTarget2DBase(AnchorTarget2D):
  9. """
  10. input: image_meta: tuple(h, w, scale)
  11. gt_bbox, ndarry(max_num_gt, 5)
  12. output: anchor_label, ndarray(num_anchor * h * w)
  13. anchor_bbox_target, ndarray(num_anchor * h * w, 4)
  14. anchor_bbox_weight, ndarray(num_anchor * h * w, 4)
  15. """
  16. def _assign_label_to_anchor(self, valid_anchor, gt_bbox, neg_thr, pos_thr, min_pos_thr):
  17. num_anchor = valid_anchor.shape[0]
  18. cls_label = np.full(shape=(num_anchor,), fill_value=-1, dtype=np.float32)
  19. ignore_label = -2
  20. if len(gt_bbox) > 0:
  21. # num_anchor x num_gt
  22. valid_ind = np.where(gt_bbox[:, 4] != ignore_label)[0]
  23. ignore_ind = np.where(gt_bbox[:, 4] == ignore_label)[0]
  24. valid_bbox = gt_bbox[valid_ind]
  25. overlaps = bbox_overlaps_cython(valid_anchor.astype(np.float32, copy=False), valid_bbox.astype(np.float32, copy=False))
  26. if ignore_ind.shape[0] != 0:
  27. ignore_bbox = gt_bbox[ignore_ind]
  28. ignore_overlaps = bbox_selfoverlaps_cython(valid_anchor.astype(np.float32, copy=False), ignore_bbox.astype(np.float32, copy=False))
  29. ignore_max_overlaps = np.max(ignore_overlaps, axis=1)
  30. else:
  31. ignore_max_overlaps = np.zeros((num_anchor, ))
  32. max_overlaps = overlaps.max(axis=1)
  33. argmax_overlaps = overlaps.argmax(axis=1)
  34. gt_max_overlaps = overlaps.max(axis=0)
  35. # TODO: speed up this
  36. # TODO: fix potentially assigning wrong anchors as positive
  37. # A correct implementation is given as
  38. # gt_argmax_overlaps = np.where((overlaps.transpose() == gt_max_overlaps[:, None]) &
  39. # (overlaps.transpose() >= min_pos_thr))[1]
  40. gt_argmax_overlaps = np.where((overlaps == gt_max_overlaps) &
  41. (overlaps >= min_pos_thr))[0]
  42. # anchor class
  43. cls_label[(max_overlaps < neg_thr)
  44. & (ignore_max_overlaps < neg_thr)
  45. ] = 0
  46. # fg label: for each gt, anchor with highest overlap
  47. cls_label[gt_argmax_overlaps] = 1
  48. # fg label: above threshold IoU
  49. cls_label[max_overlaps >= pos_thr] = 1
  50. else:
  51. cls_label[:] = 0
  52. argmax_overlaps = np.zeros(shape=(num_anchor, ))
  53. return cls_label, argmax_overlaps
  54. def apply(self, input_record):
  55. p = self.p
  56. im_info = input_record["im_info"]
  57. gt_bbox = input_record["gt_bbox"]
  58. assert isinstance(gt_bbox, np.ndarray)
  59. assert gt_bbox.dtype == np.float32
  60. valid = np.where(gt_bbox[:, 4] != -1)[0]
  61. gt_bbox = gt_bbox[valid]
  62. valid_index, valid_anchor = self._gather_valid_anchor(im_info)
  63. cls_label, anchor_label = \
  64. self._assign_label_to_anchor(valid_anchor, gt_bbox,
  65. p.assign.neg_thr, p.assign.pos_thr, p.assign.min_pos_thr)
  66. self._sample_anchor(cls_label, p.sample.image_anchor, p.sample.pos_fraction)
  67. # need to choose valid gtbbox to align index with anchor label
  68. valid_ind = np.where(gt_bbox[:, 4] == 1)[0]
  69. gt_bbox = gt_bbox[valid_ind]
  70. reg_target, reg_weight = self._cal_anchor_target(cls_label, valid_anchor, gt_bbox, anchor_label)
  71. cls_label, reg_target, reg_weight = \
  72. self._scatter_valid_anchor(valid_index, cls_label, reg_target, reg_weight)
  73. """
  74. cls_label: (all_anchor,)
  75. reg_target: (all_anchor, 4)
  76. reg_weight: (all_anchor, 4)
  77. """
  78. input_record["rpn_cls_label"] = cls_label
  79. input_record["rpn_reg_target"] = reg_target
  80. input_record["rpn_reg_weight"] = reg_weight
  81. return input_record["rpn_cls_label"], \
  82. input_record["rpn_reg_target"], \
  83. input_record["rpn_reg_weight"]
  84. class PyramidAnchorTarget2D(PyramidAnchorTarget2DBase):
  85. """
  86. input: image_meta: tuple(h, w, scale)
  87. gt_bbox, ndarry(max_num_gt, 4)
  88. output: anchor_label, ndarray(num_anchor * h * w)
  89. anchor_bbox_target, ndarray(num_anchor * 4, h * w)
  90. anchor_bbox_weight, ndarray(num_anchor * 4, h * w)
  91. """
  92. def __init__(self, pAnchor):
  93. super().__init__(pAnchor)
  94. self.pyramid_levels = len(self.p.generate.stride)
  95. self.p_list = [copy.deepcopy(self.p) for _ in range(self.pyramid_levels)]
  96. pyramid_stride = self.p.generate.stride
  97. pyramid_short = self.p.generate.short
  98. pyramid_long = self.p.generate.long
  99. for i in range(self.pyramid_levels):
  100. self.p_list[i].generate.stride = pyramid_stride[i]
  101. self.p_list[i].generate.short = pyramid_short[i]
  102. self.p_list[i].generate.long = pyramid_long[i]
  103. # generate anchors for multi-leval feature map
  104. self.anchor_target_2d_list = [PyramidAnchorTarget2DBase(p) for p in self.p_list]
  105. self.anchor_target_2d = PyramidAnchorTarget2DBase(self.p_list[0])
  106. self.anchor_target_2d.v_all_anchor = self.v_all_anchor
  107. self.anchor_target_2d.h_all_anchor = self.h_all_anchor
  108. @property
  109. def v_all_anchor(self):
  110. anchors_list = [anchor_target_2d.v_all_anchor for anchor_target_2d in self.anchor_target_2d_list]
  111. anchors = np.concatenate(anchors_list)
  112. return anchors
  113. @property
  114. def h_all_anchor(self):
  115. anchors_list = [anchor_target_2d.h_all_anchor for anchor_target_2d in self.anchor_target_2d_list]
  116. anchors = np.concatenate(anchors_list)
  117. return anchors
  118. def apply(self, input_record):
  119. anchor_size = [0] + [x.h_all_anchor.shape[0] for x in self.anchor_target_2d_list]
  120. anchor_size = np.cumsum(anchor_size)
  121. cls_label, reg_target, reg_weight = \
  122. self.anchor_target_2d.apply(input_record)
  123. im_info = input_record["im_info"]
  124. h, w = im_info[:2]
  125. cls_label_list = []
  126. reg_target_list = []
  127. reg_weight_list = []
  128. for i in range(self.pyramid_levels):
  129. p = self.anchor_target_2d_list[i].p
  130. cls_label_level = cls_label[anchor_size[i]:anchor_size[i + 1]]
  131. reg_target_level = reg_target[anchor_size[i]:anchor_size[i + 1]]
  132. reg_weight_level = reg_weight[anchor_size[i]:anchor_size[i + 1]]
  133. """
  134. label: (h * w * A) -> (A * h * w)
  135. bbox_target: (h * w * A, 4) -> (A * 4, h * w)
  136. bbox_weight: (h * w * A, 4) -> (A * 4, h * w)
  137. """
  138. if h >= w:
  139. fh, fw = p.generate.long, p.generate.short
  140. else:
  141. fh, fw = p.generate.short, p.generate.long
  142. cls_label_level = cls_label_level.reshape((fh, fw, -1)).transpose(2, 0, 1)
  143. reg_target_level = reg_target_level.reshape((fh, fw, -1)).transpose(2, 0, 1)
  144. reg_weight_level = reg_weight_level.reshape((fh, fw, -1)).transpose(2, 0, 1)
  145. cls_label_level = cls_label_level.reshape(-1, fh * fw)
  146. reg_target_level = reg_target_level.reshape(-1, fh * fw)
  147. reg_weight_level = reg_weight_level.reshape(-1, fh * fw)
  148. cls_label_list.append(cls_label_level)
  149. reg_target_list.append(reg_target_level)
  150. reg_weight_list.append(reg_weight_level)
  151. cls_label = np.concatenate(cls_label_list, axis=1).reshape(-1)
  152. reg_target = np.concatenate(reg_target_list, axis=1)
  153. reg_weight = np.concatenate(reg_weight_list, axis=1)
  154. input_record["rpn_cls_label"] = cls_label
  155. input_record["rpn_reg_target"] = reg_target
  156. input_record["rpn_reg_weight"] = reg_weight
  157. return input_record["rpn_cls_label"], \
  158. input_record["rpn_reg_target"], \
  159. input_record["rpn_reg_weight"]