PageRenderTime 53ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/mordor/protobuf.cpp

http://github.com/mozy/mordor
C++ | 579 lines | 473 code | 84 blank | 22 comment | 100 complexity | 473563f1a4f94af425956ffbe7a59a5f MD5 | raw file
Possible License(s): BSD-3-Clause
  1. // Copyright (c) 2010 - Mozy, Inc.
  2. #include "protobuf.h"
  3. #include "mordor/assert.h"
  4. #include "mordor/streams/buffer.h"
  5. #include "mordor/string.h"
  6. #ifdef MSVC
  7. // Disable some warnings, but only while
  8. // processing the google generated code
  9. #pragma warning(push)
  10. #pragma warning(disable : 4244)
  11. #endif
  12. #include <boost/algorithm/string/regex.hpp>
  13. #include <boost/foreach.hpp>
  14. #include <boost/lexical_cast.hpp>
  15. #undef TYPE_BOOL // avoid collision from Mac OS X's ConditionalMacros.h
  16. #include <google/protobuf/descriptor.h>
  17. #include <google/protobuf/io/zero_copy_stream.h>
  18. #include <google/protobuf/message.h>
  19. #ifdef MSVC
  20. #pragma warning(pop)
  21. #endif
  22. using namespace google::protobuf;
  23. using namespace Mordor::JSON;
  24. #ifdef MSVC
  25. #ifdef _DEBUG
  26. #pragma comment(lib, "libprotobufd.lib")
  27. #else
  28. #pragma comment(lib, "libprotobuf.lib")
  29. #endif
  30. #endif
  31. namespace Mordor {
  32. class BufferZeroCopyInputStream : public io::ZeroCopyInputStream
  33. {
  34. public:
  35. BufferZeroCopyInputStream(const Buffer &buffer)
  36. : m_iovs(buffer.readBuffers()),
  37. m_currentIov(0),
  38. m_currentIovOffset(0),
  39. m_complete(0)
  40. {}
  41. bool Next(const void **data, int *size)
  42. {
  43. if (m_currentIov >= m_iovs.size())
  44. return false;
  45. MORDOR_ASSERT(m_currentIovOffset <= m_iovs[m_currentIov].iov_len);
  46. *data = (char *)m_iovs[m_currentIov].iov_base + m_currentIovOffset;
  47. *size = (int)(m_iovs[m_currentIov].iov_len - m_currentIovOffset);
  48. m_complete += *size;
  49. m_currentIovOffset = 0;
  50. ++m_currentIov;
  51. return true;
  52. }
  53. void BackUp(int count)
  54. {
  55. MORDOR_ASSERT(count >= 0);
  56. MORDOR_ASSERT(count <= m_complete);
  57. m_complete -= count;
  58. while (count) {
  59. if (m_currentIovOffset == 0) {
  60. MORDOR_ASSERT(m_currentIov > 0);
  61. m_currentIovOffset = m_iovs[--m_currentIov].iov_len;
  62. }
  63. size_t todo = (std::min)(m_currentIovOffset, (size_t)count);
  64. m_currentIovOffset -= todo;
  65. count -= (int)todo;
  66. }
  67. }
  68. bool Skip(int count) {
  69. MORDOR_ASSERT(count >= 0);
  70. while (count) {
  71. if (m_currentIov >= m_iovs.size())
  72. return false;
  73. size_t todo = (std::min)((size_t)m_iovs[m_currentIov].iov_len -
  74. m_currentIovOffset, (size_t)count);
  75. m_currentIovOffset += todo;
  76. count -= (int)todo;
  77. m_complete += todo;
  78. if (m_currentIovOffset == m_iovs[m_currentIov].iov_len) {
  79. m_currentIovOffset = 0;
  80. ++m_currentIov;
  81. }
  82. }
  83. return true;
  84. }
  85. int64 ByteCount() const { return m_complete; }
  86. private:
  87. std::vector<iovec> m_iovs;
  88. size_t m_currentIov;
  89. size_t m_currentIovOffset;
  90. int64 m_complete;
  91. };
  92. class BufferZeroCopyOutputStream : public io::ZeroCopyOutputStream
  93. {
  94. public:
  95. BufferZeroCopyOutputStream(Buffer &buffer, size_t bufferSize = 1024)
  96. : m_buffer(buffer),
  97. m_bufferSize(bufferSize),
  98. m_pendingProduce(0),
  99. m_total(0)
  100. {}
  101. ~BufferZeroCopyOutputStream()
  102. {
  103. m_buffer.produce(m_pendingProduce);
  104. }
  105. bool Next(void **data, int *size)
  106. {
  107. m_buffer.produce(m_pendingProduce);
  108. m_pendingProduce = 0;
  109. // TODO: protect against std::bad_alloc?
  110. iovec iov = m_buffer.writeBuffer(m_bufferSize, false);
  111. *data = iov.iov_base;
  112. m_total += m_pendingProduce = iov.iov_len;
  113. *size = (int)m_pendingProduce;
  114. return true;
  115. }
  116. void BackUp(int count)
  117. {
  118. MORDOR_ASSERT(count <= (int)m_pendingProduce);
  119. m_pendingProduce -= count;
  120. m_total -= count;
  121. }
  122. int64 ByteCount() const
  123. {
  124. return m_total;
  125. }
  126. private:
  127. Buffer &m_buffer;
  128. size_t m_bufferSize, m_pendingProduce;
  129. int64 m_total;
  130. };
  131. void serializeToBuffer(const Message &proto, Buffer &buffer)
  132. {
  133. BufferZeroCopyOutputStream stream(buffer);
  134. if (!proto.SerializeToZeroCopyStream(&stream))
  135. MORDOR_THROW_EXCEPTION(std::invalid_argument("proto"));
  136. }
  137. void parseFromBuffer(Message &proto, const Buffer &buffer)
  138. {
  139. BufferZeroCopyInputStream stream(buffer);
  140. if (!proto.ParseFromZeroCopyStream(&stream))
  141. MORDOR_THROW_EXCEPTION(std::invalid_argument("buffer"));
  142. }
  143. // begin anonymous namespace for reflection stuff
  144. namespace {
  145. void setFieldValue(Message *message, const FieldDescriptor *descriptor, const Value &fieldValue, int index=-1)
  146. {
  147. const Reflection* reflection = message->GetReflection();
  148. #define SET_FIELD(setter, type, proto_type, converter)\
  149. /*BOOST_STATIC_ASSERT(\
  150. (#setter == "Int32" && #type == "long long") ||\
  151. (#setter == "UInt32" && #type == "long long") ||\
  152. (#setter == "Int64" && #type == "long long") ||\
  153. (#setter == "UInt64" && #type == "long long") ||\
  154. (#setter == "Float" && #type == "double") ||\
  155. (#setter == "Double" && #type == "double") ||\
  156. (#setter == "Bool" && #type == "bool") ||\
  157. (#setter == "String" && #type == "std::string"));*/\
  158. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
  159. type value = converter(boost::get<type>(fieldValue));\
  160. reflection->Set##setter(message, descriptor, (proto_type)value);\
  161. } else if (index >= 0) {\
  162. type value = converter(boost::get<type>(fieldValue));\
  163. reflection->SetRepeated##setter(message, descriptor, index, (proto_type)value);\
  164. } else {\
  165. const Array &array = boost::get<Array>(fieldValue);\
  166. BOOST_FOREACH(Value v, array) {\
  167. type value = converter(boost::get<type>(v));\
  168. reflection->Add##setter(message, descriptor, (proto_type)value);\
  169. }\
  170. }
  171. try {
  172. switch (descriptor->cpp_type()) {
  173. case FieldDescriptor::CPPTYPE_INT32:
  174. SET_FIELD(Int32, long long, int32_t, )
  175. break;
  176. case FieldDescriptor::CPPTYPE_INT64:
  177. SET_FIELD(Int64, long long, int64_t, )
  178. break;
  179. case FieldDescriptor::CPPTYPE_UINT32:
  180. SET_FIELD(UInt32, long long, uint32_t, )
  181. break;
  182. case FieldDescriptor::CPPTYPE_UINT64:
  183. SET_FIELD(UInt64, long long, uint64_t, )
  184. break;
  185. case FieldDescriptor::CPPTYPE_FLOAT:
  186. SET_FIELD(Float, double, float, )
  187. break;
  188. case FieldDescriptor::CPPTYPE_DOUBLE:
  189. SET_FIELD(Double, double, double, )
  190. break;
  191. case FieldDescriptor::CPPTYPE_BOOL:
  192. SET_FIELD(Bool, bool, bool, )
  193. break;
  194. case FieldDescriptor::CPPTYPE_STRING:
  195. // convert hexstring bytes to data
  196. if (descriptor->type() == FieldDescriptor::TYPE_BYTES) {
  197. SET_FIELD(String, std::string, std::string, dataFromHexstring)
  198. } else {
  199. SET_FIELD(String, std::string, std::string, )
  200. }
  201. break;
  202. case FieldDescriptor::CPPTYPE_ENUM:
  203. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
  204. std::string value = boost::get<std::string>(fieldValue);
  205. const EnumValueDescriptor * val = descriptor->enum_type()->FindValueByName(value);
  206. if (!val) {
  207. throw std::runtime_error("invalid enum " + value);
  208. }
  209. reflection->SetEnum(message, descriptor, val);
  210. } else {
  211. const Array &array = boost::get<Array>(fieldValue);
  212. BOOST_FOREACH(Value v, array) {
  213. std::string value = boost::get<std::string>(v);
  214. const EnumValueDescriptor * val = descriptor->enum_type()->FindValueByName(value);
  215. if (!val) {
  216. throw std::runtime_error("invalid enum " + value);
  217. }
  218. reflection->AddEnum(message, descriptor, val);
  219. }
  220. }
  221. break;
  222. default:
  223. MORDOR_NOTREACHED();
  224. }
  225. } catch (boost::bad_get &) {
  226. throw std::runtime_error(message->GetDescriptor()->name() + "." +
  227. descriptor->name() + " is invalid");
  228. }
  229. #undef SET_FIELD
  230. }
  231. Value getFieldValue(const Message &message, const FieldDescriptor *descriptor, int index=-1)
  232. {
  233. const Reflection *reflection = message.GetReflection();
  234. #define GET_FIELD(getter_type, type, converter)\
  235. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
  236. if (reflection->HasField(message, descriptor)) {\
  237. type value = reflection->Get##getter_type(message, descriptor);\
  238. return converter(value);\
  239. } else {\
  240. return boost::blank();\
  241. }\
  242. } else if (index >= 0) {\
  243. type value = reflection->GetRepeated##getter_type(message, descriptor, index);\
  244. return converter(value);\
  245. } else {\
  246. Array array;\
  247. int field_size = reflection->FieldSize(message, descriptor);\
  248. for (int i = 0; i < field_size; i++) {\
  249. type value = reflection->GetRepeated##getter_type(message, descriptor, i);\
  250. array.push_back(converter(value));\
  251. }\
  252. return array;\
  253. }
  254. try {
  255. switch (descriptor->cpp_type()) {
  256. case FieldDescriptor::CPPTYPE_INT32:
  257. GET_FIELD(Int32, long long, )
  258. case FieldDescriptor::CPPTYPE_INT64:
  259. GET_FIELD(Int64, long long, )
  260. case FieldDescriptor::CPPTYPE_UINT32:
  261. GET_FIELD(UInt32, long long, )
  262. case FieldDescriptor::CPPTYPE_UINT64:
  263. GET_FIELD(UInt64, long long, )
  264. case FieldDescriptor::CPPTYPE_FLOAT:
  265. GET_FIELD(Float, double, )
  266. case FieldDescriptor::CPPTYPE_DOUBLE:
  267. GET_FIELD(Double, double, )
  268. case FieldDescriptor::CPPTYPE_BOOL:
  269. GET_FIELD(Bool, bool, )
  270. case FieldDescriptor::CPPTYPE_STRING:
  271. if (descriptor->type() == FieldDescriptor::TYPE_BYTES) {
  272. // convert to hexstring
  273. GET_FIELD(String, std::string, hexstringFromData);
  274. } else {
  275. GET_FIELD(String, std::string, )
  276. }
  277. case FieldDescriptor::CPPTYPE_ENUM:
  278. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
  279. if (reflection->HasField(message, descriptor)) {
  280. const EnumValueDescriptor *value = reflection->GetEnum(message, descriptor);
  281. MORDOR_ASSERT(value);
  282. return value->name();
  283. }else {
  284. return boost::blank();
  285. }
  286. } else {
  287. Array array;
  288. int field_size = reflection->FieldSize(message, descriptor);
  289. for (int i = 0; i < field_size; i++) {
  290. const EnumValueDescriptor *value = reflection->GetRepeatedEnum(message, descriptor, i);
  291. array.push_back(value->name());
  292. }
  293. return array;
  294. }
  295. default:
  296. MORDOR_NOTREACHED();
  297. }
  298. } catch (boost::bad_get &) {
  299. throw std::runtime_error(message.GetDescriptor()->name() + "." + descriptor->name() + " is invalid");
  300. }
  301. #undef GET_FIELD
  302. }
  303. const FieldDescriptor *getFieldDescription(Message *msg, const std::string &fieldName, int &index)
  304. {
  305. static boost::regex index_regex("^([^ ]*)\\[([0-9]+)\\]$");
  306. const Reflection* reflection = msg->GetReflection();
  307. index = -1;
  308. size_t pos = fieldName.find('.');
  309. std::string field = fieldName.substr(0, pos);
  310. // check indexed field
  311. if (boost::regex_match(field, index_regex)) {
  312. std::string idx = field;
  313. boost::replace_all_regex(idx, index_regex, std::string("$2"));
  314. index = boost::lexical_cast<int>(idx);
  315. field = field.substr(0, field.find('['));
  316. }
  317. const FieldDescriptor *desc = msg->GetDescriptor()->FindFieldByName(field);
  318. if (desc == NULL) {
  319. throw std::runtime_error(field + " field cannot be found");
  320. }
  321. if (desc->label() != FieldDescriptor::LABEL_REPEATED && index != -1) {
  322. throw std::runtime_error(field + " is not repeated");
  323. }
  324. if (desc->label() == FieldDescriptor::LABEL_REPEATED && index == -1) {
  325. throw std::runtime_error(field + " should be repeated");
  326. }
  327. if (index >= 0 && index >= reflection->FieldSize(*msg, desc)) {
  328. throw std::runtime_error(field + " out of bound");
  329. }
  330. if (pos != std::string::npos && desc->type() != FieldDescriptor::TYPE_MESSAGE) {
  331. throw std::runtime_error(field + " is not be a sub-message");
  332. }
  333. if (pos == std::string::npos && desc->type() == FieldDescriptor::TYPE_MESSAGE) {
  334. throw std::runtime_error(field + " should be a sub-message");
  335. }
  336. return desc;
  337. }
  338. } // end of anonymous ns
  339. void serializeToJsonObject(const Message &message, Object &object, bool validate, bool includeNull)
  340. {
  341. const Reflection* reflection = message.GetReflection();
  342. for (int i = 0; i < message.GetDescriptor()->field_count(); i++) {
  343. const FieldDescriptor* descriptor = message.GetDescriptor()->field(i);
  344. std::string fieldName = descriptor->name();
  345. if (descriptor->type() == FieldDescriptor::TYPE_MESSAGE) {
  346. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
  347. Object msgObj;
  348. if (reflection->HasField(message, descriptor)) {
  349. const Message &subMessage = reflection->GetMessage(message, descriptor,
  350. MessageFactory::generated_factory());
  351. serializeToJsonObject(subMessage, msgObj, validate, includeNull);
  352. }
  353. else if (validate && descriptor->is_required()) {
  354. throw std::runtime_error(message.GetDescriptor()->name() + "." + descriptor->name() + " is required");
  355. }
  356. object[fieldName] = msgObj;
  357. } else {
  358. Array array;
  359. int field_size = reflection->FieldSize(message, descriptor);
  360. for (int i = 0; i < field_size; i++) {
  361. Object msgObj;
  362. const Message &subMessage = reflection->GetRepeatedMessage(message, descriptor, i);
  363. serializeToJsonObject(subMessage, msgObj, validate, includeNull);
  364. array.push_back(msgObj);
  365. }
  366. object[fieldName] = array;
  367. }
  368. continue;
  369. }
  370. // other value types
  371. Value value = getFieldValue(message, descriptor);
  372. if (includeNull || !value.isBlank()) {
  373. object[fieldName] = value;
  374. }
  375. }
  376. }
  377. void parseFromJsonObject(Message *message, const Object &object, bool validate)
  378. {
  379. const Reflection* reflection = message->GetReflection();
  380. MORDOR_ASSERT(reflection);
  381. for (int i = 0; i < message->GetDescriptor()->field_count(); i++) {
  382. const FieldDescriptor* descriptor = message->GetDescriptor()->field(i);
  383. std::string field_name = descriptor->name();
  384. MORDOR_ASSERT(descriptor);
  385. if (descriptor->type() == FieldDescriptor::TYPE_MESSAGE) {
  386. Object::const_iterator it = object.find(descriptor->name());
  387. if (it != object.end()) {
  388. if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
  389. Message *subMessage = reflection->MutableMessage(message, descriptor,
  390. MessageFactory::generated_factory());
  391. const Object &value = boost::get<Object>(it->second);
  392. parseFromJsonObject(subMessage, value, validate);
  393. } else {
  394. const Array &array = boost::get<Array>(it->second);
  395. BOOST_FOREACH(Value v, array) {
  396. const Object &childObject = boost::get<Object>(v);
  397. Message *subMessage = reflection->AddMessage(message, descriptor,
  398. MessageFactory::generated_factory());
  399. parseFromJsonObject(subMessage, childObject, validate);
  400. }
  401. }
  402. }
  403. } else {
  404. Object::const_iterator it = object.find(field_name);
  405. if (it == object.end() || it->second.isBlank()) {
  406. if (validate && descriptor->is_required()) {
  407. throw std::runtime_error(message->GetDescriptor()->name() + "." + field_name + " is required");
  408. }
  409. continue; // skip optional field
  410. }
  411. setFieldValue(message, descriptor, it->second);
  412. }
  413. }
  414. }
  415. Message* forName(const std::string& typeName)
  416. {
  417. Message* message = NULL;
  418. const Descriptor* descriptor = DescriptorPool::generated_pool()->FindMessageTypeByName(typeName);
  419. if (descriptor)
  420. {
  421. const Message* prototype = MessageFactory::generated_factory()->GetPrototype(descriptor);
  422. if (prototype)
  423. {
  424. message = prototype->New();
  425. }
  426. }
  427. return message;
  428. }
  429. std::string toJson(const Message &message, bool validate, bool includeNull)
  430. {
  431. Object root, msgObj;
  432. serializeToJsonObject(message, msgObj, validate, includeNull);
  433. root[message.GetDescriptor()->full_name()] = msgObj;
  434. std::ostringstream os;
  435. os << root;
  436. return os.str();
  437. }
  438. Message * fromJson(const Buffer &buffer, bool validate)
  439. {
  440. Value root;
  441. Parser parser(root);
  442. parser.run(buffer);
  443. if (!parser.final() || parser.error()) {
  444. return NULL;
  445. }
  446. Object &rootObject = boost::get<Object>(root);
  447. std::string messageType = rootObject.begin()->first;
  448. Message *message = forName(messageType);
  449. if (!message) {
  450. return NULL;
  451. }
  452. Object &messageObject = boost::get<Object>(rootObject.begin()->second);
  453. parseFromJsonObject(message, messageObject, validate);
  454. return message;
  455. }
  456. // setField(partialObject, "store_dat.len", 100);
  457. // setField(namedObject, "fs_type", "DIRECTORY");
  458. void setField(Message *msg, const std::string &fieldName, const Value &value)
  459. {
  460. const Reflection* reflection = msg->GetReflection();
  461. int index;
  462. // get child property desc
  463. const FieldDescriptor *desc = getFieldDescription(msg, fieldName, index);
  464. MORDOR_ASSERT(desc);
  465. if (desc->type() == FieldDescriptor::TYPE_MESSAGE) {
  466. Message *subMsg = (index >=0) ?
  467. reflection->MutableRepeatedMessage(msg, desc, index) :
  468. reflection->MutableMessage(msg, desc);
  469. size_t pos = fieldName.find('.');
  470. MORDOR_ASSERT(pos != std::string::npos);
  471. std::string remaining = fieldName.substr(pos + 1);
  472. return setField(subMsg, remaining, value);
  473. }
  474. setFieldValue(msg, desc, value, index);
  475. }
  476. // getField(partialObject, "store_dat.frag[1].tdn_id")
  477. Value getField(Message *msg, const std::string &fieldName)
  478. {
  479. const Reflection* reflection = msg->GetReflection();
  480. int index;
  481. // get child property desc
  482. const FieldDescriptor *desc = getFieldDescription(msg, fieldName, index);
  483. MORDOR_ASSERT(desc);
  484. if (desc->type() == FieldDescriptor::TYPE_MESSAGE) {
  485. Message *subMsg = (index >=0) ?
  486. reflection->MutableRepeatedMessage(msg, desc, index) :
  487. reflection->MutableMessage(msg, desc);
  488. size_t pos = fieldName.find('.');
  489. MORDOR_ASSERT(pos != std::string::npos);
  490. std::string remaining = fieldName.substr(pos + 1);
  491. return getField(subMsg, remaining);
  492. }
  493. return getFieldValue(*msg, desc, index);
  494. }
  495. }