PageRenderTime 94ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/Framework/Algorithms/src/CrossCorrelate.cpp

https://github.com/mantidproject/mantid
C++ | 266 lines | 192 code | 34 blank | 40 comment | 20 complexity | 6982abefe82f4f5f074dc6bb47078dfd MD5 | raw file
  1. // Mantid Repository : https://github.com/mantidproject/mantid
  2. //
  3. // Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
  4. // NScD Oak Ridge National Laboratory, European Spallation Source,
  5. // Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
  6. // SPDX - License - Identifier: GPL - 3.0 +
  7. //----------------------------------------------------------------------
  8. // Includes
  9. //----------------------------------------------------------------------
  10. #include "MantidAlgorithms/CrossCorrelate.h"
  11. #include "MantidAPI/HistogramValidator.h"
  12. #include "MantidAPI/NumericAxis.h"
  13. #include "MantidAPI/RawCountValidator.h"
  14. #include "MantidAPI/SpectraAxis.h"
  15. #include "MantidAPI/TextAxis.h"
  16. #include "MantidAPI/WorkspaceUnitValidator.h"
  17. #include "MantidDataObjects/Workspace2D.h"
  18. #include "MantidDataObjects/WorkspaceCreation.h"
  19. #include "MantidHistogramData/Histogram.h"
  20. #include "MantidKernel/BoundedValidator.h"
  21. #include "MantidKernel/CompositeValidator.h"
  22. #include "MantidKernel/UnitFactory.h"
  23. #include "MantidKernel/VectorHelper.h"
  24. #include <boost/iterator/counting_iterator.hpp>
  25. #include <numeric>
  26. #include <sstream>
  27. namespace {
  28. struct Variances {
  29. double y;
  30. double e;
  31. };
  32. Variances subtractMean(std::vector<double> &signal, std::vector<double> &error) {
  33. double mean = std::accumulate(signal.cbegin(), signal.cend(), 0.0);
  34. double errorMeanSquared =
  35. std::accumulate(error.cbegin(), error.cend(), 0.0, Mantid::Kernel::VectorHelper::SumSquares<double>());
  36. const auto n = signal.size();
  37. mean /= static_cast<double>(n);
  38. errorMeanSquared /= static_cast<double>(n * n);
  39. double variance = 0.0, errorVariance = 0.0;
  40. auto itY = signal.begin();
  41. auto itE = error.begin();
  42. for (; itY != signal.end(); ++itY, ++itE) {
  43. (*itY) -= mean; // Now the vector is (y[i]-refMean)
  44. (*itE) = (*itE) * (*itE) + errorMeanSquared; // New error squared
  45. const double t = (*itY) * (*itY); //(y[i]-refMean)^2
  46. variance += t; // Sum previous term
  47. errorVariance += 4.0 * t * (*itE); // Error squared
  48. }
  49. return {variance, errorVariance};
  50. }
  51. } // namespace
  52. namespace Mantid::Algorithms {
  53. // Register the class into the algorithm factory
  54. DECLARE_ALGORITHM(CrossCorrelate)
  55. using namespace Kernel;
  56. using namespace API;
  57. using namespace DataObjects;
  58. using namespace HistogramData;
  59. /// Initialisation method.
  60. void CrossCorrelate::init() {
  61. auto wsValidator = std::make_shared<CompositeValidator>();
  62. wsValidator->add<API::WorkspaceUnitValidator>("dSpacing");
  63. wsValidator->add<API::HistogramValidator>();
  64. wsValidator->add<API::RawCountValidator>();
  65. // Input and output workspaces
  66. declareProperty(
  67. std::make_unique<WorkspaceProperty<MatrixWorkspace>>("InputWorkspace", "", Direction::Input, wsValidator),
  68. "A 2D workspace with X values of d-spacing");
  69. declareProperty(std::make_unique<WorkspaceProperty<MatrixWorkspace>>("OutputWorkspace", "", Direction::Output),
  70. "The name of the output workspace");
  71. auto mustBePositive = std::make_shared<BoundedValidator<int>>();
  72. mustBePositive->setLower(0);
  73. // Reference spectra against which cross correlation is performed
  74. declareProperty("ReferenceSpectra", 0, mustBePositive,
  75. "The Workspace Index of the spectra to correlate all other "
  76. "spectra against. ");
  77. // Spectra in the range [min to max] will be cross correlated to referenceSpectra.
  78. declareProperty("WorkspaceIndexMin", 0, mustBePositive,
  79. "The workspace index of the first member of the range of "
  80. "spectra to cross-correlate against.");
  81. declareProperty("WorkspaceIndexMax", 0, mustBePositive,
  82. " The workspace index of the last member of the range of "
  83. "spectra to cross-correlate against.");
  84. // Only the data in the range X_min, X_max will be used
  85. declareProperty("XMin", 0.0, "The starting point of the region to be cross correlated.");
  86. declareProperty("XMax", 0.0, "The ending point of the region to be cross correlated.");
  87. // max is .1
  88. declareProperty("MaxDSpaceShift", EMPTY_DBL(), "Optional float for maximum shift to calculate (in d-spacing)");
  89. }
  90. /** Executes the algorithm
  91. *
  92. * @throw runtime_error Thrown if algorithm cannot execute
  93. */
  94. void CrossCorrelate::exec() {
  95. MatrixWorkspace_const_sptr inputWS = getProperty("InputWorkspace");
  96. double maxDSpaceShift = getProperty("MaxDSpaceShift");
  97. int referenceSpectra = getProperty("ReferenceSpectra");
  98. double xmin = getProperty("XMin");
  99. double xmax = getProperty("XMax");
  100. const int wsIndexMin = getProperty("WorkspaceIndexMin");
  101. const int wsIndexMax = getProperty("WorkspaceIndexMax");
  102. const auto index_ref = static_cast<size_t>(referenceSpectra);
  103. if (wsIndexMin >= wsIndexMax)
  104. throw std::runtime_error("Must specify WorkspaceIndexMin<WorkspaceIndexMax");
  105. // Get the number of spectra in range wsIndexMin to wsIndexMax
  106. int numSpectra = 1 + wsIndexMax - wsIndexMin;
  107. // Indexes of all spectra in range
  108. std::vector<size_t> indexes(boost::make_counting_iterator(wsIndexMin), boost::make_counting_iterator(wsIndexMax + 1));
  109. if (numSpectra == 0) {
  110. std::ostringstream message;
  111. message << "No spectra in range between" << wsIndexMin << " and " << wsIndexMax;
  112. throw std::runtime_error(message.str());
  113. }
  114. // Output messageage information
  115. g_log.information() << "There are " << numSpectra << " spectra in the range\n";
  116. // checdataIndex that the data range specified madataIndexes sense
  117. if (xmin >= xmax)
  118. throw std::runtime_error("Must specify xmin < xmax, " + std::to_string(xmin) + " vs " + std::to_string(xmax));
  119. // TadataIndexe a copy of the referenceSpectra spectrum
  120. auto &referenceSpectraE = inputWS->e(index_ref);
  121. auto &referenceSpectraX = inputWS->x(index_ref);
  122. auto &referenceSpectraY = inputWS->y(index_ref);
  123. // Now checdataIndex if the range between x_min and x_max is valid
  124. using std::placeholders::_1;
  125. auto rangeStart =
  126. std::find_if(referenceSpectraX.cbegin(), referenceSpectraX.cend(), std::bind(std::greater<double>(), _1, xmin));
  127. if (rangeStart == referenceSpectraX.cend())
  128. throw std::runtime_error("No data above XMin");
  129. auto rangeEnd = std::find_if(rangeStart, referenceSpectraX.cend(), std::bind(std::greater<double>(), _1, xmax));
  130. if (rangeStart == rangeEnd)
  131. throw std::runtime_error("Range is not valid");
  132. MantidVec::difference_type rangeStartCorrection = std::distance(referenceSpectraX.cbegin(), rangeStart);
  133. MantidVec::difference_type rangeEndCorrection = std::distance(referenceSpectraX.cbegin(), rangeEnd);
  134. const std::vector<double> referenceXVector(rangeStart, rangeEnd);
  135. std::vector<double> referenceYVector(referenceSpectraY.cbegin() + rangeStartCorrection,
  136. referenceSpectraY.cbegin() + (rangeEndCorrection - 1));
  137. std::vector<double> referenceEVector(referenceSpectraE.cbegin() + rangeStartCorrection,
  138. referenceSpectraE.cbegin() + (rangeEndCorrection - 1));
  139. g_log.information() << "min max " << referenceXVector.front() << " " << referenceXVector.back() << '\n';
  140. // Now start the real stuff
  141. // Create a 2DWorkspace that will hold the result
  142. auto numReferenceY = static_cast<int>(referenceYVector.size());
  143. // max the shift
  144. int shiftCorrection = 0;
  145. if (maxDSpaceShift != EMPTY_DBL()) {
  146. if (xmax - xmin < maxDSpaceShift)
  147. g_log.warning() << "maxDSpaceShift(" << std::to_string(maxDSpaceShift)
  148. << ") is larger than specified range of xmin(" << xmin << ") to xmax(" << xmax
  149. << "), please make it smaller or removed it entirely!"
  150. << "\n";
  151. // convert dspacing to bins, where maxDSpaceShift is at least 0.1
  152. const auto maxBins = std::max(0.0 + maxDSpaceShift * 2, 0.1) / inputWS->getDimension(0)->getBinWidth();
  153. // calc range based on max bins
  154. shiftCorrection = (int)std::max(0.0, abs((-numReferenceY + 2) - (numReferenceY - 2)) - maxBins) / 2;
  155. }
  156. const int numPoints = 2 * (numReferenceY - shiftCorrection) - 3;
  157. if (numPoints < 1)
  158. throw std::runtime_error("Range is not valid");
  159. MatrixWorkspace_sptr out = create<HistoWorkspace>(*inputWS, numSpectra, Points(numPoints));
  160. const auto referenceVariance = subtractMean(referenceYVector, referenceEVector);
  161. const double referenceNorm = 1.0 / sqrt(referenceVariance.y);
  162. double referenceNormE = 0.5 * pow(referenceNorm, 3) * sqrt(referenceVariance.e);
  163. // Now copy the other spectra
  164. bool isDistribution = inputWS->isDistribution();
  165. auto &outX = out->mutableX(0);
  166. for (int i = 0; i < static_cast<int>(outX.size()); ++i) {
  167. outX[i] = static_cast<double>(i - (numReferenceY - shiftCorrection) + 2);
  168. }
  169. // Initialise the progress reporting object
  170. m_progress = std::make_unique<Progress>(this, 0.0, 1.0, numSpectra);
  171. PARALLEL_FOR_IF(Kernel::threadSafe(*inputWS, *out))
  172. for (int currentSpecIndex = 0; currentSpecIndex < numSpectra; ++currentSpecIndex) // Now loop on all spectra
  173. {
  174. PARALLEL_START_INTERRUPT_REGION
  175. size_t wsIndex = indexes[currentSpecIndex]; // Get the ws index from the table
  176. // Copy spectra info from input Workspace
  177. out->getSpectrum(currentSpecIndex).copyInfoFrom(inputWS->getSpectrum(wsIndex));
  178. out->setSharedX(currentSpecIndex, out->sharedX(0));
  179. // Get temp referenceSpectras
  180. const auto &inputXVector = inputWS->x(wsIndex);
  181. const auto &inputYVector = inputWS->y(wsIndex);
  182. const auto &inputEVector = inputWS->e(wsIndex);
  183. // Copy Y,E data of spec(currentSpecIndex) to temp vector
  184. // Now rebin on the grid of referenceSpectra
  185. std::vector<double> tempY(numReferenceY);
  186. std::vector<double> tempE(numReferenceY);
  187. VectorHelper::rebin(inputXVector.rawData(), inputYVector.rawData(), inputEVector.rawData(), referenceXVector, tempY,
  188. tempE, isDistribution);
  189. const auto tempVar = subtractMean(tempY, tempE);
  190. // Calculate the normalisation constant
  191. const double tempNorm = 1.0 / sqrt(tempVar.y);
  192. const double tempNormE = 0.5 * pow(tempNorm, 3) * sqrt(tempVar.e);
  193. const double normalisation = referenceNorm * tempNorm;
  194. const double normalisationE2 = pow((referenceNorm * tempNormE), 2) + pow((tempNorm * referenceNormE), 2);
  195. // Get referenceSpectr to the ouput spectrum
  196. auto &outY = out->mutableY(currentSpecIndex);
  197. auto &outE = out->mutableE(currentSpecIndex);
  198. for (int dataIndex = -numReferenceY + 2 + shiftCorrection; dataIndex <= numReferenceY - 2 - shiftCorrection;
  199. ++dataIndex) {
  200. const int dataIndexP = abs(dataIndex);
  201. double val = 0, err2 = 0, x, y, xE, yE;
  202. for (int j = numReferenceY - 1 - dataIndexP; j >= 0; --j) {
  203. if (dataIndex >= 0) {
  204. x = referenceYVector[j];
  205. y = tempY[j + dataIndexP];
  206. xE = referenceEVector[j];
  207. yE = tempE[j + dataIndexP];
  208. } else {
  209. x = tempY[j];
  210. y = referenceYVector[j + dataIndexP];
  211. xE = tempE[j];
  212. yE = referenceEVector[j + dataIndexP];
  213. }
  214. val += (x * y);
  215. err2 += x * x * yE + y * y * xE;
  216. }
  217. outY[dataIndex + numReferenceY - shiftCorrection - 2] = (val * normalisation);
  218. outE[dataIndex + numReferenceY - shiftCorrection - 2] =
  219. sqrt(val * val * normalisationE2 + normalisation * normalisation * err2);
  220. }
  221. // Update progress information
  222. m_progress->report();
  223. PARALLEL_END_INTERRUPT_REGION
  224. }
  225. PARALLEL_CHECK_INTERRUPT_REGION
  226. out->getAxis(0)->unit() = UnitFactory::Instance().create("Label");
  227. Unit_sptr unit = out->getAxis(0)->unit();
  228. std::shared_ptr<Units::Label> label = std::dynamic_pointer_cast<Units::Label>(unit);
  229. label->setLabel("Bins of Shift", "\\mathbb{Z}");
  230. setProperty("OutputWorkspace", out);
  231. }
  232. } // namespace Mantid::Algorithms