PageRenderTime 353ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/Modules/Registration/PDEDeformable/test/itkDiffeomorphicDemonsRegistrationFilterTest.cxx

https://github.com/millerjv/ITK
C++ | 427 lines | 314 code | 80 blank | 33 comment | 15 complexity | a163f22d6ee75c284c3b5ec43f0fc058 MD5 | raw file
Possible License(s): Apache-2.0, BSD-3-Clause, BSD-2-Clause
  1. /*=========================================================================
  2. *
  3. * Copyright Insight Software Consortium
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0.txt
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. *=========================================================================*/
  18. #include "itkDiffeomorphicDemonsRegistrationFilter.h"
  19. #include "itkNearestNeighborInterpolateImageFunction.h"
  20. #include "itkVectorCastImageFilter.h"
  21. #include "itkImageFileWriter.h"
  22. // The following class is used to support callbacks
  23. // on the filter in the pipeline that follows later
  24. template<typename TRegistration>
  25. class DiffeomorphicDemonsShowProgressObject
  26. {
  27. public:
  28. DiffeomorphicDemonsShowProgressObject(TRegistration* o)
  29. {m_Process = o;}
  30. void ShowProgress()
  31. {
  32. std::cout << "Progress: " << m_Process->GetProgress() << " ";
  33. std::cout << "Iter: " << m_Process->GetElapsedIterations() << " ";
  34. std::cout << "Metric: " << m_Process->GetMetric() << " ";
  35. std::cout << "RMSChange: " << m_Process->GetRMSChange() << " ";
  36. std::cout << std::endl;
  37. }
  38. typename TRegistration::Pointer m_Process;
  39. };
  40. // Template function to fill in an image with a circle.
  41. template <typename TImage>
  42. void
  43. FillWithCircle(
  44. TImage * image,
  45. double * center,
  46. double radius,
  47. typename TImage::PixelType foregnd,
  48. typename TImage::PixelType backgnd )
  49. {
  50. using Iterator = itk::ImageRegionIteratorWithIndex<TImage>;
  51. Iterator it( image, image->GetBufferedRegion() );
  52. it.GoToBegin();
  53. typename TImage::IndexType index;
  54. double r2 = itk::Math::sqr( radius );
  55. for(; !it.IsAtEnd(); ++it)
  56. {
  57. index = it.GetIndex();
  58. double distance = 0;
  59. for( unsigned int j = 0; j < TImage::ImageDimension; j++ )
  60. {
  61. distance += itk::Math::sqr((double) index[j] - center[j]);
  62. }
  63. if( distance <= r2 ) it.Set( foregnd );
  64. else it.Set( backgnd );
  65. }
  66. }
  67. // Template function to copy image regions
  68. template <typename TImage>
  69. void
  70. CopyImageBuffer(
  71. TImage *input,
  72. TImage *output )
  73. {
  74. using Iterator = itk::ImageRegionIteratorWithIndex<TImage>;
  75. Iterator inIt( input, output->GetBufferedRegion() );
  76. Iterator outIt( output, output->GetBufferedRegion() );
  77. for(; !inIt.IsAtEnd(); ++inIt, ++outIt)
  78. {
  79. outIt.Set( inIt.Get() );
  80. }
  81. }
  82. int itkDiffeomorphicDemonsRegistrationFilterTest(int argc, char * argv [] )
  83. {
  84. if( argc < 9 )
  85. {
  86. std::cerr << "Missing arguments" << std::endl;
  87. std::cerr << "Usage:" << std::endl;
  88. std::cerr << argv[0] << std::endl;
  89. std::cerr << "GradientType [0=Symmetric,1=Fixed,2=WarpedMoving,3=MappedMoving]" << std::endl;
  90. std::cerr << "UseFirstOrderExp [0=No,1=Yes]" << std::endl;
  91. std::cerr << "Intensity Difference Threshold (double)" << std::endl;
  92. std::cerr << "Maximum Update step length (double)" << std::endl;
  93. std::cerr << "Maximum number of iterations (int)" << std::endl;
  94. std::cerr << "Standard deviations (double)" << std::endl;
  95. std::cerr << "Maximum error (double)" << std::endl;
  96. std::cerr << "Maximum kernel width (int)" << std::endl;
  97. return EXIT_FAILURE;
  98. }
  99. using PixelType = unsigned char;
  100. enum {ImageDimension = 2};
  101. using ImageType = itk::Image<PixelType,ImageDimension>;
  102. using VectorType = itk::Vector<float,ImageDimension>;
  103. using FieldType = itk::Image<VectorType,ImageDimension>;
  104. using IndexType = ImageType::IndexType;
  105. using SizeType = ImageType::SizeType;
  106. using RegionType = ImageType::RegionType;
  107. using DirectionType = ImageType::DirectionType;
  108. //--------------------------------------------------------
  109. std::cout << "Generate input images and initial deformation field";
  110. std::cout << std::endl;
  111. ImageType::SizeValueType sizeArray[ImageDimension] = { 128, 128 };
  112. SizeType size;
  113. size.SetSize( sizeArray );
  114. IndexType index;
  115. index.Fill( 0 );
  116. RegionType region;
  117. region.SetSize( size );
  118. region.SetIndex( index );
  119. DirectionType direction;
  120. direction.SetIdentity();
  121. direction(1,1)=-1;
  122. ImageType::Pointer moving = ImageType::New();
  123. ImageType::Pointer fixed = ImageType::New();
  124. FieldType::Pointer initField = FieldType::New();
  125. moving->SetLargestPossibleRegion( region );
  126. moving->SetBufferedRegion( region );
  127. moving->Allocate();
  128. moving->SetDirection(direction);
  129. fixed->SetLargestPossibleRegion( region );
  130. fixed->SetBufferedRegion( region );
  131. fixed->Allocate();
  132. fixed->SetDirection(direction);
  133. initField->SetLargestPossibleRegion( region );
  134. initField->SetBufferedRegion( region );
  135. initField->Allocate();
  136. initField->SetDirection(direction);
  137. double center[ImageDimension];
  138. double radius;
  139. PixelType fgnd = 250;
  140. PixelType bgnd = 15;
  141. // fill moving with circle
  142. center[0] = 64; center[1] = 64; radius = 30;
  143. FillWithCircle<ImageType>( moving, center, radius, fgnd, bgnd );
  144. // fill fixed with circle
  145. center[0] = 62; center[1] = 64; radius = 30;
  146. FillWithCircle<ImageType>( fixed, center, radius, fgnd, bgnd );
  147. // fill initial deformation with zero vectors
  148. VectorType zeroVec;
  149. zeroVec.Fill( 0.0 );
  150. initField->FillBuffer( zeroVec );
  151. using CasterType = itk::VectorCastImageFilter<FieldType,FieldType>;
  152. CasterType::Pointer caster = CasterType::New();
  153. caster->SetInput( initField );
  154. caster->InPlaceOff();
  155. //-------------------------------------------------------------
  156. std::cout << "Run registration and warp moving" << std::endl;
  157. using RegistrationType = itk::DiffeomorphicDemonsRegistrationFilter<
  158. ImageType,ImageType,FieldType>;
  159. RegistrationType::Pointer registrator = RegistrationType::New();
  160. registrator->SetInitialDisplacementField( caster->GetOutput() );
  161. registrator->SetMovingImage( moving );
  162. registrator->SetFixedImage( fixed );
  163. const double intensityDifferenceThreshold = std::stod( argv[3] );
  164. const double maximumUpdateStepLength = std::stod( argv[4] );
  165. const unsigned int numberOfIterations = std::stoi( argv[5] );
  166. const double standardDeviations = std::stod( argv[6] );
  167. const double maximumError = std::stod( argv[7] );
  168. const unsigned int maximumKernelWidth = std::stoi( argv[8] );
  169. registrator->SetIntensityDifferenceThreshold( intensityDifferenceThreshold );
  170. registrator->SetMaximumUpdateStepLength( maximumUpdateStepLength );
  171. registrator->SetNumberOfIterations( numberOfIterations );
  172. registrator->SetStandardDeviations( standardDeviations );
  173. registrator->SetMaximumError( maximumError );
  174. registrator->SetMaximumKernelWidth( maximumKernelWidth );
  175. const int gradientType = std::stoi( argv[1] );
  176. using FunctionType = RegistrationType::DemonsRegistrationFunctionType;
  177. switch( gradientType )
  178. {
  179. case 0:
  180. registrator->SetUseGradientType( FunctionType::Symmetric );
  181. break;
  182. case 1:
  183. registrator->SetUseGradientType( FunctionType::Fixed );
  184. break;
  185. case 2:
  186. registrator->SetUseGradientType( FunctionType::WarpedMoving );
  187. break;
  188. case 3:
  189. registrator->SetUseGradientType( FunctionType::MappedMoving );
  190. break;
  191. }
  192. std::cout << "GradientType = " << registrator->GetUseGradientType() << std::endl;
  193. const int useFirstOrderExponential = std::stoi( argv[2] );
  194. if( useFirstOrderExponential == 0 )
  195. {
  196. registrator->SetUseFirstOrderExp( false );
  197. }
  198. else
  199. {
  200. registrator->SetUseFirstOrderExp( true );
  201. }
  202. // turn on inplace execution
  203. registrator->InPlaceOn();
  204. FunctionType * fptr;
  205. fptr = dynamic_cast<FunctionType *>(
  206. registrator->GetDifferenceFunction().GetPointer() );
  207. fptr->Print( std::cout );
  208. // exercise other member variables
  209. std::cout << "No. Iterations: " << registrator->GetNumberOfIterations() << std::endl;
  210. std::cout << "Max. kernel error: " << registrator->GetMaximumError() << std::endl;
  211. std::cout << "Max. kernel width: " << registrator->GetMaximumKernelWidth() << std::endl;
  212. double v[ImageDimension];
  213. for ( unsigned int j = 0; j < ImageDimension; j++ )
  214. {
  215. v[j] = registrator->GetStandardDeviations()[j];
  216. }
  217. registrator->SetStandardDeviations( v );
  218. using ProgressType = DiffeomorphicDemonsShowProgressObject<RegistrationType>;
  219. ProgressType progressWatch(registrator);
  220. itk::SimpleMemberCommand<ProgressType>::Pointer command;
  221. command = itk::SimpleMemberCommand<ProgressType>::New();
  222. command->SetCallbackFunction(&progressWatch,
  223. &ProgressType::ShowProgress);
  224. registrator->AddObserver( itk::ProgressEvent(), command);
  225. // warp moving image
  226. using WarperType = itk::WarpImageFilter<ImageType,ImageType,FieldType>;
  227. WarperType::Pointer warper = WarperType::New();
  228. using CoordRepType = WarperType::CoordRepType;
  229. using InterpolatorType =
  230. itk::NearestNeighborInterpolateImageFunction<ImageType,CoordRepType>;
  231. InterpolatorType::Pointer interpolator = InterpolatorType::New();
  232. warper->SetInput( moving );
  233. warper->SetDisplacementField( registrator->GetOutput() );
  234. warper->SetInterpolator( interpolator );
  235. warper->SetOutputSpacing( fixed->GetSpacing() );
  236. warper->SetOutputOrigin( fixed->GetOrigin() );
  237. warper->SetOutputDirection( fixed->GetDirection() );
  238. warper->SetEdgePaddingValue( bgnd );
  239. warper->Print( std::cout );
  240. warper->Update();
  241. // ---------------------------------------------------------
  242. std::cout << "Compare warped moving and fixed." << std::endl;
  243. // compare the warp and fixed images
  244. itk::ImageRegionIterator<ImageType> fixedIter( fixed,
  245. fixed->GetBufferedRegion() );
  246. itk::ImageRegionIterator<ImageType> warpedIter( warper->GetOutput(),
  247. fixed->GetBufferedRegion() );
  248. unsigned int numPixelsDifferent = 0;
  249. while( !fixedIter.IsAtEnd() )
  250. {
  251. if( fixedIter.Get() != warpedIter.Get() )
  252. {
  253. numPixelsDifferent++;
  254. }
  255. ++fixedIter;
  256. ++warpedIter;
  257. }
  258. using WriterType = itk::ImageFileWriter< ImageType >;
  259. WriterType::Pointer writer1 = WriterType::New();
  260. WriterType::Pointer writer2 = WriterType::New();
  261. WriterType::Pointer writer3 = WriterType::New();
  262. writer1->SetFileName("fixedImage.mha");
  263. writer2->SetFileName("movingImage.mha");
  264. writer3->SetFileName("registeredImage.mha");
  265. writer1->SetInput( fixed );
  266. writer2->SetInput( moving );
  267. writer3->SetInput( warper->GetOutput() );
  268. writer1->Update();
  269. writer2->Update();
  270. writer3->Update();
  271. std::cout << "Number of pixels different: " << numPixelsDifferent;
  272. std::cout << std::endl;
  273. const unsigned int maximumNumberOfDifferentPixels = std::stoi( argv[9] );
  274. if( numPixelsDifferent > maximumNumberOfDifferentPixels )
  275. {
  276. std::cout << "Test failed - too many pixels different." << std::endl;
  277. return EXIT_FAILURE;
  278. }
  279. registrator->Print( std::cout );
  280. // -----------------------------------------------------------
  281. std::cout << "Test running registrator without initial deformation field.";
  282. std::cout << std::endl;
  283. bool passed = true;
  284. try
  285. {
  286. registrator->SetInput( nullptr );
  287. registrator->SetNumberOfIterations( 2 );
  288. registrator->Update();
  289. }
  290. catch( itk::ExceptionObject& err )
  291. {
  292. std::cout << "Unexpected error." << std::endl;
  293. std::cout << err << std::endl;
  294. passed = false;
  295. }
  296. if ( !passed )
  297. {
  298. std::cout << "Test failed" << std::endl;
  299. return EXIT_FAILURE;
  300. }
  301. //--------------------------------------------------------------
  302. std::cout << "Test exception handling." << std::endl;
  303. std::cout << "Test nullptr moving image. " << std::endl;
  304. passed = false;
  305. try
  306. {
  307. registrator->SetInput( caster->GetOutput() );
  308. registrator->SetMovingImage( nullptr );
  309. registrator->Update();
  310. }
  311. catch( itk::ExceptionObject & err )
  312. {
  313. std::cout << "Caught expected error." << std::endl;
  314. std::cout << err << std::endl;
  315. passed = true;
  316. }
  317. if ( !passed )
  318. {
  319. std::cout << "Test failed" << std::endl;
  320. return EXIT_FAILURE;
  321. }
  322. registrator->SetMovingImage( moving );
  323. registrator->ResetPipeline();
  324. std::cout << "Test nullptr moving image interpolator. " << std::endl;
  325. passed = false;
  326. try
  327. {
  328. fptr = dynamic_cast<FunctionType *>(
  329. registrator->GetDifferenceFunction().GetPointer() );
  330. fptr->SetMovingImageInterpolator( nullptr );
  331. registrator->SetInput( initField );
  332. registrator->Update();
  333. }
  334. catch( itk::ExceptionObject & err )
  335. {
  336. std::cout << "Caught expected error." << std::endl;
  337. std::cout << err << std::endl;
  338. passed = true;
  339. }
  340. if ( !passed )
  341. {
  342. std::cout << "Test failed" << std::endl;
  343. return EXIT_FAILURE;
  344. }
  345. std::cout << "Test passed" << std::endl;
  346. return EXIT_SUCCESS;
  347. }