/matlab/srvf_fa_groupwise_reparam.cc

https://bitbucket.org/tetonedge/libsrvf · C++ · 209 lines · 155 code · 26 blank · 28 comment · 39 complexity · 36f21bb1ea1bb32baa9c2bafe8d23e44 MD5 · raw file

  1. /*
  2. * libsrvf
  3. * =======
  4. *
  5. * A shape analysis library using the square root velocity framework.
  6. *
  7. * Copyright (C) 2012 FSU Statistical Shape Analysis and Modeling Group
  8. *
  9. * This program is free software: you can redistribute it and/or modify
  10. * it under the terms of the GNU General Public License as published by
  11. * the Free Software Foundation, either version 3 of the License, or
  12. * (at your option) any later version.
  13. *
  14. * This program is distributed in the hope that it will be useful,
  15. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  16. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  17. * GNU General Public License for more details.
  18. *
  19. * You should have received a copy of the GNU General Public License
  20. * along with this program. If not, see <http://www.gnu.org/licenses/>
  21. */
  22. #include <srvf/srvf.h>
  23. #include <srvf/functions.h>
  24. #include <mex.h>
  25. #include <vector>
  26. void do_usage()
  27. {
  28. mexPrintf(
  29. "USAGE: [Gm,TGm,Gs,TGs] = %s(Qm,Tm,Qs,Ts)\n"
  30. "Inputs:\n"
  31. "\tQm, Tm : sample points and parameters of the mean SRVF\n"
  32. "\tQs, Ts : sample points and parameters of the other SRVFs\n"
  33. "Outputs:\n"
  34. "\tGm,TGm = the reparametrization for Qm\n"
  35. "\tGs,TGs = the reparametrizations for the Qs\n",
  36. mexFunctionName()
  37. );
  38. }
  39. extern "C"
  40. {
  41. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
  42. {
  43. size_t nfuncs;
  44. std::vector<srvf::Srvf> Qs;
  45. srvf::Srvf Qm;
  46. mxArray *sampsi_data;
  47. mxArray *paramsi_data;
  48. srvf::Pointset sampsi;
  49. std::vector<double> paramsi;
  50. std::vector<srvf::Plf> Gs;
  51. mxArray *Gi_matrix, *Ti_matrix;
  52. double *Gi_data, *Ti_data;
  53. mwSize Gs_dim;
  54. // Check arguments
  55. if (nrhs != 4 || nlhs != 4)
  56. {
  57. mexPrintf("error: incorrect number of arguments.\n");
  58. do_usage();
  59. return;
  60. }
  61. if (!mxIsCell(prhs[2]) || !mxIsCell(prhs[3]))
  62. {
  63. mexPrintf("error: Qs and Ts must be cell arrays.\n");
  64. do_usage();
  65. return;
  66. }
  67. nfuncs = mxGetNumberOfElements(prhs[2]);
  68. if (mxGetNumberOfElements(prhs[3]) != nfuncs)
  69. {
  70. mexPrintf("error: Qs and Ts must have the same number of elements.\n");
  71. do_usage();
  72. return;
  73. }
  74. // Create the SRVFs
  75. for (size_t i=0; i<nfuncs; ++i)
  76. {
  77. sampsi_data = mxGetCell(prhs[2],i);
  78. paramsi_data = mxGetCell(prhs[3],i);
  79. // More input checking
  80. if (mxGetM(sampsi_data) != 1 || mxGetM(paramsi_data) != 1)
  81. {
  82. mexPrintf("error: elements of Qs and Ts must have 1 row.\n");
  83. return;
  84. }
  85. if (mxGetN(paramsi_data) != mxGetN(sampsi_data)+1)
  86. {
  87. mexPrintf("error: Ts(%d) must have length size(Qs(%d),2)+1.\n", i, i);
  88. return;
  89. }
  90. sampsi = srvf::Pointset(1, mxGetN(sampsi_data), mxGetPr(sampsi_data));
  91. paramsi = std::vector<double>(mxGetPr(paramsi_data),
  92. mxGetPr(paramsi_data)+mxGetN(paramsi_data));
  93. Qs.push_back(srvf::Srvf(sampsi, paramsi));
  94. }
  95. if (mxGetM(prhs[0]) != 1 || mxGetM(prhs[1]) != 1)
  96. {
  97. mexPrintf("error: Qm and Tm must have 1 row\n");
  98. return;
  99. }
  100. if (mxGetN(prhs[0])+1 != mxGetN(prhs[1]))
  101. {
  102. mexPrintf("error: Tm must have length size(Qm,2)+1.\n");
  103. return;
  104. }
  105. sampsi = srvf::Pointset(1,mxGetN(prhs[0]),mxGetPr(prhs[0]));
  106. paramsi = std::vector<double> (mxGetPr(prhs[1]),
  107. mxGetPr(prhs[1])+mxGetN(prhs[1]));
  108. Qm = srvf::Srvf(sampsi,paramsi);
  109. // All SRVFs must be unit-norm, constant-speed SRVFs
  110. for (size_t i=0; i<nfuncs; ++i)
  111. {
  112. Qs[i].scale_to_unit_norm();
  113. Qs[i] = srvf::constant_speed_param(Qs[i]);
  114. }
  115. Qm.scale_to_unit_norm();
  116. Qm = srvf::constant_speed_param(Qm);
  117. // Compute the groupwise alignment using libsrvf
  118. Gs = srvf::functions::groupwise_optimal_reparam(Qm,Qs);
  119. // Allocate output variables
  120. Gs_dim = (mwSize)nfuncs;
  121. plhs[2] = mxCreateCellArray(1,&Gs_dim);
  122. plhs[3] = mxCreateCellArray(1,&Gs_dim);
  123. if (!plhs[2] || !plhs[3])
  124. {
  125. mexPrintf("error: mxCreateCellArray() failed.\n");
  126. if (plhs[2]) mxDestroyArray(plhs[2]);
  127. if (plhs[3]) mxDestroyArray(plhs[3]);
  128. return;
  129. }
  130. plhs[0] = mxCreateDoubleMatrix(1, Gs.back().ncp(), mxREAL);
  131. plhs[1] = mxCreateDoubleMatrix(1, Gs.back().ncp(), mxREAL);
  132. if (!plhs[0] || !plhs[1])
  133. {
  134. mexPrintf("error: mxCreateDoubleMatrix() failed.\n");
  135. mxDestroyArray(plhs[2]);
  136. mxDestroyArray(plhs[3]);
  137. if (plhs[0]) mxDestroyArray(plhs[0]);
  138. if (plhs[1]) mxDestroyArray(plhs[1]);
  139. return;
  140. }
  141. for (size_t i=0; i<nfuncs; ++i)
  142. {
  143. Gi_matrix = mxCreateDoubleMatrix(1, Gs[i].ncp(), mxREAL);
  144. Ti_matrix = mxCreateDoubleMatrix(1, Gs[i].ncp(), mxREAL);
  145. if (!Gi_matrix || !Ti_matrix)
  146. {
  147. if (Gi_matrix) mxDestroyArray(Gi_matrix);
  148. if (Ti_matrix) mxDestroyArray(Ti_matrix);
  149. for (size_t j=0; j<i; ++j)
  150. {
  151. mxDestroyArray(mxGetCell(plhs[2],j));
  152. mxDestroyArray(mxGetCell(plhs[3],j));
  153. }
  154. mxDestroyArray(plhs[0]);
  155. mxDestroyArray(plhs[1]);
  156. mxDestroyArray(plhs[2]);
  157. mxDestroyArray(plhs[3]);
  158. mexPrintf("error: mxCreateDoubleMatrix() failed.\n");
  159. return;
  160. }
  161. mxSetCell(plhs[2], i, Gi_matrix);
  162. mxSetCell(plhs[3], i, Ti_matrix);
  163. }
  164. // Copy into plhs
  165. Gi_data = mxGetPr(plhs[0]);
  166. Ti_data = mxGetPr(plhs[1]);
  167. for (size_t i=0; i<Gs.back().ncp(); ++i)
  168. {
  169. Gi_data[i] = Gs.back().samps()[i][0];
  170. Ti_data[i] = Gs.back().params()[i];
  171. }
  172. for (size_t i=0; i<nfuncs; ++i)
  173. {
  174. Gi_data = mxGetPr(mxGetCell(plhs[2],i));
  175. Ti_data = mxGetPr(mxGetCell(plhs[3],i));
  176. for (size_t j=0; j<Gs[i].ncp(); ++j)
  177. {
  178. Gi_data[j] = Gs[i].samps()[j][0];
  179. Ti_data[j] = Gs[i].params()[j];
  180. }
  181. }
  182. }
  183. } // extern "C"