Generate message size #defines also for messages defined in multiple files.

Add testcase for the same.
This commit is contained in:
Petteri Aimonen
2013-10-23 21:01:11 +03:00
parent 2bfd497eea
commit 49bd3f35a0
5 changed files with 105 additions and 8 deletions

View File

@@ -94,6 +94,44 @@ assert varint_max_size(0) == 1
assert varint_max_size(127) == 1 assert varint_max_size(127) == 1
assert varint_max_size(128) == 2 assert varint_max_size(128) == 2
class EncodedSize:
'''Class used to represent the encoded size of a field or a message.
Consists of a combination of symbolic sizes and integer sizes.'''
def __init__(self, value = 0, symbols = []):
if isinstance(value, (str, Names)):
symbols = [str(value)]
value = 0
self.value = value
self.symbols = symbols
def __add__(self, other):
if isinstance(other, int):
return EncodedSize(self.value + other, self.symbols)
elif isinstance(other, (str, Names)):
return EncodedSize(self.value, self.symbols + [str(other)])
elif isinstance(other, EncodedSize):
return EncodedSize(self.value + other.value, self.symbols + other.symbols)
else:
raise ValueError("Cannot add size: " + repr(other))
def __mul__(self, other):
if isinstance(other, int):
return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols])
else:
raise ValueError("Cannot multiply size: " + repr(other))
def __str__(self):
if not self.symbols:
return str(self.value)
else:
return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
def upperlimit(self):
if not self.symbols:
return self.value
else:
return 2**32 - 1
class Enum: class Enum:
def __init__(self, names, desc, enum_options): def __init__(self, names, desc, enum_options):
'''desc is EnumDescriptorProto''' '''desc is EnumDescriptorProto'''
@@ -301,23 +339,27 @@ class Field:
if self.allocation != 'STATIC': if self.allocation != 'STATIC':
return None return None
encsize = self.enc_size
if self.pbtype == 'MESSAGE': if self.pbtype == 'MESSAGE':
for msg in allmsgs: for msg in allmsgs:
if msg.name == self.submsgname: if msg.name == self.submsgname:
encsize = msg.encoded_size(allmsgs) encsize = msg.encoded_size(allmsgs)
if encsize is None: if encsize is None:
return None # Submessage size is indeterminate return None # Submessage size is indeterminate
encsize += varint_max_size(encsize) # submsg length is encoded also
# Include submessage length prefix
encsize += varint_max_size(encsize.upperlimit())
break break
else: else:
# Submessage cannot be found, this currently occurs when # Submessage cannot be found, this currently occurs when
# the submessage type is defined in a different file. # the submessage type is defined in a different file.
return None # Instead of direct numeric value, reference the size that
# has been #defined in the other file.
if encsize is None: encsize = EncodedSize(self.submsgname + 'size')
elif self.enc_size is None:
raise RuntimeError("Could not determine encoded size for %s.%s" raise RuntimeError("Could not determine encoded size for %s.%s"
% (self.struct_name, self.name)) % (self.struct_name, self.name))
else:
encsize = EncodedSize(self.enc_size)
encsize += varint_max_size(self.tag << 3) # Tag + wire type encsize += varint_max_size(self.tag << 3) # Tag + wire type
@@ -362,7 +404,7 @@ class ExtensionRange(Field):
# We exclude extensions from the count, because they cannot be known # We exclude extensions from the count, because they cannot be known
# until runtime. Other option would be to return None here, but this # until runtime. Other option would be to return None here, but this
# way the value remains useful if extensions are not used. # way the value remains useful if extensions are not used.
return 0 return EncodedSize(0)
class ExtensionField(Field): class ExtensionField(Field):
def __init__(self, struct_name, desc, field_options): def __init__(self, struct_name, desc, field_options):
@@ -491,7 +533,7 @@ class Message:
'''Return the maximum size that this message can take when encoded. '''Return the maximum size that this message can take when encoded.
If the size cannot be determined, returns None. If the size cannot be determined, returns None.
''' '''
size = 0 size = EncodedSize(0)
for field in self.fields: for field in self.fields:
fsize = field.encoded_size(allmsgs) fsize = field.encoded_size(allmsgs)
if fsize is None: if fsize is None:
@@ -674,7 +716,7 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
msize = msg.encoded_size(messages) msize = msg.encoded_size(messages)
if msize is not None: if msize is not None:
identifier = '%s_size' % msg.name identifier = '%s_size' % msg.name
yield '#define %-40s %d\n' % (identifier, msize) yield '#define %-40s %s\n' % (identifier, msize)
yield '\n' yield '\n'
yield '#ifdef __cplusplus\n' yield '#ifdef __cplusplus\n'

View File

@@ -0,0 +1,11 @@
# Test the generation of message size #defines
Import('env')
incpath = env.Clone()
incpath.Append(PROTOCPATH = '#message_sizes')
incpath.NanopbProto("messages1")
incpath.NanopbProto("messages2")
incpath.Program(['dummy.c', 'messages1.pb.c', 'messages2.pb.c'])

View File

@@ -0,0 +1,9 @@
/* Just test that the file can be compiled successfully. */
#include "messages2.pb.h"
int main()
{
return xmit_size;
}

View File

@@ -0,0 +1,27 @@
enum MessageStatus {
FAIL = 0;
OK = 1;
};
message MessageInfo {
required fixed32 msg_id = 1;
optional fixed32 interface_id = 2;
}
message MessageResponseInfo {
required fixed64 interface_id = 1;
required fixed32 seq = 2;
required fixed32 msg_id = 3;
}
message MessageHeader {
required MessageInfo info = 1;
optional MessageResponseInfo response_info = 2;
optional MessageResponse response = 3;
}
message MessageResponse {
required MessageStatus status = 1;
required fixed32 seq = 2;
}

View File

@@ -0,0 +1,8 @@
import 'nanopb.proto';
import 'messages1.proto';
message xmit {
required MessageHeader header = 1;
required bytes data = 2 [(nanopb).max_size = 128];
}