145
139
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
147
cls._decoders_by_tag = {}
148
cls._extensions_by_name = {}
149
cls._extensions_by_number = {}
150
if (descriptor.has_options and
151
descriptor.GetOptions().message_set_wire_format):
152
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
153
decoder.MessageSetItemDecoder(cls._extensions_by_number))
155
# We act as a "friend" class of the descriptor, setting
156
# its _concrete_class attribute the first time we use a
157
# given descriptor to initialize a concrete protocol message
158
# class. We also attach stuff to each FieldDescriptor for quick
160
concrete_class_attr_name = '_concrete_class'
161
if not hasattr(descriptor, concrete_class_attr_name):
162
setattr(descriptor, concrete_class_attr_name, cls)
163
for field in descriptor.fields:
164
_AttachFieldHelpers(cls, field)
166
_AddEnumValues(descriptor, cls)
167
_AddInitMethod(descriptor, cls)
168
_AddPropertiesForFields(descriptor, cls)
169
_AddPropertiesForExtensions(descriptor, cls)
170
_AddStaticMethods(cls)
171
_AddMessageMethods(descriptor, cls)
172
_AddPrivateHelperMethods(cls)
140
_InitMessage(descriptor, cls)
173
141
superclass = super(GeneratedProtocolMessageType, cls)
174
142
superclass.__init__(name, bases, dictionary)
177
# Stateless helpers for GeneratedProtocolMessageType below.
178
# Outside clients should not access these directly.
180
# I opted not to make any of these methods on the metaclass, to make it more
181
# clear that I'm not really using any state there and to keep clients from
182
# thinking that they have direct access to these construction helpers.
185
def _PropertyName(proto_field_name):
186
"""Returns the name of the public property attribute which
187
clients can use to get and (in some cases) set the value
188
of a protocol message field.
191
proto_field_name: The protocol message field name, exactly
192
as it appears (or would appear) in a .proto file.
194
# TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
195
# nnorwitz makes my day by writing:
197
# FYI. See the keyword module in the stdlib. This could be as simple as:
199
# if keyword.iskeyword(proto_field_name):
200
# return proto_field_name + "_"
201
# return proto_field_name
203
# Kenton says: The above is a BAD IDEA. People rely on being able to use
204
# getattr() and setattr() to reflectively manipulate field values. If we
205
# rename the properties, then every such user has to also make sure to apply
206
# the same transformation. Note that currently if you name a field "yield",
207
# you can still access it just fine using getattr/setattr -- it's not even
208
# that cumbersome to do so.
209
# TODO(kenton): Remove this method entirely if/when everyone agrees with my
211
return proto_field_name
214
def _VerifyExtensionHandle(message, extension_handle):
215
"""Verify that the given extension handle is valid."""
217
if not isinstance(extension_handle, _FieldDescriptor):
218
raise KeyError('HasExtension() expects an extension handle, got: %s' %
221
if not extension_handle.is_extension:
222
raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
224
if extension_handle.containing_type is not message.DESCRIPTOR:
225
raise KeyError('Extension "%s" extends message type "%s", but this '
226
'message is of type "%s".' %
227
(extension_handle.full_name,
228
extension_handle.containing_type.full_name,
229
message.DESCRIPTOR.full_name))
232
def _AddSlots(message_descriptor, dictionary):
233
"""Adds a __slots__ entry to dictionary, containing the names of all valid
234
attributes for this message type.
237
message_descriptor: A Descriptor instance describing this message type.
238
dictionary: Class dictionary to which we'll add a '__slots__' entry.
240
dictionary['__slots__'] = ['_cached_byte_size',
241
'_cached_byte_size_dirty',
243
'_is_present_in_parent',
245
'_listener_for_children',
249
def _IsMessageSetExtension(field):
250
return (field.is_extension and
251
field.containing_type.has_options and
252
field.containing_type.GetOptions().message_set_wire_format and
253
field.type == _FieldDescriptor.TYPE_MESSAGE and
254
field.message_type == field.extension_scope and
255
field.label == _FieldDescriptor.LABEL_OPTIONAL)
258
def _AttachFieldHelpers(cls, field_descriptor):
259
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
260
is_packed = (field_descriptor.has_options and
261
field_descriptor.GetOptions().packed)
263
if _IsMessageSetExtension(field_descriptor):
264
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
265
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
267
field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
268
field_descriptor.number, is_repeated, is_packed)
269
sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
270
field_descriptor.number, is_repeated, is_packed)
272
field_descriptor._encoder = field_encoder
273
field_descriptor._sizer = sizer
274
field_descriptor._default_constructor = _DefaultValueConstructorForField(
277
def AddDecoder(wiretype, is_packed):
278
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279
cls._decoders_by_tag[tag_bytes] = (
280
type_checkers.TYPE_TO_DECODER[field_descriptor.type](
281
field_descriptor.number, is_repeated, is_packed,
282
field_descriptor, field_descriptor._default_constructor))
284
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
287
if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
288
# To support wire compatibility of adding packed = true, add a decoder for
289
# packed values regardless of the field's options.
290
AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
293
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
294
extension_dict = descriptor.extensions_by_name
295
for extension_name, extension_field in extension_dict.iteritems():
296
assert extension_name not in dictionary
297
dictionary[extension_name] = extension_field
300
def _AddEnumValues(descriptor, cls):
301
"""Sets class-level attributes for all enum fields defined in this message.
304
descriptor: Descriptor object for this message type.
305
cls: Class we're constructing for this message type.
307
for enum_type in descriptor.enum_types:
308
for enum_value in enum_type.values:
309
setattr(cls, enum_value.name, enum_value.number)
312
def _DefaultValueConstructorForField(field):
313
"""Returns a function which returns a default value for a field.
316
field: FieldDescriptor object for this field.
318
The returned function has one argument:
319
message: Message instance containing this field, or a weakref proxy
322
That function in turn returns a default value for this field. The default
323
value may refer back to |message| via a weak reference.
326
if field.label == _FieldDescriptor.LABEL_REPEATED:
327
if field.default_value != []:
328
raise ValueError('Repeated field default value not empty list: %s' % (
329
field.default_value))
330
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
331
# We can't look at _concrete_class yet since it might not have
332
# been set. (Depends on order in which we initialize the classes).
333
message_type = field.message_type
334
def MakeRepeatedMessageDefault(message):
335
return containers.RepeatedCompositeFieldContainer(
336
message._listener_for_children, field.message_type)
337
return MakeRepeatedMessageDefault
339
type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
340
def MakeRepeatedScalarDefault(message):
341
return containers.RepeatedScalarFieldContainer(
342
message._listener_for_children, type_checker)
343
return MakeRepeatedScalarDefault
345
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
346
# _concrete_class may not yet be initialized.
347
message_type = field.message_type
348
def MakeSubMessageDefault(message):
349
result = message_type._concrete_class()
350
result._SetListener(message._listener_for_children)
352
return MakeSubMessageDefault
354
def MakeScalarDefault(message):
355
return field.default_value
356
return MakeScalarDefault
359
def _AddInitMethod(message_descriptor, cls):
360
"""Adds an __init__ method to cls."""
361
fields = message_descriptor.fields
362
def init(self, **kwargs):
363
self._cached_byte_size = 0
364
self._cached_byte_size_dirty = False
366
self._is_present_in_parent = False
367
self._listener = message_listener_mod.NullMessageListener()
368
self._listener_for_children = _Listener(self)
369
for field_name, field_value in kwargs.iteritems():
370
field = _GetFieldByName(message_descriptor, field_name)
372
raise TypeError("%s() got an unexpected keyword argument '%s'" %
373
(message_descriptor.name, field_name))
374
if field.label == _FieldDescriptor.LABEL_REPEATED:
375
copy = field._default_constructor(self)
376
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
377
for val in field_value:
378
copy.add().MergeFrom(val)
380
copy.extend(field_value)
381
self._fields[field] = copy
382
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
383
copy = field._default_constructor(self)
384
copy.MergeFrom(field_value)
385
self._fields[field] = copy
387
self._fields[field] = field_value
389
init.__module__ = None
394
def _GetFieldByName(message_descriptor, field_name):
395
"""Returns a field descriptor by field name.
398
message_descriptor: A Descriptor describing all fields in message.
399
field_name: The name of the field to retrieve.
401
The field descriptor associated with the field name.
404
return message_descriptor.fields_by_name[field_name]
406
raise ValueError('Protocol message has no "%s" field.' % field_name)
409
def _AddPropertiesForFields(descriptor, cls):
410
"""Adds properties for all fields in this protocol message type."""
411
for field in descriptor.fields:
412
_AddPropertiesForField(field, cls)
414
if descriptor.is_extendable:
415
# _ExtensionDict is just an adaptor with no state so we allocate a new one
416
# every time it is accessed.
417
cls.Extensions = property(lambda self: _ExtensionDict(self))
420
def _AddPropertiesForField(field, cls):
421
"""Adds a public property for a protocol message field.
422
Clients can use this property to get and (in the case
423
of non-repeated scalar fields) directly set the value
424
of a protocol message field.
427
field: A FieldDescriptor for this field.
428
cls: The class we're constructing.
430
# Catch it if we add other types that we should
431
# handle specially here.
432
assert _FieldDescriptor.MAX_CPPTYPE == 10
434
constant_name = field.name.upper() + "_FIELD_NUMBER"
435
setattr(cls, constant_name, field.number)
437
if field.label == _FieldDescriptor.LABEL_REPEATED:
438
_AddPropertiesForRepeatedField(field, cls)
439
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
440
_AddPropertiesForNonRepeatedCompositeField(field, cls)
442
_AddPropertiesForNonRepeatedScalarField(field, cls)
445
def _AddPropertiesForRepeatedField(field, cls):
446
"""Adds a public property for a "repeated" protocol message field. Clients
447
can use this property to get the value of the field, which will be either a
448
_RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
451
Note that when clients add values to these containers, we perform
452
type-checking in the case of repeated scalar fields, and we also set any
453
necessary "has" bits as a side-effect.
456
field: A FieldDescriptor for this field.
457
cls: The class we're constructing.
459
proto_field_name = field.name
460
property_name = _PropertyName(proto_field_name)
463
field_value = self._fields.get(field)
464
if field_value is None:
465
# Construct a new object to represent this field.
466
field_value = field._default_constructor(self)
468
# Atomically check if another thread has preempted us and, if not, swap
469
# in the new object we just created. If someone has preempted us, we
470
# take that object and discard ours.
471
# WARNING: We are relying on setdefault() being atomic. This is true
472
# in CPython but we haven't investigated others. This warning appears
473
# in several other locations in this file.
474
field_value = self._fields.setdefault(field, field_value)
476
getter.__module__ = None
477
getter.__doc__ = 'Getter for %s.' % proto_field_name
479
# We define a setter just so we can throw an exception with a more
480
# helpful error message.
481
def setter(self, new_value):
482
raise AttributeError('Assignment not allowed to repeated field '
483
'"%s" in protocol message object.' % proto_field_name)
485
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
486
setattr(cls, property_name, property(getter, setter, doc=doc))
489
def _AddPropertiesForNonRepeatedScalarField(field, cls):
490
"""Adds a public property for a nonrepeated, scalar protocol message field.
491
Clients can use this property to get and directly set the value of the field.
492
Note that when the client sets the value of a field by using this property,
493
all necessary "has" bits are set as a side-effect, and we also perform
497
field: A FieldDescriptor for this field.
498
cls: The class we're constructing.
500
proto_field_name = field.name
501
property_name = _PropertyName(proto_field_name)
502
type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
503
default_value = field.default_value
506
return self._fields.get(field, default_value)
507
getter.__module__ = None
508
getter.__doc__ = 'Getter for %s.' % proto_field_name
509
def setter(self, new_value):
510
type_checker.CheckValue(new_value)
511
self._fields[field] = new_value
512
# Check _cached_byte_size_dirty inline to improve performance, since scalar
513
# setters are called frequently.
514
if not self._cached_byte_size_dirty:
516
setter.__module__ = None
517
setter.__doc__ = 'Setter for %s.' % proto_field_name
519
# Add a property to encapsulate the getter/setter.
520
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
521
setattr(cls, property_name, property(getter, setter, doc=doc))
524
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
525
"""Adds a public property for a nonrepeated, composite protocol message field.
526
A composite field is a "group" or "message" field.
528
Clients can use this property to get the value of the field, but cannot
529
assign to the property directly.
532
field: A FieldDescriptor for this field.
533
cls: The class we're constructing.
535
# TODO(robinson): Remove duplication with similar method
536
# for non-repeated scalars.
537
proto_field_name = field.name
538
property_name = _PropertyName(proto_field_name)
539
message_type = field.message_type
542
field_value = self._fields.get(field)
543
if field_value is None:
544
# Construct a new object to represent this field.
545
field_value = message_type._concrete_class()
546
field_value._SetListener(self._listener_for_children)
548
# Atomically check if another thread has preempted us and, if not, swap
549
# in the new object we just created. If someone has preempted us, we
550
# take that object and discard ours.
551
# WARNING: We are relying on setdefault() being atomic. This is true
552
# in CPython but we haven't investigated others. This warning appears
553
# in several other locations in this file.
554
field_value = self._fields.setdefault(field, field_value)
556
getter.__module__ = None
557
getter.__doc__ = 'Getter for %s.' % proto_field_name
559
# We define a setter just so we can throw an exception with a more
560
# helpful error message.
561
def setter(self, new_value):
562
raise AttributeError('Assignment not allowed to composite field '
563
'"%s" in protocol message object.' % proto_field_name)
565
# Add a property to encapsulate the getter.
566
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
567
setattr(cls, property_name, property(getter, setter, doc=doc))
570
def _AddPropertiesForExtensions(descriptor, cls):
571
"""Adds properties for all fields in this protocol message type."""
572
extension_dict = descriptor.extensions_by_name
573
for extension_name, extension_field in extension_dict.iteritems():
574
constant_name = extension_name.upper() + "_FIELD_NUMBER"
575
setattr(cls, constant_name, extension_field.number)
578
def _AddStaticMethods(cls):
579
# TODO(robinson): This probably needs to be thread-safe(?)
580
def RegisterExtension(extension_handle):
581
extension_handle.containing_type = cls.DESCRIPTOR
582
_AttachFieldHelpers(cls, extension_handle)
584
# Try to insert our extension, failing if an extension with the same number
586
actual_handle = cls._extensions_by_number.setdefault(
587
extension_handle.number, extension_handle)
588
if actual_handle is not extension_handle:
589
raise AssertionError(
590
'Extensions "%s" and "%s" both try to extend message type "%s" with '
592
(extension_handle.full_name, actual_handle.full_name,
593
cls.DESCRIPTOR.full_name, extension_handle.number))
595
cls._extensions_by_name[extension_handle.full_name] = extension_handle
597
handle = extension_handle # avoid line wrapping
598
if _IsMessageSetExtension(handle):
599
# MessageSet extension. Also register under type name.
600
cls._extensions_by_name[
601
extension_handle.message_type.full_name] = extension_handle
603
cls.RegisterExtension = staticmethod(RegisterExtension)
607
message.MergeFromString(s)
609
cls.FromString = staticmethod(FromString)
612
def _IsPresent(item):
613
"""Given a (FieldDescriptor, value) tuple from _fields, return true if the
614
value should be included in the list returned by ListFields()."""
616
if item[0].label == _FieldDescriptor.LABEL_REPEATED:
618
elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
619
return item[1]._is_present_in_parent
624
def _AddListFieldsMethod(message_descriptor, cls):
625
"""Helper for _AddMessageMethods()."""
627
def ListFields(self):
628
all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
629
all_fields.sort(key = lambda item: item[0].number)
632
cls.ListFields = ListFields
635
def _AddHasFieldMethod(message_descriptor, cls):
636
"""Helper for _AddMessageMethods()."""
639
for field in message_descriptor.fields:
640
if field.label != _FieldDescriptor.LABEL_REPEATED:
641
singular_fields[field.name] = field
643
def HasField(self, field_name):
645
field = singular_fields[field_name]
648
'Protocol message has no singular "%s" field.' % field_name)
650
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
651
value = self._fields.get(field)
652
return value is not None and value._is_present_in_parent
654
return field in self._fields
655
cls.HasField = HasField
658
def _AddClearFieldMethod(message_descriptor, cls):
659
"""Helper for _AddMessageMethods()."""
660
def ClearField(self, field_name):
662
field = message_descriptor.fields_by_name[field_name]
664
raise ValueError('Protocol message has no "%s" field.' % field_name)
666
if field in self._fields:
667
# Note: If the field is a sub-message, its listener will still point
668
# at us. That's fine, because the worst than can happen is that it
669
# will call _Modified() and invalidate our byte size. Big deal.
670
del self._fields[field]
672
# Always call _Modified() -- even if nothing was changed, this is
673
# a mutating method, and thus calling it should cause the field to become
674
# present in the parent message.
677
cls.ClearField = ClearField
680
def _AddClearExtensionMethod(cls):
681
"""Helper for _AddMessageMethods()."""
682
def ClearExtension(self, extension_handle):
683
_VerifyExtensionHandle(self, extension_handle)
685
# Similar to ClearField(), above.
686
if extension_handle in self._fields:
687
del self._fields[extension_handle]
689
cls.ClearExtension = ClearExtension
692
def _AddClearMethod(message_descriptor, cls):
693
"""Helper for _AddMessageMethods()."""
701
def _AddHasExtensionMethod(cls):
702
"""Helper for _AddMessageMethods()."""
703
def HasExtension(self, extension_handle):
704
_VerifyExtensionHandle(self, extension_handle)
705
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
706
raise KeyError('"%s" is repeated.' % extension_handle.full_name)
708
if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
709
value = self._fields.get(extension_handle)
710
return value is not None and value._is_present_in_parent
712
return extension_handle in self._fields
713
cls.HasExtension = HasExtension
716
def _AddEqualsMethod(message_descriptor, cls):
717
"""Helper for _AddMessageMethods()."""
718
def __eq__(self, other):
719
if (not isinstance(other, message_mod.Message) or
720
other.DESCRIPTOR != self.DESCRIPTOR):
726
return self.ListFields() == other.ListFields()
731
def _AddStrMethod(message_descriptor, cls):
732
"""Helper for _AddMessageMethods()."""
734
return text_format.MessageToString(self)
735
cls.__str__ = __str__
738
def _AddSetListenerMethod(cls):
739
"""Helper for _AddMessageMethods()."""
740
def SetListener(self, listener):
742
self._listener = message_listener_mod.NullMessageListener()
744
self._listener = listener
745
cls._SetListener = SetListener
748
def _BytesForNonRepeatedElement(value, field_number, field_type):
749
"""Returns the number of bytes needed to serialize a non-repeated element.
750
The returned byte count includes space for tag information and any
751
other additional space associated with serializing value.
754
value: Value we're serializing.
755
field_number: Field number of this value. (Since the field number
756
is stored as part of a varint-encoded tag, this has an impact
757
on the total bytes required to serialize the value).
758
field_type: The type of the field. One of the TYPE_* constants
759
within FieldDescriptor.
762
fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
763
return fn(field_number, value)
765
raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
768
def _AddByteSizeMethod(message_descriptor, cls):
769
"""Helper for _AddMessageMethods()."""
772
if not self._cached_byte_size_dirty:
773
return self._cached_byte_size
776
for field_descriptor, field_value in self.ListFields():
777
size += field_descriptor._sizer(field_value)
779
self._cached_byte_size = size
780
self._cached_byte_size_dirty = False
781
self._listener_for_children.dirty = False
784
cls.ByteSize = ByteSize
787
def _AddSerializeToStringMethod(message_descriptor, cls):
788
"""Helper for _AddMessageMethods()."""
790
def SerializeToString(self):
791
# Check if the message has all of its required fields set.
793
if not self.IsInitialized():
794
raise message_mod.EncodeError(
795
'Message is missing required fields: ' +
796
','.join(self.FindInitializationErrors()))
797
return self.SerializePartialToString()
798
cls.SerializeToString = SerializeToString
801
def _AddSerializePartialToStringMethod(message_descriptor, cls):
802
"""Helper for _AddMessageMethods()."""
804
def SerializePartialToString(self):
806
self._InternalSerialize(out.write)
807
return out.getvalue()
808
cls.SerializePartialToString = SerializePartialToString
810
def InternalSerialize(self, write_bytes):
811
for field_descriptor, field_value in self.ListFields():
812
field_descriptor._encoder(write_bytes, field_value)
813
cls._InternalSerialize = InternalSerialize
816
def _AddMergeFromStringMethod(message_descriptor, cls):
817
"""Helper for _AddMessageMethods()."""
818
def MergeFromString(self, serialized):
819
length = len(serialized)
821
if self._InternalParse(serialized, 0, length) != length:
822
# The only reason _InternalParse would return early is if it
823
# encountered an end-group tag.
824
raise message_mod.DecodeError('Unexpected end-group tag.')
826
raise message_mod.DecodeError('Truncated message.')
827
except struct.error, e:
828
raise message_mod.DecodeError(e)
829
return length # Return this for legacy reasons.
830
cls.MergeFromString = MergeFromString
832
local_ReadTag = decoder.ReadTag
833
local_SkipField = decoder.SkipField
834
decoders_by_tag = cls._decoders_by_tag
836
def InternalParse(self, buffer, pos, end):
838
field_dict = self._fields
840
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
841
field_decoder = decoders_by_tag.get(tag_bytes)
842
if field_decoder is None:
843
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
848
pos = field_decoder(buffer, new_pos, end, self, field_dict)
850
cls._InternalParse = InternalParse
853
def _AddIsInitializedMethod(message_descriptor, cls):
854
"""Adds the IsInitialized and FindInitializationError methods to the
855
protocol message class."""
857
required_fields = [field for field in message_descriptor.fields
858
if field.label == _FieldDescriptor.LABEL_REQUIRED]
860
def IsInitialized(self, errors=None):
861
"""Checks if all required fields of a message are set.
864
errors: A list which, if provided, will be populated with the field
865
paths of all missing required fields.
868
True iff the specified message has all required fields set.
871
# Performance is critical so we avoid HasField() and ListFields().
873
for field in required_fields:
874
if (field not in self._fields or
875
(field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
876
not self._fields[field]._is_present_in_parent)):
877
if errors is not None:
878
errors.extend(self.FindInitializationErrors())
881
for field, value in self._fields.iteritems():
882
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
883
if field.label == _FieldDescriptor.LABEL_REPEATED:
884
for element in value:
885
if not element.IsInitialized():
886
if errors is not None:
887
errors.extend(self.FindInitializationErrors())
889
elif value._is_present_in_parent and not value.IsInitialized():
890
if errors is not None:
891
errors.extend(self.FindInitializationErrors())
896
cls.IsInitialized = IsInitialized
898
def FindInitializationErrors(self):
899
"""Finds required fields which are not initialized.
902
A list of strings. Each string is a path to an uninitialized field from
903
the top-level message, e.g. "foo.bar[5].baz".
906
errors = [] # simplify things
908
for field in required_fields:
909
if not self.HasField(field.name):
910
errors.append(field.name)
912
for field, value in self.ListFields():
913
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
914
if field.is_extension:
915
name = "(%s)" % field.full_name
919
if field.label == _FieldDescriptor.LABEL_REPEATED:
920
for i in xrange(len(value)):
922
prefix = "%s[%d]." % (name, i)
923
sub_errors = element.FindInitializationErrors()
924
errors += [ prefix + error for error in sub_errors ]
927
sub_errors = value.FindInitializationErrors()
928
errors += [ prefix + error for error in sub_errors ]
932
cls.FindInitializationErrors = FindInitializationErrors
935
def _AddMergeFromMethod(cls):
936
LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
937
CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
939
def MergeFrom(self, msg):
940
assert msg is not self
943
fields = self._fields
945
for field, value in msg._fields.iteritems():
946
if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
947
field_value = fields.get(field)
948
if field_value is None:
949
# Construct a new object to represent this field.
950
field_value = field._default_constructor(self)
951
fields[field] = field_value
952
field_value.MergeFrom(value)
954
self._fields[field] = value
955
cls.MergeFrom = MergeFrom
958
def _AddMessageMethods(message_descriptor, cls):
959
"""Adds implementations of all Message methods to cls."""
960
_AddListFieldsMethod(message_descriptor, cls)
961
_AddHasFieldMethod(message_descriptor, cls)
962
_AddClearFieldMethod(message_descriptor, cls)
963
if message_descriptor.is_extendable:
964
_AddClearExtensionMethod(cls)
965
_AddHasExtensionMethod(cls)
966
_AddClearMethod(message_descriptor, cls)
967
_AddEqualsMethod(message_descriptor, cls)
968
_AddStrMethod(message_descriptor, cls)
969
_AddSetListenerMethod(cls)
970
_AddByteSizeMethod(message_descriptor, cls)
971
_AddSerializeToStringMethod(message_descriptor, cls)
972
_AddSerializePartialToStringMethod(message_descriptor, cls)
973
_AddMergeFromStringMethod(message_descriptor, cls)
974
_AddIsInitializedMethod(message_descriptor, cls)
975
_AddMergeFromMethod(cls)
978
def _AddPrivateHelperMethods(cls):
979
"""Adds implementation of private helper methods to cls."""
982
"""Sets the _cached_byte_size_dirty bit to true,
983
and propagates this to our listener iff this was a state change.
986
# Note: Some callers check _cached_byte_size_dirty before calling
987
# _Modified() as an extra optimization. So, if this method is ever
988
# changed such that it does stuff even when _cached_byte_size_dirty is
989
# already true, the callers need to be updated.
990
if not self._cached_byte_size_dirty:
991
self._cached_byte_size_dirty = True
992
self._listener_for_children.dirty = True
993
self._is_present_in_parent = True
994
self._listener.Modified()
996
cls._Modified = Modified
997
cls.SetInParent = Modified
1000
class _Listener(object):
1002
"""MessageListener implementation that a parent message registers with its
1005
In order to support semantics like:
1007
foo.bar.baz.qux = 23
1008
assert foo.HasField('bar')
1010
...child objects must have back references to their parents.
1011
This helper class is at the heart of this support.
1014
def __init__(self, parent_message):
1016
parent_message: The message whose _Modified() method we should call when
1017
we receive Modified() messages.
1019
# This listener establishes a back reference from a child (contained) object
1020
# to its parent (containing) object. We make this a weak reference to avoid
1021
# creating cyclic garbage when the client finishes with the 'parent' object
1023
if isinstance(parent_message, weakref.ProxyType):
1024
self._parent_message_weakref = parent_message
1026
self._parent_message_weakref = weakref.proxy(parent_message)
1028
# As an optimization, we also indicate directly on the listener whether
1029
# or not the parent message is dirty. This way we can avoid traversing
1030
# up the tree in the common case.
1037
# Propagate the signal to our parents iff this is the first field set.
1038
self._parent_message_weakref._Modified()
1039
except ReferenceError:
1040
# We can get here if a client has kept a reference to a child object,
1041
# and is now setting a field on it, but the child's parent has been
1042
# garbage-collected. This is not an error.
1046
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1047
# TODO(robinson): Unify error handling of "unknown extension" crap.
1048
# TODO(robinson): Support iteritems()-style iteration over all
1049
# extensions with the "has" bits turned on?
1050
class _ExtensionDict(object):
1052
"""Dict-like container for supporting an indexable "Extensions"
1053
field on proto instances.
1055
Note that in all cases we expect extension handles to be
1059
def __init__(self, extended_message):
1060
"""extended_message: Message instance for which we are the Extensions dict.
1063
self._extended_message = extended_message
1065
def __getitem__(self, extension_handle):
1066
"""Returns the current value of the given extension handle."""
1068
_VerifyExtensionHandle(self._extended_message, extension_handle)
1070
result = self._extended_message._fields.get(extension_handle)
1071
if result is not None:
1074
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1075
result = extension_handle._default_constructor(self._extended_message)
1076
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1077
result = extension_handle.message_type._concrete_class()
1079
result._SetListener(self._extended_message._listener_for_children)
1080
except ReferenceError:
1083
# Singular scalar -- just return the default without inserting into the
1085
return extension_handle.default_value
1087
# Atomically check if another thread has preempted us and, if not, swap
1088
# in the new object we just created. If someone has preempted us, we
1089
# take that object and discard ours.
1090
# WARNING: We are relying on setdefault() being atomic. This is true
1091
# in CPython but we haven't investigated others. This warning appears
1092
# in several other locations in this file.
1093
result = self._extended_message._fields.setdefault(
1094
extension_handle, result)
1098
def __eq__(self, other):
1099
if not isinstance(other, self.__class__):
1102
my_fields = self._extended_message.ListFields()
1103
other_fields = other._extended_message.ListFields()
1105
# Get rid of non-extension fields.
1106
my_fields = [ field for field in my_fields if field.is_extension ]
1107
other_fields = [ field for field in other_fields if field.is_extension ]
1109
return my_fields == other_fields
1111
def __ne__(self, other):
1112
return not self == other
1114
# Note that this is only meaningful for non-repeated, scalar extension
1115
# fields. Note also that we may have to call _Modified() when we do
1116
# successfully set a field this way, to set any necssary "has" bits in the
1117
# ancestors of the extended message.
1118
def __setitem__(self, extension_handle, value):
1119
"""If extension_handle specifies a non-repeated, scalar extension
1120
field, sets the value of that field.
1123
_VerifyExtensionHandle(self._extended_message, extension_handle)
1125
if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1126
extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1128
'Cannot assign to extension "%s" because it is a repeated or '
1129
'composite type.' % extension_handle.full_name)
1131
# It's slightly wasteful to lookup the type checker each time,
1132
# but we expect this to be a vanishingly uncommon case anyway.
1133
type_checker = type_checkers.GetTypeChecker(
1134
extension_handle.cpp_type, extension_handle.type)
1135
type_checker.CheckValue(value)
1136
self._extended_message._fields[extension_handle] = value
1137
self._extended_message._Modified()
1139
def _FindExtensionByName(self, name):
1140
"""Tries to find a known extension with the specified name.
1143
name: Extension full name.
1146
Extension field descriptor.
1148
return self._extended_message._extensions_by_name.get(name, None)