PageRenderTime 47ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/tensorflow_serving/model_servers/main.cc

https://gitlab.com/vectorci/serving
C++ | 208 lines | 145 code | 23 blank | 40 comment | 4 complexity | 3098f618ef6b7b222f1bd5bef5b2d378 MD5 | raw file
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // gRPC server implementation of
  13. // tensorflow_serving/apis/prediction_service.proto.
  14. //
  15. // It bring up a standard server to serve a single TensorFlow model using
  16. // command line flags, or multiple models via config file.
  17. //
  18. // ModelServer prioritizes easy invocation over flexibility,
  19. // and thus serves a statically configured set of models. New versions of these
  20. // models will be loaded and managed over time using the EagerLoadPolicy at:
  21. // tensorflow_serving/core/eager_load_policy.h.
  22. // by AspiredVersionsManager at:
  23. // tensorflow_serving/core/aspired_versions_manager.h
  24. //
  25. // ModelServer has inter-request batching support built-in, by using the
  26. // BatchingSession at:
  27. // tensorflow_serving/batching/batching_session.h
  28. //
  29. // To serve a single model, run with:
  30. // $path_to_binary/tensorflow_model_server \
  31. // --model_base_path=[/tmp/my_model | gs://gcs_address] \
  32. // To specify model name (default "default"): --model_name=my_name
  33. // To specify port (default 8500): --port=my_port
  34. // To enable batching (default disabled): --enable_batching
  35. // To log on stderr (default disabled): --alsologtostderr
  36. #include <unistd.h>
  37. #include <iostream>
  38. #include <memory>
  39. #include <utility>
  40. #include "google/protobuf/wrappers.pb.h"
  41. #include "grpc++/security/server_credentials.h"
  42. #include "grpc++/server.h"
  43. #include "grpc++/server_builder.h"
  44. #include "grpc++/server_context.h"
  45. #include "grpc++/support/status.h"
  46. #include "grpc++/support/status_code_enum.h"
  47. #include "grpc/grpc.h"
  48. #include "tensorflow/core/lib/core/status.h"
  49. #include "tensorflow/core/platform/init_main.h"
  50. #include "tensorflow/core/util/command_line_flags.h"
  51. #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
  52. #include "tensorflow_serving/apis/prediction_service.pb.h"
  53. #include "tensorflow_serving/config/model_server_config.pb.h"
  54. #include "tensorflow_serving/model_servers/server_core.h"
  55. #include "tensorflow_serving/servables/tensorflow/predict_impl.h"
  56. #include "tensorflow_serving/servables/tensorflow/session_bundle_source_adapter.h"
  57. using tensorflow::serving::BatchingParameters;
  58. using tensorflow::serving::EventBus;
  59. using tensorflow::serving::Loader;
  60. using tensorflow::serving::ModelServerConfig;
  61. using tensorflow::serving::ServableState;
  62. using tensorflow::serving::ServableStateMonitor;
  63. using tensorflow::serving::ServerCore;
  64. using tensorflow::serving::SessionBundleSourceAdapter;
  65. using tensorflow::serving::SessionBundleSourceAdapterConfig;
  66. using tensorflow::serving::Target;
  67. using tensorflow::serving::TensorflowPredictImpl;
  68. using tensorflow::serving::UniquePtrWithDeps;
  69. using tensorflow::string;
  70. using grpc::InsecureServerCredentials;
  71. using grpc::Server;
  72. using grpc::ServerAsyncResponseWriter;
  73. using grpc::ServerBuilder;
  74. using grpc::ServerContext;
  75. using grpc::ServerCompletionQueue;
  76. using tensorflow::serving::PredictRequest;
  77. using tensorflow::serving::PredictResponse;
  78. using tensorflow::serving::PredictionService;
  79. namespace {
  80. constexpr char kTensorFlowModelType[] = "tensorflow";
  81. tensorflow::Status CreateSourceAdapter(
  82. const SessionBundleSourceAdapterConfig& config, const string& model_type,
  83. std::unique_ptr<ServerCore::ModelServerSourceAdapter>* adapter) {
  84. CHECK(model_type == kTensorFlowModelType) // Crash ok
  85. << "ModelServer supports only TensorFlow model.";
  86. std::unique_ptr<SessionBundleSourceAdapter> typed_adapter;
  87. TF_RETURN_IF_ERROR(
  88. SessionBundleSourceAdapter::Create(config, &typed_adapter));
  89. *adapter = std::move(typed_adapter);
  90. return tensorflow::Status::OK();
  91. }
  92. tensorflow::Status CreateServableStateMonitor(
  93. EventBus<ServableState>* event_bus,
  94. std::unique_ptr<ServableStateMonitor>* monitor) {
  95. *monitor = nullptr;
  96. return tensorflow::Status::OK();
  97. }
  98. tensorflow::Status LoadDynamicModelConfig(
  99. const ::google::protobuf::Any& any,
  100. Target<std::unique_ptr<Loader>>* target) {
  101. CHECK(false) // Crash ok
  102. << "ModelServer does not yet support dynamic model config.";
  103. }
  104. ModelServerConfig BuildSingleModelConfig(const string& model_name,
  105. const string& model_base_path) {
  106. ModelServerConfig config;
  107. LOG(INFO) << "Building single TensorFlow model file config: "
  108. << " model_name: " << model_name
  109. << " model_base_path: " << model_base_path;
  110. tensorflow::serving::ModelConfig* single_model =
  111. config.mutable_model_config_list()->add_config();
  112. single_model->set_name(model_name);
  113. single_model->set_base_path(model_base_path);
  114. single_model->set_model_type(kTensorFlowModelType);
  115. return config;
  116. }
  117. grpc::Status ToGRPCStatus(const tensorflow::Status& status) {
  118. return grpc::Status(static_cast<grpc::StatusCode>(status.code()),
  119. status.error_message());
  120. }
  121. class PredictionServiceImpl final : public PredictionService::Service {
  122. public:
  123. explicit PredictionServiceImpl(std::unique_ptr<ServerCore> core)
  124. : core_(std::move(core)) {}
  125. grpc::Status Predict(ServerContext* context, const PredictRequest* request,
  126. PredictResponse* response) override {
  127. return ToGRPCStatus(
  128. TensorflowPredictImpl::Predict(core_.get(), *request, response));
  129. }
  130. private:
  131. std::unique_ptr<ServerCore> core_;
  132. };
  133. void RunServer(int port, std::unique_ptr<ServerCore> core) {
  134. // "0.0.0.0" is the way to listen on localhost in gRPC.
  135. const string server_address = "0.0.0.0:" + std::to_string(port);
  136. PredictionServiceImpl service(std::move(core));
  137. ServerBuilder builder;
  138. std::shared_ptr<grpc::ServerCredentials> creds = InsecureServerCredentials();
  139. builder.AddListeningPort(server_address, creds);
  140. builder.RegisterService(&service);
  141. std::unique_ptr<Server> server(builder.BuildAndStart());
  142. LOG(INFO) << "Running ModelServer at " << server_address << " ...";
  143. server->Wait();
  144. }
  145. } // namespace
  146. int main(int argc, char** argv) {
  147. tensorflow::int32 port = 8500;
  148. bool enable_batching = false;
  149. tensorflow::string model_name = "default";
  150. tensorflow::string model_base_path;
  151. const bool parse_result = tensorflow::ParseFlags(
  152. &argc, argv, {tensorflow::Flag("port", &port),
  153. tensorflow::Flag("enable_batching", &enable_batching),
  154. tensorflow::Flag("model_name", &model_name),
  155. tensorflow::Flag("model_base_path", &model_base_path)});
  156. if (!parse_result || model_base_path.empty()) {
  157. std::cout << "Usage: model_server"
  158. << " [--port=8500]"
  159. << " [--enable_batching]"
  160. << " [--model_name=my_name]"
  161. << " --model_base_path=/path/to/export" << std::endl;
  162. return -1;
  163. }
  164. tensorflow::port::InitMain(argv[0], &argc, &argv);
  165. ModelServerConfig config =
  166. BuildSingleModelConfig(model_name, model_base_path);
  167. SessionBundleSourceAdapterConfig source_adapter_config;
  168. // Batching config
  169. if (enable_batching) {
  170. BatchingParameters* batching_parameters =
  171. source_adapter_config.mutable_config()->mutable_batching_parameters();
  172. batching_parameters->mutable_thread_pool_name()->set_value(
  173. "model_server_batch_threads");
  174. }
  175. std::unique_ptr<ServerCore> core;
  176. TF_CHECK_OK(ServerCore::Create(
  177. config, std::bind(CreateSourceAdapter, source_adapter_config,
  178. std::placeholders::_1, std::placeholders::_2),
  179. &CreateServableStateMonitor, &LoadDynamicModelConfig, &core));
  180. RunServer(port, std::move(core));
  181. return 0;
  182. }