PageRenderTime 465ms CodeModel.GetById 0ms RepoModel.GetById 0ms app.codeStats 0ms

/saliency/src/train.cpp

https://github.com/jgrogers/Door-Sign-Detection
C++ | 275 lines | 252 code | 23 blank | 0 comment | 56 complexity | 5d4c331242b2d8d5ab6065910c2a6090 MD5 | raw file
  1. #include <saliency.h>
  2. #include <boost/program_options.hpp>
  3. #include <boost/filesystem.hpp>
  4. #include <iostream>
  5. #include <Blob.h>
  6. #include <BlobResult.h>
  7. #include <highgui.h>
  8. #include <cv.h>
  9. #include <cvaux.h>
  10. #include <gethog.h>
  11. #include <svm.h>
  12. #include <image.h>
  13. namespace po = boost::program_options;
  14. struct mousedata {
  15. int cntr ;
  16. CvPoint p1;
  17. CvPoint p2;
  18. int num_pts;
  19. bool neg;
  20. } pp ={0,{},{},0,false};
  21. void on_mouse(int event, int x, int y, int flags, void* param) {
  22. switch (event) {
  23. case CV_EVENT_LBUTTONDOWN:
  24. {
  25. pp.p1 = cvPoint(x,y);
  26. pp.p2 = cvPoint(x,y);
  27. pp.num_pts = 1;
  28. pp.neg =false;
  29. break;
  30. }
  31. case CV_EVENT_LBUTTONUP:
  32. {
  33. pp.p2 = cvPoint(x,y);
  34. pp.num_pts = 2;
  35. break;
  36. }
  37. case CV_EVENT_RBUTTONDOWN:
  38. {
  39. pp.p1 = cvPoint(x,y);
  40. pp.p2 = cvPoint(x,y);
  41. pp.num_pts = 1;
  42. pp.neg = true;
  43. break;
  44. }
  45. case CV_EVENT_RBUTTONUP:
  46. {
  47. pp.p2 = cvPoint(x,y);
  48. pp.num_pts = 2;
  49. break;
  50. }
  51. case CV_EVENT_MOUSEMOVE:
  52. {
  53. if (flags & CV_EVENT_FLAG_LBUTTON ||
  54. flags & CV_EVENT_FLAG_RBUTTON) {
  55. pp.p2 = cvPoint(x,y);
  56. }
  57. break;
  58. }
  59. }
  60. }
  61. int main(int argc, char** argv) {
  62. unsigned int uint_opt;
  63. double double_opt;
  64. po::options_description desc("Allowed options");
  65. desc.add_options()
  66. ("help", "produce help message")
  67. ("img", po::value<std::string>(), "Load this file to pull training from")
  68. ("trainfile", po::value<std::string>(), "Train from this training file")
  69. ("load", po::value<std::string>(), "SVM to load, for adding on")
  70. ("save", po::value<std::string>(), "SVM to save")
  71. ("test", po::value<std::string>(), "run test on this directory")
  72. ;
  73. po::variables_map vm;
  74. po::store(po::parse_command_line(argc, argv, desc),vm);
  75. po::notify(vm);
  76. if (vm.count("help")) {
  77. std::cout <<desc<<"\n";
  78. return 1;
  79. }
  80. std::vector< std::vector<float> > pos_descriptors;
  81. std::vector< std::vector<float> > neg_descriptors;
  82. cvNamedWindow("output",1);
  83. cvNamedWindow("Input",1);
  84. cvNamedWindow("NewSign",1);
  85. cvSetMouseCallback("Input", on_mouse);
  86. cv::SVM* mySVM;
  87. if (vm.count("load")) {
  88. mySVM = new cv::SVM;
  89. mySVM->load(vm["load"].as<std::string>().c_str());
  90. }
  91. bool test_mode = false;
  92. if (vm.count("trainfile")) {
  93. char fname [200];
  94. char posneg;
  95. CvRect rect;
  96. FILE* trainfile = fopen (vm["trainfile"].as<std::string>().c_str(), "r");
  97. if (!trainfile) {
  98. printf("Not able to load %s\n",
  99. vm["trainfile"].as<std::string>().c_str());
  100. exit(1);
  101. }
  102. while (!feof(trainfile)) {
  103. int found =
  104. fscanf(trainfile, "%s : %c %u %u %u %u\n",
  105. fname, &posneg, &(rect.x), &(rect.y), &(rect.width), &(rect.height));
  106. if (found != 6) {
  107. printf("Fscanf might not be working\n");
  108. }
  109. IplImage* img_in = cvLoadImage(fname);
  110. IplImage* img_bw = TrainPrepImage(img_in);
  111. CvPoint UL = cvPoint(rect.x,rect.y);
  112. CvSize signsize = cvSize(rect.width,rect.height);
  113. if (!signsize.width ||
  114. !signsize.height) {
  115. printf("too small!\n");
  116. continue;
  117. }
  118. IplImage* newsign = cvCreateImage(signsize,
  119. IPL_DEPTH_8U,
  120. 1);
  121. cvSetImageROI(img_bw, rect);
  122. cvCopy(img_bw, newsign);
  123. cvShowImage("NewSign", newsign);
  124. cvWaitKey(30);
  125. std::vector<float> desc = GetHog(newsign, 4);
  126. if (posneg == 'N' ) neg_descriptors.push_back(desc);
  127. else pos_descriptors.push_back(desc);
  128. cvReleaseImage(&newsign);
  129. cvReleaseImage(&img_in);
  130. cvReleaseImage(&img_bw);
  131. }
  132. if (pos_descriptors.size() && neg_descriptors.size()) {
  133. mySVM = TrainSVM_HOG(pos_descriptors,neg_descriptors);
  134. if (vm.count("save"))
  135. mySVM->save(vm["save"].as<std::string>().c_str(), "mysvm");
  136. else
  137. mySVM->save("testsvm.out", "mysvm");
  138. }
  139. else printf("Need at least one positive and negative examples\n");
  140. }
  141. if (vm.count("img")) {
  142. unsigned int key = -1;
  143. IplImage* img_in = cvLoadImage(vm["img"].as<std::string>().c_str());
  144. IplImage* img_bw = TrainPrepImage(img_in);
  145. cvShowImage("Input", img_in);
  146. do {
  147. IplImage* img = cvCloneImage(img_in);
  148. if (pp.num_pts != 0) {
  149. if (pp.neg) cvRectangle(img, pp.p1, pp.p2, CV_RGB(255,0,0), 1);
  150. else cvRectangle(img, pp.p1, pp.p2, CV_RGB(0,255,0), 1);
  151. }
  152. if (pp.num_pts == 2) {
  153. printf("Got two points!\n");
  154. pp.num_pts = 0;
  155. CvPoint UL = cvPoint(MIN(pp.p1.x, pp.p2.x),
  156. MIN(pp.p1.y, pp.p2.y));
  157. CvSize signsize = cvSize(abs(pp.p2.x-pp.p1.x),
  158. abs(pp.p2.y-pp.p1.y));
  159. if (!signsize.width ||
  160. !signsize.height) {
  161. printf("too small!\n");
  162. continue;
  163. }
  164. IplImage* newsign = cvCreateImage(signsize,
  165. IPL_DEPTH_8U,
  166. 1);
  167. cvSetImageROI(img_bw, cvRect(UL.x, UL.y, signsize.width, signsize.height));
  168. cvCopy(img_bw, newsign);
  169. cvShowImage("NewSign", newsign);
  170. std::vector<float> desc = GetHog(newsign, 4);
  171. if (test_mode) {
  172. float pred = TestSVM_HOG(desc,mySVM);
  173. printf("Pred:%f\n",pred);
  174. }
  175. else {
  176. if (pp.neg ) neg_descriptors.push_back(desc);
  177. else pos_descriptors.push_back(desc);
  178. }
  179. cvReleaseImage(&newsign);
  180. }
  181. cvShowImage("Input", img);
  182. key = cvWaitKey(30);
  183. cvReleaseImage(&img);
  184. if ((char)key == 'd') {
  185. if (pos_descriptors.size() && neg_descriptors.size()) {
  186. mySVM = TrainSVM_HOG(pos_descriptors,neg_descriptors);
  187. mySVM->save("testsvm.out", "mysvm");
  188. test_mode = true;
  189. printf("Moved into test mode\n");
  190. }
  191. else printf("Need at least one positive and negative examples\n");
  192. key = 'a';
  193. }
  194. }while ((char)key != 'q');
  195. cvReleaseImage(&img_in);
  196. cvReleaseImage(&img_bw);
  197. }
  198. if (vm.count("test")){
  199. std::string thepath = vm["test"].as<std::string>();
  200. if (!boost::filesystem::is_directory(thepath)) {
  201. printf("Give me a better path, not %s\n",vm["test"].as<std::string>().c_str());
  202. return 1;
  203. }
  204. unsigned int key = -1;
  205. for (boost::filesystem::directory_iterator itr(thepath);
  206. itr != boost::filesystem::directory_iterator();
  207. ++itr) {
  208. if (!boost::filesystem::is_regular_file(itr->status())){
  209. continue;
  210. }
  211. std::string full_name = thepath + itr->path().filename();
  212. printf ("Loading %s\n",
  213. full_name.c_str());
  214. IplImage* img_in = cvLoadImage(full_name.c_str());
  215. IplImage* img_bw = TrainPrepImage(img_in);
  216. do {
  217. IplImage* img = cvCloneImage(img_in);
  218. if (pp.num_pts != 0) {
  219. if (pp.neg) cvRectangle(img, pp.p1, pp.p2, CV_RGB(255,0,0), 1);
  220. else cvRectangle(img, pp.p1, pp.p2, CV_RGB(0,255,0), 1);
  221. }
  222. if (pp.num_pts == 2) {
  223. pp.num_pts = 0;
  224. CvPoint UL = cvPoint(MIN(pp.p1.x, pp.p2.x),
  225. MIN(pp.p1.y, pp.p2.y));
  226. CvSize signsize = cvSize(abs(pp.p2.x-pp.p1.x),
  227. abs(pp.p2.y-pp.p1.y));
  228. if (!signsize.width ||
  229. !signsize.height) {
  230. printf("too small!\n");
  231. continue;
  232. }
  233. IplImage* newsign = cvCreateImage(signsize,
  234. IPL_DEPTH_8U,
  235. 1);
  236. cvSetImageROI(img_bw, cvRect(UL.x, UL.y, signsize.width, signsize.height));
  237. cvCopy(img_bw, newsign);
  238. cvShowImage("NewSign", newsign);
  239. std::vector<float> desc = GetHog(newsign, 4);
  240. float pred = TestSVM_HOG(desc,mySVM);
  241. printf("Pred:%f\n",pred);
  242. cvReleaseImage(&newsign);
  243. }
  244. cvShowImage("Input", img);
  245. key = cvWaitKey(30);
  246. cvReleaseImage(&img);
  247. }while((char)key != 'd' && (char) key != 'q');
  248. if ((char)key == 'q') break;
  249. }
  250. }
  251. return -1;
  252. }