/src/backends/caffe/caffeinputconns.h

https://github.com/beniz/deepdetect · C Header · 1628 lines · 1399 code · 129 blank · 100 comment · 233 complexity · 44204124c4ee8eec7005360c0f90be08 MD5 · raw file

  1. /**
  2. * DeepDetect
  3. * Copyright (c) 2014-2016 Emmanuel Benazera
  4. * Author: Emmanuel Benazera <beniz@droidnik.fr>
  5. *
  6. * This file is part of deepdetect.
  7. *
  8. * deepdetect is free software: you can redistribute it and/or modify
  9. * it under the terms of the GNU Lesser General Public License as published by
  10. * the Free Software Foundation, either version 3 of the License, or
  11. * (at your option) any later version.
  12. *
  13. * deepdetect is distributed in the hope that it will be useful,
  14. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  15. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  16. * GNU Lesser General Public License for more details.
  17. *
  18. * You should have received a copy of the GNU Lesser General Public License
  19. * along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
  20. */
  21. #ifndef CAFFEINPUTCONNS_H
  22. #define CAFFEINPUTCONNS_H
  23. #include "imginputfileconn.h"
  24. #include "csvinputfileconn.h"
  25. #include "csvtsinputfileconn.h"
  26. #include "txtinputfileconn.h"
  27. #include "svminputfileconn.h"
  28. #pragma GCC diagnostic push
  29. #pragma GCC diagnostic ignored "-Wunused-parameter"
  30. #include "caffe/llogging.h"
  31. #include "caffe/caffe.hpp"
  32. #include "caffe/util/db.hpp"
  33. #pragma GCC diagnostic pop
  34. #include "utils/fileops.hpp"
  35. namespace dd
  36. {
  37. /**
  38. * \brief high-level data structure shared among Caffe-compatible connectors
  39. * of DeepDetect
  40. */
  41. class CaffeInputInterface
  42. {
  43. public:
  44. CaffeInputInterface()
  45. {
  46. }
  47. CaffeInputInterface(const CaffeInputInterface &cii)
  48. : _db(cii._db), _dv(cii._dv), _dv_test(cii._dv_test),
  49. _flat1dconv(cii._flat1dconv), _has_mean_file(cii._has_mean_file),
  50. _mean_values(cii._mean_values), _sparse(cii._sparse),
  51. _embed(cii._embed), _sequence_txt(cii._sequence_txt),
  52. _max_embed_id(cii._max_embed_id), _segmentation(cii._segmentation),
  53. _bbox(cii._bbox), _multi_label(cii._multi_label), _ctc(cii._ctc),
  54. _autoencoder(cii._autoencoder), _alphabet_size(cii._alphabet_size),
  55. _root_folder(cii._root_folder), _dbfullname(cii._dbfullname),
  56. _test_dbfullname(cii._test_dbfullname), _timesteps(cii._timesteps),
  57. _datadim(cii._datadim), _ntargets(cii._ntargets)
  58. {
  59. }
  60. ~CaffeInputInterface()
  61. {
  62. }
  63. /**
  64. * \brief when using db, this provide a batch iterator to db data,
  65. * used in measuring the output of the net
  66. * @param num the size of the data 'batch' to get from the db
  67. * @param has_mean_file flag that tells whether the mean of images in the
  68. * training set is removed from each image.
  69. * @return a vector of Caffe Datum
  70. * @see ImgCaffeInputFileConn
  71. */
  72. std::vector<caffe::Datum> get_dv_test(const int &num,
  73. const bool &has_mean_file)
  74. {
  75. (void)has_mean_file;
  76. return std::vector<caffe::Datum>(num);
  77. }
  78. std::vector<caffe::SparseDatum> get_dv_test_sparse(const int &num)
  79. {
  80. return std::vector<caffe::SparseDatum>(num);
  81. }
  82. void reset_dv_test()
  83. {
  84. }
  85. // write class weights to binary proto
  86. void write_class_weights(const std::string &model_repo,
  87. const APIData &ad_mllib);
  88. bool _db = false; /**< whether to use a db. */
  89. std::vector<caffe::Datum>
  90. _dv; /**< main input datum vector, used for training or prediction */
  91. std::vector<caffe::Datum> _dv_test; /**< test input datum vector, when
  92. applicable in training mode */
  93. std::vector<caffe::SparseDatum> _dv_sparse;
  94. std::vector<caffe::SparseDatum> _dv_test_sparse;
  95. bool _flat1dconv = false; /**< whether a 1D convolution model. */
  96. bool _has_mean_file = false; /**< image model mean.binaryproto. */
  97. std::vector<float>
  98. _mean_values; /**< mean image values across a dataset. */
  99. bool _sparse = false; /**< whether to use sparse representation. */
  100. bool _embed
  101. = false; /**< whether model is using an input embedding layer. */
  102. int _sequence_txt = -1; /**< sequence of txt input connector. */
  103. int _max_embed_id = -1; /**< in embeddings, the max index. */
  104. bool _segmentation = false; /**< whether it is a segmentation service. */
  105. bool _bbox = false; /**< whether it is an object detection service. */
  106. bool _multi_label = false; /**< multi label setup */
  107. bool _ctc = false; /**< whether it is a CTC / OCR service. */
  108. bool _autoencoder = false; /**< whether an autoencoder. */
  109. int _alphabet_size = 0; /**< for sequence to sequence models. */
  110. std::string _root_folder; /**< root folder for image list layer. */
  111. std::unordered_map<std::string, std::pair<int, int>>
  112. _imgs_size; /**< image sizes, used in detection. */
  113. std::string _dbfullname = "train.lmdb";
  114. std::string _test_dbfullname = "test.lmdb";
  115. int _timesteps = -1; // default length for csv timeseries
  116. int _datadim = -1; // default size of vector data for timeseries
  117. int _ntargets = -1; // number of outputs for timeseries
  118. };
  119. /**
  120. * \brief Caffe image connector, supports both files and building of database
  121. * for training
  122. */
  123. class ImgCaffeInputFileConn : public ImgInputFileConn,
  124. public CaffeInputInterface
  125. {
  126. public:
  127. ImgCaffeInputFileConn() : ImgInputFileConn()
  128. {
  129. reset_dv_test();
  130. }
  131. ImgCaffeInputFileConn(const ImgCaffeInputFileConn &i)
  132. : ImgInputFileConn(i), CaffeInputInterface(i)
  133. { /* _db = true;*/
  134. }
  135. ~ImgCaffeInputFileConn()
  136. {
  137. }
  138. // size of each element in Caffe jargon
  139. int channels() const
  140. {
  141. if (_bw)
  142. return 1;
  143. else
  144. return 3; // RGB
  145. }
  146. int height() const
  147. {
  148. return _height;
  149. }
  150. int width() const
  151. {
  152. return _width;
  153. }
  154. int batch_size() const
  155. {
  156. if (_db_batchsize > 0)
  157. return _db_batchsize;
  158. else if (!_dv.empty())
  159. return _dv.size();
  160. else
  161. return ImgInputFileConn::batch_size();
  162. }
  163. int test_batch_size() const
  164. {
  165. if (_db_testbatchsize > 0)
  166. return _db_testbatchsize;
  167. else if (!_dv_test.empty())
  168. return _dv_test.size();
  169. else
  170. return ImgInputFileConn::test_batch_size();
  171. }
  172. void init(const APIData &ad)
  173. {
  174. ImgInputFileConn::init(ad);
  175. if (ad.has("db"))
  176. _db = ad.get("db").get<bool>();
  177. if (ad.has("multi_label"))
  178. _multi_label = ad.get("multi_label").get<bool>();
  179. if (ad.has("root_folder"))
  180. _root_folder = ad.get("root_folder").get<std::string>();
  181. if (ad.has("segmentation"))
  182. _segmentation = ad.get("segmentation").get<bool>();
  183. if (ad.has("bbox"))
  184. _bbox = ad.get("bbox").get<bool>();
  185. if (ad.has("ctc"))
  186. _ctc = ad.get("ctc").get<bool>();
  187. }
  188. void transform(const APIData &ad)
  189. {
  190. // in prediction mode, convert the images to Datum, a Caffe data
  191. // structure
  192. if (!_train)
  193. {
  194. // if no img height x width, we assume 224x224 (works if user is
  195. // lucky, i.e. the best we can do)
  196. if (_width == -1)
  197. _width = 224;
  198. if (_height == -1)
  199. _height = 224;
  200. if (ad.has("has_mean_file"))
  201. _has_mean_file = ad.get("has_mean_file").get<bool>();
  202. APIData ad_input = ad.getobj("parameters").getobj("input");
  203. if (ad_input.has("segmentation"))
  204. _segmentation = ad_input.get("segmentation").get<bool>();
  205. if (ad_input.has("bbox"))
  206. _bbox = ad_input.get("bbox").get<bool>();
  207. if (ad_input.has("multi_label"))
  208. _multi_label = ad_input.get("multi_label").get<bool>();
  209. if (ad.has("root_folder"))
  210. _root_folder = ad.get("root_folder").get<std::string>();
  211. try
  212. {
  213. ImgInputFileConn::transform(ad);
  214. }
  215. catch (InputConnectorBadParamException &e)
  216. {
  217. throw;
  218. }
  219. float *mean = nullptr;
  220. if (_data_mean.count() == 0 && _has_mean_file)
  221. {
  222. std::string meanfullname = _model_repo + "/" + _meanfname;
  223. caffe::BlobProto blob_proto;
  224. caffe::ReadProtoFromBinaryFile(meanfullname.c_str(),
  225. &blob_proto);
  226. _data_mean.FromProto(blob_proto);
  227. mean = _data_mean.mutable_cpu_data();
  228. }
  229. if (!_db_fname.empty())
  230. {
  231. _test_dbfullname = _db_fname;
  232. _db = true;
  233. return; // done
  234. }
  235. else
  236. _db = false;
  237. for (int i = 0; i < (int)this->_images.size(); i++)
  238. {
  239. caffe::Datum datum;
  240. caffe::CVMatToDatum(this->_images.at(i), &datum);
  241. if (!_test_labels.empty())
  242. datum.set_label(_test_labels.at(i));
  243. if (_data_mean.count() != 0)
  244. {
  245. int height = datum.height();
  246. int width = datum.width();
  247. for (int c = 0; c < datum.channels(); ++c)
  248. for (int h = 0; h < height; ++h)
  249. for (int w = 0; w < width; ++w)
  250. {
  251. int data_index = (c * height + h) * width + w;
  252. float datum_element = static_cast<float>(
  253. static_cast<uint8_t>(datum.data()[data_index]));
  254. datum.add_float_data(datum_element
  255. - mean[data_index]);
  256. }
  257. datum.clear_data();
  258. }
  259. else if (_has_mean_scalar)
  260. {
  261. int height = datum.height();
  262. int width = datum.width();
  263. for (int c = 0; c < datum.channels(); ++c)
  264. for (int h = 0; h < height; ++h)
  265. for (int w = 0; w < width; ++w)
  266. {
  267. int data_index = (c * height + h) * width + w;
  268. float datum_element = static_cast<float>(
  269. static_cast<uint8_t>(datum.data()[data_index]));
  270. datum.add_float_data(datum_element - _mean[c]);
  271. }
  272. datum.clear_data();
  273. }
  274. _dv_test.push_back(datum);
  275. _imgs_size.insert(std::pair<std::string, std::pair<int, int>>(
  276. this->_ids.at(i), this->_images_size.at(i)));
  277. }
  278. if (!ad.has("chain"))
  279. {
  280. this->_images.clear();
  281. this->_images_size.clear();
  282. }
  283. }
  284. else
  285. {
  286. _shuffle = true;
  287. int db_height = 0;
  288. int db_width = 0;
  289. APIData ad_mllib;
  290. if (ad.has("parameters")) // hotplug of parameters, overriding the
  291. // defaults
  292. {
  293. APIData ad_param = ad.getobj("parameters");
  294. if (ad_param.has("input"))
  295. {
  296. APIData ad_input = ad_param.getobj("input");
  297. fillup_parameters(ad_param.getobj("input"));
  298. if (ad_input.has("db"))
  299. _db = ad_input.get("db").get<bool>();
  300. if (ad_input.has("segmentation"))
  301. _segmentation = ad_input.get("segmentation").get<bool>();
  302. if (ad_input.has("bbox"))
  303. _bbox = ad_input.get("bbox").get<bool>();
  304. if (ad_input.has("multi_label"))
  305. _multi_label = ad_input.get("multi_label").get<bool>();
  306. if (ad_input.has("root_folder"))
  307. _root_folder
  308. = ad_input.get("root_folder").get<std::string>();
  309. if (ad_input.has("align"))
  310. _align = ad_input.get("align").get<bool>();
  311. if (ad_input.has("db_height"))
  312. db_height = ad_input.get("db_height").get<int>();
  313. if (ad_input.has("db_width"))
  314. db_width = ad_input.get("db_width").get<int>();
  315. if (ad.has("autoencoder"))
  316. _autoencoder = ad.get("autoencoder").get<bool>();
  317. }
  318. ad_mllib = ad_param.getobj("mllib");
  319. }
  320. if (_segmentation)
  321. {
  322. try
  323. {
  324. get_data(ad);
  325. }
  326. catch (InputConnectorBadParamException &ex)
  327. {
  328. throw ex;
  329. }
  330. if (!fileops::file_exists(_uris.at(0)))
  331. throw InputConnectorBadParamException(
  332. "segmentation input train file " + _uris.at(0)
  333. + " not found");
  334. if (_uris.size() > 1)
  335. {
  336. if (!fileops::file_exists(_uris.at(1)))
  337. throw InputConnectorBadParamException(
  338. "segmentation input test file " + _uris.at(1)
  339. + " not found");
  340. }
  341. // class weights if any
  342. write_class_weights(_model_repo, ad_mllib);
  343. // TODO: if test split (+ optional shuffle)
  344. APIData sourcead;
  345. sourcead.add("source_train", _uris.at(0));
  346. if (_uris.size() > 1)
  347. sourcead.add("source_test", _uris.at(1));
  348. const_cast<APIData &>(ad).add("source", sourcead);
  349. }
  350. else if (_bbox)
  351. {
  352. try
  353. {
  354. get_data(ad);
  355. }
  356. catch (InputConnectorBadParamException &ex)
  357. {
  358. throw ex;
  359. }
  360. if (!fileops::file_exists(_uris.at(0)))
  361. throw InputConnectorBadParamException(
  362. "object detection input train file " + _uris.at(0)
  363. + " not found");
  364. if (_uris.size() > 1)
  365. {
  366. if (!fileops::file_exists(_uris.at(1)))
  367. throw InputConnectorBadParamException(
  368. "object detection input test file " + _uris.at(1)
  369. + " not found");
  370. }
  371. // - create lmdbs
  372. _dbfullname = _model_repo + "/" + _dbfullname;
  373. _test_dbfullname = _model_repo + "/" + _test_dbfullname;
  374. objects_to_db(_uris, db_height, db_width, _dbfullname,
  375. _test_dbfullname);
  376. // data object with db files location
  377. APIData dbad;
  378. dbad.add("train_db", _dbfullname);
  379. if (_test_split > 0.0 || _uris.size() > 1)
  380. dbad.add("test_db", _test_dbfullname);
  381. const_cast<APIData &>(ad).add("db", dbad);
  382. }
  383. else if (_ctc)
  384. {
  385. _dbfullname = _model_repo + "/train";
  386. _test_dbfullname = _model_repo + "/test.h5";
  387. try
  388. {
  389. get_data(ad);
  390. }
  391. catch (InputConnectorBadParamException
  392. &ex) // in case the db is in the net config
  393. {
  394. throw ex;
  395. }
  396. // read images list and create dbs
  397. #ifdef USE_HDF5
  398. images_to_hdf5(_uris, _dbfullname, _test_dbfullname);
  399. #endif // USE_HDF5
  400. // enrich data object with db files location
  401. APIData dbad;
  402. dbad.add("train_db", _model_repo + "/training.txt");
  403. if (_uris.size() > 1 || _test_split > 0.0)
  404. dbad.add("test_db", _model_repo + "/testing.txt");
  405. const_cast<APIData &>(ad).add("db", dbad);
  406. }
  407. else // more complicated, since images can be heavy, a db is built so
  408. // that it is less costly to iterate than the filesystem, unless
  409. // image data layer is used (e.g. multi-class image training)
  410. {
  411. _dbfullname = _model_repo + "/" + _dbfullname;
  412. _test_dbfullname = _model_repo + "/" + _test_dbfullname;
  413. try
  414. {
  415. get_data(ad);
  416. }
  417. catch (InputConnectorBadParamException
  418. &ex) // in case the db is in the net config
  419. {
  420. // API defines no data as a user error (bad param).
  421. // However, Caffe does allow to specify the input database
  422. // into the net's definition, which makes it difficult to
  423. // enforce the API here. So for now, this check is kept
  424. // disabled.
  425. /*if (!fileops::file_exists(_model_repo + "/" + _dbname))
  426. throw ex;*/
  427. return;
  428. }
  429. if (!this->_db)
  430. {
  431. // create test db for image data layer (no db of images)
  432. create_test_db_for_imagedatalayer(
  433. _uris.at(1), _model_repo + "/" + _test_dbname);
  434. return;
  435. }
  436. // create db
  437. // Check if the indicated uri is a folder
  438. bool dir_images = true;
  439. fileops::file_exists(_uris.at(0), dir_images);
  440. if (!this->_unchanged_data)
  441. images_to_db(_uris, _model_repo + "/" + _dbname,
  442. _model_repo + "/" + _test_dbname, dir_images);
  443. else
  444. images_to_db(_uris, _model_repo + "/" + _dbname,
  445. _model_repo + "/" + _test_dbname, dir_images,
  446. "lmdb", false, "");
  447. // compute mean of images, not forcely used, depends on net, see
  448. // has_mean_file
  449. if (!this->_unchanged_data)
  450. compute_images_mean(_model_repo + "/" + _dbname,
  451. _model_repo + "/" + _meanfname);
  452. // class weights if any
  453. write_class_weights(_model_repo, ad_mllib);
  454. // enrich data object with db files location
  455. APIData dbad;
  456. dbad.add("train_db", _dbfullname);
  457. if (_test_split > 0.0)
  458. dbad.add("test_db", _test_dbfullname);
  459. dbad.add("meanfile", _model_repo + "/" + _meanfname);
  460. const_cast<APIData &>(ad).add("db", dbad);
  461. }
  462. }
  463. }
  464. std::vector<caffe::Datum> get_dv_test(const int &num,
  465. const bool &has_mean_file)
  466. {
  467. if (_segmentation && _train)
  468. {
  469. return get_dv_test_segmentation(num, has_mean_file);
  470. }
  471. else if (!_train && _db_fname.empty())
  472. {
  473. int i = 0;
  474. std::vector<caffe::Datum> dv;
  475. while (_dt_vit != _dv_test.end() && i < num)
  476. {
  477. dv.push_back((*_dt_vit));
  478. ++i;
  479. ++_dt_vit;
  480. }
  481. return dv;
  482. }
  483. else
  484. return get_dv_test_db(num, has_mean_file);
  485. }
  486. std::vector<caffe::Datum> get_dv_test_db(const int &num,
  487. const bool &has_mean_file);
  488. std::vector<caffe::Datum>
  489. get_dv_test_segmentation(const int &num, const bool &has_mean_file);
  490. void reset_dv_test();
  491. private:
  492. void create_test_db_for_imagedatalayer(
  493. const std::string &test_lst, const std::string &testdbname,
  494. const std::string &backend = "lmdb", // lmdb, leveldb
  495. const bool &encoded = true, // save the encoded image in datum
  496. const std::string &encode_type = ""); // 'png', 'jpg', ...
  497. int images_to_db(const std::vector<std::string> &rpaths,
  498. const std::string &traindbname,
  499. const std::string &testdbname, const bool &folders = true,
  500. const std::string &backend = "lmdb", // lmdb, leveldb
  501. const bool &encoded
  502. = true, // save the encoded image in datum
  503. const std::string &encode_type = ""); // 'png', 'jpg', ...
  504. void
  505. write_image_to_db(const std::string &dbfullname,
  506. const std::vector<std::pair<std::string, int>> &lfiles,
  507. const std::string &backend, const bool &encoded,
  508. const std::string &encode_type);
  509. void write_image_to_db_multilabel(
  510. const std::string &dbfullname,
  511. const std::vector<std::pair<std::string, std::vector<float>>> &lfiles,
  512. const std::string &backend, const bool &encoded,
  513. const std::string &encode_type);
  514. #ifdef USE_HDF5
  515. void images_to_hdf5(const std::vector<std::string> &img_lists,
  516. const std::string &traindbname,
  517. const std::string &testdbname);
  518. void write_images_to_hdf5(const std::string &inputfilename,
  519. const std::string &dbfullbame,
  520. const std::string &dblistfilename,
  521. std::unordered_map<uint32_t, int> &alphabet,
  522. int &max_ocr_length, const bool &train_db);
  523. #endif // USE_HDF5
  524. int objects_to_db(const std::vector<std::string> &rfolders,
  525. const int &db_height, const int &db_width,
  526. const std::string &traindbname,
  527. const std::string &testdbname,
  528. const bool &encoded = true,
  529. const std::string &encode_type = "",
  530. const std::string &backend = "lmdb");
  531. void write_objects_to_db(
  532. const std::string &dbfullname, const int &db_height,
  533. const int &db_width,
  534. const std::vector<std::pair<std::string, std::string>> &lines,
  535. const bool &encoded, const std::string &encode_type,
  536. const std::string &backend, const bool &train);
  537. int compute_images_mean(const std::string &dbname,
  538. const std::string &meanfile,
  539. const std::string &backend = "lmdb");
  540. std::string guess_encoding(const std::string &file);
  541. public:
  542. int _db_batchsize = -1;
  543. int _db_testbatchsize = -1;
  544. std::unique_ptr<caffe::db::DB> _test_db;
  545. std::unique_ptr<caffe::db::Cursor> _test_db_cursor;
  546. std::string _dbname = "train";
  547. std::string _test_dbname = "test";
  548. std::string _meanfname = "mean.binaryproto";
  549. std::string _correspname = "corresp.txt";
  550. caffe::Blob<float> _data_mean; // mean binary image if available.
  551. std::vector<caffe::Datum>::const_iterator _dt_vit;
  552. std::vector<std::pair<std::string, std::string>> _segmentation_data_lines;
  553. int _dt_seg = 0;
  554. bool _align = false;
  555. };
  556. /**
  557. * \brief Caffe CSV connector
  558. * \note use 'label_offset' in API to make sure that labels start at 0
  559. */
  560. class CSVCaffeInputFileConn;
  561. class DDCCsv
  562. {
  563. public:
  564. DDCCsv()
  565. {
  566. }
  567. ~DDCCsv()
  568. {
  569. }
  570. int read_file(const std::string &fname);
  571. int read_db(const std::string &fname);
  572. int read_mem(const std::string &content);
  573. int read_dir(const std::string &dir)
  574. {
  575. throw InputConnectorBadParamException(
  576. "uri " + dir + " is a directory, requires a CSV file");
  577. }
  578. CSVCaffeInputFileConn *_cifc = nullptr;
  579. APIData _adconf;
  580. std::shared_ptr<spdlog::logger> _logger;
  581. };
  582. class CSVCaffeInputFileConn : public CSVInputFileConn,
  583. public CaffeInputInterface
  584. {
  585. public:
  586. CSVCaffeInputFileConn() : CSVInputFileConn()
  587. {
  588. reset_dv_test();
  589. }
  590. CSVCaffeInputFileConn(const CSVCaffeInputFileConn &i)
  591. : CSVInputFileConn(i), CaffeInputInterface(i)
  592. {
  593. }
  594. ~CSVCaffeInputFileConn()
  595. {
  596. }
  597. void init(const APIData &ad)
  598. {
  599. CSVInputFileConn::init(ad);
  600. }
  601. // size of each element in Caffe jargon
  602. int channels() const
  603. {
  604. if (_channels > 0)
  605. return _channels;
  606. return feature_size();
  607. }
  608. int height() const
  609. {
  610. return 1;
  611. }
  612. int width() const
  613. {
  614. return 1;
  615. }
  616. int batch_size() const
  617. {
  618. if (_db_batchsize > 0)
  619. return _db_batchsize;
  620. else
  621. return _dv.size();
  622. }
  623. int test_batch_size() const
  624. {
  625. if (_db_testbatchsize > 0)
  626. return _db_testbatchsize;
  627. else
  628. return _dv_test.size();
  629. }
  630. virtual void add_train_csvline(const std::string &id,
  631. std::vector<double> &vals);
  632. virtual void add_test_csvline(const std::string &id,
  633. std::vector<double> &vals);
  634. void transform(const APIData &ad)
  635. {
  636. APIData ad_param = ad.getobj("parameters");
  637. APIData ad_input = ad_param.getobj("input");
  638. APIData ad_mllib = ad_param.getobj("mllib");
  639. if (_train && ad_input.has("db") && ad_input.get("db").get<bool>())
  640. {
  641. _dbfullname = _model_repo + "/" + _dbfullname;
  642. _test_dbfullname = _model_repo + "/" + _test_dbfullname;
  643. fillup_parameters(ad_input);
  644. get_data(ad);
  645. _db = true;
  646. csv_to_db(_model_repo + "/" + _dbname,
  647. _model_repo + "/" + _test_dbname, ad_input);
  648. write_class_weights(_model_repo, ad_mllib);
  649. // enrich data object with db files location
  650. APIData dbad;
  651. dbad.add("train_db", _dbfullname);
  652. if (_test_split > 0.0)
  653. dbad.add("test_db", _test_dbfullname);
  654. const_cast<APIData &>(ad).add("db", dbad);
  655. }
  656. else
  657. {
  658. try
  659. {
  660. CSVInputFileConn::transform(ad);
  661. }
  662. catch (std::exception &e)
  663. {
  664. throw;
  665. }
  666. // transform to datum by filling up float_data
  667. if (_train)
  668. {
  669. auto hit = _csvdata.begin();
  670. while (hit != _csvdata.end())
  671. {
  672. if (_label.size() == 1)
  673. _dv.push_back(to_datum((*hit)._v));
  674. else // multi labels or autoencoder
  675. {
  676. caffe::Datum dat = to_datum((*hit)._v, true);
  677. for (size_t i = 0; i < _label_pos.size();
  678. i++) // concat labels and slice them out in the
  679. // network itself
  680. {
  681. dat.add_float_data(
  682. static_cast<float>((*hit)._v.at(_label_pos[i])));
  683. }
  684. dat.set_channels(dat.channels() + _label.size());
  685. _dv.push_back(dat);
  686. }
  687. this->_ids.push_back((*hit)._str);
  688. ++hit;
  689. }
  690. }
  691. if (!_train)
  692. {
  693. if (!_db_fname.empty())
  694. {
  695. _test_dbfullname = _db_fname;
  696. _db = true;
  697. return; // done
  698. }
  699. _csvdata_test = std::move(_csvdata);
  700. }
  701. else
  702. _csvdata.clear();
  703. auto hit = _csvdata_test.begin();
  704. while (hit != _csvdata_test.end())
  705. {
  706. // no ids taken on the test set
  707. if (_label.size() == 1)
  708. _dv_test.push_back(to_datum((*hit)._v));
  709. else
  710. {
  711. caffe::Datum dat = to_datum((*hit)._v, true);
  712. for (size_t i = 0; i < _label_pos.size(); i++)
  713. {
  714. dat.add_float_data(
  715. static_cast<float>((*hit)._v.at(_label_pos[i])));
  716. }
  717. dat.set_channels(dat.channels() + _label.size());
  718. _dv_test.push_back(dat);
  719. }
  720. if (!_train)
  721. this->_ids.push_back((*hit)._str);
  722. ++hit;
  723. }
  724. _csvdata_test.clear();
  725. }
  726. _csvdata_test.clear();
  727. }
  728. std::vector<caffe::Datum> get_dv_test(const int &num,
  729. const bool &has_mean_file)
  730. {
  731. (void)has_mean_file;
  732. if (!_db)
  733. {
  734. int i = 0;
  735. std::vector<caffe::Datum> dv;
  736. while (_dt_vit != _dv_test.end() && i < num)
  737. {
  738. dv.push_back((*_dt_vit));
  739. ++i;
  740. ++_dt_vit;
  741. }
  742. return dv;
  743. }
  744. else
  745. return get_dv_test_db(num);
  746. }
  747. std::vector<caffe::Datum> get_dv_test_db(const int &num);
  748. void reset_dv_test();
  749. /**
  750. * \brief turns a vector of values into a Caffe Datum structure
  751. * @param vector of values
  752. * @return datum
  753. */
  754. caffe::Datum to_datum(const std::vector<double> &vf,
  755. const bool &multi_label = false)
  756. {
  757. caffe::Datum datum;
  758. int datum_channels = vf.size();
  759. if (!_label.empty())
  760. datum_channels -= _label.size();
  761. if (!_id.empty())
  762. datum_channels--;
  763. datum.set_channels(datum_channels);
  764. datum.set_height(1);
  765. datum.set_width(1);
  766. auto lit = _columns.begin();
  767. for (int i = 0; i < (int)vf.size(); i++)
  768. {
  769. if (!multi_label && !this->_label.empty() && i == _label_pos[0])
  770. {
  771. datum.set_label(
  772. static_cast<float>(vf.at(i) + this->_label_offset[0]));
  773. }
  774. else if (i == _id_pos)
  775. {
  776. ++lit;
  777. continue;
  778. }
  779. else if (std::find(_label_pos.begin(), _label_pos.end(), i)
  780. == _label_pos.end()) // XXX: could do a faster lookup
  781. {
  782. datum.add_float_data(static_cast<float>(vf.at(i)));
  783. }
  784. ++lit;
  785. }
  786. return datum;
  787. }
  788. private:
  789. int csv_to_db(const std::string &traindbname,
  790. const std::string &testdbname, const APIData &ad_input,
  791. const std::string &backend = "lmdb"); // lmdb, leveldb
  792. void write_csvline_to_db(const std::string &dbfullname,
  793. const std::string &testdbfullname,
  794. const APIData &ad_input,
  795. const std::string &backend = "lmdb");
  796. public:
  797. std::vector<caffe::Datum>::const_iterator _dt_vit;
  798. int _db_batchsize = -1;
  799. int _db_testbatchsize = -1;
  800. std::unique_ptr<caffe::db::DB> _test_db;
  801. std::unique_ptr<caffe::db::Cursor> _test_db_cursor;
  802. std::string _dbname = "train";
  803. std::string _test_dbname = "test";
  804. std::string _correspname = "corresp.txt";
  805. private:
  806. std::unique_ptr<caffe::db::Transaction> _txn;
  807. std::unique_ptr<caffe::db::DB> _tdb;
  808. std::unique_ptr<caffe::db::Transaction> _ttxn;
  809. std::unique_ptr<caffe::db::DB> _ttdb;
  810. int _channels = 0;
  811. };
  812. /**
  813. * \brief caffe csv timeseries connector
  814. */
  815. class CSVTSCaffeInputFileConn;
  816. class DDCCsvTS
  817. {
  818. public:
  819. DDCCsvTS()
  820. {
  821. }
  822. ~DDCCsvTS()
  823. {
  824. }
  825. int read_file(const std::string &fname, bool is_test_data = false);
  826. int read_db(const std::string &fname);
  827. int read_mem(const std::string &content);
  828. int read_dir(const std::string &dir);
  829. DDCsvTS _ddcsvts;
  830. CSVTSCaffeInputFileConn *_cifc = nullptr;
  831. APIData _adconf;
  832. std::shared_ptr<spdlog::logger> _logger;
  833. };
  834. class CSVTSCaffeInputFileConn : public CSVTSInputFileConn,
  835. public CaffeInputInterface
  836. {
  837. public:
  838. CSVTSCaffeInputFileConn()
  839. : CSVTSInputFileConn(), _dv_index(-1), _dv_test_index(-1),
  840. _continuation(false), _offset(100)
  841. {
  842. reset_dv_test();
  843. }
  844. CSVTSCaffeInputFileConn(const CSVTSCaffeInputFileConn &i)
  845. : CSVTSInputFileConn(i), CaffeInputInterface(i),
  846. _dv_index(i._dv_index), _dv_test_index(i._dv_test_index),
  847. _continuation(i._continuation), _offset(i._offset)
  848. {
  849. this->_datadim = i._datadim;
  850. }
  851. ~CSVTSCaffeInputFileConn()
  852. {
  853. }
  854. void init(const APIData &ad)
  855. {
  856. fillup_parameters(ad);
  857. }
  858. void fillup_parameters(const APIData &ad_input)
  859. {
  860. CSVTSInputFileConn::fillup_parameters(ad_input);
  861. _ntargets = _label.size();
  862. _offset = _timesteps;
  863. if (ad_input.has("timesteps"))
  864. {
  865. _timesteps = ad_input.get("timesteps").get<int>();
  866. _offset = _timesteps;
  867. }
  868. if (ad_input.has("continuation"))
  869. _continuation = ad_input.get("continuation").get<bool>();
  870. if (ad_input.has("offset"))
  871. _offset = ad_input.get("offset").get<int>();
  872. }
  873. // size of each element in Caffe jargon
  874. int channels() const
  875. {
  876. return _timesteps;
  877. }
  878. int height() const
  879. {
  880. return _datadim;
  881. }
  882. int width() const
  883. {
  884. return 1;
  885. }
  886. int batch_size() const
  887. {
  888. if (_db_batchsize > 0)
  889. return _db_batchsize;
  890. else if (_dv.size() != 0)
  891. return _dv.size();
  892. else
  893. return 1;
  894. }
  895. int test_batch_size() const
  896. {
  897. if (_db_testbatchsize > 0)
  898. return _db_testbatchsize;
  899. else if (_dv_test.size() != 0)
  900. return _dv_test.size();
  901. else
  902. return 1;
  903. }
  904. void push_csv_to_csvts(bool is_test_data = false);
  905. void set_datadim(bool is_test_data = false);
  906. void transform(
  907. const APIData &ad); // calls CSVTSInputfileconn::transform and db stuff
  908. void reset_dv_test();
  909. std::vector<caffe::Datum> get_dv_test(const int &num,
  910. const bool &has_mean_file);
  911. std::vector<caffe::Datum> get_dv_test_db(const int &num);
  912. int csvts_to_db(const std::string &traindbname,
  913. const std::string &testdbname, const APIData &ad_input,
  914. const std::string &backend = "lmdb"); // lmdb, leveldb
  915. void csvts_to_dv(bool is_test_data = false, bool clear_dv_first = false,
  916. bool clear_csvts_after = false, bool split_seqs = true,
  917. bool first_is_cont = false);
  918. void dv_to_db(bool is_test_data = false);
  919. void write_csvts_to_db(const std::string &dbfullname,
  920. const std::string &testdbfullname,
  921. const APIData &ad_input,
  922. const std::string &backend);
  923. std::vector<caffe::Datum>::const_iterator _dt_vit;
  924. int _dv_index;
  925. int _dv_test_index;
  926. int _db_batchsize = -1;
  927. int _db_testbatchsize = -1;
  928. std::unique_ptr<caffe::db::DB> _test_db;
  929. std::unique_ptr<caffe::db::Cursor> _test_db_cursor;
  930. std::string _dbname = "train";
  931. std::string _test_dbname = "test";
  932. std::string _correspname = "corresp.txt";
  933. bool _continuation;
  934. int _offset;
  935. private:
  936. std::unique_ptr<caffe::db::Transaction> _txn;
  937. std::unique_ptr<caffe::db::DB> _tdb;
  938. std::unique_ptr<caffe::db::Transaction> _ttxn;
  939. std::unique_ptr<caffe::db::DB> _ttdb;
  940. int _channels = 0;
  941. };
  942. /**
  943. * \brief Caffe text connector
  944. */
  945. class TxtCaffeInputFileConn : public TxtInputFileConn,
  946. public CaffeInputInterface
  947. {
  948. public:
  949. TxtCaffeInputFileConn() : TxtInputFileConn()
  950. {
  951. reset_dv_test();
  952. }
  953. TxtCaffeInputFileConn(const TxtCaffeInputFileConn &i)
  954. : TxtInputFileConn(i), CaffeInputInterface(i)
  955. {
  956. }
  957. ~TxtCaffeInputFileConn()
  958. {
  959. }
  960. void init(const APIData &ad)
  961. {
  962. TxtInputFileConn::init(ad);
  963. if (_characters)
  964. _flat1dconv = true;
  965. if (ad.has("sparse") && ad.get("sparse").get<bool>())
  966. _sparse = true;
  967. if (ad.has("embedding") && ad.get("embedding").get<bool>())
  968. _embed = true;
  969. _sequence_txt = _sequence;
  970. _max_embed_id = _alphabet.size() + 1; // +1 as offset to null index
  971. }
  972. int channels() const
  973. {
  974. if (_characters)
  975. return 1;
  976. if (_embed)
  977. {
  978. if (!_characters)
  979. return _sequence;
  980. else
  981. return 1;
  982. }
  983. if (_channels > 0)
  984. return _channels;
  985. return feature_size();
  986. }
  987. int height() const
  988. {
  989. if (_characters)
  990. return _sequence;
  991. else
  992. return 1;
  993. }
  994. int width() const
  995. {
  996. if (_characters && !_embed)
  997. return _alphabet.size();
  998. return 1;
  999. }
  1000. int batch_size() const
  1001. {
  1002. if (_db_batchsize > 0)
  1003. return _db_batchsize;
  1004. else if (!_sparse)
  1005. return _dv.size();
  1006. else
  1007. return _dv_sparse.size();
  1008. }
  1009. int test_batch_size() const
  1010. {
  1011. if (_db_testbatchsize > 0)
  1012. return _db_testbatchsize;
  1013. else if (!_sparse)
  1014. return _dv_test.size();
  1015. else
  1016. return _dv_test_sparse.size();
  1017. }
  1018. int txt_to_db(const std::string &traindbname,
  1019. const std::string &testdbname,
  1020. const std::string &backend = "lmdb");
  1021. void write_txt_to_db(const std::string &dbname,
  1022. std::vector<TxtEntry<double> *> &txt,
  1023. const std::string &backend = "lmdb");
  1024. void write_sparse_txt_to_db(const std::string &dbname,
  1025. std::vector<TxtEntry<double> *> &txt,
  1026. const std::string &backend = "lmdb");
  1027. void transform(const APIData &ad)
  1028. {
  1029. APIData ad_param = ad.getobj("parameters");
  1030. APIData ad_input = ad_param.getobj("input");
  1031. APIData ad_mllib = ad_param.getobj("mllib");
  1032. if (ad_input.has("db") && ad_input.get("db").get<bool>())
  1033. _db = true;
  1034. if (ad_input.has("embedding") && ad_input.get("embedding").get<bool>())
  1035. {
  1036. _embed = true;
  1037. }
  1038. // transform to one-hot vector datum
  1039. if (_train && _db)
  1040. {
  1041. _dbfullname = _model_repo + "/" + _dbfullname;
  1042. _test_dbfullname = _model_repo + "/" + _test_dbfullname;
  1043. // std::string dbfullname = _model_repo + "/" + _dbname + ".lmdb";
  1044. if (!fileops::file_exists(
  1045. _dbfullname)) // if no existing db, preprocess from txt files
  1046. TxtInputFileConn::transform(ad);
  1047. txt_to_db(_model_repo + "/" + _dbname,
  1048. _model_repo + "/" + _test_dbname);
  1049. write_class_weights(_model_repo, ad_mllib);
  1050. // enrich data object with db files location
  1051. APIData dbad;
  1052. dbad.add("train_db", _dbfullname);
  1053. if (_test_split > 0.0)
  1054. dbad.add("test_db", _test_dbfullname);
  1055. const_cast<APIData &>(ad).add("db", dbad);
  1056. }
  1057. else
  1058. {
  1059. TxtInputFileConn::transform(ad);
  1060. if (_train)
  1061. {
  1062. auto hit = _txt.begin();
  1063. while (hit != _txt.end())
  1064. {
  1065. if (!_sparse)
  1066. {
  1067. if (_characters)
  1068. _dv.push_back(std::move(to_datum<TxtCharEntry>(
  1069. static_cast<TxtCharEntry *>((*hit)))));
  1070. else
  1071. _dv.push_back(std::move(to_datum<TxtBowEntry>(
  1072. static_cast<TxtBowEntry *>((*hit)))));
  1073. }
  1074. else
  1075. {
  1076. if (_characters)
  1077. {
  1078. // TODO
  1079. }
  1080. else
  1081. _dv_sparse.push_back(std::move(to_sparse_datum(
  1082. static_cast<TxtBowEntry *>((*hit)))));
  1083. }
  1084. this->_ids.push_back((*hit)->_uri);
  1085. ++hit;
  1086. }
  1087. }
  1088. if (!_train)
  1089. {
  1090. if (!_db_fname.empty())
  1091. {
  1092. _test_dbfullname = _db_fname;
  1093. _db = true;
  1094. return; // done
  1095. }
  1096. _test_txt = std::move(_txt);
  1097. }
  1098. int n = 0;
  1099. auto hit = _test_txt.begin();
  1100. while (hit != _test_txt.end())
  1101. {
  1102. if (!_sparse)
  1103. {
  1104. if (_characters)
  1105. _dv_test.push_back(std::move(to_datum<TxtCharEntry>(
  1106. static_cast<TxtCharEntry *>((*hit)))));
  1107. else
  1108. _dv_test.push_back(std::move(to_datum<TxtBowEntry>(
  1109. static_cast<TxtBowEntry *>((*hit)))));
  1110. }
  1111. else
  1112. {
  1113. if (_characters)
  1114. {
  1115. // TODO
  1116. }
  1117. else
  1118. _dv_test_sparse.push_back(std::move(
  1119. to_sparse_datum(static_cast<TxtBowEntry *>((*hit)))));
  1120. }
  1121. if (!_train)
  1122. this->_ids.push_back(std::to_string(n));
  1123. ++hit;
  1124. ++n;
  1125. }
  1126. }
  1127. }
  1128. std::vector<caffe::Datum> get_dv_test_db(const int &num);
  1129. std::vector<caffe::SparseDatum> get_dv_test_sparse_db(const int &num);
  1130. std::vector<caffe::Datum> get_dv_test(const int &num,
  1131. const bool &has_mean_file)
  1132. {
  1133. (void)has_mean_file;
  1134. if (!_db)
  1135. {
  1136. int i = 0;
  1137. std::vector<caffe::Datum> dv;
  1138. while (_dt_vit != _dv_test.end() && i < num)
  1139. {
  1140. dv.push_back((*_dt_vit));
  1141. ++i;
  1142. ++_dt_vit;
  1143. }
  1144. return dv;
  1145. }
  1146. else
  1147. return get_dv_test_db(num);
  1148. }
  1149. std::vector<caffe::SparseDatum> get_dv_test_sparse(const int &num)
  1150. {
  1151. if (!_db)
  1152. {
  1153. int i = 0;
  1154. std::vector<caffe::SparseDatum> dv;
  1155. while (_dt_vit_sparse != _dv_test_sparse.end() && i < num)
  1156. {
  1157. dv.push_back((*_dt_vit_sparse));
  1158. ++i;
  1159. ++_dt_vit_sparse;
  1160. }
  1161. return dv;
  1162. }
  1163. else
  1164. return get_dv_test_sparse_db(num);
  1165. }
  1166. void reset_dv_test()
  1167. {
  1168. if (!_sparse)
  1169. _dt_vit = _dv_test.begin();
  1170. else
  1171. _dt_vit_sparse = _dv_test_sparse.begin();
  1172. _test_db_cursor = std::unique_ptr<caffe::db::Cursor>();
  1173. _test_db = std::unique_ptr<caffe::db::DB>();
  1174. }
  1175. template <class TEntry> caffe::Datum to_datum(TEntry *tbe)
  1176. {
  1177. caffe::Datum datum;
  1178. int datum_channels;
  1179. if (_characters)
  1180. datum_channels = 1;
  1181. else if (_embed && !_characters)
  1182. datum_channels = _sequence;
  1183. else
  1184. datum_channels = _vocab.size(); // XXX: may be very large
  1185. datum.set_channels(datum_channels);
  1186. datum.set_height(1);
  1187. datum.set_width(1);
  1188. datum.set_label(tbe->_target);
  1189. if (!_characters)
  1190. {
  1191. std::unordered_map<std::string, Word>::const_iterator wit;
  1192. if (!_embed)
  1193. {
  1194. for (int i = 0; i < datum_channels;
  1195. i++) // XXX: expected to be slow
  1196. datum.add_float_data(0.0);
  1197. tbe->reset();
  1198. while (tbe->has_elt())
  1199. {
  1200. std::string key;
  1201. double val;
  1202. tbe->get_next_elt(key, val);
  1203. if ((wit = _vocab.find(key)) != _vocab.end())
  1204. datum.set_float_data(_vocab[key]._pos,
  1205. static_cast<float>(val));
  1206. }
  1207. }
  1208. else
  1209. {
  1210. tbe->reset();
  1211. int i = 0;
  1212. while (tbe->has_elt())
  1213. {
  1214. std::string key;
  1215. double val;
  1216. tbe->get_next_elt(key, val);
  1217. if ((wit = _vocab.find(key)) != _vocab.end())
  1218. datum.add_float_data(static_cast<float>(_vocab[key]._pos));
  1219. ++i;
  1220. if (i == _sequence) // tmp limit on sequence length
  1221. break;
  1222. }
  1223. while (datum.float_data_size() < _sequence)
  1224. datum.add_float_data(0.0);
  1225. }
  1226. }
  1227. else // character-level features
  1228. {
  1229. tbe->reset();
  1230. std::vector<int> vals;
  1231. std::unordered_map<uint32_t, int>::const_iterator whit;
  1232. while (tbe->has_elt())
  1233. {
  1234. std::string key;
  1235. double val = -1.0;
  1236. tbe->get_next_elt(key, val);
  1237. uint32_t c = std::strtoul(key.c_str(), 0, 10);
  1238. if ((whit = _alphabet.find(c)) != _alphabet.end())
  1239. vals.push_back((*whit).second);
  1240. else
  1241. vals.push_back(-1);
  1242. }
  1243. /*if (vals.size() > _sequence)
  1244. std::cerr << "more characters than sequence / " << vals.size() << "
  1245. / sequence=" << _sequence << std::endl;*/
  1246. if (!_embed)
  1247. {
  1248. for (int c = 0; c < _sequence; c++)
  1249. {
  1250. std::vector<float> v(_alphabet.size(), 0.0);
  1251. if (c < (int)vals.size() && vals[c] != -1)
  1252. v[vals[c]] = 1.0;
  1253. for (float f : v)
  1254. datum.add_float_data(f);
  1255. }
  1256. datum.set_height(_sequence);
  1257. datum.set_width(_alphabet.size());
  1258. }
  1259. else
  1260. {
  1261. for (int c = 0; c < _sequence; c++)
  1262. {
  1263. double val = 0.0;
  1264. if (c < (int)vals.size() && vals[c] != -1)
  1265. val = static_cast<float>(
  1266. vals[c] + 1.0); // +1 as offset to null index
  1267. datum.add_float_data(val);
  1268. }
  1269. datum.set_height(_sequence);
  1270. datum.set_width(1);
  1271. }
  1272. }
  1273. return datum;
  1274. }
  1275. caffe::SparseDatum to_sparse_datum(TxtBowEntry *tbe)
  1276. {
  1277. caffe::SparseDatum datum;
  1278. datum.set_label(tbe->_target);
  1279. std::unordered_map<std::string, Word>::const_iterator wit;
  1280. tbe->reset();
  1281. int nwords = 0;
  1282. while (tbe->has_elt())
  1283. {
  1284. std::string key;
  1285. double val;
  1286. tbe->get_next_elt(key, val);
  1287. if ((wit = _vocab.find(key)) != _vocab.end())
  1288. {
  1289. int word_pos = _vocab[key]._pos;
  1290. datum.add_data(static_cast<float>(val));
  1291. datum.add_indices(word_pos);
  1292. ++nwords;
  1293. }
  1294. }
  1295. datum.set_nnz(nwords);
  1296. datum.set_size(_vocab.size());
  1297. return datum;
  1298. }
  1299. std::vector<caffe::Datum>::const_iterator _dt_vit;
  1300. std::vector<caffe::SparseDatum>::const_iterator _dt_vit_sparse;
  1301. public:
  1302. int _db_batchsize = -1;
  1303. int _db_testbatchsize = -1;
  1304. std::unique_ptr<caffe::db::DB> _test_db;
  1305. std::unique_ptr<caffe::db::Cursor> _test_db_cursor;
  1306. std::string _dbname = "train";
  1307. std::string _test_dbname = "test";
  1308. int _channels = 0;
  1309. };
  1310. /**
  1311. * \brief Caffe SVM connector
  1312. */
  1313. class SVMCaffeInputFileConn : public SVMInputFileConn,
  1314. public CaffeInputInterface
  1315. {
  1316. public:
  1317. SVMCaffeInputFileConn() : SVMInputFileConn()
  1318. {
  1319. _sparse = true;
  1320. reset_dv_test();
  1321. }
  1322. SVMCaffeInputFileConn(const SVMCaffeInputFileConn &i)
  1323. : SVMInputFileConn(i), CaffeInputInterface(i)
  1324. {
  1325. }
  1326. ~SVMCaffeInputFileConn()
  1327. {
  1328. }
  1329. void init(const APIData &ad)
  1330. {
  1331. SVMInputFileConn::init(ad);
  1332. }
  1333. int channels() const
  1334. {
  1335. if (_channels > 0)
  1336. return _channels;
  1337. else
  1338. return feature_size();
  1339. }
  1340. int height() const
  1341. {
  1342. return 1;
  1343. }
  1344. int width() const
  1345. {
  1346. return 1;
  1347. }
  1348. int batch_size() const
  1349. {
  1350. if (_db_batchsize > 0)
  1351. return _db_batchsize;
  1352. else
  1353. return _dv_sparse.size();
  1354. }
  1355. int test_batch_size() const
  1356. {
  1357. if (_db_testbatchsize > 0)
  1358. return _db_testbatchsize;
  1359. else
  1360. return _dv_test_sparse.size();
  1361. }
  1362. virtual void add_train_svmline(const int &label,
  1363. const std::unordered_map<int, double> &vals,
  1364. const int &count);
  1365. virtual void add_test_svmline(const int &label,
  1366. const std::unordered_map<int, double> &vals,
  1367. const int &count);
  1368. void transform(const APIData &ad)
  1369. {
  1370. APIData ad_param = ad.getobj("parameters");
  1371. APIData ad_input = ad_param.getobj("input");
  1372. APIData ad_mllib = ad_param.getobj("mllib");
  1373. if (_train && ad_input.has("db") && ad_input.get("db").get<bool>())
  1374. {
  1375. _dbfullname = _model_repo + "/" + _dbfullname;
  1376. _test_dbfullname = _model_repo + "/" + _test_dbfullname;
  1377. fillup_parameters(ad_input);
  1378. get_data(ad);
  1379. _db = true;
  1380. svm_to_db(_model_repo + "/" + _dbname,
  1381. _model_repo + "/" + _test_dbname, ad_input);
  1382. write_class_weights(_model_repo, ad_mllib);
  1383. // enrich data object with db files location
  1384. APIData dbad;
  1385. dbad.add("train_db", _dbfullname);
  1386. if (_test_split > 0.0)
  1387. dbad.add("test_db", _test_dbfullname);
  1388. const_cast<APIData &>(ad).add("db", dbad);
  1389. serialize_vocab();
  1390. }
  1391. else
  1392. {
  1393. _test_dbfullname = "";
  1394. try
  1395. {
  1396. SVMInputFileConn::transform(ad);
  1397. }
  1398. catch (std::exception &e)
  1399. {
  1400. throw;
  1401. }
  1402. if (_train)
  1403. {
  1404. write_class_weights(_model_repo, ad_mllib);
  1405. int n = 0;
  1406. auto hit = _svmdata.begin();
  1407. while (hit != _svmdata.end())
  1408. {
  1409. _dv_sparse.push_back(to_sparse_datum((*hit)));
  1410. this->_ids.push_back(std::to_string(n));
  1411. ++n;
  1412. ++hit;
  1413. }
  1414. }
  1415. if (!_train)
  1416. {
  1417. if (!_db_fname.empty())
  1418. {
  1419. _test_dbfullname = _db_fname;
  1420. _db = true;
  1421. return; // done
  1422. }
  1423. _svmdata_test = std::move(_svmdata);
  1424. }
  1425. else
  1426. _svmdata.clear();
  1427. int n = 0;
  1428. auto hit = _svmdata_test.begin();
  1429. while (hit != _svmdata_test.end())
  1430. {
  1431. _dv_test_sparse.push_back(to_sparse_datum((*hit)));
  1432. if (!_train)
  1433. this->_ids.push_back(std::to_string(n));
  1434. ++n;
  1435. ++hit;
  1436. }
  1437. }
  1438. }
  1439. caffe::SparseDatum to_sparse_datum(const SVMline &svml)
  1440. {
  1441. caffe::SparseDatum datum;
  1442. datum.set_label(svml._label);
  1443. auto hit = svml._v.begin();
  1444. int nelts = 0;
  1445. while (hit != svml._v.end())
  1446. {
  1447. datum.add_data(static_cast<float>((*hit).second));
  1448. datum.add_indices((*hit).first);
  1449. ++nelts;
  1450. ++hit;
  1451. }
  1452. datum.set_nnz(nelts);
  1453. datum.set_size(channels());
  1454. return datum;
  1455. }
  1456. std::vector<caffe::SparseDatum> get_dv_test_sparse_db(const int &num);
  1457. std::vector<caffe::SparseDatum> get_dv_test_sparse(const int &num)
  1458. {
  1459. if (_test_dbfullname.empty())
  1460. {
  1461. int i = 0;
  1462. std::vector<caffe::SparseDatum> dv;
  1463. while (_dt_vit != _dv_test_sparse.end() && i < num)
  1464. {
  1465. dv.push_back((*_dt_vit));
  1466. ++i;
  1467. ++_dt_vit;
  1468. }
  1469. return dv;
  1470. }
  1471. else
  1472. return get_dv_test_sparse_db(num);
  1473. }
  1474. void reset_dv_test();
  1475. private:
  1476. int svm_to_db(const std::string &traindbname,
  1477. const std::string &testdbname, const APIData &ad_input,
  1478. const std::string &backend = "lmdb"); // lmdb, leveldb
  1479. void write_svmline_to_db(const std::string &dbfullname,
  1480. const std::string &testdbfullname,
  1481. const APIData &ad_input,
  1482. const std::string &backend = "lmdb");
  1483. public:
  1484. std::vector<caffe::SparseDatum>::const_iterator _dt_vit;
  1485. int _db_batchsize = -1;
  1486. int _db_testbatchsize = -1;
  1487. std::unique_ptr<caffe::db::DB> _test_db;
  1488. std::unique_ptr<caffe::db::Cursor> _test_db_cursor;
  1489. std::string _dbname = "train";
  1490. std::string _test_dbname = "test";
  1491. private:
  1492. std::unique_ptr<caffe::db::Transaction> _txn;
  1493. std::unique_ptr<caffe::db::DB> _tdb;
  1494. std::unique_ptr<caffe::db::Transaction> _ttxn;
  1495. std::unique_ptr<caffe::db::DB> _ttdb;
  1496. int _channels = 0;
  1497. };
  1498. }
  1499. #endif