PageRenderTime 74ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/generator/nanopb_generator.py

https://code.google.com/
Python | 436 lines | 395 code | 30 blank | 11 comment | 20 complexity | a74bf27f232191176abf15d5b71da41f MD5 | raw file
  1. '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
  2. import google.protobuf.descriptor_pb2 as descriptor
  3. import nanopb_pb2
  4. import os.path
  5. # Values are tuple (c type, pb ltype)
  6. FieldD = descriptor.FieldDescriptorProto
  7. datatypes = {
  8. FieldD.TYPE_BOOL: ('bool', 'PB_LTYPE_VARINT'),
  9. FieldD.TYPE_DOUBLE: ('double', 'PB_LTYPE_FIXED64'),
  10. FieldD.TYPE_FIXED32: ('uint32_t', 'PB_LTYPE_FIXED32'),
  11. FieldD.TYPE_FIXED64: ('uint64_t', 'PB_LTYPE_FIXED64'),
  12. FieldD.TYPE_FLOAT: ('float', 'PB_LTYPE_FIXED32'),
  13. FieldD.TYPE_INT32: ('int32_t', 'PB_LTYPE_VARINT'),
  14. FieldD.TYPE_INT64: ('int64_t', 'PB_LTYPE_VARINT'),
  15. FieldD.TYPE_SFIXED32: ('int32_t', 'PB_LTYPE_FIXED32'),
  16. FieldD.TYPE_SFIXED64: ('int64_t', 'PB_LTYPE_FIXED64'),
  17. FieldD.TYPE_SINT32: ('int32_t', 'PB_LTYPE_SVARINT'),
  18. FieldD.TYPE_SINT64: ('int64_t', 'PB_LTYPE_SVARINT'),
  19. FieldD.TYPE_UINT32: ('uint32_t', 'PB_LTYPE_VARINT'),
  20. FieldD.TYPE_UINT64: ('uint64_t', 'PB_LTYPE_VARINT')
  21. }
  22. class Names:
  23. '''Keeps a set of nested names and formats them to C identifier.
  24. You can subclass this with your own implementation.
  25. '''
  26. def __init__(self, parts = ()):
  27. if isinstance(parts, Names):
  28. parts = parts.parts
  29. self.parts = tuple(parts)
  30. def __str__(self):
  31. return '_'.join(self.parts)
  32. def __add__(self, other):
  33. if isinstance(other, (str, unicode)):
  34. return Names(self.parts + (other,))
  35. elif isinstance(other, tuple):
  36. return Names(self.parts + other)
  37. else:
  38. raise ValueError("Name parts should be of type str")
  39. def names_from_type_name(type_name):
  40. '''Parse Names() from FieldDescriptorProto type_name'''
  41. if type_name[0] != '.':
  42. raise NotImplementedError("Lookup of non-absolute type names is not supported")
  43. return Names(type_name[1:].split('.'))
  44. class Enum:
  45. def __init__(self, names, desc):
  46. '''desc is EnumDescriptorProto'''
  47. self.names = names + desc.name
  48. self.values = [(self.names + x.name, x.number) for x in desc.value]
  49. def __str__(self):
  50. result = 'typedef enum {\n'
  51. result += ',\n'.join([" %s = %d" % x for x in self.values])
  52. result += '\n} %s;' % self.names
  53. return result
  54. class Field:
  55. def __init__(self, struct_name, desc):
  56. '''desc is FieldDescriptorProto'''
  57. self.tag = desc.number
  58. self.struct_name = struct_name
  59. self.name = desc.name
  60. self.default = None
  61. self.max_size = None
  62. self.max_count = None
  63. self.array_decl = ""
  64. # Parse nanopb-specific field options
  65. if desc.options.HasExtension(nanopb_pb2.nanopb):
  66. ext = desc.options.Extensions[nanopb_pb2.nanopb]
  67. if ext.HasField("max_size"):
  68. self.max_size = ext.max_size
  69. if ext.HasField("max_count"):
  70. self.max_count = ext.max_count
  71. if desc.HasField('default_value'):
  72. self.default = desc.default_value
  73. # Decide HTYPE
  74. # HTYPE is the high-order nibble of nanopb field description,
  75. # defining whether value is required/optional/repeated.
  76. is_callback = False
  77. if desc.label == FieldD.LABEL_REQUIRED:
  78. self.htype = 'PB_HTYPE_REQUIRED'
  79. elif desc.label == FieldD.LABEL_OPTIONAL:
  80. self.htype = 'PB_HTYPE_OPTIONAL'
  81. elif desc.label == FieldD.LABEL_REPEATED:
  82. if self.max_count is None:
  83. is_callback = True
  84. else:
  85. self.htype = 'PB_HTYPE_ARRAY'
  86. self.array_decl = '[%d]' % self.max_count
  87. else:
  88. raise NotImplementedError(desc.label)
  89. # Decide LTYPE and CTYPE
  90. # LTYPE is the low-order nibble of nanopb field description,
  91. # defining how to decode an individual value.
  92. # CTYPE is the name of the c type to use in the struct.
  93. if datatypes.has_key(desc.type):
  94. self.ctype, self.ltype = datatypes[desc.type]
  95. elif desc.type == FieldD.TYPE_ENUM:
  96. self.ltype = 'PB_LTYPE_VARINT'
  97. self.ctype = names_from_type_name(desc.type_name)
  98. if self.default is not None:
  99. self.default = self.ctype + self.default
  100. elif desc.type == FieldD.TYPE_STRING:
  101. self.ltype = 'PB_LTYPE_STRING'
  102. if self.max_size is None:
  103. is_callback = True
  104. else:
  105. self.ctype = 'char'
  106. self.array_decl += '[%d]' % self.max_size
  107. elif desc.type == FieldD.TYPE_BYTES:
  108. self.ltype = 'PB_LTYPE_BYTES'
  109. if self.max_size is None:
  110. is_callback = True
  111. else:
  112. self.ctype = self.struct_name + self.name + 't'
  113. elif desc.type == FieldD.TYPE_MESSAGE:
  114. self.ltype = 'PB_LTYPE_SUBMESSAGE'
  115. self.ctype = self.submsgname = names_from_type_name(desc.type_name)
  116. else:
  117. raise NotImplementedError(desc.type)
  118. if is_callback:
  119. self.htype = 'PB_HTYPE_CALLBACK'
  120. self.ctype = 'pb_callback_t'
  121. self.array_decl = ''
  122. def __cmp__(self, other):
  123. return cmp(self.tag, other.tag)
  124. def __str__(self):
  125. if self.htype == 'PB_HTYPE_OPTIONAL':
  126. result = ' bool has_' + self.name + ';\n'
  127. elif self.htype == 'PB_HTYPE_ARRAY':
  128. result = ' size_t ' + self.name + '_count;\n'
  129. else:
  130. result = ''
  131. result += ' %s %s%s;' % (self.ctype, self.name, self.array_decl)
  132. return result
  133. def types(self):
  134. '''Return definitions for any special types this field might need.'''
  135. if self.ltype == 'PB_LTYPE_BYTES' and self.max_size is not None:
  136. result = 'typedef struct {\n'
  137. result += ' size_t size;\n'
  138. result += ' uint8_t bytes[%d];\n' % self.max_size
  139. result += '} %s;\n' % self.ctype
  140. else:
  141. result = None
  142. return result
  143. def default_decl(self, declaration_only = False):
  144. '''Return definition for this field's default value.'''
  145. if self.default is None:
  146. return None
  147. if self.ltype == 'PB_LTYPE_STRING':
  148. ctype = 'char'
  149. if self.max_size is None:
  150. return None # Not implemented
  151. else:
  152. array_decl = '[%d]' % (self.max_size + 1)
  153. default = str(self.default).encode('string_escape')
  154. default = default.replace('"', '\\"')
  155. default = '"' + default + '"'
  156. elif self.ltype == 'PB_LTYPE_BYTES':
  157. data = self.default.decode('string_escape')
  158. data = ['0x%02x' % ord(c) for c in data]
  159. if self.max_size is None:
  160. return None # Not implemented
  161. else:
  162. ctype = self.ctype
  163. default = '{%d, {%s}}' % (len(data), ','.join(data))
  164. array_decl = ''
  165. else:
  166. ctype, default = self.ctype, self.default
  167. array_decl = ''
  168. if declaration_only:
  169. return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl)
  170. else:
  171. return 'const %s %s_default%s = %s;' % (ctype, self.struct_name + self.name, array_decl, default)
  172. def pb_field_t(self, prev_field_name):
  173. '''Return the pb_field_t initializer to use in the constant array.
  174. prev_field_name is the name of the previous field or None.
  175. '''
  176. result = ' {%d, ' % self.tag
  177. result += self.htype
  178. if self.ltype is not None:
  179. result += ' | ' + self.ltype
  180. result += ',\n'
  181. if prev_field_name is None:
  182. result += ' offsetof(%s, %s),' % (self.struct_name, self.name)
  183. else:
  184. result += ' pb_delta_end(%s, %s, %s),' % (self.struct_name, self.name, prev_field_name)
  185. if self.htype == 'PB_HTYPE_OPTIONAL':
  186. result += '\n pb_delta(%s, has_%s, %s),' % (self.struct_name, self.name, self.name)
  187. elif self.htype == 'PB_HTYPE_ARRAY':
  188. result += '\n pb_delta(%s, %s_count, %s),' % (self.struct_name, self.name, self.name)
  189. else:
  190. result += ' 0,'
  191. if self.htype == 'PB_HTYPE_ARRAY':
  192. result += '\n pb_membersize(%s, %s[0]),' % (self.struct_name, self.name)
  193. result += ('\n pb_membersize(%s, %s) / pb_membersize(%s, %s[0]),'
  194. % (self.struct_name, self.name, self.struct_name, self.name))
  195. else:
  196. result += '\n pb_membersize(%s, %s),' % (self.struct_name, self.name)
  197. result += ' 0,'
  198. if self.ltype == 'PB_LTYPE_SUBMESSAGE':
  199. result += '\n &%s_fields}' % self.submsgname
  200. elif self.default is None or self.htype == 'PB_HTYPE_CALLBACK':
  201. result += ' 0}'
  202. else:
  203. result += '\n &%s_default}' % (self.struct_name + self.name)
  204. return result
  205. class Message:
  206. def __init__(self, names, desc):
  207. self.name = names
  208. self.fields = [Field(self.name, f) for f in desc.field]
  209. self.ordered_fields = self.fields[:]
  210. self.ordered_fields.sort()
  211. def get_dependencies(self):
  212. '''Get list of type names that this structure refers to.'''
  213. return [str(field.ctype) for field in self.fields]
  214. def __str__(self):
  215. result = 'typedef struct {\n'
  216. result += '\n'.join([str(f) for f in self.ordered_fields])
  217. result += '\n} %s;' % self.name
  218. return result
  219. def types(self):
  220. result = ""
  221. for field in self.fields:
  222. types = field.types()
  223. if types is not None:
  224. result += types + '\n'
  225. return result
  226. def default_decl(self, declaration_only = False):
  227. result = ""
  228. for field in self.fields:
  229. default = field.default_decl(declaration_only)
  230. if default is not None:
  231. result += default + '\n'
  232. return result
  233. def fields_declaration(self):
  234. result = 'extern const pb_field_t %s_fields[%d];' % (self.name, len(self.fields) + 1)
  235. return result
  236. def fields_definition(self):
  237. result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, len(self.fields) + 1)
  238. prev = None
  239. for field in self.ordered_fields:
  240. result += field.pb_field_t(prev)
  241. result += ',\n\n'
  242. prev = field.name
  243. result += ' PB_LAST_FIELD\n};'
  244. return result
  245. def iterate_messages(desc, names = Names()):
  246. '''Recursively find all messages. For each, yield name, DescriptorProto.'''
  247. if hasattr(desc, 'message_type'):
  248. submsgs = desc.message_type
  249. else:
  250. submsgs = desc.nested_type
  251. for submsg in submsgs:
  252. sub_names = names + submsg.name
  253. yield sub_names, submsg
  254. for x in iterate_messages(submsg, sub_names):
  255. yield x
  256. def parse_file(fdesc):
  257. '''Takes a FileDescriptorProto and returns tuple (enum, messages).'''
  258. enums = []
  259. messages = []
  260. if fdesc.package:
  261. base_name = Names(fdesc.package.split('.'))
  262. else:
  263. base_name = Names()
  264. for enum in fdesc.enum_type:
  265. enums.append(Enum(base_name, enum))
  266. for names, message in iterate_messages(fdesc, base_name):
  267. messages.append(Message(names, message))
  268. for enum in message.enum_type:
  269. enums.append(Enum(names, enum))
  270. return enums, messages
  271. def toposort2(data):
  272. '''Topological sort.
  273. From http://code.activestate.com/recipes/577413-topological-sort/
  274. This function is under the MIT license.
  275. '''
  276. for k, v in data.items():
  277. v.discard(k) # Ignore self dependencies
  278. extra_items_in_deps = reduce(set.union, data.values()) - set(data.keys())
  279. data.update(dict([(item, set()) for item in extra_items_in_deps]))
  280. while True:
  281. ordered = set(item for item,dep in data.items() if not dep)
  282. if not ordered:
  283. break
  284. for item in sorted(ordered):
  285. yield item
  286. data = dict([(item, (dep - ordered)) for item,dep in data.items()
  287. if item not in ordered])
  288. assert not data, "A cyclic dependency exists amongst %r" % data
  289. def sort_dependencies(messages):
  290. '''Sort a list of Messages based on dependencies.'''
  291. dependencies = {}
  292. message_by_name = {}
  293. for message in messages:
  294. dependencies[str(message.name)] = set(message.get_dependencies())
  295. message_by_name[str(message.name)] = message
  296. for msgname in toposort2(dependencies):
  297. if msgname in message_by_name:
  298. yield message_by_name[msgname]
  299. def generate_header(dependencies, headername, enums, messages):
  300. '''Generate content for a header file.
  301. Generates strings, which should be concatenated and stored to file.
  302. '''
  303. yield '/* Automatically generated nanopb header */\n'
  304. symbol = headername.replace('.', '_').upper()
  305. yield '#ifndef _PB_%s_\n' % symbol
  306. yield '#define _PB_%s_\n' % symbol
  307. yield '#include <pb.h>\n\n'
  308. for dependency in dependencies:
  309. noext = os.path.splitext(dependency)[0]
  310. yield '#include "%s.pb.h"\n' % noext
  311. yield '\n'
  312. yield '/* Enum definitions */\n'
  313. for enum in enums:
  314. yield str(enum) + '\n\n'
  315. yield '/* Struct definitions */\n'
  316. for msg in sort_dependencies(messages):
  317. yield msg.types()
  318. yield str(msg) + '\n\n'
  319. yield '/* Default values for struct fields */\n'
  320. for msg in messages:
  321. yield msg.default_decl(True)
  322. yield '\n'
  323. yield '/* Struct field encoding specification for nanopb */\n'
  324. for msg in messages:
  325. yield msg.fields_declaration() + '\n'
  326. yield '\n#endif\n'
  327. def generate_source(headername, enums, messages):
  328. '''Generate content for a source file.'''
  329. yield '/* Automatically generated nanopb constant definitions */\n'
  330. yield '#include "%s"\n\n' % headername
  331. for msg in messages:
  332. yield msg.default_decl(False)
  333. yield '\n\n'
  334. for msg in messages:
  335. yield msg.fields_definition() + '\n\n'
  336. if __name__ == '__main__':
  337. import sys
  338. import os.path
  339. if len(sys.argv) != 2:
  340. print "Usage: " + sys.argv[0] + " file.pb"
  341. print "where file.pb has been compiled from .proto by:"
  342. print "protoc -ofile.pb file.proto"
  343. print "Output fill be written to file.pb.h and file.pb.c"
  344. sys.exit(1)
  345. data = open(sys.argv[1], 'rb').read()
  346. fdesc = descriptor.FileDescriptorSet.FromString(data)
  347. enums, messages = parse_file(fdesc.file[0])
  348. noext = os.path.splitext(sys.argv[1])[0]
  349. headername = noext + '.pb.h'
  350. sourcename = noext + '.pb.c'
  351. headerbasename = os.path.basename(headername)
  352. print "Writing to " + headername + " and " + sourcename
  353. # List of .proto files that should not be included in the C header file
  354. # even if they are mentioned in the source .proto.
  355. excludes = ['nanopb.proto']
  356. dependencies = [d for d in fdesc.file[0].dependency if d not in excludes]
  357. header = open(headername, 'w')
  358. for part in generate_header(dependencies, headerbasename, enums, messages):
  359. header.write(part)
  360. source = open(sourcename, 'w')
  361. for part in generate_source(headerbasename, enums, messages):
  362. source.write(part)