/src/selector.cpp

https://github.com/PRBonn/rangenet_lib · C++ · 57 lines · 28 code · 8 blank · 21 comment · 6 complexity · fe193ac3fed46295b0678046cfad9637 MD5 · raw file

  1. /* Copyright (c) 2019 Andres Milioto, Cyrill Stachniss, University of Bonn.
  2. *
  3. * This file is part of rangenet_lib, and covered by the provided LICENSE file.
  4. *
  5. */
  6. // selective network library (conditional build)
  7. #include <selector.hpp>
  8. // Only to be used with segmentation
  9. namespace rangenet {
  10. namespace segmentation {
  11. /**
  12. * @brief Makes a network with the desired backend, checking that it exists,
  13. * it is implemented, and that it was compiled.
  14. *
  15. * @param backend "pytorch, tensorrt"
  16. * @return std::unique_ptr<Net>
  17. */
  18. std::unique_ptr<Net> make_net(const std::string& path,
  19. const std::string& backend) {
  20. // these are the options
  21. std::vector<std::string> options = {"tensorrt"};
  22. // check that backend exists
  23. std::string lc_backend(backend); // get a copy
  24. std::transform(lc_backend.begin(), lc_backend.end(), lc_backend.begin(),
  25. ::tolower); // lowercase to allow user to be sloppy
  26. if (std::find(options.begin(), options.end(), lc_backend) == options.end()) {
  27. // not found
  28. std::cerr << "Backend must be one of the following: " << std::endl;
  29. for (auto& b : options) {
  30. std::cerr << b << std::endl;
  31. }
  32. throw std::runtime_error("Choose a valid backend");
  33. }
  34. // make a network
  35. std::unique_ptr<Net> network;
  36. // Select backend
  37. if (lc_backend == "tensorrt") {
  38. #ifdef TENSORRT_FOUND
  39. // generate net with tf backend
  40. network = std::unique_ptr<Net>(new NetTensorRT(path));
  41. #endif
  42. } else {
  43. // Should't get here but just in case my logic fails (it mostly does)
  44. throw std::runtime_error(backend + " backend not implemented");
  45. }
  46. return network;
  47. }
  48. } // namespace segmentation
  49. } // namespace rangenet