PageRenderTime 247ms CodeModel.GetById 80ms app.highlight 100ms RepoModel.GetById 61ms app.codeStats 0ms

/mordor/protobuf.cpp

http://github.com/mozy/mordor
C++ | 579 lines | 473 code | 84 blank | 22 comment | 100 complexity | 473563f1a4f94af425956ffbe7a59a5f MD5 | raw file
  1// Copyright (c) 2010 - Mozy, Inc.
  2
  3#include "protobuf.h"
  4
  5#include "mordor/assert.h"
  6#include "mordor/streams/buffer.h"
  7#include "mordor/string.h"
  8
  9#ifdef MSVC
 10// Disable some warnings, but only while
 11// processing the google generated code
 12#pragma warning(push)
 13#pragma warning(disable : 4244)
 14#endif
 15
 16
 17#include <boost/algorithm/string/regex.hpp>
 18#include <boost/foreach.hpp>
 19#include <boost/lexical_cast.hpp>
 20
 21#undef TYPE_BOOL // avoid collision from Mac OS X's ConditionalMacros.h
 22#include <google/protobuf/descriptor.h>
 23#include <google/protobuf/io/zero_copy_stream.h>
 24#include <google/protobuf/message.h>
 25
 26#ifdef MSVC
 27#pragma warning(pop)
 28#endif
 29
 30using namespace google::protobuf;
 31using namespace Mordor::JSON;
 32
 33#ifdef MSVC
 34#ifdef _DEBUG
 35#pragma comment(lib, "libprotobufd.lib")
 36#else
 37#pragma comment(lib, "libprotobuf.lib")
 38#endif
 39#endif
 40
 41namespace Mordor {
 42
 43class BufferZeroCopyInputStream : public io::ZeroCopyInputStream
 44{
 45public:
 46    BufferZeroCopyInputStream(const Buffer &buffer)
 47        : m_iovs(buffer.readBuffers()),
 48          m_currentIov(0),
 49          m_currentIovOffset(0),
 50          m_complete(0)
 51    {}
 52
 53    bool Next(const void **data, int *size)
 54    {
 55        if (m_currentIov >= m_iovs.size())
 56            return false;
 57        MORDOR_ASSERT(m_currentIovOffset <= m_iovs[m_currentIov].iov_len);
 58        *data = (char *)m_iovs[m_currentIov].iov_base + m_currentIovOffset;
 59        *size = (int)(m_iovs[m_currentIov].iov_len - m_currentIovOffset);
 60
 61        m_complete += *size;
 62        m_currentIovOffset = 0;
 63        ++m_currentIov;
 64
 65        return true;
 66    }
 67
 68    void BackUp(int count)
 69    {
 70        MORDOR_ASSERT(count >= 0);
 71        MORDOR_ASSERT(count <= m_complete);
 72        m_complete -= count;
 73        while (count) {
 74            if (m_currentIovOffset == 0) {
 75                MORDOR_ASSERT(m_currentIov > 0);
 76                m_currentIovOffset = m_iovs[--m_currentIov].iov_len;
 77            }
 78            size_t todo = (std::min)(m_currentIovOffset, (size_t)count);
 79            m_currentIovOffset -= todo;
 80            count -= (int)todo;
 81        }
 82    }
 83
 84    bool Skip(int count) {
 85        MORDOR_ASSERT(count >= 0);
 86        while (count) {
 87            if (m_currentIov >= m_iovs.size())
 88                return false;
 89            size_t todo = (std::min)((size_t)m_iovs[m_currentIov].iov_len -
 90                m_currentIovOffset, (size_t)count);
 91            m_currentIovOffset += todo;
 92            count -= (int)todo;
 93            m_complete += todo;
 94            if (m_currentIovOffset == m_iovs[m_currentIov].iov_len) {
 95                m_currentIovOffset = 0;
 96                ++m_currentIov;
 97            }
 98        }
 99        return true;
100    }
101
102    int64 ByteCount() const { return m_complete; }
103
104private:
105    std::vector<iovec> m_iovs;
106    size_t m_currentIov;
107    size_t m_currentIovOffset;
108    int64 m_complete;
109};
110
111class BufferZeroCopyOutputStream : public io::ZeroCopyOutputStream
112{
113public:
114    BufferZeroCopyOutputStream(Buffer &buffer, size_t bufferSize = 1024)
115        : m_buffer(buffer),
116          m_bufferSize(bufferSize),
117          m_pendingProduce(0),
118          m_total(0)
119    {}
120    ~BufferZeroCopyOutputStream()
121    {
122        m_buffer.produce(m_pendingProduce);
123    }
124
125    bool Next(void **data, int *size)
126    {
127        m_buffer.produce(m_pendingProduce);
128        m_pendingProduce = 0;
129
130        // TODO: protect against std::bad_alloc?
131        iovec iov = m_buffer.writeBuffer(m_bufferSize, false);
132        *data = iov.iov_base;
133        m_total += m_pendingProduce = iov.iov_len;
134        *size = (int)m_pendingProduce;
135
136        return true;
137    }
138
139    void BackUp(int count)
140    {
141        MORDOR_ASSERT(count <= (int)m_pendingProduce);
142        m_pendingProduce -= count;
143        m_total -= count;
144    }
145
146    int64 ByteCount() const
147    {
148        return m_total;
149    }
150
151private:
152    Buffer &m_buffer;
153    size_t m_bufferSize, m_pendingProduce;
154    int64 m_total;
155};
156
157void serializeToBuffer(const Message &proto, Buffer &buffer)
158{
159    BufferZeroCopyOutputStream stream(buffer);
160    if (!proto.SerializeToZeroCopyStream(&stream))
161        MORDOR_THROW_EXCEPTION(std::invalid_argument("proto"));
162}
163
164void parseFromBuffer(Message &proto, const Buffer &buffer)
165{
166    BufferZeroCopyInputStream stream(buffer);
167    if (!proto.ParseFromZeroCopyStream(&stream))
168        MORDOR_THROW_EXCEPTION(std::invalid_argument("buffer"));
169}
170
171
172// begin anonymous namespace for reflection stuff
173namespace {
174
175void setFieldValue(Message *message, const FieldDescriptor *descriptor, const Value &fieldValue, int index=-1)
176{
177
178    const Reflection* reflection = message->GetReflection();
179
180#define SET_FIELD(setter, type, proto_type, converter)\
181    /*BOOST_STATIC_ASSERT(\
182        (#setter == "Int32" && #type == "long long") ||\
183        (#setter == "UInt32" && #type == "long long") ||\
184        (#setter == "Int64" && #type == "long long") ||\
185        (#setter == "UInt64" && #type == "long long") ||\
186        (#setter == "Float" && #type == "double") ||\
187        (#setter == "Double" && #type == "double") ||\
188        (#setter == "Bool" && #type == "bool") ||\
189        (#setter == "String" && #type == "std::string"));*/\
190    if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
191        type value = converter(boost::get<type>(fieldValue));\
192        reflection->Set##setter(message, descriptor, (proto_type)value);\
193    } else if (index >= 0) {\
194        type value = converter(boost::get<type>(fieldValue));\
195        reflection->SetRepeated##setter(message, descriptor, index, (proto_type)value);\
196    } else {\
197        const Array &array = boost::get<Array>(fieldValue);\
198        BOOST_FOREACH(Value v, array) {\
199            type value = converter(boost::get<type>(v));\
200            reflection->Add##setter(message, descriptor, (proto_type)value);\
201        }\
202    }
203
204
205    try {
206        switch (descriptor->cpp_type()) {
207            case FieldDescriptor::CPPTYPE_INT32:
208                SET_FIELD(Int32, long long, int32_t, )
209                break;
210            case FieldDescriptor::CPPTYPE_INT64:
211                SET_FIELD(Int64, long long, int64_t, )
212                break;
213            case FieldDescriptor::CPPTYPE_UINT32:
214                SET_FIELD(UInt32, long long, uint32_t, )
215                break;
216            case FieldDescriptor::CPPTYPE_UINT64:
217                SET_FIELD(UInt64, long long, uint64_t, )
218                break;
219            case FieldDescriptor::CPPTYPE_FLOAT:
220                SET_FIELD(Float, double, float, )
221                break;
222            case FieldDescriptor::CPPTYPE_DOUBLE:
223                SET_FIELD(Double, double, double, )
224                break;
225            case FieldDescriptor::CPPTYPE_BOOL:
226                SET_FIELD(Bool, bool, bool, )
227                break;
228            case FieldDescriptor::CPPTYPE_STRING:
229                // convert hexstring bytes to data
230                if (descriptor->type() == FieldDescriptor::TYPE_BYTES) {
231                    SET_FIELD(String, std::string, std::string, dataFromHexstring)
232                } else {
233                    SET_FIELD(String, std::string, std::string, )
234                }
235                break;
236            case FieldDescriptor::CPPTYPE_ENUM:
237                if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
238                    std::string value = boost::get<std::string>(fieldValue);
239                    const EnumValueDescriptor * val = descriptor->enum_type()->FindValueByName(value);
240                    if (!val) {
241                        throw std::runtime_error("invalid enum " + value);
242                    }
243                    reflection->SetEnum(message, descriptor, val);
244                } else {
245                    const Array &array = boost::get<Array>(fieldValue);
246                    BOOST_FOREACH(Value v, array) {
247                        std::string value = boost::get<std::string>(v);
248                        const EnumValueDescriptor * val = descriptor->enum_type()->FindValueByName(value);
249                        if (!val) {
250                            throw std::runtime_error("invalid enum " + value);
251                        }
252                        reflection->AddEnum(message, descriptor, val);
253                    }
254                }
255                break;
256            default:
257                MORDOR_NOTREACHED();
258        }
259    } catch (boost::bad_get &) {
260        throw std::runtime_error(message->GetDescriptor()->name() + "." +
261            descriptor->name() + " is invalid");
262    }
263#undef SET_FIELD
264}
265
266Value getFieldValue(const Message &message, const FieldDescriptor *descriptor, int index=-1)
267{
268    const Reflection *reflection = message.GetReflection();
269
270#define GET_FIELD(getter_type, type, converter)\
271    if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
272        if (reflection->HasField(message, descriptor)) {\
273            type value = reflection->Get##getter_type(message, descriptor);\
274            return converter(value);\
275        } else {\
276            return boost::blank();\
277        }\
278    } else if (index >= 0) {\
279        type value = reflection->GetRepeated##getter_type(message, descriptor, index);\
280        return converter(value);\
281    } else {\
282        Array array;\
283        int field_size = reflection->FieldSize(message, descriptor);\
284        for (int i = 0; i < field_size; i++) {\
285            type value = reflection->GetRepeated##getter_type(message, descriptor, i);\
286            array.push_back(converter(value));\
287        }\
288        return array;\
289    }
290
291    try {
292        switch (descriptor->cpp_type()) {
293            case FieldDescriptor::CPPTYPE_INT32:
294                GET_FIELD(Int32, long long, )
295            case FieldDescriptor::CPPTYPE_INT64:
296                GET_FIELD(Int64, long long, )
297            case FieldDescriptor::CPPTYPE_UINT32:
298                GET_FIELD(UInt32, long long, )
299            case FieldDescriptor::CPPTYPE_UINT64:
300                GET_FIELD(UInt64, long long, )
301            case FieldDescriptor::CPPTYPE_FLOAT:
302                GET_FIELD(Float, double, )
303            case FieldDescriptor::CPPTYPE_DOUBLE:
304                GET_FIELD(Double, double, )
305            case FieldDescriptor::CPPTYPE_BOOL:
306                GET_FIELD(Bool, bool, )
307            case FieldDescriptor::CPPTYPE_STRING:
308                if (descriptor->type() == FieldDescriptor::TYPE_BYTES) {
309                    // convert to hexstring
310                    GET_FIELD(String, std::string, hexstringFromData);
311                } else {
312                    GET_FIELD(String, std::string, )
313                }
314            case FieldDescriptor::CPPTYPE_ENUM:
315                if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {\
316                    if (reflection->HasField(message, descriptor)) {
317                        const EnumValueDescriptor *value = reflection->GetEnum(message, descriptor);
318                        MORDOR_ASSERT(value);
319                        return value->name();
320                    }else {
321                        return boost::blank();
322                    }
323                } else {
324                    Array array;
325                    int field_size = reflection->FieldSize(message, descriptor);
326                    for (int i = 0; i < field_size; i++) {
327                        const EnumValueDescriptor *value = reflection->GetRepeatedEnum(message, descriptor, i);
328                        array.push_back(value->name());
329                    }
330                    return array;
331                }
332            default:
333                MORDOR_NOTREACHED();
334        }
335    } catch (boost::bad_get &) {
336        throw std::runtime_error(message.GetDescriptor()->name() + "." + descriptor->name() + " is invalid");
337    }
338#undef GET_FIELD
339}
340
341
342const FieldDescriptor *getFieldDescription(Message *msg, const std::string &fieldName, int &index)
343{
344    static boost::regex index_regex("^([^ ]*)\\[([0-9]+)\\]$");
345    const Reflection* reflection = msg->GetReflection();
346
347
348    index = -1;
349    size_t pos = fieldName.find('.');
350
351    std::string field = fieldName.substr(0, pos);
352
353    // check indexed field
354    if (boost::regex_match(field, index_regex)) {
355        std::string idx = field;
356        boost::replace_all_regex(idx, index_regex, std::string("$2"));
357        index = boost::lexical_cast<int>(idx);
358        field = field.substr(0, field.find('['));
359    }
360
361    const FieldDescriptor *desc = msg->GetDescriptor()->FindFieldByName(field);
362    if (desc == NULL) {
363        throw std::runtime_error(field + " field cannot be found");
364    }
365    if (desc->label() != FieldDescriptor::LABEL_REPEATED && index != -1) {
366        throw std::runtime_error(field + " is not repeated");
367    }
368    if (desc->label() == FieldDescriptor::LABEL_REPEATED && index == -1) {
369        throw std::runtime_error(field + " should be repeated");
370    }
371
372    if (index >= 0 && index >= reflection->FieldSize(*msg, desc)) {
373        throw std::runtime_error(field + " out of bound");
374    }
375
376    if (pos != std::string::npos && desc->type() != FieldDescriptor::TYPE_MESSAGE) {
377        throw std::runtime_error(field + " is not be a sub-message");
378    }
379    if (pos == std::string::npos && desc->type() == FieldDescriptor::TYPE_MESSAGE) {
380        throw std::runtime_error(field + " should be a sub-message");
381    }
382
383    return desc;
384}
385
386
387} // end of anonymous ns
388
389
390void serializeToJsonObject(const Message &message, Object &object, bool validate, bool includeNull)
391{
392
393    const Reflection* reflection = message.GetReflection();
394    for (int i = 0; i < message.GetDescriptor()->field_count(); i++) {
395        const FieldDescriptor* descriptor = message.GetDescriptor()->field(i);
396        std::string fieldName = descriptor->name();
397
398        if (descriptor->type() == FieldDescriptor::TYPE_MESSAGE) {
399            if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
400                Object msgObj;
401                if (reflection->HasField(message, descriptor)) {
402                    const Message &subMessage = reflection->GetMessage(message, descriptor,
403                            MessageFactory::generated_factory());
404                    serializeToJsonObject(subMessage, msgObj, validate, includeNull);
405                }
406                else if (validate && descriptor->is_required()) {
407                    throw std::runtime_error(message.GetDescriptor()->name() + "." + descriptor->name() + " is required");
408                }
409                object[fieldName] = msgObj;
410            } else {
411                Array array;
412                int field_size = reflection->FieldSize(message, descriptor);
413                for (int i = 0; i < field_size; i++) {
414                    Object msgObj;
415                    const Message &subMessage = reflection->GetRepeatedMessage(message, descriptor, i);
416                    serializeToJsonObject(subMessage, msgObj, validate, includeNull);
417                    array.push_back(msgObj);
418                }
419                object[fieldName] = array;
420            }
421            continue;
422        }
423
424        // other value types
425        Value value = getFieldValue(message, descriptor);
426        if (includeNull || !value.isBlank()) {
427            object[fieldName] = value;
428        }
429    }
430}
431
432void parseFromJsonObject(Message *message, const Object &object, bool validate)
433{
434    const Reflection* reflection = message->GetReflection();
435    MORDOR_ASSERT(reflection);
436    for (int i = 0; i < message->GetDescriptor()->field_count(); i++) {
437        const FieldDescriptor* descriptor = message->GetDescriptor()->field(i);
438        std::string field_name = descriptor->name();
439
440        MORDOR_ASSERT(descriptor);
441
442        if (descriptor->type() == FieldDescriptor::TYPE_MESSAGE) {
443            Object::const_iterator it = object.find(descriptor->name());
444            if (it != object.end()) {
445                if (descriptor->label() != FieldDescriptor::LABEL_REPEATED) {
446                    Message *subMessage = reflection->MutableMessage(message, descriptor,
447                        MessageFactory::generated_factory());
448                    const Object &value = boost::get<Object>(it->second);
449                    parseFromJsonObject(subMessage, value, validate);
450                } else {
451                    const Array &array = boost::get<Array>(it->second);
452                    BOOST_FOREACH(Value v, array) {
453                        const Object &childObject = boost::get<Object>(v);
454                        Message *subMessage = reflection->AddMessage(message, descriptor,
455                            MessageFactory::generated_factory());
456                        parseFromJsonObject(subMessage, childObject, validate);
457                    }
458                }
459            }
460        } else {
461            Object::const_iterator it = object.find(field_name);
462            if (it == object.end() || it->second.isBlank()) {
463                if (validate && descriptor->is_required()) {
464                    throw std::runtime_error(message->GetDescriptor()->name() + "." + field_name + " is required");
465                }
466                continue; // skip optional field
467            }
468            setFieldValue(message, descriptor, it->second);
469        }
470    }
471}
472
473Message* forName(const std::string& typeName)
474{
475    Message* message = NULL;
476    const Descriptor* descriptor = DescriptorPool::generated_pool()->FindMessageTypeByName(typeName);
477    if (descriptor)
478    {
479        const Message* prototype = MessageFactory::generated_factory()->GetPrototype(descriptor);
480        if (prototype)
481        {
482            message = prototype->New();
483        }
484    }
485    return message;
486}
487
488std::string toJson(const Message &message, bool validate, bool includeNull)
489{
490    Object root, msgObj;
491
492    serializeToJsonObject(message, msgObj, validate, includeNull);
493    root[message.GetDescriptor()->full_name()] = msgObj;
494
495    std::ostringstream os;
496    os << root;
497    return os.str();
498}
499
500
501Message * fromJson(const Buffer &buffer, bool validate)
502{
503    Value root;
504    Parser parser(root);
505    parser.run(buffer);
506    if (!parser.final() || parser.error()) {
507        return NULL;
508    }
509
510    Object &rootObject = boost::get<Object>(root);
511    std::string messageType = rootObject.begin()->first;
512
513    Message *message = forName(messageType);
514    if (!message) {
515        return NULL;
516    }
517
518    Object &messageObject = boost::get<Object>(rootObject.begin()->second);
519    parseFromJsonObject(message, messageObject, validate);
520
521    return message;
522}
523
524
525
526// setField(partialObject, "store_dat.len", 100);
527// setField(namedObject, "fs_type", "DIRECTORY");
528void setField(Message *msg, const std::string &fieldName, const Value &value)
529{
530    const Reflection* reflection = msg->GetReflection();
531
532    int index;
533    // get child property desc
534    const FieldDescriptor *desc = getFieldDescription(msg, fieldName, index);
535
536    MORDOR_ASSERT(desc);
537    if (desc->type() == FieldDescriptor::TYPE_MESSAGE) {
538        Message *subMsg = (index >=0) ?
539            reflection->MutableRepeatedMessage(msg, desc, index) :
540            reflection->MutableMessage(msg, desc);
541
542        size_t pos = fieldName.find('.');
543        MORDOR_ASSERT(pos != std::string::npos);
544
545        std::string remaining = fieldName.substr(pos + 1);
546
547        return setField(subMsg, remaining, value);
548    }
549    setFieldValue(msg, desc, value, index);
550}
551
552// getField(partialObject, "store_dat.frag[1].tdn_id")
553Value getField(Message *msg, const std::string &fieldName)
554{
555    const Reflection* reflection = msg->GetReflection();
556
557    int index;
558    // get child property desc
559    const FieldDescriptor *desc = getFieldDescription(msg, fieldName, index);
560
561    MORDOR_ASSERT(desc);
562    if (desc->type() == FieldDescriptor::TYPE_MESSAGE) {
563        Message *subMsg = (index >=0) ?
564            reflection->MutableRepeatedMessage(msg, desc, index) :
565            reflection->MutableMessage(msg, desc);
566
567        size_t pos = fieldName.find('.');
568        MORDOR_ASSERT(pos != std::string::npos);
569
570        std::string remaining = fieldName.substr(pos + 1);
571
572        return getField(subMsg, remaining);
573    }
574    return getFieldValue(*msg, desc, index);
575}
576
577
578
579}