/tags/dev-20000116/FreeSpeech/VQ/src/Cell.cc

# · C++ · 418 lines · 184 code · 33 blank · 201 comment · 51 complexity · 5147e4bd2c60a49ebb219b5aaf13a24a MD5 · raw file

  1. #include "Cell.h"
  2. #include <string>
  3. #include "ObjectParser.h"
  4. #include "misc.h"
  5. #include <algo.h>
  6. DECLARE_TYPE(Cell)
  7. void Cell::recursiveSplit (const vector<pair<int, float *> > &data, int level)
  8. {
  9. if (level <= 0)
  10. {
  11. cerr << "LEAF: " << data.size() << endl;
  12. return;
  13. }
  14. int dim;
  15. float thresh;
  16. split(data, dim, thresh);
  17. vector<pair<int, float *> > firstData;
  18. vector<pair<int, float *> > secondData;
  19. for (int i=0;i<data.size();i++)
  20. if (data[i].second[dim] < thresh)
  21. {
  22. //cerr << i << "(" << data[i].second[0] << "," << data[i].second[1] << ") goes to first\n";
  23. firstData.insert(firstData.end(), data[i]);
  24. } else {
  25. //cerr << i << "(" << data[i].second[0] << "," << data[i].second[1] << ") goes to second\n";
  26. secondData.insert(secondData.end(), data[i]);
  27. }
  28. splitDimension = dim;
  29. threshold = thresh;
  30. //cout << dimension << endl;
  31. first = new Cell (dimension, numberClasses);
  32. second = new Cell (dimension, numberClasses);
  33. terminal = false;
  34. first->recursiveSplit(firstData, level-1);
  35. second->recursiveSplit(secondData, level-1);
  36. }
  37. void Cell::split(const vector<pair<int, float *> > &data, int &bestDim, float &bestThreshold)
  38. {
  39. bestDim=0;
  40. bestThreshold=0;
  41. float bestMutual = -FLT_MAX;
  42. for (int i=0;i<dimension;i++)
  43. {
  44. float threshold;
  45. float currentMutual;
  46. findThreshold(data, i, threshold, currentMutual);
  47. //cerr << "threshold: " << threshold << " currentMutual: " << currentMutual << endl;
  48. if (currentMutual > bestMutual)
  49. {
  50. bestMutual=currentMutual;
  51. bestDim=i;
  52. bestThreshold=threshold;
  53. }
  54. }
  55. cerr << "bestDim: " << bestDim << " bestThreshold: " << bestThreshold << endl;
  56. //if (some condition on bestMutual) don't perform the split
  57. //splitWithThreshold(data, bestDim, bestThreshold);
  58. }
  59. /*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &bestThresh, float &bestScore)
  60. {
  61. if (data.size()==0)
  62. {
  63. bestThresh=0;
  64. bestScore = 0;
  65. return;
  66. }
  67. float min_value = FLT_MAX, max_value = -FLT_MAX;
  68. int min_ind = 0, max_ind = 0;
  69. int i,k;
  70. for (i=0;i<data.size();i++)
  71. {
  72. if (data[i].second[dim] > max_value)
  73. {
  74. max_value = data[i].second[dim];
  75. max_ind=i;
  76. }
  77. if (data[i].second[dim] < min_value)
  78. {
  79. min_value = data[i].second[dim];
  80. min_ind=i;
  81. }
  82. }
  83. bestThresh = 0;
  84. bestScore = -FLT_MAX;
  85. float thresh;
  86. float score;
  87. for (thresh = min_value; thresh < max_value; thresh += (max_value-min_value)/15.0)
  88. {
  89. int sumAi = 0, sumBi = 0;
  90. vector<int> Ai (numberClasses, 0);
  91. vector<int> Bi (numberClasses, 0);
  92. for (k=0;k<data.size();k++)
  93. {
  94. if (data[k].second[dim] >= thresh)
  95. {
  96. sumAi++;
  97. Ai[data[k].first]++;
  98. } else {
  99. sumBi++;
  100. Bi[data[k].first]++;
  101. }
  102. }
  103. double weight = double(sumAi)/data.size();
  104. //cerr << "weight: " << weight << " sumAi = " << sumAi << endl;
  105. score = 0.0;
  106. for (i = 0;i<numberClasses;i++)
  107. {
  108. //cerr << "A[" << i << "] = " << Ai[i] << "\t" << "Ai[i] / sumAi = " << (double( Ai[i] ) / sumAi ) << "\t";
  109. //cerr << "B[" << i << "] = " << Bi[i] << "\t" << "Bi[i] / sumBi = " << (double( Bi[i] ) / sumBi ) << "\t";
  110. if (sumAi)
  111. score -= weight * entropy_funct (double( Ai[i] ) / sumAi );
  112. if (sumBi)
  113. score -= (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
  114. //cerr << "score = " << score << endl;
  115. }
  116. cerr << "got " << score << " for threshold " << thresh << endl;
  117. if (score > bestScore)
  118. {
  119. bestThresh = thresh;
  120. bestScore = score;
  121. }
  122. }
  123. }*/
  124. /*
  125. void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &bestThresh, float &bestScore)
  126. {
  127. float sum = 0, s2 = 0;
  128. int i,k;
  129. for (i=0;i<data.size();i++)
  130. {
  131. sum += data[i].second[dim];
  132. s2+= sqr(data[i].second[dim]);
  133. }
  134. if (data.size()<=1)
  135. {
  136. bestThresh=0;
  137. bestScore=0;
  138. return;
  139. }
  140. sum /= data.size();
  141. s2=sqrt(s2/data.size() - sqr(sum) );
  142. //cerr << "s2 = " << s2 << " N = " << data.size() << endl;
  143. float min_value = sum - 1.5*s2;
  144. float max_value = sum + 1.5*s2;
  145. //thresh=sum/data.size();
  146. //if (data.size()==0) thresh=0;
  147. bestThresh = 0;
  148. bestScore = -FLT_MAX;
  149. float thresh;
  150. float score;
  151. for (thresh = min_value; thresh < max_value; thresh += (max_value-min_value)/15.0)
  152. {
  153. int sumAi = 0, sumBi = 0;
  154. vector<int> Ai (numberClasses, 0);
  155. vector<int> Bi (numberClasses, 0);
  156. for (k=0;k<data.size();k++)
  157. {
  158. if (data[k].second[dim] >= thresh)
  159. {
  160. sumAi++;
  161. Ai[data[k].first]++;
  162. } else {
  163. sumBi++;
  164. Bi[data[k].first]++;
  165. }
  166. }
  167. double weight = double(sumAi)/data.size();
  168. score = - numberClasses * .01*abs(thresh-sum)/s2;
  169. //score = 0;
  170. for (i = 0;i<numberClasses;i++)
  171. {
  172. score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
  173. - (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
  174. }
  175. if (score > bestScore)
  176. {
  177. bestThresh = thresh;
  178. bestScore = score;
  179. }
  180. }
  181. }
  182. */
  183. static int float_less(const void *a, const void *b)
  184. {
  185. return *((float *)a) < *((float *)b);
  186. }
  187. //find threshold using split at median and mutual information
  188. void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
  189. {
  190. float sum = 0;
  191. int i,k;
  192. /*for (i=0;i<data.size();i++)
  193. sum += data[i].second[dim];
  194. thresh=sum/data.size();*/
  195. if (data.size()==0) thresh=0;
  196. else {
  197. float sorted[data.size()];
  198. for (i=0;i<data.size();i++)
  199. sorted[i] = data[i].second[dim];
  200. //qsort(sorted,data.size(),sizeof(float), float_less);
  201. sort (sorted,sorted+data.size());
  202. thresh=sorted[data.size()/2];
  203. }
  204. int sumAi = 0, sumBi = 0;
  205. vector<int> Ai (numberClasses, 0);
  206. vector<int> Bi (numberClasses, 0);
  207. for (k=0;k<data.size();k++)
  208. {
  209. if (data[k].second[dim] >= thresh)
  210. {
  211. sumAi++;
  212. Ai[data[k].first]++;
  213. } else {
  214. sumBi++;
  215. Bi[data[k].first]++;
  216. }
  217. }
  218. double weight = double(sumAi)/data.size();
  219. score = 0.0;
  220. for (i = 0;i<numberClasses;i++)
  221. {
  222. score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
  223. - (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
  224. }
  225. }
  226. //find threshold using split at average and mutual information
  227. /*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
  228. {
  229. float sum = 0;
  230. int i,k;
  231. for (i=0;i<data.size();i++)
  232. sum += data[i].second[dim];
  233. thresh=sum/data.size();
  234. if (data.size()==0) thresh=0;
  235. int sumAi = 0, sumBi = 0;
  236. vector<int> Ai (numberClasses, 0);
  237. vector<int> Bi (numberClasses, 0);
  238. for (k=0;k<data.size();k++)
  239. {
  240. if (data[k].second[dim] >= thresh)
  241. {
  242. sumAi++;
  243. Ai[data[k].first]++;
  244. } else {
  245. sumBi++;
  246. Bi[data[k].first]++;
  247. }
  248. }
  249. double weight = double(sumAi)/data.size();
  250. score = 0.0;
  251. for (i = 0;i<numberClasses;i++)
  252. {
  253. score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
  254. - (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
  255. }
  256. }
  257. */
  258. /*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
  259. {
  260. float sum = 0;
  261. int i,k;
  262. for (i=0;i<data.size();i++)
  263. sum += data[i].second[dim];
  264. thresh=sum/data.size();
  265. if (data.size()==0) thresh=0;
  266. vector<int> scores (numberClasses, 0);
  267. for (i=0;i<data.size();i++)
  268. if (data[i].second[dim] >= thresh)
  269. {
  270. scores[data[i].first]++;
  271. }
  272. else
  273. {
  274. scores[data[i].first]--;
  275. }
  276. score = 0.0;
  277. for (i=0;i<numberClasses;i++)
  278. score += abs(scores[i]);
  279. }*/
  280. int Cell::setNumbering(int start)
  281. {
  282. if (terminal)
  283. {
  284. cellID=start;
  285. //cerr << start << endl;
  286. return start+1;
  287. } else {
  288. return second->setNumbering(first->setNumbering(start));
  289. }
  290. }
  291. int Cell::belongs(float *vect) const
  292. {
  293. if (terminal) return cellID;
  294. if (vect[splitDimension] < threshold)
  295. return first->belongs(vect);
  296. else
  297. return second->belongs(vect);
  298. }
  299. void Cell::calcTemplate (const vector<float *> &features, vector<int> &templ) const
  300. {
  301. for (vector<float *>::const_iterator feature = features.begin();
  302. feature < features.end(); feature++)
  303. {
  304. //cerr << "(" << (*feature)[0] << "," << (*feature)[1] << "): " << belongs(*feature) << endl;
  305. templ[belongs(*feature)]++;
  306. }
  307. }
  308. void Cell::printOn(ostream &out) const
  309. {
  310. out << "<Cell " << endl;
  311. out << "<dimension " << dimension << ">" << endl;
  312. out << "<numberClasses " << numberClasses << ">" << endl;
  313. out << "<terminal " << terminal << ">" << endl;
  314. if (terminal)
  315. {
  316. out << "<cellID " << cellID << ">" << endl;
  317. } else {
  318. out << "<threshold " << threshold << ">" << endl;
  319. out << "<splitDimension " << splitDimension << ">" << endl;
  320. out << "<first " << *first << ">" << endl;;
  321. out << "<second " << *second << ">" << endl;;
  322. }
  323. out << ">\n";
  324. }
  325. ostream &operator << (ostream &out, const Cell &cell)
  326. {
  327. cell.printOn(out);
  328. return out;
  329. }
  330. void Cell::readFrom (istream &in)
  331. {
  332. string tag;
  333. while (1)
  334. {
  335. char ch;
  336. in >> ch;
  337. if (ch == '>') break;
  338. in >> tag;
  339. if (tag == "dimension")
  340. in >> dimension;
  341. else if (tag == "numberClasses")
  342. in >> numberClasses;
  343. else if (tag == "terminal")
  344. in >> terminal;
  345. else if (tag == "cellID")
  346. in >> cellID;
  347. else if (tag == "threshold")
  348. in >> threshold;
  349. else if (tag == "splitDimension")
  350. in >> splitDimension;
  351. else if (tag == "first")
  352. {
  353. Cell *tmp = new Cell;
  354. in >> *tmp;
  355. first = tmp;
  356. }
  357. else if (tag == "second")
  358. {
  359. Cell *tmp = new Cell;
  360. in >> *tmp;
  361. second = tmp;
  362. } else
  363. throw ParsingException ("unknown argument: " + tag);
  364. if (!in) throw ParsingException ("Parse error trying to build " + tag);
  365. in >> tag;
  366. if (tag != ">") throw ParsingException ("Parse error: '>' expected ");
  367. }
  368. }
  369. istream &operator >> (istream &in, Cell &cell)
  370. {
  371. if (!isValidType(in, "Cell"))
  372. return in;
  373. cell.readFrom(in);
  374. return in;
  375. }