Generate #defines for initializing message structures.

Usage like:
MyMessage foo = MyMessage_init_default;

MyMessage_init_default will initialize to default values defined in .proto.

MyMessage_init_zero will initialize to null/zero values. Same results as {}
or {0}, but will avoid compiler warnings by initializing everything explicitly.

Update issue 79
Status: FixedInGit
This commit is contained in:
Petteri Aimonen
2014-08-04 18:40:40 +03:00
parent 1d7f60fec3
commit ec3bff4ba1
5 changed files with 86 additions and 22 deletions

View File

@@ -292,28 +292,37 @@ class Field:
result = None result = None
return result return result
def default_decl(self, declaration_only = False): def get_initializer(self, null_init):
'''Return definition for this field's default value.''' '''Return literal expression for this field's default value.'''
if self.default is None:
return None
ctype, default = self.ctype, self.default if self.pbtype == 'MESSAGE':
array_decl = '' if null_init:
return '%s_init_zero' % self.ctype
else:
return '%s_init_default' % self.ctype
if self.default is None or null_init:
if self.pbtype == 'STRING':
return '""'
elif self.pbtype == 'BYTES':
return '{0, {0}}'
elif self.pbtype == 'ENUM':
return '(%s)0' % self.ctype
else:
return '0'
default = str(self.default)
if self.pbtype == 'STRING': if self.pbtype == 'STRING':
if self.allocation != 'STATIC': default = default.encode('utf-8').encode('string_escape')
return None # Not implemented
array_decl = '[%d]' % self.max_size
default = str(self.default).encode('string_escape')
default = default.replace('"', '\\"') default = default.replace('"', '\\"')
default = '"' + default + '"' default = '"' + default + '"'
elif self.pbtype == 'BYTES': elif self.pbtype == 'BYTES':
if self.allocation != 'STATIC': data = default.decode('string_escape')
return None # Not implemented
data = self.default.decode('string_escape')
data = ['0x%02x' % ord(c) for c in data] data = ['0x%02x' % ord(c) for c in data]
if len(data) == 0:
default = '{0, {0}}'
else:
default = '{%d, {%s}}' % (len(data), ','.join(data)) default = '{%d, {%s}}' % (len(data), ','.join(data))
elif self.pbtype in ['FIXED32', 'UINT32']: elif self.pbtype in ['FIXED32', 'UINT32']:
default += 'u' default += 'u'
@@ -322,6 +331,25 @@ class Field:
elif self.pbtype in ['SFIXED64', 'INT64']: elif self.pbtype in ['SFIXED64', 'INT64']:
default += 'll' default += 'll'
return default
def default_decl(self, declaration_only = False):
'''Return definition for this field's default value.'''
if self.default is None:
return None
ctype = self.ctype
default = self.get_initializer(False)
array_decl = ''
if self.pbtype == 'STRING':
if self.allocation != 'STATIC':
return None # Not implemented
array_decl = '[%d]' % self.max_size
elif self.pbtype == 'BYTES':
if self.allocation != 'STATIC':
return None # Not implemented
if declaration_only: if declaration_only:
return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl) return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl)
else: else:
@@ -553,6 +581,32 @@ class Message:
result += types + '\n' result += types + '\n'
return result return result
def get_initializer(self, null_init):
if not self.ordered_fields:
return '{0}'
parts = []
for field in self.ordered_fields:
if field.allocation == 'STATIC':
if field.rules == 'REPEATED':
parts.append('0')
parts.append('{'
+ ', '.join([field.get_initializer(null_init)] * field.max_count)
+ '}')
elif field.rules == 'OPTIONAL':
parts.append('false')
parts.append(field.get_initializer(null_init))
else:
parts.append(field.get_initializer(null_init))
elif field.allocation == 'POINTER':
parts.append('NULL')
elif field.allocation == 'CALLBACK':
if field.pbtype == 'EXTENSION':
parts.append('NULL')
else:
parts.append('{{NULL}, NULL}')
return '{' + ', '.join(parts) + '}'
def default_decl(self, declaration_only = False): def default_decl(self, declaration_only = False):
result = "" result = ""
for field in self.fields: for field in self.fields:
@@ -755,6 +809,15 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
yield msg.default_decl(True) yield msg.default_decl(True)
yield '\n' yield '\n'
yield '/* Initializer values for message structs */\n'
for msg in messages:
identifier = '%s_init_default' % msg.name
yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
for msg in messages:
identifier = '%s_init_zero' % msg.name
yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
yield '\n'
yield '/* Field tags (for use in manual encoding/decoding) */\n' yield '/* Field tags (for use in manual encoding/decoding) */\n'
for msg in sort_dependencies(messages): for msg in sort_dependencies(messages):
for field in msg.fields: for field in msg.fields:

View File

@@ -19,7 +19,8 @@
the decoding and checks the fields. */ the decoding and checks the fields. */
bool check_alltypes(pb_istream_t *stream, int mode) bool check_alltypes(pb_istream_t *stream, int mode)
{ {
AllTypes alltypes; /* Uses _init_default to just make sure that it works. */
AllTypes alltypes = AllTypes_init_default;
/* Fill with garbage to better detect initialization errors */ /* Fill with garbage to better detect initialization errors */
memset(&alltypes, 0xAA, sizeof(alltypes)); memset(&alltypes, 0xAA, sizeof(alltypes));

View File

@@ -13,7 +13,7 @@ int main(int argc, char **argv)
int mode = (argc > 1) ? atoi(argv[1]) : 0; int mode = (argc > 1) ? atoi(argv[1]) : 0;
/* Initialize the structure with constants */ /* Initialize the structure with constants */
AllTypes alltypes = {0}; AllTypes alltypes = AllTypes_init_zero;
alltypes.req_int32 = -1001; alltypes.req_int32 = -1001;
alltypes.req_int64 = -1002; alltypes.req_int64 = -1002;

View File

@@ -16,7 +16,7 @@
bool print_person(pb_istream_t *stream) bool print_person(pb_istream_t *stream)
{ {
int i; int i;
Person person; Person person = Person_init_zero;
if (!pb_decode(stream, Person_fields, &person)) if (!pb_decode(stream, Person_fields, &person))
return false; return false;

View File

@@ -12,7 +12,7 @@
bool print_person(pb_istream_t *stream) bool print_person(pb_istream_t *stream)
{ {
int i; int i;
Person person; Person person = Person_init_zero;
if (!pb_decode(stream, Person_fields, &person)) if (!pb_decode(stream, Person_fields, &person))
return false; return false;