PageRenderTime 45ms CodeModel.GetById 33ms app.highlight 10ms RepoModel.GetById 1ms app.codeStats 0ms

/matlab/srvf_fa_groupwise_reparam.cc

https://bitbucket.org/tetonedge/libsrvf
C++ | 209 lines | 155 code | 26 blank | 28 comment | 39 complexity | 36f21bb1ea1bb32baa9c2bafe8d23e44 MD5 | raw file
Possible License(s): GPL-3.0
  1/*
  2 * libsrvf
  3 * =======
  4 *
  5 * A shape analysis library using the square root velocity framework.
  6 *
  7 * Copyright (C) 2012   FSU Statistical Shape Analysis and Modeling Group
  8 * 
  9 * This program is free software: you can redistribute it and/or modify
 10 * it under the terms of the GNU General Public License as published by
 11 * the Free Software Foundation, either version 3 of the License, or
 12 * (at your option) any later version.
 13 * 
 14 * This program is distributed in the hope that it will be useful,
 15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 17 * GNU General Public License for more details.
 18 * 
 19 * You should have received a copy of the GNU General Public License
 20 * along with this program.  If not, see <http://www.gnu.org/licenses/>
 21 */
 22#include <srvf/srvf.h>
 23#include <srvf/functions.h>
 24
 25#include <mex.h>
 26#include <vector>
 27
 28
 29void do_usage()
 30{
 31  mexPrintf(
 32    "USAGE: [Gm,TGm,Gs,TGs] = %s(Qm,Tm,Qs,Ts)\n"
 33    "Inputs:\n"
 34    "\tQm, Tm : sample points and parameters of the mean SRVF\n"
 35    "\tQs, Ts : sample points and parameters of the other SRVFs\n"
 36    "Outputs:\n"
 37    "\tGm,TGm = the reparametrization for Qm\n"
 38    "\tGs,TGs = the reparametrizations for the Qs\n",
 39    mexFunctionName()
 40  );
 41}
 42
 43
 44extern "C"
 45{
 46void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
 47{
 48  size_t nfuncs;
 49
 50  std::vector<srvf::Srvf> Qs;
 51  srvf::Srvf Qm;
 52  mxArray *sampsi_data;
 53  mxArray *paramsi_data;
 54  srvf::Pointset sampsi;
 55  std::vector<double> paramsi;
 56
 57  std::vector<srvf::Plf> Gs;
 58  mxArray *Gi_matrix, *Ti_matrix;
 59  double *Gi_data, *Ti_data;
 60  mwSize Gs_dim;
 61
 62
 63  // Check arguments
 64  if (nrhs != 4 || nlhs != 4)
 65  {
 66    mexPrintf("error: incorrect number of arguments.\n");
 67    do_usage();
 68    return;
 69  }
 70  if (!mxIsCell(prhs[2]) || !mxIsCell(prhs[3]))
 71  {
 72    mexPrintf("error: Qs and Ts must be cell arrays.\n");
 73    do_usage();
 74    return;
 75  }
 76  nfuncs = mxGetNumberOfElements(prhs[2]);
 77  if (mxGetNumberOfElements(prhs[3]) != nfuncs)
 78  {
 79    mexPrintf("error: Qs and Ts must have the same number of elements.\n");
 80    do_usage();
 81    return;
 82  }
 83
 84
 85  // Create the SRVFs
 86  for (size_t i=0; i<nfuncs; ++i)
 87  {
 88    sampsi_data = mxGetCell(prhs[2],i);
 89    paramsi_data = mxGetCell(prhs[3],i);
 90
 91    // More input checking
 92    if (mxGetM(sampsi_data) != 1 || mxGetM(paramsi_data) != 1)
 93    {
 94      mexPrintf("error: elements of Qs and Ts must have 1 row.\n");
 95      return;
 96    }
 97    if (mxGetN(paramsi_data) != mxGetN(sampsi_data)+1)
 98    {
 99      mexPrintf("error: Ts(%d) must have length size(Qs(%d),2)+1.\n", i, i);
100      return;
101    }
102
103    sampsi = srvf::Pointset(1, mxGetN(sampsi_data), mxGetPr(sampsi_data));
104    paramsi = std::vector<double>(mxGetPr(paramsi_data), 
105      mxGetPr(paramsi_data)+mxGetN(paramsi_data));
106
107    Qs.push_back(srvf::Srvf(sampsi, paramsi));
108  }
109  if (mxGetM(prhs[0]) != 1 || mxGetM(prhs[1]) != 1)
110  {
111    mexPrintf("error: Qm and Tm must have 1 row\n");
112    return;
113  }
114  if (mxGetN(prhs[0])+1 != mxGetN(prhs[1]))
115  {
116    mexPrintf("error: Tm must have length size(Qm,2)+1.\n");
117    return;
118  }
119  sampsi = srvf::Pointset(1,mxGetN(prhs[0]),mxGetPr(prhs[0]));
120  paramsi = std::vector<double> (mxGetPr(prhs[1]), 
121    mxGetPr(prhs[1])+mxGetN(prhs[1]));
122  Qm = srvf::Srvf(sampsi,paramsi);
123
124
125  // All SRVFs must be unit-norm, constant-speed SRVFs
126  for (size_t i=0; i<nfuncs; ++i)
127  {
128    Qs[i].scale_to_unit_norm();
129    Qs[i] = srvf::constant_speed_param(Qs[i]);
130  }
131  Qm.scale_to_unit_norm();
132  Qm = srvf::constant_speed_param(Qm);
133
134
135  // Compute the groupwise alignment using libsrvf
136  Gs = srvf::functions::groupwise_optimal_reparam(Qm,Qs);
137
138
139  // Allocate output variables
140  Gs_dim = (mwSize)nfuncs;
141  plhs[2] = mxCreateCellArray(1,&Gs_dim);
142  plhs[3] = mxCreateCellArray(1,&Gs_dim);
143  if (!plhs[2] || !plhs[3])
144  {
145    mexPrintf("error: mxCreateCellArray() failed.\n");
146    if (plhs[2]) mxDestroyArray(plhs[2]);
147    if (plhs[3]) mxDestroyArray(plhs[3]);
148    return;
149  }
150
151  plhs[0] = mxCreateDoubleMatrix(1, Gs.back().ncp(), mxREAL);
152  plhs[1] = mxCreateDoubleMatrix(1, Gs.back().ncp(), mxREAL);
153  if (!plhs[0] || !plhs[1])
154  {
155    mexPrintf("error: mxCreateDoubleMatrix() failed.\n");
156    mxDestroyArray(plhs[2]);
157    mxDestroyArray(plhs[3]);
158    if (plhs[0]) mxDestroyArray(plhs[0]);
159    if (plhs[1]) mxDestroyArray(plhs[1]);
160    return;
161  }
162  for (size_t i=0; i<nfuncs; ++i)
163  {
164    Gi_matrix = mxCreateDoubleMatrix(1, Gs[i].ncp(), mxREAL);
165    Ti_matrix = mxCreateDoubleMatrix(1, Gs[i].ncp(), mxREAL);
166    if (!Gi_matrix || !Ti_matrix)
167    {
168      if (Gi_matrix) mxDestroyArray(Gi_matrix);
169      if (Ti_matrix) mxDestroyArray(Ti_matrix);
170      for (size_t j=0; j<i; ++j)
171      {
172        mxDestroyArray(mxGetCell(plhs[2],j));
173        mxDestroyArray(mxGetCell(plhs[3],j));
174      }
175      mxDestroyArray(plhs[0]);
176      mxDestroyArray(plhs[1]);
177      mxDestroyArray(plhs[2]);
178      mxDestroyArray(plhs[3]);
179
180      mexPrintf("error: mxCreateDoubleMatrix() failed.\n");
181      return;
182    }
183    mxSetCell(plhs[2], i, Gi_matrix);
184    mxSetCell(plhs[3], i, Ti_matrix);
185  }
186
187  
188  // Copy into plhs
189  Gi_data = mxGetPr(plhs[0]);
190  Ti_data = mxGetPr(plhs[1]);
191  for (size_t i=0; i<Gs.back().ncp(); ++i)
192  {
193    Gi_data[i] = Gs.back().samps()[i][0];
194    Ti_data[i] = Gs.back().params()[i];
195  }
196  for (size_t i=0; i<nfuncs; ++i)
197  {
198    Gi_data = mxGetPr(mxGetCell(plhs[2],i));
199    Ti_data = mxGetPr(mxGetCell(plhs[3],i));
200
201    for (size_t j=0; j<Gs[i].ncp(); ++j)
202    {
203      Gi_data[j] = Gs[i].samps()[j][0];
204      Ti_data[j] = Gs[i].params()[j];
205    }
206  }
207
208}
209} // extern "C"