531
645
self.idx = self.limit
532
646
return r.tostring()
649
TYPE_DOUBLE: getDouble,
650
TYPE_FLOAT: getFloat,
653
TYPE_INT32: getVarInt32,
654
TYPE_INT64: getVarInt64,
655
TYPE_UINT64: getVarUint64,
656
TYPE_BOOL: getBoolean,
657
TYPE_STRING: getPrefixedString }
663
class ExtensionIdentifier(object):
664
__slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
665
'default', 'containing_cls', 'composite_cls', 'message_name')
666
def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
668
self.full_name = full_name
670
self.field_type = field_type
671
self.wire_tag = wire_tag
672
self.is_repeated = is_repeated
673
self.default = default
675
class ExtendableProtocolMessage(ProtocolMessage):
676
def HasExtension(self, extension):
677
self._VerifyExtensionIdentifier(extension)
678
return extension in self._extension_fields
680
def ClearExtension(self, extension):
681
self._VerifyExtensionIdentifier(extension)
682
if extension in self._extension_fields:
683
del self._extension_fields[extension]
685
def GetExtension(self, extension, index=None):
686
self._VerifyExtensionIdentifier(extension)
687
if extension in self._extension_fields:
688
result = self._extension_fields[extension]
690
if extension.is_repeated:
692
elif extension.composite_cls:
693
result = extension.composite_cls()
695
result = extension.default
696
if extension.is_repeated:
697
result = result[index]
700
def SetExtension(self, extension, *args):
701
self._VerifyExtensionIdentifier(extension)
702
if extension.composite_cls:
704
'Cannot assign to extension "%s" because it is a composite type.' %
706
if extension.is_repeated:
709
'SetExtension(extension, index, value) for repeated extension '
710
'takes exactly 3 arguments: (%d given)' % len(args))
713
self._extension_fields[extension][index] = value
717
'SetExtension(extension, value) for singular extension '
718
'takes exactly 3 arguments: (%d given)' % len(args))
720
self._extension_fields[extension] = value
722
def MutableExtension(self, extension, index=None):
723
self._VerifyExtensionIdentifier(extension)
724
if extension.composite_cls is None:
726
'MutableExtension() cannot be applied to "%s", because it is not a '
727
'composite type.' % extension.full_name)
728
if extension.is_repeated:
731
'MutableExtension(extension, index) for repeated extension '
732
'takes exactly 2 arguments: (1 given)')
733
return self.GetExtension(extension, index)
734
if extension in self._extension_fields:
735
return self._extension_fields[extension]
737
result = extension.composite_cls()
738
self._extension_fields[extension] = result
741
def ExtensionList(self, extension):
742
self._VerifyExtensionIdentifier(extension)
743
if not extension.is_repeated:
745
'ExtensionList() cannot be applied to "%s", because it is not a '
746
'repeated extension.' % extension.full_name)
747
if extension in self._extension_fields:
748
return self._extension_fields[extension]
750
self._extension_fields[extension] = result
753
def ExtensionSize(self, extension):
754
self._VerifyExtensionIdentifier(extension)
755
if not extension.is_repeated:
757
'ExtensionSize() cannot be applied to "%s", because it is not a '
758
'repeated extension.' % extension.full_name)
759
if extension in self._extension_fields:
760
return len(self._extension_fields[extension])
763
def AddExtension(self, extension, value=None):
764
self._VerifyExtensionIdentifier(extension)
765
if not extension.is_repeated:
767
'AddExtension() cannot be applied to "%s", because it is not a '
768
'repeated extension.' % extension.full_name)
769
if extension in self._extension_fields:
770
field = self._extension_fields[extension]
773
self._extension_fields[extension] = field
775
if extension.composite_cls:
776
if value is not None:
778
'value must not be set in AddExtension() for "%s", because it is '
779
'a message type extension. Set values on the returned message '
780
'instead.' % extension.full_name)
781
msg = extension.composite_cls()
787
def _VerifyExtensionIdentifier(self, extension):
788
if extension.containing_cls != self.__class__:
789
raise TypeError("Containing type of %s is %s, but not %s."
790
% (extension.full_name,
791
extension.containing_cls.__name__,
792
self.__class__.__name__))
794
def _MergeExtensionFields(self, x):
795
for ext, val in x._extension_fields.items():
797
for i in xrange(len(val)):
798
if ext.composite_cls is None:
799
self.AddExtension(ext, val[i])
801
self.AddExtension(ext).MergeFrom(val[i])
803
if ext.composite_cls is None:
804
self.SetExtension(ext, val)
806
self.MutableExtension(ext).MergeFrom(val)
808
def _ListExtensions(self):
809
result = [ext for ext in self._extension_fields.keys()
810
if (not ext.is_repeated) or self.ExtensionSize(ext) > 0]
811
result.sort(key = lambda item: item.number)
814
def _ExtensionEquals(self, x):
815
extensions = self._ListExtensions()
816
if extensions != x._ListExtensions():
818
for ext in extensions:
820
if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
821
for e1, e2 in zip(self.ExtensionList(ext),
822
x.ExtensionList(ext)):
823
if e1 != e2: return False
825
if self.GetExtension(ext) != x.GetExtension(ext): return False
828
def _OutputExtensionFields(self, out, partial, extensions, start_index,
830
def OutputSingleField(ext, value):
831
out.putVarInt32(ext.wire_tag)
832
if ext.field_type == TYPE_GROUP:
834
value.OutputPartial(out)
836
value.OutputUnchecked(out)
837
out.putVarInt32(wire_tag + 1)
838
elif ext.field_type == TYPE_FOREIGN:
840
out.putVarInt32(value.ByteSizePartial())
841
value.OutputPartial(out)
843
out.putVarInt32(value.ByteSize())
844
value.OutputUnchecked(out)
846
Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
848
size = len(extensions)
849
for ext_index in xrange(start_index, size):
850
ext = extensions[ext_index]
851
if ext.number >= end_field_number:
855
for i in xrange(len(self._extension_fields[ext])):
856
OutputSingleField(ext, self._extension_fields[ext][i])
858
OutputSingleField(ext, self._extension_fields[ext])
861
def _ParseOneExtensionField(self, wire_tag, d):
862
number = wire_tag >> 3
863
if number in self._extensions_by_field_number:
864
ext = self._extensions_by_field_number[number]
865
if wire_tag != ext.wire_tag:
868
if ext.field_type == TYPE_FOREIGN:
869
length = d.getVarInt32()
870
tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
872
self.AddExtension(ext).TryMerge(tmp)
874
self.MutableExtension(ext).TryMerge(tmp)
876
elif ext.field_type == TYPE_GROUP:
878
self.AddExtension(ext).TryMerge(d)
880
self.MutableExtension(ext).TryMerge(d)
882
value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
884
self.AddExtension(ext, value)
886
self.SetExtension(ext, value)
891
def _ExtensionByteSize(self, partial):
893
for extension, value in self._extension_fields.items():
894
ftype = extension.field_type
895
tag_size = self.lengthVarInt64(extension.wire_tag)
896
if ftype == TYPE_GROUP:
898
if extension.is_repeated:
899
size += tag_size * len(value)
900
for single_value in value:
901
size += self._FieldByteSize(ftype, single_value, partial)
903
size += tag_size + self._FieldByteSize(ftype, value, partial)
906
def _FieldByteSize(self, ftype, value, partial):
908
if ftype == TYPE_STRING:
909
size = self.lengthString(len(value))
910
elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
912
size = self.lengthString(value.ByteSizePartial())
914
size = self.lengthString(value.ByteSize())
915
elif ftype == TYPE_INT64 or ftype == TYPE_UINT64 or ftype == TYPE_INT32:
916
size = self.lengthVarInt64(value)
918
if ftype in Encoder._TYPE_TO_BYTE_SIZE:
919
size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
921
raise AssertionError(
922
'Extension type %d is not recognized.' % ftype)
925
def _ExtensionDebugString(self, prefix, printElemNumber):
927
extensions = self._ListExtensions()
928
for extension in extensions:
929
value = self._extension_fields[extension]
930
if extension.is_repeated:
934
if printElemNumber: elm = "(%d)" % cnt
935
if extension.composite_cls is not None:
936
res += prefix + "[%s%s] {\n" % (extension.full_name, elm)
937
res += e.__str__(prefix + " ", printElemNumber)
938
res += prefix + "}\n"
940
if extension.composite_cls is not None:
941
res += prefix + "[%s] {\n" % extension.full_name
942
res += value.__str__(
943
prefix + " ", printElemNumber)
944
res += prefix + "}\n"
946
if extension.field_type in _TYPE_TO_DEBUG_STRING:
947
text_value = _TYPE_TO_DEBUG_STRING[
948
extension.field_type](self, value)
950
text_value = self.DebugFormat(value)
951
res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
955
def _RegisterExtension(cls, extension, composite_cls=None):
956
extension.containing_cls = cls
957
extension.composite_cls = composite_cls
958
if composite_cls is not None:
959
extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
960
actual_handle = cls._extensions_by_field_number.setdefault(
961
extension.number, extension)
962
if actual_handle is not extension:
963
raise AssertionError(
964
'Extensions "%s" and "%s" both try to extend message type "%s" with'
966
(extension.full_name, actual_handle.full_name,
967
cls.__name__, extension.number))
535
969
class ProtocolBufferDecodeError(Exception): pass
536
970
class ProtocolBufferEncodeError(Exception): pass