~nchohan/appscale/zk3.3.4

« back to all changes in this revision

Viewing changes to AppServer/google/net/proto/ProtocolBuffer.py

  • Committer: Chris Bunch
  • Date: 2012-02-17 08:19:21 UTC
  • mfrom: (787.2.3 appscale-raj-merge)
  • Revision ID: cgb@cs.ucsb.edu-20120217081921-pakidyksaenlpzur
merged with main branch, gaining rabbitmq and upgrades for hbase, cassandra, and hypertable, as well as upgrading to gae 1.6.1 for python and go

Show diffs side-by-side

added added

removed removed

Lines of Context:
16
16
#
17
17
 
18
18
 
 
19
 
 
20
 
19
21
import struct
20
22
import array
21
23
import string
23
25
from google.pyglib.gexcept import AbstractMethod
24
26
import httplib
25
27
from string import strip, split
 
28
 
26
29
__all__ = ['ProtocolMessage', 'Encoder', 'Decoder',
 
30
           'ExtendableProtocolMessage',
27
31
           'ProtocolBufferDecodeError',
28
32
           'ProtocolBufferEncodeError',
29
33
           'ProtocolBufferReturnError']
33
37
class ProtocolMessage:
34
38
 
35
39
 
 
40
 
 
41
 
 
42
 
36
43
  def __init__(self, contents=None):
37
44
    raise AbstractMethod
38
45
 
50
57
      self.Output(e)
51
58
      return e.buffer().tostring()
52
59
 
 
60
  def SerializeToString(self):
 
61
    return self.Encode()
 
62
 
 
63
  def SerializePartialToString(self):
 
64
    try:
 
65
      return self._CEncodePartial()
 
66
    except (AbstractMethod, AttributeError):
 
67
      e = Encoder()
 
68
      self.OutputPartial(e)
 
69
      return e.buffer().tostring()
 
70
 
53
71
  def _CEncode(self):
54
72
    raise AbstractMethod
55
73
 
 
74
  def _CEncodePartial(self):
 
75
    raise AbstractMethod
 
76
 
56
77
  def ParseFromString(self, s):
57
78
    self.Clear()
58
79
    self.MergeFromString(s)
59
 
    return
 
80
 
 
81
  def ParsePartialFromString(self, s):
 
82
    self.Clear()
 
83
    self.MergePartialFromString(s)
60
84
 
61
85
  def MergeFromString(self, s):
 
86
    self.MergePartialFromString(s)
 
87
    dbg = []
 
88
    if not self.IsInitialized(dbg):
 
89
      raise ProtocolBufferDecodeError, '\n\t'.join(dbg)
 
90
 
 
91
  def MergePartialFromString(self, s):
62
92
    try:
63
93
      self._CMergeFromString(s)
64
 
      dbg = []
65
 
      if not self.IsInitialized(dbg):
66
 
        raise ProtocolBufferDecodeError, '\n\t'.join(dbg)
67
94
    except AbstractMethod:
 
95
 
 
96
 
68
97
      a = array.array('B')
69
98
      a.fromstring(s)
70
99
      d = Decoder(a, 0, len(a))
71
 
      self.Merge(d)
72
 
      return
 
100
      self.TryMerge(d)
73
101
 
74
102
  def _CMergeFromString(self, s):
75
103
    raise AbstractMethod
134
162
  def ToShortASCII(self):
135
163
    return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII)
136
164
 
 
165
 
 
166
 
137
167
  _NUMERIC_ASCII = 0
138
168
  _SYMBOLIC_SHORT_ASCII = 1
139
169
  _SYMBOLIC_FULL_ASCII = 2
151
181
    raise AbstractMethod
152
182
 
153
183
  def __eq__(self, other):
 
184
 
 
185
 
 
186
 
 
187
 
 
188
 
154
189
    if other.__class__ is self.__class__:
155
190
      return self.Equals(other)
156
191
    return NotImplemented
157
192
 
158
193
  def __ne__(self, other):
 
194
 
 
195
 
 
196
 
 
197
 
 
198
 
159
199
    if other.__class__ is self.__class__:
160
200
      return not self.Equals(other)
161
201
    return NotImplemented
162
202
 
163
203
 
 
204
 
 
205
 
 
206
 
164
207
  def Output(self, e):
165
208
    dbg = []
166
209
    if not self.IsInitialized(dbg):
171
214
  def OutputUnchecked(self, e):
172
215
    raise AbstractMethod
173
216
 
 
217
  def OutputPartial(self, e):
 
218
    raise AbstractMethod
 
219
 
174
220
  def Parse(self, d):
175
221
    self.Clear()
176
222
    self.Merge(d)
195
241
    raise AbstractMethod
196
242
 
197
243
 
 
244
 
 
245
 
 
246
 
198
247
  def lengthVarInt32(self, n):
199
248
    return self.lengthVarInt64(n)
200
249
 
223
272
      return self.DebugFormatFixed64(value)
224
273
    return "%d" % value
225
274
  def DebugFormatString(self, value):
 
275
 
 
276
 
 
277
 
226
278
    def escape(c):
227
279
      o = ord(c)
228
280
      if o == 10: return r"\n"
248
300
    else:
249
301
      return "false"
250
302
 
 
303
 
 
304
TYPE_DOUBLE  = 1
 
305
TYPE_FLOAT   = 2
 
306
TYPE_INT64   = 3
 
307
TYPE_UINT64  = 4
 
308
TYPE_INT32   = 5
 
309
TYPE_FIXED64 = 6
 
310
TYPE_FIXED32 = 7
 
311
TYPE_BOOL    = 8
 
312
TYPE_STRING  = 9
 
313
TYPE_GROUP   = 10
 
314
TYPE_FOREIGN = 11
 
315
 
 
316
 
 
317
_TYPE_TO_DEBUG_STRING = {
 
318
    TYPE_INT32:   ProtocolMessage.DebugFormatInt32,
 
319
    TYPE_INT64:   ProtocolMessage.DebugFormatInt64,
 
320
    TYPE_UINT64:  ProtocolMessage.DebugFormatInt64,
 
321
    TYPE_FLOAT:   ProtocolMessage.DebugFormatFloat,
 
322
    TYPE_STRING:  ProtocolMessage.DebugFormatString,
 
323
    TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32,
 
324
    TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64,
 
325
    TYPE_BOOL:    ProtocolMessage.DebugFormatBool }
 
326
 
 
327
 
 
328
 
251
329
class Encoder:
252
330
 
 
331
 
253
332
  NUMERIC     = 0
254
333
  DOUBLE      = 1
255
334
  STRING      = 2
298
377
 
299
378
  def putVarInt32(self, v):
300
379
 
 
380
 
 
381
 
 
382
 
 
383
 
 
384
 
 
385
 
 
386
 
301
387
    buf_append = self.buf.append
302
388
    if v & 127 == v:
303
389
      buf_append(v)
347
433
    return
348
434
 
349
435
 
 
436
 
 
437
 
 
438
 
 
439
 
350
440
  def putFloat(self, v):
351
441
    a = array.array('B')
352
442
    a.fromstring(struct.pack("<f", v))
367
457
    return
368
458
 
369
459
  def putPrefixedString(self, v):
 
460
 
 
461
 
 
462
 
370
463
    v = str(v)
371
464
    self.putVarInt32(len(v))
372
465
    self.buf.fromstring(v)
375
468
  def putRawString(self, v):
376
469
    self.buf.fromstring(v)
377
470
 
 
471
  _TYPE_TO_METHOD = {
 
472
      TYPE_DOUBLE:   putDouble,
 
473
      TYPE_FLOAT:    putFloat,
 
474
      TYPE_FIXED64:  put64,
 
475
      TYPE_FIXED32:  put32,
 
476
      TYPE_INT32:    putVarInt32,
 
477
      TYPE_INT64:    putVarInt64,
 
478
      TYPE_UINT64:   putVarUint64,
 
479
      TYPE_BOOL:     putBoolean,
 
480
      TYPE_STRING:   putPrefixedString }
 
481
 
 
482
  _TYPE_TO_BYTE_SIZE = {
 
483
      TYPE_DOUBLE:  8,
 
484
      TYPE_FLOAT:   4,
 
485
      TYPE_FIXED64: 8,
 
486
      TYPE_FIXED32: 4,
 
487
      TYPE_BOOL:    1 }
378
488
 
379
489
class Decoder:
380
490
  def __init__(self, buf, idx, limit):
422
532
    else:
423
533
      raise ProtocolBufferDecodeError, "corrupted"
424
534
 
 
535
 
425
536
  def get8(self):
426
537
    if self.idx >= self.limit: raise ProtocolBufferDecodeError, "truncated"
427
538
    c = self.buf[self.idx]
459
570
            | (e << 16) | (d << 8) | c)
460
571
 
461
572
  def getVarInt32(self):
 
573
 
 
574
 
 
575
 
462
576
    b = self.get8()
463
577
    if not (b & 128):
464
578
      return b
531
645
    self.idx = self.limit
532
646
    return r.tostring()
533
647
 
 
648
  _TYPE_TO_METHOD = {
 
649
      TYPE_DOUBLE:   getDouble,
 
650
      TYPE_FLOAT:    getFloat,
 
651
      TYPE_FIXED64:  get64,
 
652
      TYPE_FIXED32:  get32,
 
653
      TYPE_INT32:    getVarInt32,
 
654
      TYPE_INT64:    getVarInt64,
 
655
      TYPE_UINT64:   getVarUint64,
 
656
      TYPE_BOOL:     getBoolean,
 
657
      TYPE_STRING:   getPrefixedString }
 
658
 
 
659
 
 
660
 
 
661
 
 
662
 
 
663
class ExtensionIdentifier(object):
 
664
  __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
 
665
               'default', 'containing_cls', 'composite_cls', 'message_name')
 
666
  def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
 
667
               default):
 
668
    self.full_name = full_name
 
669
    self.number = number
 
670
    self.field_type = field_type
 
671
    self.wire_tag = wire_tag
 
672
    self.is_repeated = is_repeated
 
673
    self.default = default
 
674
 
 
675
class ExtendableProtocolMessage(ProtocolMessage):
 
676
  def HasExtension(self, extension):
 
677
    self._VerifyExtensionIdentifier(extension)
 
678
    return extension in self._extension_fields
 
679
 
 
680
  def ClearExtension(self, extension):
 
681
    self._VerifyExtensionIdentifier(extension)
 
682
    if extension in self._extension_fields:
 
683
      del self._extension_fields[extension]
 
684
 
 
685
  def GetExtension(self, extension, index=None):
 
686
    self._VerifyExtensionIdentifier(extension)
 
687
    if extension in self._extension_fields:
 
688
      result = self._extension_fields[extension]
 
689
    else:
 
690
      if extension.is_repeated:
 
691
        result = []
 
692
      elif extension.composite_cls:
 
693
        result = extension.composite_cls()
 
694
      else:
 
695
        result = extension.default
 
696
    if extension.is_repeated:
 
697
      result = result[index]
 
698
    return result
 
699
 
 
700
  def SetExtension(self, extension, *args):
 
701
    self._VerifyExtensionIdentifier(extension)
 
702
    if extension.composite_cls:
 
703
      raise TypeError(
 
704
          'Cannot assign to extension "%s" because it is a composite type.' %
 
705
          extension.full_name)
 
706
    if extension.is_repeated:
 
707
      if (len(args) != 2):
 
708
        raise TypeError(
 
709
            'SetExtension(extension, index, value) for repeated extension '
 
710
            'takes exactly 3 arguments: (%d given)' % len(args))
 
711
      index = args[0]
 
712
      value = args[1]
 
713
      self._extension_fields[extension][index] = value
 
714
    else:
 
715
      if (len(args) != 1):
 
716
        raise TypeError(
 
717
            'SetExtension(extension, value) for singular extension '
 
718
            'takes exactly 3 arguments: (%d given)' % len(args))
 
719
      value = args[0]
 
720
      self._extension_fields[extension] = value
 
721
 
 
722
  def MutableExtension(self, extension, index=None):
 
723
    self._VerifyExtensionIdentifier(extension)
 
724
    if extension.composite_cls is None:
 
725
      raise TypeError(
 
726
          'MutableExtension() cannot be applied to "%s", because it is not a '
 
727
          'composite type.' % extension.full_name)
 
728
    if extension.is_repeated:
 
729
      if index is None:
 
730
        raise TypeError(
 
731
            'MutableExtension(extension, index) for repeated extension '
 
732
            'takes exactly 2 arguments: (1 given)')
 
733
      return self.GetExtension(extension, index)
 
734
    if extension in self._extension_fields:
 
735
      return self._extension_fields[extension]
 
736
    else:
 
737
      result = extension.composite_cls()
 
738
      self._extension_fields[extension] = result
 
739
      return result
 
740
 
 
741
  def ExtensionList(self, extension):
 
742
    self._VerifyExtensionIdentifier(extension)
 
743
    if not extension.is_repeated:
 
744
      raise TypeError(
 
745
          'ExtensionList() cannot be applied to "%s", because it is not a '
 
746
          'repeated extension.' % extension.full_name)
 
747
    if extension in self._extension_fields:
 
748
      return self._extension_fields[extension]
 
749
    result = []
 
750
    self._extension_fields[extension] = result
 
751
    return result
 
752
 
 
753
  def ExtensionSize(self, extension):
 
754
    self._VerifyExtensionIdentifier(extension)
 
755
    if not extension.is_repeated:
 
756
      raise TypeError(
 
757
          'ExtensionSize() cannot be applied to "%s", because it is not a '
 
758
          'repeated extension.' % extension.full_name)
 
759
    if extension in self._extension_fields:
 
760
      return len(self._extension_fields[extension])
 
761
    return 0
 
762
 
 
763
  def AddExtension(self, extension, value=None):
 
764
    self._VerifyExtensionIdentifier(extension)
 
765
    if not extension.is_repeated:
 
766
      raise TypeError(
 
767
          'AddExtension() cannot be applied to "%s", because it is not a '
 
768
          'repeated extension.' % extension.full_name)
 
769
    if extension in self._extension_fields:
 
770
      field = self._extension_fields[extension]
 
771
    else:
 
772
      field = []
 
773
      self._extension_fields[extension] = field
 
774
 
 
775
    if extension.composite_cls:
 
776
      if value is not None:
 
777
        raise TypeError(
 
778
            'value must not be set in AddExtension() for "%s", because it is '
 
779
            'a message type extension. Set values on the returned message '
 
780
            'instead.' % extension.full_name)
 
781
      msg = extension.composite_cls()
 
782
      field.append(msg)
 
783
      return msg
 
784
 
 
785
    field.append(value)
 
786
 
 
787
  def _VerifyExtensionIdentifier(self, extension):
 
788
    if extension.containing_cls != self.__class__:
 
789
      raise TypeError("Containing type of %s is %s, but not %s."
 
790
                      % (extension.full_name,
 
791
                         extension.containing_cls.__name__,
 
792
                         self.__class__.__name__))
 
793
 
 
794
  def _MergeExtensionFields(self, x):
 
795
    for ext, val in x._extension_fields.items():
 
796
      if ext.is_repeated:
 
797
        for i in xrange(len(val)):
 
798
          if ext.composite_cls is None:
 
799
            self.AddExtension(ext, val[i])
 
800
          else:
 
801
            self.AddExtension(ext).MergeFrom(val[i])
 
802
      else:
 
803
        if ext.composite_cls is None:
 
804
          self.SetExtension(ext, val)
 
805
        else:
 
806
          self.MutableExtension(ext).MergeFrom(val)
 
807
 
 
808
  def _ListExtensions(self):
 
809
    result = [ext for ext in self._extension_fields.keys()
 
810
              if (not ext.is_repeated) or self.ExtensionSize(ext) > 0]
 
811
    result.sort(key = lambda item: item.number)
 
812
    return result
 
813
 
 
814
  def _ExtensionEquals(self, x):
 
815
    extensions = self._ListExtensions()
 
816
    if extensions != x._ListExtensions():
 
817
      return False
 
818
    for ext in extensions:
 
819
      if ext.is_repeated:
 
820
        if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
 
821
        for e1, e2 in zip(self.ExtensionList(ext),
 
822
                          x.ExtensionList(ext)):
 
823
          if e1 != e2: return False
 
824
      else:
 
825
        if self.GetExtension(ext) != x.GetExtension(ext): return False
 
826
    return True
 
827
 
 
828
  def _OutputExtensionFields(self, out, partial, extensions, start_index,
 
829
                             end_field_number):
 
830
    def OutputSingleField(ext, value):
 
831
      out.putVarInt32(ext.wire_tag)
 
832
      if ext.field_type == TYPE_GROUP:
 
833
        if partial:
 
834
          value.OutputPartial(out)
 
835
        else:
 
836
          value.OutputUnchecked(out)
 
837
        out.putVarInt32(wire_tag + 1)
 
838
      elif ext.field_type == TYPE_FOREIGN:
 
839
        if partial:
 
840
          out.putVarInt32(value.ByteSizePartial())
 
841
          value.OutputPartial(out)
 
842
        else:
 
843
          out.putVarInt32(value.ByteSize())
 
844
          value.OutputUnchecked(out)
 
845
      else:
 
846
        Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
 
847
 
 
848
    size = len(extensions)
 
849
    for ext_index in xrange(start_index, size):
 
850
      ext = extensions[ext_index]
 
851
      if ext.number >= end_field_number:
 
852
 
 
853
        return ext_index
 
854
      if ext.is_repeated:
 
855
        for i in xrange(len(self._extension_fields[ext])):
 
856
          OutputSingleField(ext, self._extension_fields[ext][i])
 
857
      else:
 
858
        OutputSingleField(ext, self._extension_fields[ext])
 
859
    return size
 
860
 
 
861
  def _ParseOneExtensionField(self, wire_tag, d):
 
862
    number = wire_tag >> 3
 
863
    if number in self._extensions_by_field_number:
 
864
      ext = self._extensions_by_field_number[number]
 
865
      if wire_tag != ext.wire_tag:
 
866
 
 
867
        return
 
868
      if ext.field_type == TYPE_FOREIGN:
 
869
        length = d.getVarInt32()
 
870
        tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
 
871
        if ext.is_repeated:
 
872
          self.AddExtension(ext).TryMerge(tmp)
 
873
        else:
 
874
          self.MutableExtension(ext).TryMerge(tmp)
 
875
        d.skip(length)
 
876
      elif ext.field_type == TYPE_GROUP:
 
877
        if ext.is_repeated:
 
878
          self.AddExtension(ext).TryMerge(d)
 
879
        else:
 
880
          self.MutableExtension(ext).TryMerge(d)
 
881
      else:
 
882
        value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
 
883
        if ext.is_repeated:
 
884
          self.AddExtension(ext, value)
 
885
        else:
 
886
          self.SetExtension(ext, value)
 
887
    else:
 
888
 
 
889
      d.skipData(wire_tag)
 
890
 
 
891
  def _ExtensionByteSize(self, partial):
 
892
    size = 0
 
893
    for extension, value in self._extension_fields.items():
 
894
      ftype = extension.field_type
 
895
      tag_size = self.lengthVarInt64(extension.wire_tag)
 
896
      if ftype == TYPE_GROUP:
 
897
        tag_size *= 2
 
898
      if extension.is_repeated:
 
899
        size += tag_size * len(value)
 
900
        for single_value in value:
 
901
          size += self._FieldByteSize(ftype, single_value, partial)
 
902
      else:
 
903
        size += tag_size + self._FieldByteSize(ftype, value, partial)
 
904
    return size
 
905
 
 
906
  def _FieldByteSize(self, ftype, value, partial):
 
907
    size = 0
 
908
    if ftype == TYPE_STRING:
 
909
      size = self.lengthString(len(value))
 
910
    elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
 
911
      if partial:
 
912
        size = self.lengthString(value.ByteSizePartial())
 
913
      else:
 
914
        size = self.lengthString(value.ByteSize())
 
915
    elif ftype == TYPE_INT64 or ftype == TYPE_UINT64 or ftype == TYPE_INT32:
 
916
      size = self.lengthVarInt64(value)
 
917
    else:
 
918
      if ftype in Encoder._TYPE_TO_BYTE_SIZE:
 
919
        size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
 
920
      else:
 
921
        raise AssertionError(
 
922
            'Extension type %d is not recognized.' % ftype)
 
923
    return size
 
924
 
 
925
  def _ExtensionDebugString(self, prefix, printElemNumber):
 
926
    res = ''
 
927
    extensions = self._ListExtensions()
 
928
    for extension in extensions:
 
929
      value = self._extension_fields[extension]
 
930
      if extension.is_repeated:
 
931
        cnt = 0
 
932
        for e in value:
 
933
          elm=""
 
934
          if printElemNumber: elm = "(%d)" % cnt
 
935
          if extension.composite_cls is not None:
 
936
            res += prefix + "[%s%s] {\n" % (extension.full_name, elm)
 
937
            res += e.__str__(prefix + "  ", printElemNumber)
 
938
            res += prefix + "}\n"
 
939
      else:
 
940
        if extension.composite_cls is not None:
 
941
          res += prefix + "[%s] {\n" % extension.full_name
 
942
          res += value.__str__(
 
943
              prefix + "  ", printElemNumber)
 
944
          res += prefix + "}\n"
 
945
        else:
 
946
          if extension.field_type in _TYPE_TO_DEBUG_STRING:
 
947
            text_value = _TYPE_TO_DEBUG_STRING[
 
948
                extension.field_type](self, value)
 
949
          else:
 
950
            text_value = self.DebugFormat(value)
 
951
          res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
 
952
    return res
 
953
 
 
954
  @staticmethod
 
955
  def _RegisterExtension(cls, extension, composite_cls=None):
 
956
    extension.containing_cls = cls
 
957
    extension.composite_cls = composite_cls
 
958
    if composite_cls is not None:
 
959
      extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
 
960
    actual_handle = cls._extensions_by_field_number.setdefault(
 
961
        extension.number, extension)
 
962
    if actual_handle is not extension:
 
963
      raise AssertionError(
 
964
          'Extensions "%s" and "%s" both try to extend message type "%s" with'
 
965
          'field number %d.' %
 
966
          (extension.full_name, actual_handle.full_name,
 
967
           cls.__name__, extension.number))
534
968
 
535
969
class ProtocolBufferDecodeError(Exception): pass
536
970
class ProtocolBufferEncodeError(Exception): pass