PageRenderTime 64ms CodeModel.GetById 32ms RepoModel.GetById 0ms app.codeStats 0ms

/src/SSVM with Cutting Plane Algorithm/svm_struct_learn_mex.c

https://github.com/XiananHao/Players-with-Best-Chemistry
C | 448 lines | 359 code | 50 blank | 39 comment | 73 complexity | ad1ad24ddbe872e18c8adfe77a77a22c MD5 | raw file
  1. /***********************************************************************/
  2. /* */
  3. /* svm_struct_main.c */
  4. /* */
  5. /* Command line interface to the alignment learning module of the */
  6. /* Support Vector Machine. */
  7. /* */
  8. /* Author: Thorsten Joachims */
  9. /* Date: 03.07.04 */
  10. /* */
  11. /* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
  12. /* */
  13. /* This software is available for non-commercial use only. It must */
  14. /* not be modified and distributed without prior permission of the */
  15. /* author. The author is not responsible for implications from the */
  16. /* use of this software. */
  17. /* */
  18. /***********************************************************************/
  19. #ifdef __cplusplus
  20. extern "C" {
  21. #endif
  22. #include "svm_light/svm_common.h"
  23. #include "svm_light/svm_learn.h"
  24. #ifdef __cplusplus
  25. }
  26. #endif
  27. # include "svm_struct/svm_struct_learn.h"
  28. # include "svm_struct/svm_struct_common.h"
  29. # include "svm_struct_api.h"
  30. #include <stdio.h>
  31. #include <string.h>
  32. #include <assert.h>
  33. void read_input_parameters (int, char **,
  34. long *, long *,
  35. STRUCT_LEARN_PARM *, LEARN_PARM *, KERNEL_PARM *,
  36. int *);
  37. void arg_split (char *string, int *argc, char ***argv) ;
  38. void init_qp_solver() ;
  39. void free_qp_solver() ;
  40. /** ------------------------------------------------------------------
  41. ** @brief MEX entry point
  42. **/
  43. void
  44. mexFunction (int nout, mxArray ** out, int nin, mxArray const ** in)
  45. {
  46. SAMPLE sample; /* training sample */
  47. LEARN_PARM learn_parm;
  48. KERNEL_PARM kernel_parm;
  49. STRUCT_LEARN_PARM struct_parm;
  50. STRUCTMODEL structmodel;
  51. int alg_type;
  52. enum {IN_ARGS=0, IN_SPARM} ;
  53. enum {OUT_W=0} ;
  54. char arg [1024 + 1] ;
  55. int argc ;
  56. char ** argv ;
  57. mxArray const * sparm_array;
  58. mxArray const * patterns_array ;
  59. mxArray const * labels_array ;
  60. mxArray const * kernelFn_array ;
  61. int numExamples, ei ;
  62. mxArray * model_array;
  63. /* SVM-light is not fully reentrant, so we need to run this patch first */
  64. init_qp_solver() ;
  65. verbosity = 0 ;
  66. kernel_cache_statistic = 0 ;
  67. if (nin != 2) {
  68. mexErrMsgTxt("Two arguments required") ;
  69. }
  70. /* Parse ARGS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  71. if (! uIsString(in[IN_ARGS], -1)) {
  72. mexErrMsgTxt("ARGS must be a string") ;
  73. }
  74. mxGetString(in[IN_ARGS], arg, sizeof(arg) / sizeof(char)) ;
  75. arg_split (arg, &argc, &argv) ;
  76. svm_struct_learn_api_init(argc+1, argv-1) ;
  77. read_input_parameters (argc+1,argv-1,
  78. &verbosity, &struct_verbosity,
  79. &struct_parm, &learn_parm,
  80. &kernel_parm, &alg_type ) ;
  81. if (kernel_parm.kernel_type != LINEAR &&
  82. kernel_parm.kernel_type != CUSTOM) {
  83. mexErrMsgTxt ("Only LINEAR or CUSTOM kerneles are supported") ;
  84. }
  85. /* Parse SPARM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  86. sparm_array = in [IN_SPARM] ;
  87. // jk remove
  88. if (! sparm_array) {
  89. mexErrMsgTxt("SPARM must be a structure") ;
  90. }
  91. struct_parm.mex = sparm_array ;
  92. patterns_array = mxGetField(sparm_array, 0, "patterns") ;
  93. if (! patterns_array ||
  94. ! mxIsCell(patterns_array)) {
  95. mexErrMsgTxt("SPARM.PATTERNS must be a cell array") ;
  96. }
  97. numExamples = mxGetNumberOfElements(patterns_array) ;
  98. labels_array = mxGetField(sparm_array, 0, "labels") ;
  99. if (! labels_array ||
  100. ! mxIsCell(labels_array) ||
  101. ! mxGetNumberOfElements(labels_array) == numExamples) {
  102. mexErrMsgTxt("SPARM.LABELS must be a cell array "
  103. "with the same number of elements of "
  104. "SPARM.PATTERNS") ;
  105. }
  106. sample.n = numExamples ;
  107. sample.examples = (EXAMPLE *) my_malloc (sizeof(EXAMPLE) * numExamples) ;
  108. for (ei = 0 ; ei < numExamples ; ++ ei) {
  109. sample.examples[ei].x.mex = mxGetCell(patterns_array, ei) ;
  110. sample.examples[ei].y.mex = mxGetCell(labels_array, ei) ;
  111. sample.examples[ei].y.isOwner = 0 ;
  112. }
  113. if (struct_verbosity >= 1) {
  114. mexPrintf("There are %d training examples\n", numExamples) ;
  115. }
  116. kernelFn_array = mxGetField(sparm_array, 0, "kernelFn") ;
  117. if (! kernelFn_array && kernel_parm.kernel_type == CUSTOM) {
  118. mexErrMsgTxt("SPARM.KERNELFN must be defined for CUSTOM kernels") ;
  119. }
  120. if (kernelFn_array) {
  121. MexKernelInfo * info ;
  122. if (mxGetClassID(kernelFn_array) != mxFUNCTION_CLASS) {
  123. mexErrMsgTxt("SPARM.KERNELFN must be a valid function handle") ;
  124. }
  125. info = (MexKernelInfo*) kernel_parm.custom ;
  126. info -> structParm = sparm_array ;
  127. info -> kernelFn = kernelFn_array ;
  128. }
  129. /* Learning ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  130. switch (alg_type) {
  131. case 0:
  132. svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG) ;
  133. break ;
  134. case 1:
  135. svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
  136. break ;
  137. case 2:
  138. svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
  139. break ;
  140. case 3:
  141. svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
  142. break ;
  143. case 4:
  144. svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
  145. break ;
  146. case 9:
  147. svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
  148. break ;
  149. default:
  150. mexErrMsgTxt("Unknown algorithm type") ;
  151. }
  152. /* Write output ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  153. /* Warning: The model contains references to the original data 'docs'.
  154. If you want to free the original data, and only keep the model, you
  155. have to make a deep copy of 'model'. */
  156. // jk change
  157. model_array = newMxArrayEncapsulatingSmodel (&structmodel) ;
  158. out[OUT_W] = mxDuplicateArray (model_array) ;
  159. destroyMxArrayEncapsulatingSmodel (model_array) ;
  160. free_struct_sample (sample) ;
  161. free_struct_model (structmodel) ;
  162. svm_struct_learn_api_exit () ;
  163. free_qp_solver () ;
  164. }
  165. /** ------------------------------------------------------------------
  166. ** @brief Parse argument string
  167. **/
  168. void
  169. read_input_parameters (int argc,char *argv[],
  170. long *verbosity,long *struct_verbosity,
  171. STRUCT_LEARN_PARM *struct_parm,
  172. LEARN_PARM *learn_parm, KERNEL_PARM *kernel_parm,
  173. int *alg_type)
  174. {
  175. long i ;
  176. (*alg_type)=DEFAULT_ALG_TYPE;
  177. /* SVM struct options */
  178. (*struct_verbosity)=1;
  179. struct_parm->C=-0.01;
  180. struct_parm->slack_norm=1;
  181. struct_parm->epsilon=DEFAULT_EPS;
  182. struct_parm->custom_argc=0;
  183. struct_parm->loss_function=DEFAULT_LOSS_FCT;
  184. struct_parm->loss_type=DEFAULT_RESCALING;
  185. struct_parm->newconstretrain=100;
  186. struct_parm->ccache_size=5;
  187. struct_parm->batch_size=100;
  188. /* SVM light options */
  189. (*verbosity)=0;
  190. strcpy (learn_parm->predfile, "trans_predictions");
  191. strcpy (learn_parm->alphafile, "");
  192. learn_parm->biased_hyperplane=1;
  193. learn_parm->remove_inconsistent=0;
  194. learn_parm->skip_final_opt_check=0;
  195. learn_parm->svm_maxqpsize=10;
  196. learn_parm->svm_newvarsinqp=0;
  197. learn_parm->svm_iter_to_shrink=-9999;
  198. learn_parm->maxiter=100000;
  199. learn_parm->kernel_cache_size=40;
  200. learn_parm->svm_c=99999999; /* overridden by struct_parm->C */
  201. learn_parm->eps=0.001; /* overridden by struct_parm->epsilon */
  202. learn_parm->transduction_posratio=-1.0;
  203. learn_parm->svm_costratio=1.0;
  204. learn_parm->svm_costratio_unlab=1.0;
  205. learn_parm->svm_unlabbound=1E-5;
  206. learn_parm->epsilon_crit=0.001;
  207. learn_parm->epsilon_a=1E-10; /* changed from 1e-15 */
  208. learn_parm->compute_loo=0;
  209. learn_parm->rho=1.0;
  210. learn_parm->xa_depth=0;
  211. kernel_parm->kernel_type=0;
  212. kernel_parm->poly_degree=3;
  213. kernel_parm->rbf_gamma=1.0;
  214. kernel_parm->coef_lin=1;
  215. kernel_parm->coef_const=1;
  216. strcpy (kernel_parm->custom,"empty");
  217. /* Parse -x options, delegat --x ones */
  218. for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
  219. switch ((argv[i])[1])
  220. {
  221. case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break;
  222. case 'c': i++; struct_parm->C=atof(argv[i]); break;
  223. case 'p': i++; struct_parm->slack_norm=atol(argv[i]); break;
  224. case 'e': i++; struct_parm->epsilon=atof(argv[i]); break;
  225. case 'k': i++; struct_parm->newconstretrain=atol(argv[i]); break;
  226. case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break;
  227. case '#': i++; learn_parm->maxiter=atol(argv[i]); break;
  228. case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break;
  229. case 'w': i++; (*alg_type)=atol(argv[i]); break;
  230. case 'o': i++; struct_parm->loss_type=atol(argv[i]); break;
  231. case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break;
  232. case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break;
  233. case 'l': i++; struct_parm->loss_function=atol(argv[i]); break;
  234. case 'f': i++; struct_parm->ccache_size=atol(argv[i]); break;
  235. case 'b': i++; struct_parm->batch_size=atof(argv[i]); break;
  236. case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break;
  237. case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break;
  238. case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break;
  239. case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break;
  240. case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break;
  241. case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break;
  242. case 'v': i++; (*struct_verbosity)=atol(argv[i]); break;
  243. case 'y': i++; (*verbosity)=atol(argv[i]); break;
  244. case '-':
  245. strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);
  246. i++;
  247. strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);
  248. break;
  249. default:
  250. {
  251. char msg [1024+1] ;
  252. #ifndef WIN
  253. snprintf(msg, sizeof(msg)/sizeof(char),
  254. "Unrecognized option '%s'",argv[i]) ;
  255. #else
  256. sprintf(msg, sizeof(msg)/sizeof(char),
  257. "Unrecognized option '%s'",argv[i]) ;
  258. #endif
  259. mexErrMsgTxt(msg) ;
  260. }
  261. }
  262. }
  263. /* whatever is left is an error */
  264. if (i < argc) {
  265. char msg [1024+1] ;
  266. #ifndef WIN
  267. snprintf(msg, sizeof(msg)/sizeof(char),
  268. "Unrecognized argument '%s'", argv[i]) ;
  269. #else
  270. sprintf(msg, sizeof(msg)/sizeof(char),
  271. "Unrecognized argument '%s'", argv[i]) ;
  272. #endif
  273. mexErrMsgTxt(msg) ;
  274. }
  275. /* Check parameter validity */
  276. if(learn_parm->svm_iter_to_shrink == -9999) {
  277. learn_parm->svm_iter_to_shrink=100;
  278. }
  279. if((learn_parm->skip_final_opt_check)
  280. && (kernel_parm->kernel_type == LINEAR)) {
  281. mexWarnMsgTxt("It does not make sense to skip the final optimality check for linear kernels.");
  282. learn_parm->skip_final_opt_check=0;
  283. }
  284. if((learn_parm->skip_final_opt_check)
  285. && (learn_parm->remove_inconsistent)) {
  286. mexErrMsgTxt("It is necessary to do the final optimality check when removing inconsistent examples.");
  287. }
  288. if((learn_parm->svm_maxqpsize<2)) {
  289. char msg [1025] ;
  290. #ifndef WIN
  291. snprintf(msg, sizeof(msg)/sizeof(char),
  292. "Maximum size of QP-subproblems not in valid range: %ld [2..]",learn_parm->svm_maxqpsize) ;
  293. #else
  294. sprintf(msg, sizeof(msg)/sizeof(char),
  295. "Maximum size of QP-subproblems not in valid range: %ld [2..]",learn_parm->svm_maxqpsize) ;
  296. #endif
  297. mexErrMsgTxt(msg) ;
  298. }
  299. if((learn_parm->svm_maxqpsize<learn_parm->svm_newvarsinqp)) {
  300. char msg [1025] ;
  301. #ifndef WIN
  302. snprintf(msg, sizeof(msg)/sizeof(char),
  303. "Maximum size of QP-subproblems [%ld] must be larger than the number of"
  304. " new variables [%ld] entering the working set in each iteration.",
  305. learn_parm->svm_maxqpsize, learn_parm->svm_newvarsinqp) ;
  306. #else
  307. sprintf(msg, sizeof(msg)/sizeof(char),
  308. "Maximum size of QP-subproblems [%ld] must be larger than the number of"
  309. " new variables [%ld] entering the working set in each iteration.",
  310. learn_parm->svm_maxqpsize, learn_parm->svm_newvarsinqp) ;
  311. #endif
  312. mexErrMsgTxt(msg) ;
  313. }
  314. if(learn_parm->svm_iter_to_shrink<1) {
  315. char msg [1025] ;
  316. #ifndef WIN
  317. snprintf(msg, sizeof(msg)/sizeof(char),
  318. "Maximum number of iterations for shrinking not in valid range: %ld [1,..]",
  319. learn_parm->svm_iter_to_shrink);
  320. #else
  321. sprintf(msg, sizeof(msg)/sizeof(char),
  322. "Maximum number of iterations for shrinking not in valid range: %ld [1,..]",
  323. learn_parm->svm_iter_to_shrink);
  324. #endif
  325. mexErrMsgTxt(msg) ;
  326. }
  327. if(struct_parm->C<0) {
  328. mexErrMsgTxt("You have to specify a value for the parameter '-c' (C>0)!");
  329. }
  330. if(((*alg_type) < 0) || (((*alg_type) > 5) && ((*alg_type) != 9))) {
  331. mexErrMsgTxt("Algorithm type must be either '0', '1', '2', '3', '4', or '9'!");
  332. }
  333. if(learn_parm->transduction_posratio>1) {
  334. mexErrMsgTxt("The fraction of unlabeled examples to classify as positives must "
  335. "be less than 1.0 !!!");
  336. }
  337. if(learn_parm->svm_costratio<=0) {
  338. mexErrMsgTxt("The COSTRATIO parameter must be greater than zero!");
  339. }
  340. if(struct_parm->epsilon<=0) {
  341. mexErrMsgTxt("The epsilon parameter must be greater than zero!");
  342. }
  343. if((struct_parm->ccache_size<=0) && ((*alg_type) == 4)) {
  344. mexErrMsgTxt("The cache size must be at least 1!");
  345. }
  346. if(((struct_parm->batch_size<=0) || (struct_parm->batch_size>100))
  347. && ((*alg_type) == 4)) {
  348. mexErrMsgTxt("The batch size must be in the interval ]0,100]!");
  349. }
  350. if((struct_parm->slack_norm<1) || (struct_parm->slack_norm>2)) {
  351. mexErrMsgTxt("The norm of the slacks must be either 1 (L1-norm) or 2 (L2-norm)!");
  352. }
  353. if((struct_parm->loss_type != SLACK_RESCALING)
  354. && (struct_parm->loss_type != MARGIN_RESCALING)) {
  355. mexErrMsgTxt("The loss type must be either 1 (slack rescaling) or 2 (margin rescaling)!");
  356. }
  357. if(learn_parm->rho<0) {
  358. mexErrMsgTxt("The parameter rho for xi/alpha-estimates and leave-one-out pruning must"
  359. " be greater than zero (typically 1.0 or 2.0, see T. Joachims, Estimating the"
  360. " Generalization Performance of an SVM Efficiently, ICML, 2000.)!");
  361. }
  362. if((learn_parm->xa_depth<0) || (learn_parm->xa_depth>100)) {
  363. mexErrMsgTxt("The parameter depth for ext. xi/alpha-estimates must be in [0..100] (zero"
  364. "for switching to the conventional xa/estimates described in T. Joachims,"
  365. "Estimating the Generalization Performance of an SVM Efficiently, ICML, 2000.)") ;
  366. }
  367. parse_struct_parameters (struct_parm) ;
  368. }
  369. void
  370. arg_split (char *string, int *argc, char ***argv)
  371. {
  372. size_t size;
  373. char *d, *p;
  374. for (size = 1, p = string; *p; p++) {
  375. if (isspace((int) *p)) {
  376. size++;
  377. }
  378. }
  379. size++; /* leave space for final NULL pointer. */
  380. *argv = (char **) my_malloc(((size * sizeof(char *)) + (p - string) + 1));
  381. for (*argc = 0, p = string, d = ((char *) *argv) + size*sizeof(char *);
  382. *p != 0; ) {
  383. (*argv)[*argc] = NULL;
  384. while (*p && isspace((int) *p)) p++;
  385. if (*argc == 0 && *p == '#') {
  386. break;
  387. }
  388. if (*p) {
  389. char *s = p;
  390. (*argv)[(*argc)++] = d;
  391. while (*p && !isspace((int) *p)) p++;
  392. memcpy(d, s, p-s);
  393. d += p-s;
  394. *d++ = 0;
  395. while (*p && isspace((int) *p)) p++;
  396. }
  397. }
  398. }