/deploy/src/segmentation/lib/src/selector.cpp

https://github.com/PRBonn/bonnetal · C++ · 68 lines · 36 code · 8 blank · 24 comment · 9 complexity · 85ebedc9927952290ece257591122dcc MD5 · raw file

  1. /* Copyright (c) 2019 Andres Milioto, Cyrill Stachniss, University of Bonn.
  2. *
  3. * This file is part of Bonnetal, and covered by the provided LICENSE file.
  4. *
  5. */
  6. // selective network library (conditional build)
  7. #include <selector.hpp>
  8. // Only to be used with segmentation
  9. namespace bonnetal {
  10. namespace segmentation {
  11. /**
  12. * @brief Makes a network with the desired backend, checking that it exists,
  13. * it is implemented, and that it was compiled.
  14. *
  15. * @param backend "pytorch, tensorrt"
  16. * @return std::unique_ptr<Net>
  17. */
  18. std::unique_ptr<Net> make_net(const std::string& path,
  19. const std::string& backend) {
  20. // these are the options
  21. std::vector<std::string> options = {"pytorch", "tensorrt"};
  22. // check that backend exists
  23. std::string lc_backend(backend); // get a copy
  24. std::transform(lc_backend.begin(), lc_backend.end(), lc_backend.begin(),
  25. ::tolower); // lowercase to allow user to be sloppy
  26. if (std::find(options.begin(), options.end(), lc_backend) == options.end()) {
  27. // not found
  28. std::cerr << "Backend must be one of the following: " << std::endl;
  29. for (auto& b : options) {
  30. std::cerr << b << std::endl;
  31. }
  32. throw std::runtime_error("Choose a valid backend");
  33. }
  34. // make a network
  35. std::unique_ptr<Net> network;
  36. // Select backend
  37. if (lc_backend == "pytorch") {
  38. #ifdef TORCH_FOUND
  39. // generate net with tf backend
  40. network = std::unique_ptr<Net>(new NetPytorch(path));
  41. #else
  42. // complain
  43. throw std::runtime_error("'pytorch' backend implemented but not built.");
  44. #endif
  45. } else if (lc_backend == "tensorrt") {
  46. #ifdef TENSORRT_FOUND
  47. // generate net with tf backend
  48. network = std::unique_ptr<Net>(new NetTensorRT(path));
  49. #else
  50. // complain
  51. throw std::runtime_error("'tensorrt' backend implemented but not built.");
  52. #endif
  53. } else {
  54. // Should't get here but just in case my logic fails (it mostly does)
  55. throw std::runtime_error(backend + " backend not implemented");
  56. }
  57. return network;
  58. }
  59. } // namespace segmentation
  60. } // namespace bonnetal