Implement support for oneofs (C unions).
Basic test included, should probably add an oneof to the AllTypes test also. Update issue 131 Status: Started
This commit is contained in:
@@ -171,6 +171,7 @@ class Field:
|
||||
'''desc is FieldDescriptorProto'''
|
||||
self.tag = desc.number
|
||||
self.struct_name = struct_name
|
||||
self.union_name = None
|
||||
self.name = desc.name
|
||||
self.default = None
|
||||
self.max_size = None
|
||||
@@ -300,57 +301,91 @@ class Field:
|
||||
if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
|
||||
result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype)
|
||||
else:
|
||||
result = None
|
||||
result = ''
|
||||
return result
|
||||
|
||||
def get_initializer(self, null_init):
|
||||
'''Return literal expression for this field's default value.'''
|
||||
|
||||
def get_dependencies(self):
|
||||
'''Get list of type names used by this field.'''
|
||||
if self.allocation == 'STATIC':
|
||||
return [str(self.ctype)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_initializer(self, null_init, inner_init_only = False):
|
||||
'''Return literal expression for this field's default value.
|
||||
null_init: If True, initialize to a 0 value instead of default from .proto
|
||||
inner_init_only: If True, exclude initialization for any count/has fields
|
||||
'''
|
||||
|
||||
inner_init = None
|
||||
if self.pbtype == 'MESSAGE':
|
||||
if null_init:
|
||||
return '%s_init_zero' % self.ctype
|
||||
inner_init = '%s_init_zero' % self.ctype
|
||||
else:
|
||||
return '%s_init_default' % self.ctype
|
||||
|
||||
if self.default is None or null_init:
|
||||
inner_init = '%s_init_default' % self.ctype
|
||||
elif self.default is None or null_init:
|
||||
if self.pbtype == 'STRING':
|
||||
return '""'
|
||||
inner_init = '""'
|
||||
elif self.pbtype == 'BYTES':
|
||||
return '{0, {0}}'
|
||||
inner_init = '{0, {0}}'
|
||||
elif self.pbtype == 'ENUM':
|
||||
return '(%s)0' % self.ctype
|
||||
inner_init = '(%s)0' % self.ctype
|
||||
else:
|
||||
return '0'
|
||||
|
||||
default = str(self.default)
|
||||
|
||||
if self.pbtype == 'STRING':
|
||||
default = default.encode('utf-8').encode('string_escape')
|
||||
default = default.replace('"', '\\"')
|
||||
default = '"' + default + '"'
|
||||
elif self.pbtype == 'BYTES':
|
||||
data = default.decode('string_escape')
|
||||
data = ['0x%02x' % ord(c) for c in data]
|
||||
if len(data) == 0:
|
||||
default = '{0, {0}}'
|
||||
inner_init = '0'
|
||||
else:
|
||||
if self.pbtype == 'STRING':
|
||||
inner_init = self.default.encode('utf-8').encode('string_escape')
|
||||
inner_init = inner_init.replace('"', '\\"')
|
||||
inner_init = '"' + inner_init + '"'
|
||||
elif self.pbtype == 'BYTES':
|
||||
data = str(self.default).decode('string_escape')
|
||||
data = ['0x%02x' % ord(c) for c in data]
|
||||
if len(data) == 0:
|
||||
inner_init = '{0, {0}}'
|
||||
else:
|
||||
inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
|
||||
elif self.pbtype in ['FIXED32', 'UINT32']:
|
||||
inner_init = str(self.default) + 'u'
|
||||
elif self.pbtype in ['FIXED64', 'UINT64']:
|
||||
inner_init = str(self.default) + 'ull'
|
||||
elif self.pbtype in ['SFIXED64', 'INT64']:
|
||||
inner_init = str(self.default) + 'll'
|
||||
else:
|
||||
default = '{%d, {%s}}' % (len(data), ','.join(data))
|
||||
elif self.pbtype in ['FIXED32', 'UINT32']:
|
||||
default += 'u'
|
||||
elif self.pbtype in ['FIXED64', 'UINT64']:
|
||||
default += 'ull'
|
||||
elif self.pbtype in ['SFIXED64', 'INT64']:
|
||||
default += 'll'
|
||||
inner_init = str(self.default)
|
||||
|
||||
return default
|
||||
|
||||
if inner_init_only:
|
||||
return inner_init
|
||||
|
||||
outer_init = None
|
||||
if self.allocation == 'STATIC':
|
||||
if self.rules == 'REPEATED':
|
||||
outer_init = '0, {'
|
||||
outer_init += ', '.join([inner_init] * self.max_count)
|
||||
outer_init += '}'
|
||||
elif self.rules == 'OPTIONAL':
|
||||
outer_init = 'false, ' + inner_init
|
||||
else:
|
||||
outer_init = inner_init
|
||||
elif self.allocation == 'POINTER':
|
||||
if self.rules == 'REPEATED':
|
||||
outer_init = '0, NULL'
|
||||
else:
|
||||
outer_init = 'NULL'
|
||||
elif self.allocation == 'CALLBACK':
|
||||
if self.pbtype == 'EXTENSION':
|
||||
outer_init = 'NULL'
|
||||
else:
|
||||
outer_init = '{{NULL}, NULL}'
|
||||
|
||||
return outer_init
|
||||
|
||||
def default_decl(self, declaration_only = False):
|
||||
'''Return definition for this field's default value.'''
|
||||
if self.default is None:
|
||||
return None
|
||||
|
||||
ctype = self.ctype
|
||||
default = self.get_initializer(False)
|
||||
default = self.get_initializer(False, True)
|
||||
array_decl = ''
|
||||
|
||||
if self.pbtype == 'STRING':
|
||||
@@ -375,7 +410,13 @@ class Field:
|
||||
'''Return the pb_field_t initializer to use in the constant array.
|
||||
prev_field_name is the name of the previous field or None.
|
||||
'''
|
||||
result = ' PB_FIELD(%3d, ' % self.tag
|
||||
|
||||
if self.rules == 'ONEOF':
|
||||
result = ' PB_ONEOF_FIELD(%s, ' % self.union_name
|
||||
else:
|
||||
result = ' PB_FIELD('
|
||||
|
||||
result += '%3d, ' % self.tag
|
||||
result += '%-8s, ' % self.pbtype
|
||||
result += '%s, ' % self.rules
|
||||
result += '%-8s, ' % self.allocation
|
||||
@@ -403,6 +444,8 @@ class Field:
|
||||
if self.pbtype == 'MESSAGE':
|
||||
if self.rules == 'REPEATED' and self.allocation == 'STATIC':
|
||||
return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name)
|
||||
elif self.rules == 'ONEOF':
|
||||
return 'pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name)
|
||||
else:
|
||||
return 'pb_membersize(%s, %s)' % (self.struct_name, self.name)
|
||||
|
||||
@@ -534,6 +577,71 @@ class ExtensionField(Field):
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation of oneofs (unions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OneOf(Field):
|
||||
def __init__(self, oneof_desc):
|
||||
self.name = oneof_desc.name
|
||||
self.ctype = 'union'
|
||||
self.fields = []
|
||||
|
||||
def add_field(self, field):
|
||||
if field.allocation == 'CALLBACK':
|
||||
raise Exception("Callback fields inside of oneof are not supported"
|
||||
+ " (field %s)" % field.fullname)
|
||||
|
||||
field.union_name = self.name
|
||||
field.rules = 'ONEOF'
|
||||
self.fields.append(field)
|
||||
self.fields.sort(key = lambda f: f.tag)
|
||||
|
||||
# Sort by the lowest tag number inside union
|
||||
self.tag = min([f.tag for f in self.fields])
|
||||
|
||||
def __cmp__(self, other):
|
||||
return cmp(self.tag, other.tag)
|
||||
|
||||
def __str__(self):
|
||||
result = ''
|
||||
if self.fields:
|
||||
result += ' pb_size_t which_' + self.name + ";\n"
|
||||
result += ' union {\n'
|
||||
for f in self.fields:
|
||||
result += ' ' + str(f).replace('\n', '\n ') + '\n'
|
||||
result += ' } ' + self.name + ';'
|
||||
return result
|
||||
|
||||
def types(self):
|
||||
return ''.join([f.types() for f in self.fields])
|
||||
|
||||
def get_dependencies(self):
|
||||
deps = []
|
||||
for f in self.fields:
|
||||
deps += f.get_dependencies()
|
||||
return deps
|
||||
|
||||
def get_initializer(self, null_init):
|
||||
return '0, {' + self.fields[0].get_initializer(null_init) + '}'
|
||||
|
||||
def default_decl(self, declaration_only = False):
|
||||
return None
|
||||
|
||||
def tags(self):
|
||||
return '\n'.join([f.tags() for f in self.fields])
|
||||
|
||||
def pb_field_t(self, prev_field_name):
|
||||
prev_field_name = prev_field_name or self.name
|
||||
result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields])
|
||||
return result
|
||||
|
||||
def largest_field_value(self):
|
||||
return max([f.largest_field_value() for f in self.fields])
|
||||
|
||||
def encoded_size(self, allmsgs):
|
||||
return max([f.encoded_size(allmsgs) for f in self.fields])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation of messages (structures)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -543,11 +651,24 @@ class Message:
|
||||
def __init__(self, names, desc, message_options):
|
||||
self.name = names
|
||||
self.fields = []
|
||||
|
||||
self.oneofs = []
|
||||
|
||||
if hasattr(desc, 'oneof_decl'):
|
||||
for f in desc.oneof_decl:
|
||||
oneof = OneOf(f)
|
||||
self.oneofs.append(oneof)
|
||||
self.fields.append(oneof)
|
||||
|
||||
for f in desc.field:
|
||||
field_options = get_nanopb_suboptions(f, message_options, self.name + f.name)
|
||||
if field_options.type != nanopb_pb2.FT_IGNORE:
|
||||
self.fields.append(Field(self.name, f, field_options))
|
||||
if field_options.type == nanopb_pb2.FT_IGNORE:
|
||||
continue
|
||||
|
||||
field = Field(self.name, f, field_options)
|
||||
if hasattr(f, 'oneof_index') and f.HasField('oneof_index'):
|
||||
self.oneofs[f.oneof_index].add_field(field)
|
||||
else:
|
||||
self.fields.append(field)
|
||||
|
||||
if len(desc.extension_range) > 0:
|
||||
field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
|
||||
@@ -561,7 +682,10 @@ class Message:
|
||||
|
||||
def get_dependencies(self):
|
||||
'''Get list of type names that this structure refers to.'''
|
||||
return [str(field.ctype) for field in self.fields if field.allocation == 'STATIC']
|
||||
deps = []
|
||||
for f in self.fields:
|
||||
deps += f.get_dependencies()
|
||||
return deps
|
||||
|
||||
def __str__(self):
|
||||
result = 'typedef struct _%s {\n' % self.name
|
||||
@@ -586,39 +710,15 @@ class Message:
|
||||
return result
|
||||
|
||||
def types(self):
|
||||
result = ""
|
||||
for field in self.fields:
|
||||
types = field.types()
|
||||
if types is not None:
|
||||
result += types + '\n'
|
||||
return result
|
||||
|
||||
return ''.join([f.types() for f in self.fields])
|
||||
|
||||
def get_initializer(self, null_init):
|
||||
if not self.ordered_fields:
|
||||
return '{0}'
|
||||
|
||||
parts = []
|
||||
for field in self.ordered_fields:
|
||||
if field.allocation == 'STATIC':
|
||||
if field.rules == 'REPEATED':
|
||||
parts.append('0')
|
||||
parts.append('{'
|
||||
+ ', '.join([field.get_initializer(null_init)] * field.max_count)
|
||||
+ '}')
|
||||
elif field.rules == 'OPTIONAL':
|
||||
parts.append('false')
|
||||
parts.append(field.get_initializer(null_init))
|
||||
else:
|
||||
parts.append(field.get_initializer(null_init))
|
||||
elif field.allocation == 'POINTER':
|
||||
if field.rules == 'REPEATED':
|
||||
parts.append('0')
|
||||
parts.append('NULL')
|
||||
elif field.allocation == 'CALLBACK':
|
||||
if field.pbtype == 'EXTENSION':
|
||||
parts.append('NULL')
|
||||
else:
|
||||
parts.append('{{NULL}, NULL}')
|
||||
parts.append(field.get_initializer(null_init))
|
||||
return '{' + ', '.join(parts) + '}'
|
||||
|
||||
def default_decl(self, declaration_only = False):
|
||||
@@ -629,18 +729,39 @@ class Message:
|
||||
result += default + '\n'
|
||||
return result
|
||||
|
||||
def count_required_fields(self):
|
||||
'''Returns number of required fields inside this message'''
|
||||
count = 0
|
||||
for f in self.fields:
|
||||
if f not in self.oneofs:
|
||||
if f.rules == 'REQUIRED':
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def count_all_fields(self):
|
||||
count = 0
|
||||
for f in self.fields:
|
||||
if f in self.oneofs:
|
||||
count += len(f.fields)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def fields_declaration(self):
|
||||
result = 'extern const pb_field_t %s_fields[%d];' % (self.name, len(self.fields) + 1)
|
||||
result = 'extern const pb_field_t %s_fields[%d];' % (self.name, self.count_all_fields() + 1)
|
||||
return result
|
||||
|
||||
def fields_definition(self):
|
||||
result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, len(self.fields) + 1)
|
||||
result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, self.count_all_fields() + 1)
|
||||
|
||||
prev = None
|
||||
for field in self.ordered_fields:
|
||||
result += field.pb_field_t(prev)
|
||||
result += ',\n'
|
||||
prev = field.name
|
||||
if isinstance(field, OneOf):
|
||||
prev = field.name + '.' + field.fields[-1].name
|
||||
else:
|
||||
prev = field.name
|
||||
|
||||
result += ' PB_LAST_FIELD\n};'
|
||||
return result
|
||||
@@ -894,9 +1015,8 @@ def generate_source(headername, enums, messages, extensions, options):
|
||||
|
||||
# Add checks for numeric limits
|
||||
if messages:
|
||||
count_required_fields = lambda m: len([f for f in msg.fields if f.rules == 'REQUIRED'])
|
||||
largest_msg = max(messages, key = count_required_fields)
|
||||
largest_count = count_required_fields(largest_msg)
|
||||
largest_msg = max(messages, key = lambda m: m.count_required_fields())
|
||||
largest_count = largest_msg.count_required_fields()
|
||||
if largest_count > 64:
|
||||
yield '\n/* Check that missing required fields will be properly detected */\n'
|
||||
yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
|
||||
|
||||
Reference in New Issue
Block a user