/probreg/l2dist_regs.py

https://github.com/neka-nat/probreg · Python · 195 lines · 157 code · 20 blank · 18 comment · 7 complexity · f0b634b59abdc89de319293ec21c61a6 MD5 · raw file

  1. from __future__ import print_function
  2. from __future__ import division
  3. from collections import namedtuple
  4. import logging
  5. import numpy as np
  6. from scipy.optimize import minimize
  7. import open3d as o3
  8. from . import features as ft
  9. from . import cost_functions as cf
  10. from .log import log
  11. class L2DistRegistration(object):
  12. """L2 distance registration class
  13. This algorithm expresses point clouds as mixture gaussian distributions and
  14. performs registration by minimizing the distance between two distributions.
  15. Args:
  16. source (numpy.ndarray): Source point cloud data.
  17. feature_gen (probreg.features.Feature): Generator of mixture gaussian distribution.
  18. cost_fn (probreg.cost_functions.CostFunction): Cost function to caliculate L2 distance.
  19. sigma (float, optional): Scaling parameter for L2 distance.
  20. delta (float, optional): Annealing parameter for optimization.
  21. use_estimated_sigma (float, optional): If this flag is True,
  22. sigma estimates from the source point cloud.
  23. """
  24. def __init__(self, source, feature_gen, cost_fn,
  25. sigma=1.0, delta=0.9,
  26. use_estimated_sigma=True):
  27. self._source = source
  28. self._feature_gen = feature_gen
  29. self._cost_fn = cost_fn
  30. self._sigma = sigma
  31. self._delta = delta
  32. self._use_estimated_sigma = use_estimated_sigma
  33. self._callbacks = []
  34. if not self._source is None and self._use_estimated_sigma:
  35. self._estimate_sigma(self._source)
  36. def set_source(self, source):
  37. self._source = source
  38. if self._use_estimated_sigma:
  39. self._estimate_sigma(self._source)
  40. def set_callbacks(self, callbacks):
  41. self._callbacks.extend(callbacks)
  42. def _estimate_sigma(self, data):
  43. ndata, dim = data.shape
  44. data_hat = data - np.mean(data, axis=0)
  45. self._sigma = np.power(np.linalg.det(np.dot(data_hat.T, data_hat) / (ndata - 1)), 1.0 / (2.0 * dim))
  46. def _annealing(self):
  47. self._sigma *= self._delta
  48. def optimization_cb(self, x):
  49. tf_result = self._cost_fn.to_transformation(x)
  50. for c in self._callbacks:
  51. c(tf_result)
  52. def registration(self, target, maxiter=1, tol=1.0e-3,
  53. opt_maxiter=50, opt_tol=1.0e-3):
  54. f = None
  55. x_ini = self._cost_fn.initial()
  56. for _ in range(maxiter):
  57. self._feature_gen.init()
  58. mu_source, phi_source = self._feature_gen.compute(self._source)
  59. mu_target, phi_target = self._feature_gen.compute(target)
  60. args = (mu_source, phi_source,
  61. mu_target, phi_target, self._sigma)
  62. res = minimize(self._cost_fn,
  63. x_ini,
  64. args=args,
  65. method='BFGS', jac=True,
  66. tol=opt_tol,
  67. options={'maxiter': opt_maxiter,
  68. 'disp': log.level == logging.DEBUG},
  69. callback=self.optimization_cb)
  70. self._annealing()
  71. self._feature_gen.annealing()
  72. if not f is None and abs(res.fun - f) < tol:
  73. break
  74. f = res.fun
  75. x_ini = res.x
  76. return self._cost_fn.to_transformation(res.x)
  77. class RigidGMMReg(L2DistRegistration):
  78. def __init__(self, source, sigma=1.0, delta=0.9,
  79. n_gmm_components=800, use_estimated_sigma=True):
  80. n_gmm_components = min(n_gmm_components, int(source.shape[0] * 0.8))
  81. super(RigidGMMReg, self).__init__(source, ft.GMM(n_gmm_components),
  82. cf.RigidCostFunction(),
  83. sigma, delta,
  84. use_estimated_sigma)
  85. class TPSGMMReg(L2DistRegistration):
  86. def __init__(self, source, sigma=1.0, delta=0.9,
  87. n_gmm_components=800, alpha=1.0, beta=0.1,
  88. use_estimated_sigma=True):
  89. n_gmm_components = min(n_gmm_components, int(source.shape[0] * 0.8))
  90. super(TPSGMMReg, self).__init__(source, ft.GMM(n_gmm_components),
  91. cf.TPSCostFunction([], alpha, beta),
  92. sigma, delta,
  93. use_estimated_sigma)
  94. self._feature_gen.init()
  95. control_pts, _ = self._feature_gen.compute(source)
  96. self._cost_fn._control_pts = control_pts
  97. class RigidSVR(L2DistRegistration):
  98. def __init__(self, source, sigma=1.0, delta=0.9,
  99. gamma=0.5, nu=0.1, use_estimated_sigma=True):
  100. super(RigidSVR, self).__init__(source,
  101. ft.OneClassSVM(source.shape[1],
  102. sigma, gamma, nu),
  103. cf.RigidCostFunction(),
  104. sigma, delta,
  105. use_estimated_sigma)
  106. def _estimate_sigma(self, data):
  107. super(RigidSVR, self)._estimate_sigma(data)
  108. self._feature_gen._sigma = self._sigma
  109. self._feature_gen._gamma = 1.0 / (2.0 * np.square(self._sigma))
  110. class TPSSVR(L2DistRegistration):
  111. def __init__(self, source, sigma=1.0, delta=0.9,
  112. gamma=0.5, nu=0.1, alpha=1.0, beta=0.1,
  113. use_estimated_sigma=True):
  114. super(TPSSVR, self).__init__(source,
  115. ft.OneClassSVM(source.shape[1],
  116. sigma, gamma, nu),
  117. cf.TPSCostFunction([], alpha, beta),
  118. sigma, delta,
  119. use_estimated_sigma)
  120. self._feature_gen.init()
  121. control_pts, _ = self._feature_gen.compute(source)
  122. self._cost_fn._control_pts = control_pts
  123. def _estimate_sigma(self, data):
  124. super(TPSSVR, self)._estimate_sigma(data)
  125. self._feature_gen._sigma = self._sigma
  126. self._feature_gen._gamma = 1.0 / (2.0 * np.square(self._sigma))
  127. def registration_gmmreg(source, target, tf_type_name='rigid',
  128. callbacks=[], **kargs):
  129. """GMMReg.
  130. Args:
  131. source (numpy.ndarray): Source point cloud data.
  132. target (numpy.ndarray): Target point cloud data.
  133. tf_type_name (str, optional): Transformation type('rigid', 'nonrigid')
  134. callback (:obj:`list` of :obj:`function`, optional): Called after each iteration.
  135. `callback(probreg.Transformation)`
  136. """
  137. cv = lambda x: np.asarray(x.points if isinstance(x, o3.geometry.PointCloud) else x)
  138. if tf_type_name == 'rigid':
  139. gmmreg = RigidGMMReg(cv(source), **kargs)
  140. elif tf_type_name == 'nonrigid':
  141. gmmreg = TPSGMMReg(cv(source), **kargs)
  142. else:
  143. raise ValueError('Unknown transform type %s' % tf_type_name)
  144. gmmreg.set_callbacks(callbacks)
  145. return gmmreg.registration(cv(target))
  146. def registration_svr(source, target, tf_type_name='rigid',
  147. maxiter=1, tol=1.0e-3,
  148. opt_maxiter=50, opt_tol=1.0e-3,
  149. callbacks=[], **kargs):
  150. """Support Vector Registration.
  151. Args:
  152. source (numpy.ndarray): Source point cloud data.
  153. target (numpy.ndarray): Target point cloud data.
  154. tf_type_name (str, optional): Transformation type('rigid', 'nonrigid')
  155. maxitr (int, optional): Maximum number of iterations for outer loop.
  156. tol (float, optional): Tolerance for termination of outer loop.
  157. opt_maxitr (int, optional): Maximum number of iterations for inner loop.
  158. opt_tol (float, optional): Tolerance for termination of inner loop.
  159. callback (:obj:`list` of :obj:`function`, optional): Called after each iteration.
  160. `callback(probreg.Transformation)`
  161. """
  162. cv = lambda x: np.asarray(x.points if isinstance(x, o3.geometry.PointCloud) else x)
  163. if tf_type_name == 'rigid':
  164. svr = RigidSVR(cv(source), **kargs)
  165. elif tf_type_name == 'nonrigid':
  166. svr = TPSSVR(cv(source), **kargs)
  167. else:
  168. raise ValueError('Unknown transform type %s' % tf_type_name)
  169. svr.set_callbacks(callbacks)
  170. return svr.registration(cv(target), maxiter, tol, opt_maxiter, opt_tol)