PageRenderTime 51ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/common/rpc/StreamRpcChannel.cpp

https://code.google.com/
C++ | 611 lines | 409 code | 98 blank | 104 comment | 71 complexity | 1e6bcb9dd8f932e229300561319b22b3 MD5 | raw file
Possible License(s): LGPL-2.1, GPL-2.0
  1. /*
  2. * This program is free software; you can redistribute it and/or modify
  3. * it under the terms of the GNU General Public License as published by
  4. * the Free Software Foundation; either version 2 of the License, or
  5. * (at your option) any later version.
  6. *
  7. * This program is distributed in the hope that it will be useful,
  8. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. * GNU Library General Public License for more details.
  11. *
  12. * You should have received a copy of the GNU General Public License
  13. * along with this program; if not, write to the Free Software
  14. * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
  15. *
  16. * StreamRpcChannel.cpp
  17. * Interface for the UDP RPC Channel
  18. * Copyright (C) 2005-2008 Simon Newton
  19. */
  20. #include <errno.h>
  21. #include <google/protobuf/service.h>
  22. #include <google/protobuf/message.h>
  23. #include <google/protobuf/descriptor.h>
  24. #include <google/protobuf/dynamic_message.h>
  25. #include <string>
  26. #include "common/rpc/Rpc.pb.h"
  27. #include "common/rpc/SimpleRpcController.h"
  28. #include "common/rpc/StreamRpcChannel.h"
  29. #include "ola/Callback.h"
  30. #include "ola/Logging.h"
  31. namespace ola {
  32. namespace rpc {
  33. using google::protobuf::ServiceDescriptor;
  34. const char StreamRpcChannel::K_RPC_RECEIVED_TYPE_VAR[] = "rpc-received-type";
  35. const char StreamRpcChannel::K_RPC_RECEIVED_VAR[] = "rpc-received";
  36. const char StreamRpcChannel::K_RPC_SENT_ERROR_VAR[] = "rpc-send-errors";
  37. const char StreamRpcChannel::K_RPC_SENT_VAR[] = "rpc-sent";
  38. const char StreamRpcChannel::STREAMING_NO_RESPONSE[] = "STREAMING_NO_RESPONSE";
  39. StreamRpcChannel::StreamRpcChannel(
  40. Service *service,
  41. ola::io::ConnectedDescriptor *descriptor,
  42. ExportMap *export_map)
  43. : m_service(service),
  44. m_on_close(NULL),
  45. m_descriptor(descriptor),
  46. m_seq(0),
  47. m_buffer(NULL),
  48. m_buffer_size(0),
  49. m_expected_size(0),
  50. m_current_size(0),
  51. m_export_map(export_map),
  52. m_recv_type_map(NULL) {
  53. descriptor->SetOnData(
  54. ola::NewCallback(this, &StreamRpcChannel::DescriptorReady));
  55. // init the counters
  56. const char *vars[] = {
  57. K_RPC_RECEIVED_VAR,
  58. K_RPC_SENT_ERROR_VAR,
  59. K_RPC_SENT_VAR,
  60. };
  61. if (m_export_map) {
  62. for (unsigned int i = 0; i < sizeof(vars) / sizeof(vars[0]); ++i)
  63. m_export_map->GetCounterVar(string(vars[i]));
  64. m_recv_type_map = m_export_map->GetUIntMapVar(K_RPC_RECEIVED_TYPE_VAR,
  65. "type");
  66. }
  67. }
  68. StreamRpcChannel::~StreamRpcChannel() {
  69. if (m_on_close)
  70. delete m_on_close;
  71. free(m_buffer);
  72. }
  73. /*
  74. * Receive a message for this RPCChannel. Called when data is available on the
  75. * descriptor.
  76. */
  77. void StreamRpcChannel::DescriptorReady() {
  78. if (!m_expected_size) {
  79. // this is a new msg
  80. unsigned int version;
  81. if (ReadHeader(&version, &m_expected_size) < 0)
  82. return;
  83. if (!m_expected_size)
  84. return;
  85. if (version != PROTOCOL_VERSION) {
  86. OLA_WARN << "protocol mismatch " << version << " != " <<
  87. PROTOCOL_VERSION;
  88. return;
  89. }
  90. m_current_size = 0;
  91. m_buffer_size = AllocateMsgBuffer(m_expected_size);
  92. if (m_buffer_size < m_expected_size) {
  93. OLA_WARN << "buffer size to small " << m_buffer_size << " < " <<
  94. m_expected_size;
  95. return;
  96. }
  97. }
  98. unsigned int data_read;
  99. if (m_descriptor->Receive(m_buffer + m_current_size,
  100. m_expected_size - m_current_size,
  101. data_read) < 0) {
  102. OLA_WARN << "something went wrong in descriptor recv\n";
  103. return;
  104. }
  105. m_current_size += data_read;
  106. if (m_current_size == m_expected_size) {
  107. // we've got all of this message so parse it.
  108. if (!HandleNewMsg(m_buffer, m_expected_size)) {
  109. // this probably means we've messed the framing up, close the channel
  110. OLA_WARN << "Errors detected on RPC channel, closing";
  111. m_descriptor->Close();
  112. }
  113. m_expected_size = 0;
  114. }
  115. return;
  116. }
  117. /*
  118. * Set the Closure to be called if a write on this channel fails. This is
  119. * different from the Descriptor on close handler which is called when reads hit
  120. * EOF/
  121. */
  122. void StreamRpcChannel::SetOnClose(SingleUseCallback0<void> *closure) {
  123. if (closure != m_on_close) {
  124. delete m_on_close;
  125. m_on_close = closure;
  126. }
  127. }
  128. /*
  129. * Call a method with the given request and reply
  130. * TODO(simonn): reduce the number of copies here
  131. */
  132. void StreamRpcChannel::CallMethod(
  133. const MethodDescriptor *method,
  134. RpcController *controller,
  135. const Message *request,
  136. Message *reply,
  137. google::protobuf::Closure *done) {
  138. string output;
  139. RpcMessage message;
  140. bool is_streaming = false;
  141. // Streaming methods are those with a reply set to STREAMING_NO_RESPONSE and
  142. // no controller, request or closure provided
  143. if (method->output_type()->name() == STREAMING_NO_RESPONSE) {
  144. if (controller || reply || done) {
  145. OLA_FATAL << "Calling streaming method " << method->name() <<
  146. " but a controller, reply or closure in non-NULL";
  147. return;
  148. }
  149. is_streaming = true;
  150. }
  151. message.set_type(is_streaming ? STREAM_REQUEST : REQUEST);
  152. message.set_id(m_seq++);
  153. message.set_name(method->name());
  154. request->SerializeToString(&output);
  155. message.set_buffer(output);
  156. bool r = SendMsg(&message);
  157. if (is_streaming)
  158. return;
  159. if (!r) {
  160. // send failed, call the handler now
  161. controller->SetFailed("Failed to send request");
  162. done->Run();
  163. return;
  164. }
  165. OutstandingResponse *response = GetOutstandingResponse(message.id());
  166. if (response) {
  167. // fail any outstanding response with the same id
  168. OLA_WARN << "response " << response->id << " already pending, failing " <<
  169. "now";
  170. response->controller->SetFailed("Duplicate request found");
  171. InvokeCallbackAndCleanup(response);
  172. }
  173. response = new OutstandingResponse();
  174. response->id = message.id();
  175. response->controller = controller;
  176. response->callback = done;
  177. response->reply = reply;
  178. m_responses[message.id()] = response;
  179. }
  180. /*
  181. * Called when a response is ready.
  182. */
  183. void StreamRpcChannel::RequestComplete(OutstandingRequest *request) {
  184. string output;
  185. RpcMessage message;
  186. if (request->controller->Failed()) {
  187. SendRequestFailed(request);
  188. return;
  189. }
  190. message.set_type(RESPONSE);
  191. message.set_id(request->id);
  192. request->response->SerializeToString(&output);
  193. message.set_buffer(output);
  194. SendMsg(&message);
  195. DeleteOutstandingRequest(request);
  196. }
  197. // private
  198. //-----------------------------------------------------------------------------
  199. /*
  200. * Write an RpcMessage to the write descriptor.
  201. */
  202. bool StreamRpcChannel::SendMsg(RpcMessage *msg) {
  203. if (!m_descriptor->ValidReadDescriptor()) {
  204. OLA_WARN << "RPC descriptor closed, not sending messages";
  205. return false;
  206. }
  207. string output;
  208. msg->SerializeToString(&output);
  209. int length = output.length();
  210. uint32_t header;
  211. StreamRpcHeader::EncodeHeader(&header, PROTOCOL_VERSION, length);
  212. ssize_t ret = m_descriptor->Send(reinterpret_cast<const uint8_t*>(&header),
  213. sizeof(header));
  214. ret = m_descriptor->Send(reinterpret_cast<const uint8_t*>(output.data()),
  215. length);
  216. if (ret != length) {
  217. if (ret == -1)
  218. OLA_WARN << "Send failed " << strerror(errno);
  219. else
  220. OLA_WARN << "Failed to send full datagram, closing channel";
  221. // At the point framing is screwed and we should shut the channel down
  222. m_descriptor->Close();
  223. if (m_on_close)
  224. m_on_close->Run();
  225. if (m_export_map)
  226. (*m_export_map->GetCounterVar(K_RPC_SENT_ERROR_VAR))++;
  227. return false;
  228. }
  229. if (m_export_map)
  230. (*m_export_map->GetCounterVar(K_RPC_SENT_VAR))++;
  231. return true;
  232. }
  233. /*
  234. * Allocate an incomming message buffer
  235. * @param size the size of the new buffer to allocate
  236. * @returns the size of the new buffer
  237. */
  238. int StreamRpcChannel::AllocateMsgBuffer(unsigned int size) {
  239. unsigned int requested_size = size;
  240. uint8_t *new_buffer;
  241. if (size < m_buffer_size)
  242. return size;
  243. if (m_buffer_size == 0 && size < INITIAL_BUFFER_SIZE)
  244. requested_size = INITIAL_BUFFER_SIZE;
  245. if (requested_size > MAX_BUFFER_SIZE)
  246. return m_buffer_size;
  247. new_buffer = static_cast<uint8_t*>(realloc(m_buffer, requested_size));
  248. if (new_buffer < 0)
  249. return m_buffer_size;
  250. m_buffer = new_buffer;
  251. m_buffer_size = requested_size;
  252. return requested_size;
  253. }
  254. /*
  255. * Read 4 bytes and decode the header fields.
  256. * @returns: -1 if there is no data is available, version and size are 0
  257. */
  258. int StreamRpcChannel::ReadHeader(unsigned int *version,
  259. unsigned int *size) const {
  260. uint32_t header;
  261. unsigned int data_read = 0;
  262. *version = *size = 0;
  263. if (m_descriptor->Receive(reinterpret_cast<uint8_t*>(&header),
  264. sizeof(header), data_read)) {
  265. OLA_WARN << "read header error: " << strerror(errno);
  266. return -1;
  267. }
  268. if (!data_read)
  269. return 0;
  270. StreamRpcHeader::DecodeHeader(header, version, size);
  271. return 0;
  272. }
  273. /*
  274. * Parse a new message and handle it.
  275. */
  276. bool StreamRpcChannel::HandleNewMsg(uint8_t *data, unsigned int size) {
  277. RpcMessage msg;
  278. if (!msg.ParseFromArray(data, size)) {
  279. OLA_WARN << "Failed to parse RPC";
  280. return false;
  281. }
  282. if (m_export_map)
  283. (*m_export_map->GetCounterVar(K_RPC_RECEIVED_VAR))++;
  284. switch (msg.type()) {
  285. case REQUEST:
  286. if (m_recv_type_map)
  287. (*m_recv_type_map)["request"]++;
  288. HandleRequest(&msg);
  289. break;
  290. case RESPONSE:
  291. if (m_recv_type_map)
  292. (*m_recv_type_map)["response"]++;
  293. HandleResponse(&msg);
  294. break;
  295. case RESPONSE_CANCEL:
  296. if (m_recv_type_map)
  297. (*m_recv_type_map)["cancelled"]++;
  298. HandleCanceledResponse(&msg);
  299. break;
  300. case RESPONSE_FAILED:
  301. if (m_recv_type_map)
  302. (*m_recv_type_map)["failed"]++;
  303. HandleFailedResponse(&msg);
  304. break;
  305. case RESPONSE_NOT_IMPLEMENTED:
  306. if (m_recv_type_map)
  307. (*m_recv_type_map)["not-implemented"]++;
  308. HandleNotImplemented(&msg);
  309. break;
  310. case STREAM_REQUEST:
  311. if (m_recv_type_map)
  312. (*m_recv_type_map)["stream_request"]++;
  313. HandleStreamRequest(&msg);
  314. break;
  315. default:
  316. OLA_WARN << "not sure of msg type " << msg.type();
  317. break;
  318. }
  319. return true;
  320. }
  321. /*
  322. * Handle a new RPC method call.
  323. */
  324. void StreamRpcChannel::HandleRequest(RpcMessage *msg) {
  325. if (!m_service) {
  326. OLA_WARN << "no service registered";
  327. return;
  328. }
  329. const ServiceDescriptor *service = m_service->GetDescriptor();
  330. if (!service) {
  331. OLA_WARN << "failed to get service descriptor";
  332. return;
  333. }
  334. const MethodDescriptor *method = service->FindMethodByName(msg->name());
  335. if (!method) {
  336. OLA_WARN << "failed to get method descriptor";
  337. SendNotImplemented(msg->id());
  338. return;
  339. }
  340. Message* request_pb = m_service->GetRequestPrototype(method).New();
  341. Message* response_pb = m_service->GetResponsePrototype(method).New();
  342. if (!request_pb || !response_pb) {
  343. OLA_WARN << "failed to get request or response objects";
  344. return;
  345. }
  346. if (!request_pb->ParseFromString(msg->buffer())) {
  347. OLA_WARN << "parsing of request pb failed";
  348. return;
  349. }
  350. OutstandingRequest *request = new OutstandingRequest();
  351. request->id = msg->id();
  352. request->controller = new SimpleRpcController();
  353. request->response = response_pb;
  354. if (m_requests.find(msg->id()) != m_requests.end()) {
  355. OLA_WARN << "dup sequence number for request " << msg->id();
  356. SendRequestFailed(m_requests[msg->id()]);
  357. }
  358. m_requests[msg->id()] = request;
  359. google::protobuf::Closure *callback = NewCallback(
  360. this, &StreamRpcChannel::RequestComplete, request);
  361. m_service->CallMethod(method, request->controller, request_pb, response_pb,
  362. callback);
  363. delete request_pb;
  364. }
  365. /*
  366. * Handle a streaming RPC call. This doesn't return any response to the client.
  367. */
  368. void StreamRpcChannel::HandleStreamRequest(RpcMessage *msg) {
  369. if (!m_service) {
  370. OLA_WARN << "no service registered";
  371. return;
  372. }
  373. const ServiceDescriptor *service = m_service->GetDescriptor();
  374. if (!service) {
  375. OLA_WARN << "failed to get service descriptor";
  376. return;
  377. }
  378. const MethodDescriptor *method = service->FindMethodByName(msg->name());
  379. if (!method) {
  380. OLA_WARN << "failed to get method descriptor";
  381. SendNotImplemented(msg->id());
  382. return;
  383. }
  384. if (method->output_type()->name() != STREAMING_NO_RESPONSE) {
  385. OLA_WARN << "Streaming request recieved for " << method->name() <<
  386. ", but the output type isn't STREAMING_NO_RESPONSE";
  387. return;
  388. }
  389. Message* request_pb = m_service->GetRequestPrototype(method).New();
  390. if (!request_pb) {
  391. OLA_WARN << "failed to get request or response objects";
  392. return;
  393. }
  394. if (!request_pb->ParseFromString(msg->buffer())) {
  395. OLA_WARN << "parsing of request pb failed";
  396. return;
  397. }
  398. m_service->CallMethod(method, NULL, request_pb, NULL, NULL);
  399. delete request_pb;
  400. }
  401. // server side
  402. /*
  403. * Notify the caller that the request failed.
  404. */
  405. void StreamRpcChannel::SendRequestFailed(OutstandingRequest *request) {
  406. RpcMessage message;
  407. message.set_type(RESPONSE_FAILED);
  408. message.set_id(request->id);
  409. message.set_buffer(request->controller->ErrorText());
  410. SendMsg(&message);
  411. DeleteOutstandingRequest(request);
  412. }
  413. /*
  414. * Sent if we get a request for a non-existant method.
  415. */
  416. void StreamRpcChannel::SendNotImplemented(int msg_id) {
  417. RpcMessage message;
  418. message.set_type(RESPONSE_NOT_IMPLEMENTED);
  419. message.set_id(msg_id);
  420. SendMsg(&message);
  421. }
  422. /*
  423. * Cleanup an outstanding request after the response has been returned
  424. */
  425. void StreamRpcChannel::DeleteOutstandingRequest(OutstandingRequest *request) {
  426. m_requests.erase(request->id);
  427. delete request->controller;
  428. delete request->response;
  429. delete request;
  430. }
  431. // client side methods
  432. /*
  433. * Handle a RPC response by invoking the callback.
  434. */
  435. void StreamRpcChannel::HandleResponse(RpcMessage *msg) {
  436. OutstandingResponse *response = GetOutstandingResponse(msg->id());
  437. if (response) {
  438. response->reply->ParseFromString(msg->buffer());
  439. InvokeCallbackAndCleanup(response);
  440. }
  441. }
  442. /*
  443. * Handle a RPC response by invoking the callback.
  444. */
  445. void StreamRpcChannel::HandleFailedResponse(RpcMessage *msg) {
  446. OutstandingResponse *response = GetOutstandingResponse(msg->id());
  447. if (response) {
  448. response->controller->SetFailed(msg->buffer());
  449. InvokeCallbackAndCleanup(response);
  450. }
  451. }
  452. /*
  453. * Handle a RPC response by invoking the callback.
  454. */
  455. void StreamRpcChannel::HandleCanceledResponse(RpcMessage *msg) {
  456. OLA_INFO << "Received a canceled response";
  457. OutstandingResponse *response = GetOutstandingResponse(msg->id());
  458. if (response) {
  459. response->controller->SetFailed(msg->buffer());
  460. InvokeCallbackAndCleanup(response);
  461. }
  462. }
  463. /*
  464. * Handle a NOT_IMPLEMENTED by invoking the callback.
  465. */
  466. void StreamRpcChannel::HandleNotImplemented(RpcMessage *msg) {
  467. OLA_INFO << "Received a non-implemented response";
  468. OutstandingResponse *response = GetOutstandingResponse(msg->id());
  469. if (response) {
  470. response->controller->SetFailed("Not Implemented");
  471. InvokeCallbackAndCleanup(response);
  472. }
  473. }
  474. /*
  475. * Find the outstanding response with id msg_id.
  476. */
  477. OutstandingResponse *StreamRpcChannel::GetOutstandingResponse(int msg_id) {
  478. if (m_responses.find(msg_id) != m_responses.end()) {
  479. return m_responses[msg_id];
  480. }
  481. return NULL;
  482. }
  483. /*
  484. * Run the callback for a request.
  485. */
  486. void StreamRpcChannel::InvokeCallbackAndCleanup(OutstandingResponse *response) {
  487. if (response) {
  488. int id = response->id;
  489. response->callback->Run();
  490. delete response;
  491. m_responses.erase(id);
  492. }
  493. }
  494. // StreamRpcHeader
  495. //--------------------------------------------------------------
  496. /**
  497. * Encode a header
  498. */
  499. void StreamRpcHeader::EncodeHeader(uint32_t *header, unsigned int version,
  500. unsigned int size) {
  501. *header = (version << 28) & VERSION_MASK;
  502. *header |= size & SIZE_MASK;
  503. }
  504. /**
  505. * Decode a header
  506. */
  507. void StreamRpcHeader::DecodeHeader(uint32_t header, unsigned int *version,
  508. unsigned int *size) {
  509. *version = (header & VERSION_MASK) >> 28;
  510. *size = header & SIZE_MASK;
  511. }
  512. } // rpc
  513. } // ola