PageRenderTime 27ms CodeModel.GetById 0ms RepoModel.GetById 0ms app.codeStats 1ms

/Modules/Registration/Common/test/itkImageRegistrationMethodTest_15.cxx

https://github.com/paniwani/ITK
C++ | 408 lines | 224 code | 82 blank | 102 comment | 16 complexity | 911590dfc55144377f75fe9b5db665a7 MD5 | raw file
  1. /*=========================================================================
  2. *
  3. * Copyright Insight Software Consortium
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0.txt
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. *=========================================================================*/
  18. #include "itkImageRegistrationMethod.h"
  19. #include "itkAffineTransform.h"
  20. #include "itkMattesMutualInformationImageToImageMetric.h"
  21. #include "itkBSplineInterpolateImageFunction.h"
  22. #include "itkGradientDescentOptimizer.h"
  23. #include "itkTextOutput.h"
  24. #include "itkImageRegionIterator.h"
  25. #include "itkCommandIterationUpdate.h"
  26. #include "vnl/vnl_sample.h"
  27. namespace
  28. {
  29. double F( itk::Vector<double,3> & v );
  30. }
  31. /**
  32. * This program test one instantiation of the itk::ImageRegistrationMethod class
  33. *
  34. * This file tests the combination of:
  35. * - MattesMutualInformation
  36. * - AffineTransform
  37. * - GradientDescentOptimizer
  38. * - BSplineInterpolateImageFunction
  39. *
  40. * The test image pattern consists of a 3D gaussian in the middle
  41. * with some directional pattern on the outside.
  42. * One image is scaled and shifted relative to the other.
  43. *
  44. * Notes:
  45. * =======
  46. * This example performs an affine registration
  47. * between a moving (source) and fixed (target) image using mutual information.
  48. * It uses a simple steepest descent optimizer to find the
  49. * best affine transform to register the moving image onto the fixed
  50. * image.
  51. *
  52. * The mutual information value and its derivatives are estimated
  53. * using spatial sampling.
  54. *
  55. * The registration uses a simple stochastic gradient ascent scheme. Steps
  56. * are repeatedly taken that are proportional to the approximate
  57. * deriviative of the mutual information with respect to the affine
  58. * transform parameters. The stepsize is governed by the LearningRate
  59. * parameter.
  60. *
  61. * Since the parameters of the linear part is different in magnitude
  62. * to the parameters in the offset part, scaling is required
  63. * to improve convergence. The scaling can set via the optimizer.
  64. *
  65. * In the optimizer's scale transform set the scaling for
  66. * all the translation parameters to TranslationScale^{-2}.
  67. * Set the scale for all other parameters to 1.0.
  68. *
  69. * Note: the optimization performance can be improved by
  70. * setting the image origin to center of mass of the image.
  71. *
  72. */
  73. int itkImageRegistrationMethodTest_15(int, char* [] )
  74. {
  75. itk::OutputWindow::SetInstance(itk::TextOutput::New().GetPointer());
  76. /*==================================================*/
  77. /**
  78. * Debugging vnl_sample
  79. */
  80. std::cout << "Debugging vnl_sample" << std::endl;
  81. #if VXL_STDLIB_HAS_DRAND48
  82. std::cout << "vxl stdlib has drand48" << std::endl;
  83. #else
  84. std::cout << "vxl stdlib does not have drand48" << std::endl;
  85. #endif
  86. std::cout << std::endl;
  87. std::cout << "printout 10 numbers with default seeds" << std::endl;
  88. for( int p = 0; p < 10; p++ )
  89. {
  90. double value = vnl_sample_uniform( 0, 100 );
  91. std::cout << p << "\t" << value << std::endl;
  92. }
  93. std::cout << "printout 10 numbers with seed 171219" << std::endl;
  94. vnl_sample_reseed( 171219 );
  95. for( int p = 0; p < 10; p++ )
  96. {
  97. double value = vnl_sample_uniform( 0, 100 );
  98. std::cout << p << "\t" << value << std::endl;
  99. }
  100. /*==================================================*/
  101. bool pass = true;
  102. const unsigned int dimension = 3;
  103. unsigned int j;
  104. typedef float PixelType;
  105. // Fixed Image Type
  106. typedef itk::Image<PixelType,dimension> FixedImageType;
  107. // Moving Image Type
  108. typedef itk::Image<PixelType,dimension> MovingImageType;
  109. // Transform Type
  110. typedef itk::AffineTransform< double,dimension > TransformType;
  111. // Optimizer Type
  112. typedef itk::GradientDescentOptimizer OptimizerType;
  113. // Metric Type
  114. typedef itk::MattesMutualInformationImageToImageMetric<
  115. FixedImageType,
  116. MovingImageType > MetricType;
  117. // Interpolation technique
  118. typedef itk:: BSplineInterpolateImageFunction<
  119. MovingImageType,
  120. double > InterpolatorType;
  121. // Registration Method
  122. typedef itk::ImageRegistrationMethod<
  123. FixedImageType,
  124. MovingImageType > RegistrationType;
  125. MetricType::Pointer metric = MetricType::New();
  126. TransformType::Pointer transform = TransformType::New();
  127. OptimizerType::Pointer optimizer = OptimizerType::New();
  128. FixedImageType::Pointer fixedImage = FixedImageType::New();
  129. MovingImageType::Pointer movingImage = MovingImageType::New();
  130. InterpolatorType::Pointer interpolator = InterpolatorType::New();
  131. RegistrationType::Pointer registration = RegistrationType::New();
  132. /*********************************************************
  133. * Set up the two input images.
  134. * One image scaled and shifted with respect to the other.
  135. **********************************************************/
  136. double displacement[dimension] = {3,1,1};
  137. double scale[dimension] = { 0.90, 1.0, 1.0 };
  138. FixedImageType::SizeType size = {{100,100,40}};
  139. FixedImageType::IndexType index = {{0,0,0}};
  140. FixedImageType::RegionType region;
  141. region.SetSize( size );
  142. region.SetIndex( index );
  143. fixedImage->SetLargestPossibleRegion( region );
  144. fixedImage->SetBufferedRegion( region );
  145. fixedImage->SetRequestedRegion( region );
  146. fixedImage->Allocate();
  147. movingImage->SetLargestPossibleRegion( region );
  148. movingImage->SetBufferedRegion( region );
  149. movingImage->SetRequestedRegion( region );
  150. movingImage->Allocate();
  151. typedef itk::ImageRegionIterator<MovingImageType> MovingImageIterator;
  152. typedef itk::ImageRegionIterator<FixedImageType> FixedImageIterator;
  153. itk::Point<double,dimension> center;
  154. for ( j = 0; j < dimension; j++ )
  155. {
  156. center[j] = 0.5 * (double)region.GetSize()[j];
  157. }
  158. itk::Point<double,dimension> p;
  159. itk::Vector<double,dimension> d;
  160. MovingImageIterator mIter( movingImage, region );
  161. FixedImageIterator fIter( fixedImage, region );
  162. while( !mIter.IsAtEnd() )
  163. {
  164. for ( j = 0; j < dimension; j++ )
  165. {
  166. p[j] = mIter.GetIndex()[j];
  167. }
  168. d = p - center;
  169. fIter.Set( (PixelType) F(d) );
  170. for ( j = 0; j < dimension; j++ )
  171. {
  172. d[j] = d[j] * scale[j] + displacement[j];
  173. }
  174. mIter.Set( (PixelType) F(d) );
  175. ++fIter;
  176. ++mIter;
  177. }
  178. // set the image origin to be center of the image
  179. double transCenter[dimension];
  180. for ( j = 0; j < dimension; j++ )
  181. {
  182. transCenter[j] = -0.5 * double(size[j]);
  183. }
  184. movingImage->SetOrigin( transCenter );
  185. fixedImage->SetOrigin( transCenter );
  186. /******************************************************************
  187. * Set up the optimizer.
  188. ******************************************************************/
  189. // set the translation scale
  190. typedef OptimizerType::ScalesType ScalesType;
  191. ScalesType parametersScales( transform->GetNumberOfParameters() );
  192. parametersScales.Fill( 1.0 );
  193. for ( j = 9; j < 12; j++ )
  194. {
  195. parametersScales[j] = 0.0001;
  196. }
  197. optimizer->SetScales( parametersScales );
  198. optimizer->MaximizeOff();
  199. /******************************************************************
  200. * Set up the optimizer observer
  201. ******************************************************************/
  202. typedef itk::CommandIterationUpdate< OptimizerType > CommandIterationType;
  203. CommandIterationType::Pointer iterationCommand =
  204. CommandIterationType::New();
  205. iterationCommand->SetOptimizer( optimizer );
  206. /******************************************************************
  207. * Set up the metric.
  208. ******************************************************************/
  209. metric->SetNumberOfSpatialSamples( static_cast<unsigned long>(
  210. 0.01 * fixedImage->GetBufferedRegion().GetNumberOfPixels() ) );
  211. metric->SetNumberOfHistogramBins( 50 );
  212. for( unsigned int jj = 0; jj < dimension; jj++ )
  213. {
  214. size[jj] -= 4;
  215. index[jj] += 2;
  216. }
  217. region.SetSize( size );
  218. region.SetIndex( index );
  219. metric->SetFixedImageRegion( region );
  220. /******************************************************************
  221. * Set up the registrator.
  222. ******************************************************************/
  223. // connect up the components
  224. registration->SetMetric( metric );
  225. registration->SetOptimizer( optimizer );
  226. registration->SetTransform( transform );
  227. registration->SetFixedImage( fixedImage );
  228. registration->SetMovingImage( movingImage );
  229. registration->SetInterpolator( interpolator );
  230. // set initial parameters to identity
  231. RegistrationType::ParametersType initialParameters(
  232. transform->GetNumberOfParameters() );
  233. initialParameters.Fill( 0.0 );
  234. initialParameters[0] = 1.0;
  235. initialParameters[4] = 1.0;
  236. initialParameters[8] = 1.0;
  237. /***********************************************************
  238. * Run the registration
  239. ************************************************************/
  240. const unsigned int numberOfLoops = 2;
  241. unsigned int iter[numberOfLoops] = { 50, 0 };
  242. double rates[numberOfLoops] = { 1e-3, 5e-4 };
  243. for ( j = 0; j < numberOfLoops; j++ )
  244. {
  245. try
  246. {
  247. optimizer->SetNumberOfIterations( iter[j] );
  248. optimizer->SetLearningRate( rates[j] );
  249. registration->SetInitialTransformParameters( initialParameters );
  250. registration->Update();
  251. initialParameters = registration->GetLastTransformParameters();
  252. }
  253. catch( itk::ExceptionObject & e )
  254. {
  255. std::cout << "Registration failed" << std::endl;
  256. std::cout << "Reason " << e.GetDescription() << std::endl;
  257. return EXIT_FAILURE;
  258. }
  259. }
  260. /***********************************************************
  261. * Check the results
  262. ************************************************************/
  263. RegistrationType::ParametersType solution =
  264. registration->GetLastTransformParameters();
  265. std::cout << "Solution is: " << solution << std::endl;
  266. RegistrationType::ParametersType trueParameters(
  267. transform->GetNumberOfParameters() );
  268. trueParameters.Fill( 0.0 );
  269. trueParameters[ 0] = 1/scale[0];
  270. trueParameters[ 4] = 1/scale[1];
  271. trueParameters[ 8] = 1/scale[2];
  272. trueParameters[ 9] = - displacement[0]/scale[0];
  273. trueParameters[10] = - displacement[1]/scale[1];
  274. trueParameters[11] = - displacement[2]/scale[2];
  275. std::cout << "True solution is: " << trueParameters << std::endl;
  276. for( j = 0; j < 9; j++ )
  277. {
  278. if( vnl_math_abs( solution[j] - trueParameters[j] ) > 0.025 )
  279. {
  280. pass = false;
  281. }
  282. }
  283. for( j = 9; j < 12; j++ )
  284. {
  285. if( vnl_math_abs( solution[j] - trueParameters[j] ) > 1.0 )
  286. {
  287. pass = false;
  288. }
  289. }
  290. if( !pass )
  291. {
  292. std::cout << "Test failed." << std::endl;
  293. return EXIT_FAILURE;
  294. }
  295. std::cout << "Test passed." << std::endl;
  296. return EXIT_SUCCESS;
  297. }
  298. namespace
  299. {
  300. /**
  301. * This function defines the test image pattern.
  302. * The pattern is a 3D gaussian in the middle
  303. * and some directional pattern on the outside.
  304. */
  305. double F( itk::Vector<double,3> & v )
  306. {
  307. double x = v[0];
  308. double y = v[1];
  309. double z = v[2];
  310. const double s = 50;
  311. double value = 200.0 * vcl_exp( - ( x*x + y*y + z*z )/(s*s) );
  312. x -= 8; y += 3; z += 0;
  313. double r = vcl_sqrt( x*x + y*y + z*z );
  314. if( r > 35 )
  315. {
  316. value = 2 * ( vnl_math_abs( x ) +
  317. 0.8 * vnl_math_abs( y ) +
  318. 0.5 * vnl_math_abs( z ) );
  319. }
  320. if( r < 4 )
  321. {
  322. value = 400;
  323. }
  324. return value;
  325. }
  326. }