1
# Protocol Buffers - Google's data interchange format
2
# Copyright 2008 Google Inc. All rights reserved.
3
# http://code.google.com/p/protobuf/
5
# Redistribution and use in source and binary forms, with or without
6
# modification, are permitted provided that the following conditions are
9
# * Redistributions of source code must retain the above copyright
10
# notice, this list of conditions and the following disclaimer.
11
# * Redistributions in binary form must reproduce the above
12
# copyright notice, this list of conditions and the following disclaimer
13
# in the documentation and/or other materials provided with the
15
# * Neither the name of Google Inc. nor the names of its
16
# contributors may be used to endorse or promote products derived from
17
# this software without specific prior written permission.
19
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
# This code is meant to work on Python 2.4 and above only.
33
# TODO(robinson): Helpers for verbose, common checks like seeing if a
34
# descriptor's cpp_type is CPPTYPE_MESSAGE.
36
"""Contains a metaclass and helper functions used to create
37
protocol message classes from Descriptor objects at runtime.
39
Recall that a metaclass is the "type" of a class.
40
(A class is to a metaclass what an instance is to a class.)
42
In this case, we use the GeneratedProtocolMessageType metaclass
43
to inject all the useful functionality into the classes
44
output by the protocol compiler at compile-time.
46
The upshot of all this is that the real implementation
47
details for ALL pure-Python protocol buffers are *here in
51
__author__ = 'robinson@google.com (Will Robinson)'
54
from cStringIO import StringIO
56
from StringIO import StringIO
60
# We use "as" to avoid name collisions with variables.
61
from google.protobuf.internal import containers
62
from google.protobuf.internal import decoder
63
from google.protobuf.internal import encoder
64
from google.protobuf.internal import message_listener as message_listener_mod
65
from google.protobuf.internal import type_checkers
66
from google.protobuf.internal import wire_format
67
from google.protobuf import descriptor as descriptor_mod
68
from google.protobuf import message as message_mod
69
from google.protobuf import text_format
71
_FieldDescriptor = descriptor_mod.FieldDescriptor
74
def NewMessage(descriptor, dictionary):
75
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
76
_AddSlots(descriptor, dictionary)
79
def InitMessage(descriptor, cls):
80
cls._decoders_by_tag = {}
81
cls._extensions_by_name = {}
82
cls._extensions_by_number = {}
83
if (descriptor.has_options and
84
descriptor.GetOptions().message_set_wire_format):
85
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
86
decoder.MessageSetItemDecoder(cls._extensions_by_number))
88
# Attach stuff to each FieldDescriptor for quick lookup later on.
89
for field in descriptor.fields:
90
_AttachFieldHelpers(cls, field)
92
_AddEnumValues(descriptor, cls)
93
_AddInitMethod(descriptor, cls)
94
_AddPropertiesForFields(descriptor, cls)
95
_AddPropertiesForExtensions(descriptor, cls)
96
_AddStaticMethods(cls)
97
_AddMessageMethods(descriptor, cls)
98
_AddPrivateHelperMethods(cls)
101
# Stateless helpers for GeneratedProtocolMessageType below.
102
# Outside clients should not access these directly.
104
# I opted not to make any of these methods on the metaclass, to make it more
105
# clear that I'm not really using any state there and to keep clients from
106
# thinking that they have direct access to these construction helpers.
109
def _PropertyName(proto_field_name):
110
"""Returns the name of the public property attribute which
111
clients can use to get and (in some cases) set the value
112
of a protocol message field.
115
proto_field_name: The protocol message field name, exactly
116
as it appears (or would appear) in a .proto file.
118
# TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
119
# nnorwitz makes my day by writing:
121
# FYI. See the keyword module in the stdlib. This could be as simple as:
123
# if keyword.iskeyword(proto_field_name):
124
# return proto_field_name + "_"
125
# return proto_field_name
127
# Kenton says: The above is a BAD IDEA. People rely on being able to use
128
# getattr() and setattr() to reflectively manipulate field values. If we
129
# rename the properties, then every such user has to also make sure to apply
130
# the same transformation. Note that currently if you name a field "yield",
131
# you can still access it just fine using getattr/setattr -- it's not even
132
# that cumbersome to do so.
133
# TODO(kenton): Remove this method entirely if/when everyone agrees with my
135
return proto_field_name
138
def _VerifyExtensionHandle(message, extension_handle):
139
"""Verify that the given extension handle is valid."""
141
if not isinstance(extension_handle, _FieldDescriptor):
142
raise KeyError('HasExtension() expects an extension handle, got: %s' %
145
if not extension_handle.is_extension:
146
raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
148
if extension_handle.containing_type is not message.DESCRIPTOR:
149
raise KeyError('Extension "%s" extends message type "%s", but this '
150
'message is of type "%s".' %
151
(extension_handle.full_name,
152
extension_handle.containing_type.full_name,
153
message.DESCRIPTOR.full_name))
156
def _AddSlots(message_descriptor, dictionary):
157
"""Adds a __slots__ entry to dictionary, containing the names of all valid
158
attributes for this message type.
161
message_descriptor: A Descriptor instance describing this message type.
162
dictionary: Class dictionary to which we'll add a '__slots__' entry.
164
dictionary['__slots__'] = ['_cached_byte_size',
165
'_cached_byte_size_dirty',
167
'_is_present_in_parent',
169
'_listener_for_children',
173
def _IsMessageSetExtension(field):
174
return (field.is_extension and
175
field.containing_type.has_options and
176
field.containing_type.GetOptions().message_set_wire_format and
177
field.type == _FieldDescriptor.TYPE_MESSAGE and
178
field.message_type == field.extension_scope and
179
field.label == _FieldDescriptor.LABEL_OPTIONAL)
182
def _AttachFieldHelpers(cls, field_descriptor):
183
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
184
is_packed = (field_descriptor.has_options and
185
field_descriptor.GetOptions().packed)
187
if _IsMessageSetExtension(field_descriptor):
188
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
189
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
191
field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
192
field_descriptor.number, is_repeated, is_packed)
193
sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
194
field_descriptor.number, is_repeated, is_packed)
196
field_descriptor._encoder = field_encoder
197
field_descriptor._sizer = sizer
198
field_descriptor._default_constructor = _DefaultValueConstructorForField(
201
def AddDecoder(wiretype, is_packed):
202
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
203
cls._decoders_by_tag[tag_bytes] = (
204
type_checkers.TYPE_TO_DECODER[field_descriptor.type](
205
field_descriptor.number, is_repeated, is_packed,
206
field_descriptor, field_descriptor._default_constructor))
208
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
211
if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
212
# To support wire compatibility of adding packed = true, add a decoder for
213
# packed values regardless of the field's options.
214
AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
217
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
218
extension_dict = descriptor.extensions_by_name
219
for extension_name, extension_field in extension_dict.iteritems():
220
assert extension_name not in dictionary
221
dictionary[extension_name] = extension_field
224
def _AddEnumValues(descriptor, cls):
225
"""Sets class-level attributes for all enum fields defined in this message.
228
descriptor: Descriptor object for this message type.
229
cls: Class we're constructing for this message type.
231
for enum_type in descriptor.enum_types:
232
for enum_value in enum_type.values:
233
setattr(cls, enum_value.name, enum_value.number)
236
def _DefaultValueConstructorForField(field):
237
"""Returns a function which returns a default value for a field.
240
field: FieldDescriptor object for this field.
242
The returned function has one argument:
243
message: Message instance containing this field, or a weakref proxy
246
That function in turn returns a default value for this field. The default
247
value may refer back to |message| via a weak reference.
250
if field.label == _FieldDescriptor.LABEL_REPEATED:
251
if field.default_value != []:
252
raise ValueError('Repeated field default value not empty list: %s' % (
253
field.default_value))
254
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
255
# We can't look at _concrete_class yet since it might not have
256
# been set. (Depends on order in which we initialize the classes).
257
message_type = field.message_type
258
def MakeRepeatedMessageDefault(message):
259
return containers.RepeatedCompositeFieldContainer(
260
message._listener_for_children, field.message_type)
261
return MakeRepeatedMessageDefault
263
type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
264
def MakeRepeatedScalarDefault(message):
265
return containers.RepeatedScalarFieldContainer(
266
message._listener_for_children, type_checker)
267
return MakeRepeatedScalarDefault
269
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
270
# _concrete_class may not yet be initialized.
271
message_type = field.message_type
272
def MakeSubMessageDefault(message):
273
result = message_type._concrete_class()
274
result._SetListener(message._listener_for_children)
276
return MakeSubMessageDefault
278
def MakeScalarDefault(message):
279
return field.default_value
280
return MakeScalarDefault
283
def _AddInitMethod(message_descriptor, cls):
284
"""Adds an __init__ method to cls."""
285
fields = message_descriptor.fields
286
def init(self, **kwargs):
287
self._cached_byte_size = 0
288
self._cached_byte_size_dirty = len(kwargs) > 0
290
self._is_present_in_parent = False
291
self._listener = message_listener_mod.NullMessageListener()
292
self._listener_for_children = _Listener(self)
293
for field_name, field_value in kwargs.iteritems():
294
field = _GetFieldByName(message_descriptor, field_name)
296
raise TypeError("%s() got an unexpected keyword argument '%s'" %
297
(message_descriptor.name, field_name))
298
if field.label == _FieldDescriptor.LABEL_REPEATED:
299
copy = field._default_constructor(self)
300
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
301
for val in field_value:
302
copy.add().MergeFrom(val)
304
copy.extend(field_value)
305
self._fields[field] = copy
306
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
307
copy = field._default_constructor(self)
308
copy.MergeFrom(field_value)
309
self._fields[field] = copy
311
setattr(self, field_name, field_value)
313
init.__module__ = None
318
def _GetFieldByName(message_descriptor, field_name):
319
"""Returns a field descriptor by field name.
322
message_descriptor: A Descriptor describing all fields in message.
323
field_name: The name of the field to retrieve.
325
The field descriptor associated with the field name.
328
return message_descriptor.fields_by_name[field_name]
330
raise ValueError('Protocol message has no "%s" field.' % field_name)
333
def _AddPropertiesForFields(descriptor, cls):
334
"""Adds properties for all fields in this protocol message type."""
335
for field in descriptor.fields:
336
_AddPropertiesForField(field, cls)
338
if descriptor.is_extendable:
339
# _ExtensionDict is just an adaptor with no state so we allocate a new one
340
# every time it is accessed.
341
cls.Extensions = property(lambda self: _ExtensionDict(self))
344
def _AddPropertiesForField(field, cls):
345
"""Adds a public property for a protocol message field.
346
Clients can use this property to get and (in the case
347
of non-repeated scalar fields) directly set the value
348
of a protocol message field.
351
field: A FieldDescriptor for this field.
352
cls: The class we're constructing.
354
# Catch it if we add other types that we should
355
# handle specially here.
356
assert _FieldDescriptor.MAX_CPPTYPE == 10
358
constant_name = field.name.upper() + "_FIELD_NUMBER"
359
setattr(cls, constant_name, field.number)
361
if field.label == _FieldDescriptor.LABEL_REPEATED:
362
_AddPropertiesForRepeatedField(field, cls)
363
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
364
_AddPropertiesForNonRepeatedCompositeField(field, cls)
366
_AddPropertiesForNonRepeatedScalarField(field, cls)
369
def _AddPropertiesForRepeatedField(field, cls):
370
"""Adds a public property for a "repeated" protocol message field. Clients
371
can use this property to get the value of the field, which will be either a
372
_RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
375
Note that when clients add values to these containers, we perform
376
type-checking in the case of repeated scalar fields, and we also set any
377
necessary "has" bits as a side-effect.
380
field: A FieldDescriptor for this field.
381
cls: The class we're constructing.
383
proto_field_name = field.name
384
property_name = _PropertyName(proto_field_name)
387
field_value = self._fields.get(field)
388
if field_value is None:
389
# Construct a new object to represent this field.
390
field_value = field._default_constructor(self)
392
# Atomically check if another thread has preempted us and, if not, swap
393
# in the new object we just created. If someone has preempted us, we
394
# take that object and discard ours.
395
# WARNING: We are relying on setdefault() being atomic. This is true
396
# in CPython but we haven't investigated others. This warning appears
397
# in several other locations in this file.
398
field_value = self._fields.setdefault(field, field_value)
400
getter.__module__ = None
401
getter.__doc__ = 'Getter for %s.' % proto_field_name
403
# We define a setter just so we can throw an exception with a more
404
# helpful error message.
405
def setter(self, new_value):
406
raise AttributeError('Assignment not allowed to repeated field '
407
'"%s" in protocol message object.' % proto_field_name)
409
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
410
setattr(cls, property_name, property(getter, setter, doc=doc))
413
def _AddPropertiesForNonRepeatedScalarField(field, cls):
414
"""Adds a public property for a nonrepeated, scalar protocol message field.
415
Clients can use this property to get and directly set the value of the field.
416
Note that when the client sets the value of a field by using this property,
417
all necessary "has" bits are set as a side-effect, and we also perform
421
field: A FieldDescriptor for this field.
422
cls: The class we're constructing.
424
proto_field_name = field.name
425
property_name = _PropertyName(proto_field_name)
426
type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
427
default_value = field.default_value
431
return self._fields.get(field, default_value)
432
getter.__module__ = None
433
getter.__doc__ = 'Getter for %s.' % proto_field_name
434
def setter(self, new_value):
435
type_checker.CheckValue(new_value)
436
self._fields[field] = new_value
437
# Check _cached_byte_size_dirty inline to improve performance, since scalar
438
# setters are called frequently.
439
if not self._cached_byte_size_dirty:
442
setter.__module__ = None
443
setter.__doc__ = 'Setter for %s.' % proto_field_name
445
# Add a property to encapsulate the getter/setter.
446
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
447
setattr(cls, property_name, property(getter, setter, doc=doc))
450
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
451
"""Adds a public property for a nonrepeated, composite protocol message field.
452
A composite field is a "group" or "message" field.
454
Clients can use this property to get the value of the field, but cannot
455
assign to the property directly.
458
field: A FieldDescriptor for this field.
459
cls: The class we're constructing.
461
# TODO(robinson): Remove duplication with similar method
462
# for non-repeated scalars.
463
proto_field_name = field.name
464
property_name = _PropertyName(proto_field_name)
465
message_type = field.message_type
468
field_value = self._fields.get(field)
469
if field_value is None:
470
# Construct a new object to represent this field.
471
field_value = message_type._concrete_class()
472
field_value._SetListener(self._listener_for_children)
474
# Atomically check if another thread has preempted us and, if not, swap
475
# in the new object we just created. If someone has preempted us, we
476
# take that object and discard ours.
477
# WARNING: We are relying on setdefault() being atomic. This is true
478
# in CPython but we haven't investigated others. This warning appears
479
# in several other locations in this file.
480
field_value = self._fields.setdefault(field, field_value)
482
getter.__module__ = None
483
getter.__doc__ = 'Getter for %s.' % proto_field_name
485
# We define a setter just so we can throw an exception with a more
486
# helpful error message.
487
def setter(self, new_value):
488
raise AttributeError('Assignment not allowed to composite field '
489
'"%s" in protocol message object.' % proto_field_name)
491
# Add a property to encapsulate the getter.
492
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
493
setattr(cls, property_name, property(getter, setter, doc=doc))
496
def _AddPropertiesForExtensions(descriptor, cls):
497
"""Adds properties for all fields in this protocol message type."""
498
extension_dict = descriptor.extensions_by_name
499
for extension_name, extension_field in extension_dict.iteritems():
500
constant_name = extension_name.upper() + "_FIELD_NUMBER"
501
setattr(cls, constant_name, extension_field.number)
504
def _AddStaticMethods(cls):
505
# TODO(robinson): This probably needs to be thread-safe(?)
506
def RegisterExtension(extension_handle):
507
extension_handle.containing_type = cls.DESCRIPTOR
508
_AttachFieldHelpers(cls, extension_handle)
510
# Try to insert our extension, failing if an extension with the same number
512
actual_handle = cls._extensions_by_number.setdefault(
513
extension_handle.number, extension_handle)
514
if actual_handle is not extension_handle:
515
raise AssertionError(
516
'Extensions "%s" and "%s" both try to extend message type "%s" with '
518
(extension_handle.full_name, actual_handle.full_name,
519
cls.DESCRIPTOR.full_name, extension_handle.number))
521
cls._extensions_by_name[extension_handle.full_name] = extension_handle
523
handle = extension_handle # avoid line wrapping
524
if _IsMessageSetExtension(handle):
525
# MessageSet extension. Also register under type name.
526
cls._extensions_by_name[
527
extension_handle.message_type.full_name] = extension_handle
529
cls.RegisterExtension = staticmethod(RegisterExtension)
533
message.MergeFromString(s)
535
cls.FromString = staticmethod(FromString)
538
def _IsPresent(item):
539
"""Given a (FieldDescriptor, value) tuple from _fields, return true if the
540
value should be included in the list returned by ListFields()."""
542
if item[0].label == _FieldDescriptor.LABEL_REPEATED:
544
elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
545
return item[1]._is_present_in_parent
550
def _AddListFieldsMethod(message_descriptor, cls):
551
"""Helper for _AddMessageMethods()."""
553
def ListFields(self):
554
all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
555
all_fields.sort(key = lambda item: item[0].number)
558
cls.ListFields = ListFields
561
def _AddHasFieldMethod(message_descriptor, cls):
562
"""Helper for _AddMessageMethods()."""
565
for field in message_descriptor.fields:
566
if field.label != _FieldDescriptor.LABEL_REPEATED:
567
singular_fields[field.name] = field
569
def HasField(self, field_name):
571
field = singular_fields[field_name]
574
'Protocol message has no singular "%s" field.' % field_name)
576
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
577
value = self._fields.get(field)
578
return value is not None and value._is_present_in_parent
580
return field in self._fields
581
cls.HasField = HasField
584
def _AddClearFieldMethod(message_descriptor, cls):
585
"""Helper for _AddMessageMethods()."""
586
def ClearField(self, field_name):
588
field = message_descriptor.fields_by_name[field_name]
590
raise ValueError('Protocol message has no "%s" field.' % field_name)
592
if field in self._fields:
593
# Note: If the field is a sub-message, its listener will still point
594
# at us. That's fine, because the worst than can happen is that it
595
# will call _Modified() and invalidate our byte size. Big deal.
596
del self._fields[field]
598
# Always call _Modified() -- even if nothing was changed, this is
599
# a mutating method, and thus calling it should cause the field to become
600
# present in the parent message.
603
cls.ClearField = ClearField
606
def _AddClearExtensionMethod(cls):
607
"""Helper for _AddMessageMethods()."""
608
def ClearExtension(self, extension_handle):
609
_VerifyExtensionHandle(self, extension_handle)
611
# Similar to ClearField(), above.
612
if extension_handle in self._fields:
613
del self._fields[extension_handle]
615
cls.ClearExtension = ClearExtension
618
def _AddClearMethod(message_descriptor, cls):
619
"""Helper for _AddMessageMethods()."""
627
def _AddHasExtensionMethod(cls):
628
"""Helper for _AddMessageMethods()."""
629
def HasExtension(self, extension_handle):
630
_VerifyExtensionHandle(self, extension_handle)
631
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
632
raise KeyError('"%s" is repeated.' % extension_handle.full_name)
634
if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
635
value = self._fields.get(extension_handle)
636
return value is not None and value._is_present_in_parent
638
return extension_handle in self._fields
639
cls.HasExtension = HasExtension
642
def _AddEqualsMethod(message_descriptor, cls):
643
"""Helper for _AddMessageMethods()."""
644
def __eq__(self, other):
645
if (not isinstance(other, message_mod.Message) or
646
other.DESCRIPTOR != self.DESCRIPTOR):
652
return self.ListFields() == other.ListFields()
657
def _AddStrMethod(message_descriptor, cls):
658
"""Helper for _AddMessageMethods()."""
660
return text_format.MessageToString(self)
661
cls.__str__ = __str__
664
def _AddUnicodeMethod(unused_message_descriptor, cls):
665
"""Helper for _AddMessageMethods()."""
667
def __unicode__(self):
668
return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
669
cls.__unicode__ = __unicode__
672
def _AddSetListenerMethod(cls):
673
"""Helper for _AddMessageMethods()."""
674
def SetListener(self, listener):
676
self._listener = message_listener_mod.NullMessageListener()
678
self._listener = listener
679
cls._SetListener = SetListener
682
def _BytesForNonRepeatedElement(value, field_number, field_type):
683
"""Returns the number of bytes needed to serialize a non-repeated element.
684
The returned byte count includes space for tag information and any
685
other additional space associated with serializing value.
688
value: Value we're serializing.
689
field_number: Field number of this value. (Since the field number
690
is stored as part of a varint-encoded tag, this has an impact
691
on the total bytes required to serialize the value).
692
field_type: The type of the field. One of the TYPE_* constants
693
within FieldDescriptor.
696
fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
697
return fn(field_number, value)
699
raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
702
def _AddByteSizeMethod(message_descriptor, cls):
703
"""Helper for _AddMessageMethods()."""
706
if not self._cached_byte_size_dirty:
707
return self._cached_byte_size
710
for field_descriptor, field_value in self.ListFields():
711
size += field_descriptor._sizer(field_value)
713
self._cached_byte_size = size
714
self._cached_byte_size_dirty = False
715
self._listener_for_children.dirty = False
718
cls.ByteSize = ByteSize
721
def _AddSerializeToStringMethod(message_descriptor, cls):
722
"""Helper for _AddMessageMethods()."""
724
def SerializeToString(self):
725
# Check if the message has all of its required fields set.
727
if not self.IsInitialized():
728
raise message_mod.EncodeError(
729
'Message is missing required fields: ' +
730
','.join(self.FindInitializationErrors()))
731
return self.SerializePartialToString()
732
cls.SerializeToString = SerializeToString
735
def _AddSerializePartialToStringMethod(message_descriptor, cls):
736
"""Helper for _AddMessageMethods()."""
738
def SerializePartialToString(self):
740
self._InternalSerialize(out.write)
741
return out.getvalue()
742
cls.SerializePartialToString = SerializePartialToString
744
def InternalSerialize(self, write_bytes):
745
for field_descriptor, field_value in self.ListFields():
746
field_descriptor._encoder(write_bytes, field_value)
747
cls._InternalSerialize = InternalSerialize
750
def _AddMergeFromStringMethod(message_descriptor, cls):
751
"""Helper for _AddMessageMethods()."""
752
def MergeFromString(self, serialized):
753
length = len(serialized)
755
if self._InternalParse(serialized, 0, length) != length:
756
# The only reason _InternalParse would return early is if it
757
# encountered an end-group tag.
758
raise message_mod.DecodeError('Unexpected end-group tag.')
760
raise message_mod.DecodeError('Truncated message.')
761
except struct.error, e:
762
raise message_mod.DecodeError(e)
763
return length # Return this for legacy reasons.
764
cls.MergeFromString = MergeFromString
766
local_ReadTag = decoder.ReadTag
767
local_SkipField = decoder.SkipField
768
decoders_by_tag = cls._decoders_by_tag
770
def InternalParse(self, buffer, pos, end):
772
field_dict = self._fields
774
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
775
field_decoder = decoders_by_tag.get(tag_bytes)
776
if field_decoder is None:
777
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
782
pos = field_decoder(buffer, new_pos, end, self, field_dict)
784
cls._InternalParse = InternalParse
787
def _AddIsInitializedMethod(message_descriptor, cls):
788
"""Adds the IsInitialized and FindInitializationError methods to the
789
protocol message class."""
791
required_fields = [field for field in message_descriptor.fields
792
if field.label == _FieldDescriptor.LABEL_REQUIRED]
794
def IsInitialized(self, errors=None):
795
"""Checks if all required fields of a message are set.
798
errors: A list which, if provided, will be populated with the field
799
paths of all missing required fields.
802
True iff the specified message has all required fields set.
805
# Performance is critical so we avoid HasField() and ListFields().
807
for field in required_fields:
808
if (field not in self._fields or
809
(field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
810
not self._fields[field]._is_present_in_parent)):
811
if errors is not None:
812
errors.extend(self.FindInitializationErrors())
815
for field, value in self._fields.iteritems():
816
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
817
if field.label == _FieldDescriptor.LABEL_REPEATED:
818
for element in value:
819
if not element.IsInitialized():
820
if errors is not None:
821
errors.extend(self.FindInitializationErrors())
823
elif value._is_present_in_parent and not value.IsInitialized():
824
if errors is not None:
825
errors.extend(self.FindInitializationErrors())
830
cls.IsInitialized = IsInitialized
832
def FindInitializationErrors(self):
833
"""Finds required fields which are not initialized.
836
A list of strings. Each string is a path to an uninitialized field from
837
the top-level message, e.g. "foo.bar[5].baz".
840
errors = [] # simplify things
842
for field in required_fields:
843
if not self.HasField(field.name):
844
errors.append(field.name)
846
for field, value in self.ListFields():
847
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
848
if field.is_extension:
849
name = "(%s)" % field.full_name
853
if field.label == _FieldDescriptor.LABEL_REPEATED:
854
for i in xrange(len(value)):
856
prefix = "%s[%d]." % (name, i)
857
sub_errors = element.FindInitializationErrors()
858
errors += [ prefix + error for error in sub_errors ]
861
sub_errors = value.FindInitializationErrors()
862
errors += [ prefix + error for error in sub_errors ]
866
cls.FindInitializationErrors = FindInitializationErrors
869
def _AddMergeFromMethod(cls):
870
LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
871
CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
873
def MergeFrom(self, msg):
874
if not isinstance(msg, cls):
876
"Parameter to MergeFrom() must be instance of same class.")
878
assert msg is not self
881
fields = self._fields
883
for field, value in msg._fields.iteritems():
884
if field.label == LABEL_REPEATED:
885
field_value = fields.get(field)
886
if field_value is None:
887
# Construct a new object to represent this field.
888
field_value = field._default_constructor(self)
889
fields[field] = field_value
890
field_value.MergeFrom(value)
891
elif field.cpp_type == CPPTYPE_MESSAGE:
892
if value._is_present_in_parent:
893
field_value = fields.get(field)
894
if field_value is None:
895
# Construct a new object to represent this field.
896
field_value = field._default_constructor(self)
897
fields[field] = field_value
898
field_value.MergeFrom(value)
900
self._fields[field] = value
901
cls.MergeFrom = MergeFrom
904
def _AddMessageMethods(message_descriptor, cls):
905
"""Adds implementations of all Message methods to cls."""
906
_AddListFieldsMethod(message_descriptor, cls)
907
_AddHasFieldMethod(message_descriptor, cls)
908
_AddClearFieldMethod(message_descriptor, cls)
909
if message_descriptor.is_extendable:
910
_AddClearExtensionMethod(cls)
911
_AddHasExtensionMethod(cls)
912
_AddClearMethod(message_descriptor, cls)
913
_AddEqualsMethod(message_descriptor, cls)
914
_AddStrMethod(message_descriptor, cls)
915
_AddUnicodeMethod(message_descriptor, cls)
916
_AddSetListenerMethod(cls)
917
_AddByteSizeMethod(message_descriptor, cls)
918
_AddSerializeToStringMethod(message_descriptor, cls)
919
_AddSerializePartialToStringMethod(message_descriptor, cls)
920
_AddMergeFromStringMethod(message_descriptor, cls)
921
_AddIsInitializedMethod(message_descriptor, cls)
922
_AddMergeFromMethod(cls)
925
def _AddPrivateHelperMethods(cls):
926
"""Adds implementation of private helper methods to cls."""
929
"""Sets the _cached_byte_size_dirty bit to true,
930
and propagates this to our listener iff this was a state change.
933
# Note: Some callers check _cached_byte_size_dirty before calling
934
# _Modified() as an extra optimization. So, if this method is ever
935
# changed such that it does stuff even when _cached_byte_size_dirty is
936
# already true, the callers need to be updated.
937
if not self._cached_byte_size_dirty:
938
self._cached_byte_size_dirty = True
939
self._listener_for_children.dirty = True
940
self._is_present_in_parent = True
941
self._listener.Modified()
943
cls._Modified = Modified
944
cls.SetInParent = Modified
947
class _Listener(object):
949
"""MessageListener implementation that a parent message registers with its
952
In order to support semantics like:
955
assert foo.HasField('bar')
957
...child objects must have back references to their parents.
958
This helper class is at the heart of this support.
961
def __init__(self, parent_message):
963
parent_message: The message whose _Modified() method we should call when
964
we receive Modified() messages.
966
# This listener establishes a back reference from a child (contained) object
967
# to its parent (containing) object. We make this a weak reference to avoid
968
# creating cyclic garbage when the client finishes with the 'parent' object
970
if isinstance(parent_message, weakref.ProxyType):
971
self._parent_message_weakref = parent_message
973
self._parent_message_weakref = weakref.proxy(parent_message)
975
# As an optimization, we also indicate directly on the listener whether
976
# or not the parent message is dirty. This way we can avoid traversing
977
# up the tree in the common case.
984
# Propagate the signal to our parents iff this is the first field set.
985
self._parent_message_weakref._Modified()
986
except ReferenceError:
987
# We can get here if a client has kept a reference to a child object,
988
# and is now setting a field on it, but the child's parent has been
989
# garbage-collected. This is not an error.
993
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
994
# TODO(robinson): Unify error handling of "unknown extension" crap.
995
# TODO(robinson): Support iteritems()-style iteration over all
996
# extensions with the "has" bits turned on?
997
class _ExtensionDict(object):
999
"""Dict-like container for supporting an indexable "Extensions"
1000
field on proto instances.
1002
Note that in all cases we expect extension handles to be
1006
def __init__(self, extended_message):
1007
"""extended_message: Message instance for which we are the Extensions dict.
1010
self._extended_message = extended_message
1012
def __getitem__(self, extension_handle):
1013
"""Returns the current value of the given extension handle."""
1015
_VerifyExtensionHandle(self._extended_message, extension_handle)
1017
result = self._extended_message._fields.get(extension_handle)
1018
if result is not None:
1021
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1022
result = extension_handle._default_constructor(self._extended_message)
1023
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1024
result = extension_handle.message_type._concrete_class()
1026
result._SetListener(self._extended_message._listener_for_children)
1027
except ReferenceError:
1030
# Singular scalar -- just return the default without inserting into the
1032
return extension_handle.default_value
1034
# Atomically check if another thread has preempted us and, if not, swap
1035
# in the new object we just created. If someone has preempted us, we
1036
# take that object and discard ours.
1037
# WARNING: We are relying on setdefault() being atomic. This is true
1038
# in CPython but we haven't investigated others. This warning appears
1039
# in several other locations in this file.
1040
result = self._extended_message._fields.setdefault(
1041
extension_handle, result)
1045
def __eq__(self, other):
1046
if not isinstance(other, self.__class__):
1049
my_fields = self._extended_message.ListFields()
1050
other_fields = other._extended_message.ListFields()
1052
# Get rid of non-extension fields.
1053
my_fields = [ field for field in my_fields if field.is_extension ]
1054
other_fields = [ field for field in other_fields if field.is_extension ]
1056
return my_fields == other_fields
1058
def __ne__(self, other):
1059
return not self == other
1062
raise TypeError('unhashable object')
1064
# Note that this is only meaningful for non-repeated, scalar extension
1065
# fields. Note also that we may have to call _Modified() when we do
1066
# successfully set a field this way, to set any necssary "has" bits in the
1067
# ancestors of the extended message.
1068
def __setitem__(self, extension_handle, value):
1069
"""If extension_handle specifies a non-repeated, scalar extension
1070
field, sets the value of that field.
1073
_VerifyExtensionHandle(self._extended_message, extension_handle)
1075
if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1076
extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1078
'Cannot assign to extension "%s" because it is a repeated or '
1079
'composite type.' % extension_handle.full_name)
1081
# It's slightly wasteful to lookup the type checker each time,
1082
# but we expect this to be a vanishingly uncommon case anyway.
1083
type_checker = type_checkers.GetTypeChecker(
1084
extension_handle.cpp_type, extension_handle.type)
1085
type_checker.CheckValue(value)
1086
self._extended_message._fields[extension_handle] = value
1087
self._extended_message._Modified()
1089
def _FindExtensionByName(self, name):
1090
"""Tries to find a known extension with the specified name.
1093
name: Extension full name.
1096
Extension field descriptor.
1098
return self._extended_message._extensions_by_name.get(name, None)