PageRenderTime 52ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/a4io/src/input_stream_a4_impl.h

https://github.com/a4/a4
C Header | 403 lines | 322 code | 54 blank | 27 comment | 79 complexity | 429238bd0275b6733e33b2f13b500a65 MD5 | raw file
  1. #ifndef _A4_INPUT_STREAM_A4_IMPL_
  2. #define _A4_INPUT_STREAM_A4_IMPL_
  3. #include <a4/types.h>
  4. #include <string>
  5. #include <tuple>
  6. #include <unordered_set>
  7. #include <vector>
  8. #include "base_compressed_streams.h"
  9. #include "proto_class_pool.h"
  10. #include "zero_copy_resource.h"
  11. #include <a4/message.h>
  12. #include <a4/io/A4Stream.pb.h>
  13. #include <boost/thread.hpp>
  14. #include <boost/thread/locks.hpp>
  15. typedef boost::unique_lock<boost::mutex> Lock;
  16. const uint32_t HIGH_BIT = 1 << 31;
  17. const std::string START_MAGIC = "A4STREAM";
  18. const std::string END_MAGIC = "KTHXBYE4";
  19. const int START_MAGIC_len = 8;
  20. const int END_MAGIC_len = 8;
  21. #include <a4/input_stream_impl.h>
  22. namespace a4 {
  23. namespace io {
  24. class InputStreamA4Impl : public a4::io::InputStreamImpl
  25. {
  26. public:
  27. InputStreamA4Impl(UNIQUE<ZeroCopyStreamResource>, std::string name);
  28. virtual ~InputStreamA4Impl();
  29. /// Returns the next regular message in the stream.
  30. shared<A4Message> next(bool skip_metadata=true);
  31. /// Returns the next bare message in the stream.
  32. shared<A4Message> next_bare_message();
  33. /// Returns the next regular or metadata message in the stream.
  34. shared<A4Message> next_with_metadata() { return next(false); }
  35. /// Return the current metadata message.
  36. shared<const A4Message> current_metadata() {return _current_metadata; }
  37. /// Seek to the given header/metadata combination.
  38. /// If carry==false, specifying a metadata index not in that header section
  39. /// causes an exception, otherwise the next header is used, or false is returned on EOF.
  40. bool seek_to(uint32_t header, int32_t metadata, bool carry=false);
  41. /// Skip to the start of the next metadata block. Return false if EOF, true if not.
  42. bool skip_to_next_metadata() {
  43. return seek_to(_current_header_index, _current_metadata_index+1, true);
  44. }
  45. /// True if new metadata has appeared since the last call to this function.
  46. bool new_metadata() {
  47. if (!_started) startup();
  48. if (_new_metadata) {
  49. _new_metadata = false;
  50. return true;
  51. }
  52. return false;
  53. }
  54. /// True if the stream has not ended or encountered an error.
  55. bool good() { return _good; }
  56. /// True if the stream has encountered an error.
  57. bool error() { return _error; }
  58. /// True if the stream has finished without error.
  59. bool end() { return !_error && !_good; }
  60. /// explicitely close the stream
  61. void close() {
  62. _coded_in.reset();
  63. _compressed_in.reset();
  64. _raw_in.reset();
  65. _good = false;
  66. };
  67. size_t ByteCount() { return _raw_in->ByteCount(); }
  68. std::string str() { return _inputname; };
  69. const std::vector<std::vector<shared<a4::io::A4Message>>>& all_metadata() {
  70. if (_metadata_per_header.size() == 0) {
  71. if (_started)
  72. FATAL("Coding Bug: all_metadata first called after reading started!");
  73. startup(true);
  74. }
  75. return _metadata_per_header;
  76. }
  77. const std::vector<StreamFooter>& footers() {
  78. if (_footers.size() == 0) {
  79. if (_started)
  80. FATAL("Coding Bug: footers() first called after reading started!");
  81. startup(true);
  82. }
  83. return _footers;
  84. }
  85. std::vector<const google::protobuf::FileDescriptor*> get_filedescriptors() {
  86. if (not _current_class_pool) {
  87. if (_started)
  88. FATAL("Coding Bug: footers() first called after reading started!");
  89. startup(true);
  90. }
  91. return _current_class_pool->get_filedescriptors();
  92. }
  93. void set_hint_copy(bool hint_copy);
  94. bool try_read(Message & msg, const google::protobuf::Descriptor* d);
  95. private:
  96. UNIQUE<ZeroCopyStreamResource> _raw_in;
  97. UNIQUE<BaseCompressedInputStream> _compressed_in;
  98. shared<google::protobuf::io::CodedInputStream> _coded_in;
  99. shared<ProtoClassPool> _current_class_pool;
  100. // variables set at construction time
  101. std::string _inputname;
  102. // status variables
  103. bool _good, _error, _started, _discovery_complete, _do_reset_metadata;
  104. uint64_t _items_read;
  105. unsigned int _current_header_index;
  106. int32_t _current_metadata_index;
  107. // metadata-related status
  108. bool _new_metadata, _current_metadata_refers_forward;
  109. shared<A4Message> _current_metadata;
  110. shared<A4Message> _pickup;
  111. std::vector<std::vector<uint64_t>> _metadata_offset_per_header;
  112. std::vector<std::vector<shared<A4Message>>> _metadata_per_header;
  113. std::vector<bool> _headers_forward;
  114. std::vector<StreamFooter> _footers;
  115. // internal functions
  116. void startup(bool discovery_requested=false);
  117. void reset_coded_stream();
  118. bool discover_all_metadata();
  119. bool start_compression(const a4::io::StartCompressedSection& cs);
  120. bool stop_compression(const a4::io::EndCompressedSection& cs);
  121. void drop_compression();
  122. bool read_header(bool discovery_requested=false);
  123. int64_t seek(int64_t position);
  124. int64_t seek_back(int64_t position);
  125. shared<A4Message> bare_message();
  126. shared<A4Message> next_message();
  127. bool handle_compressed_section(shared<A4Message> msg);
  128. bool handle_stream_command(shared<A4Message> msg);
  129. bool handle_metadata(shared<A4Message> msg);
  130. bool carry_metadata(uint32_t& header, int32_t& metadata);
  131. void notify_last_unread_message();
  132. shared<A4Message> _last_unread_message;
  133. // set error/end status and return A4Message
  134. bool set_error();
  135. bool set_end();
  136. bool _hint_copy;
  137. };
  138. inline
  139. bool InputStreamA4Impl::try_read(Message & msg, const google::protobuf::Descriptor* d) {
  140. if (!_started)
  141. startup();
  142. if (!_good)
  143. return false;
  144. if (_hint_copy) notify_last_unread_message();
  145. if (_items_read++ % 10000 == 0) {
  146. reset_coded_stream();
  147. }
  148. uint32_t size = 0;
  149. if (!_coded_in->ReadLittleEndian32(&size)) {
  150. if (_compressed_in && _compressed_in->ByteCount() == 0) {
  151. FATAL("Reading from compressed section failed!");
  152. } else {
  153. FATAL("Unexpected end of file or corruption [0]!");
  154. }
  155. }
  156. uint32_t class_id = 0;
  157. if (size & HIGH_BIT) {
  158. size = size & (HIGH_BIT - 1);
  159. if (!_coded_in->ReadLittleEndian32(&class_id))
  160. FATAL("Unexpected end of file [1]!");
  161. }
  162. if (_current_class_pool->check_match(class_id, d)) {
  163. auto lim = _coded_in->PushLimit(size);
  164. if (not msg.ParseFromCodedStream(_coded_in.get())) {
  165. FATAL("Failed to read expected event!");
  166. }
  167. _coded_in->PopLimit(lim);
  168. return true;
  169. } else {
  170. auto _message = _current_class_pool->parse_message(class_id, _coded_in, size);
  171. _pickup.reset(new A4Message(class_id, _message, _current_class_pool));
  172. return false;
  173. }
  174. }
  175. inline
  176. shared<A4Message> InputStreamA4Impl::bare_message() {
  177. if (_pickup) {
  178. auto res = _pickup;
  179. _pickup.reset();
  180. return res;
  181. }
  182. if (!_started)
  183. startup();
  184. if (!_good)
  185. return shared<A4Message>();
  186. if (_hint_copy) notify_last_unread_message();
  187. if (_items_read++ % 10000 == 0) {
  188. reset_coded_stream();
  189. }
  190. uint32_t size = 0;
  191. if (!_coded_in->ReadLittleEndian32(&size)) {
  192. if (_compressed_in && _compressed_in->ByteCount() == 0) {
  193. FATAL("Reading from compressed section failed! inside compression: ",
  194. bool(_compressed_in), " compressed bytecount: ",
  195. _compressed_in ? _compressed_in->ByteCount() : 0,
  196. " raw_in bytecount: ", _raw_in->ByteCount());
  197. } else {
  198. FATAL("Unexpected end of file or corruption [0]!");
  199. }
  200. }
  201. uint32_t class_id = 0;
  202. if (size & HIGH_BIT) {
  203. size = size & (HIGH_BIT - 1);
  204. if (!_coded_in->ReadLittleEndian32(&class_id))
  205. FATAL("Unexpected end of file [1]!");
  206. }
  207. //VERBOSE("Next part: ", _raw_in->ByteCount(), " -- ", size, " - ", class_id);
  208. if (_hint_copy) {
  209. _last_unread_message.reset(new A4Message(class_id, size, _coded_in, _current_class_pool));
  210. return _last_unread_message;
  211. } else {
  212. auto _message = _current_class_pool->parse_message(class_id, _coded_in, size);
  213. return make_shared<A4Message>(class_id, _message, _current_class_pool);
  214. }
  215. }
  216. /// Deals with a4.io.StartCompressedSection and a4.io.EndCompressedSection messages
  217. inline
  218. bool InputStreamA4Impl::handle_compressed_section(shared<A4Message> msg) {
  219. if (msg->is<StartCompressedSection>()) {
  220. if (!start_compression(*msg->as<StartCompressedSection>()))
  221. FATAL("Unable to start compressed section!");
  222. return true;
  223. } else if (msg->is<EndCompressedSection>()) {
  224. if (!stop_compression(*msg->as<EndCompressedSection>()))
  225. FATAL("Unable to stop compressed section!");
  226. return true;
  227. }
  228. return false;
  229. }
  230. /// Deals with all internal messages, with class id between 100 and 200.
  231. inline
  232. bool InputStreamA4Impl::handle_stream_command(shared<A4Message> msg) {
  233. if (msg->class_id() < 100 || msg->class_id() > 200)
  234. return false;
  235. if (msg->is<StreamFooter>()) {
  236. if (_hint_copy) notify_last_unread_message();
  237. uint32_t size;
  238. if (!_coded_in->ReadLittleEndian32(&size))
  239. FATAL("Unexpected end of file [3]!");
  240. std::string magic;
  241. if (!_coded_in->ReadString(&magic, 8))
  242. FATAL("Unexpected end of file [4]!");
  243. if (0 != magic.compare(END_MAGIC))
  244. FATAL("Corrupt footer! Read: ", magic);
  245. if (_coded_in->ExpectAtEnd()) {
  246. // Regular end of stream
  247. _good = false;
  248. return true;
  249. }
  250. _current_header_index++;
  251. if (!read_header()) {
  252. if (_error)
  253. FATAL("Corrupt header!");
  254. _good = false; // Slightly strange but regular end of stream
  255. return true;
  256. }
  257. _current_metadata = shared<A4Message>();
  258. if (!_current_metadata_refers_forward) {
  259. _current_metadata_index = 0;
  260. if (_metadata_per_header[_current_header_index].size() > 0)
  261. _current_metadata = _metadata_per_header[_current_header_index][0];
  262. } else {
  263. _current_metadata_index = -1;
  264. }
  265. _do_reset_metadata = false; // if we had an increment before, ignore it
  266. _new_metadata = true; // a footer invalidates metadata
  267. return true;
  268. } else if (msg->is<StreamHeader>()) {
  269. FATAL("Unexpected header!");
  270. } else if (msg->is<ProtoClass>()) {
  271. _current_class_pool->add_protoclass(*msg->as<ProtoClass>());
  272. return true;
  273. }
  274. {
  275. static Lock l;
  276. static std::unordered_set<uint32_t> warned_ids;
  277. const auto id = msg->class_id();
  278. if (warned_ids.find(id) != warned_ids.end()) {
  279. warned_ids.insert(id);
  280. WARNING("Encountered unexpected internal class id: ", id,
  281. ". The input may be from a newer version of A4? "
  282. "Continuing anyway..");
  283. }
  284. }
  285. return false;
  286. }
  287. inline
  288. bool InputStreamA4Impl::handle_metadata(shared<A4Message> msg) {
  289. //if (_current_header_index > 0) {
  290. // for (int i = 0; i < _current_header_index; i++) {
  291. // _metadata_per_header[i].clear();
  292. // }
  293. //}
  294. if (msg->metadata()) {
  295. _current_metadata_index++;
  296. if (_current_metadata_refers_forward) {
  297. _current_metadata = msg;
  298. _new_metadata = true;
  299. } else {
  300. _do_reset_metadata = true;
  301. }
  302. return true;
  303. }
  304. return false;
  305. }
  306. inline
  307. shared<A4Message> InputStreamA4Impl::next_message() {
  308. if (_do_reset_metadata) {
  309. _do_reset_metadata = false;
  310. _current_metadata = shared<A4Message>();
  311. if (_metadata_per_header.size() > _current_header_index) {
  312. auto& header_metadata = _metadata_per_header[_current_header_index];
  313. if (static_cast<int32_t>(header_metadata.size()) > _current_metadata_index)
  314. _current_metadata = header_metadata[_current_metadata_index];
  315. }
  316. _new_metadata = true;
  317. }
  318. shared<A4Message> msg = bare_message();
  319. if (msg and handle_compressed_section(msg))
  320. return next_message();
  321. return msg;
  322. }
  323. inline
  324. shared<A4Message> InputStreamA4Impl::next(bool skip_metadata) {
  325. shared<A4Message> msg = next_message();
  326. if (msg and handle_stream_command(msg))
  327. return next(skip_metadata);
  328. if (msg and handle_metadata(msg) && skip_metadata)
  329. return next(skip_metadata);
  330. return msg;
  331. }
  332. inline
  333. shared<A4Message> InputStreamA4Impl::next_bare_message() {
  334. shared<A4Message> msg = next_message();
  335. if (msg and handle_stream_command(msg))
  336. return msg;
  337. if (msg and handle_metadata(msg))
  338. return msg;
  339. return msg;
  340. }
  341. inline
  342. void InputStreamA4Impl::notify_last_unread_message() {
  343. if (_last_unread_message) {
  344. _last_unread_message->invalidate_stream();
  345. _last_unread_message.reset();
  346. }
  347. }
  348. };};
  349. #endif