~ubuntu-branches/ubuntu/trusty/protobuf/trusty-proposed

« back to all changes in this revision

Viewing changes to python/google/protobuf/reflection.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2011-05-31 14:41:47 UTC
  • mfrom: (2.2.8 sid)
  • Revision ID: james.westby@ubuntu.com-20110531144147-s41g5fozgvyo462l
Tags: 2.4.0a-2ubuntu1
* Merge with Debian; remaining changes:
  - Fix linking with -lpthread.

Show diffs side-by-side

added added

removed removed

Lines of Context:
29
29
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
30
 
31
31
# This code is meant to work on Python 2.4 and above only.
32
 
#
33
 
# TODO(robinson): Helpers for verbose, common checks like seeing if a
34
 
# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
32
 
36
33
"""Contains a metaclass and helper functions used to create
37
34
protocol message classes from Descriptor objects at runtime.
50
47
 
51
48
__author__ = 'robinson@google.com (Will Robinson)'
52
49
 
53
 
try:
54
 
  from cStringIO import StringIO
55
 
except ImportError:
56
 
  from StringIO import StringIO
57
 
import struct
58
 
import weakref
59
50
 
60
 
# We use "as" to avoid name collisions with variables.
61
 
from google.protobuf.internal import containers
62
 
from google.protobuf.internal import decoder
63
 
from google.protobuf.internal import encoder
64
 
from google.protobuf.internal import message_listener as message_listener_mod
65
 
from google.protobuf.internal import type_checkers
66
 
from google.protobuf.internal import wire_format
 
51
from google.protobuf.internal import api_implementation
67
52
from google.protobuf import descriptor as descriptor_mod
68
 
from google.protobuf import message as message_mod
69
 
from google.protobuf import text_format
70
 
 
71
53
_FieldDescriptor = descriptor_mod.FieldDescriptor
72
54
 
73
55
 
 
56
if api_implementation.Type() == 'cpp':
 
57
  from google.protobuf.internal import cpp_message
 
58
  _NewMessage = cpp_message.NewMessage
 
59
  _InitMessage = cpp_message.InitMessage
 
60
else:
 
61
  from google.protobuf.internal import python_message
 
62
  _NewMessage = python_message.NewMessage
 
63
  _InitMessage = python_message.InitMessage
 
64
 
 
65
 
74
66
class GeneratedProtocolMessageType(type):
75
67
 
76
68
  """Metaclass for protocol message classes created at runtime from Descriptors.
120
112
      Newly-allocated class.
121
113
    """
122
114
    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
123
 
    _AddSlots(descriptor, dictionary)
124
 
    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
 
115
    _NewMessage(descriptor, dictionary)
125
116
    superclass = super(GeneratedProtocolMessageType, cls)
126
 
    return superclass.__new__(cls, name, bases, dictionary)
 
117
 
 
118
    new_class = superclass.__new__(cls, name, bases, dictionary)
 
119
    setattr(descriptor, '_concrete_class', new_class)
 
120
    return new_class
127
121
 
128
122
  def __init__(cls, name, bases, dictionary):
129
123
    """Here we perform the majority of our work on the class.
143
137
        type.
144
138
    """
145
139
    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
146
 
 
147
 
    cls._decoders_by_tag = {}
148
 
    cls._extensions_by_name = {}
149
 
    cls._extensions_by_number = {}
150
 
    if (descriptor.has_options and
151
 
        descriptor.GetOptions().message_set_wire_format):
152
 
      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
153
 
          decoder.MessageSetItemDecoder(cls._extensions_by_number))
154
 
 
155
 
    # We act as a "friend" class of the descriptor, setting
156
 
    # its _concrete_class attribute the first time we use a
157
 
    # given descriptor to initialize a concrete protocol message
158
 
    # class.  We also attach stuff to each FieldDescriptor for quick
159
 
    # lookup later on.
160
 
    concrete_class_attr_name = '_concrete_class'
161
 
    if not hasattr(descriptor, concrete_class_attr_name):
162
 
      setattr(descriptor, concrete_class_attr_name, cls)
163
 
      for field in descriptor.fields:
164
 
        _AttachFieldHelpers(cls, field)
165
 
 
166
 
    _AddEnumValues(descriptor, cls)
167
 
    _AddInitMethod(descriptor, cls)
168
 
    _AddPropertiesForFields(descriptor, cls)
169
 
    _AddPropertiesForExtensions(descriptor, cls)
170
 
    _AddStaticMethods(cls)
171
 
    _AddMessageMethods(descriptor, cls)
172
 
    _AddPrivateHelperMethods(cls)
 
140
    _InitMessage(descriptor, cls)
173
141
    superclass = super(GeneratedProtocolMessageType, cls)
174
142
    superclass.__init__(name, bases, dictionary)
175
 
 
176
 
 
177
 
# Stateless helpers for GeneratedProtocolMessageType below.
178
 
# Outside clients should not access these directly.
179
 
#
180
 
# I opted not to make any of these methods on the metaclass, to make it more
181
 
# clear that I'm not really using any state there and to keep clients from
182
 
# thinking that they have direct access to these construction helpers.
183
 
 
184
 
 
185
 
def _PropertyName(proto_field_name):
186
 
  """Returns the name of the public property attribute which
187
 
  clients can use to get and (in some cases) set the value
188
 
  of a protocol message field.
189
 
 
190
 
  Args:
191
 
    proto_field_name: The protocol message field name, exactly
192
 
      as it appears (or would appear) in a .proto file.
193
 
  """
194
 
  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
195
 
  # nnorwitz makes my day by writing:
196
 
  # """
197
 
  # FYI.  See the keyword module in the stdlib. This could be as simple as:
198
 
  #
199
 
  # if keyword.iskeyword(proto_field_name):
200
 
  #   return proto_field_name + "_"
201
 
  # return proto_field_name
202
 
  # """
203
 
  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
204
 
  #   getattr() and setattr() to reflectively manipulate field values.  If we
205
 
  #   rename the properties, then every such user has to also make sure to apply
206
 
  #   the same transformation.  Note that currently if you name a field "yield",
207
 
  #   you can still access it just fine using getattr/setattr -- it's not even
208
 
  #   that cumbersome to do so.
209
 
  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
210
 
  #   position.
211
 
  return proto_field_name
212
 
 
213
 
 
214
 
def _VerifyExtensionHandle(message, extension_handle):
215
 
  """Verify that the given extension handle is valid."""
216
 
 
217
 
  if not isinstance(extension_handle, _FieldDescriptor):
218
 
    raise KeyError('HasExtension() expects an extension handle, got: %s' %
219
 
                   extension_handle)
220
 
 
221
 
  if not extension_handle.is_extension:
222
 
    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
223
 
 
224
 
  if extension_handle.containing_type is not message.DESCRIPTOR:
225
 
    raise KeyError('Extension "%s" extends message type "%s", but this '
226
 
                   'message is of type "%s".' %
227
 
                   (extension_handle.full_name,
228
 
                    extension_handle.containing_type.full_name,
229
 
                    message.DESCRIPTOR.full_name))
230
 
 
231
 
 
232
 
def _AddSlots(message_descriptor, dictionary):
233
 
  """Adds a __slots__ entry to dictionary, containing the names of all valid
234
 
  attributes for this message type.
235
 
 
236
 
  Args:
237
 
    message_descriptor: A Descriptor instance describing this message type.
238
 
    dictionary: Class dictionary to which we'll add a '__slots__' entry.
239
 
  """
240
 
  dictionary['__slots__'] = ['_cached_byte_size',
241
 
                             '_cached_byte_size_dirty',
242
 
                             '_fields',
243
 
                             '_is_present_in_parent',
244
 
                             '_listener',
245
 
                             '_listener_for_children',
246
 
                             '__weakref__']
247
 
 
248
 
 
249
 
def _IsMessageSetExtension(field):
250
 
  return (field.is_extension and
251
 
          field.containing_type.has_options and
252
 
          field.containing_type.GetOptions().message_set_wire_format and
253
 
          field.type == _FieldDescriptor.TYPE_MESSAGE and
254
 
          field.message_type == field.extension_scope and
255
 
          field.label == _FieldDescriptor.LABEL_OPTIONAL)
256
 
 
257
 
 
258
 
def _AttachFieldHelpers(cls, field_descriptor):
259
 
  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
260
 
  is_packed = (field_descriptor.has_options and
261
 
               field_descriptor.GetOptions().packed)
262
 
 
263
 
  if _IsMessageSetExtension(field_descriptor):
264
 
    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
265
 
    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
266
 
  else:
267
 
    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
268
 
        field_descriptor.number, is_repeated, is_packed)
269
 
    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
270
 
        field_descriptor.number, is_repeated, is_packed)
271
 
 
272
 
  field_descriptor._encoder = field_encoder
273
 
  field_descriptor._sizer = sizer
274
 
  field_descriptor._default_constructor = _DefaultValueConstructorForField(
275
 
      field_descriptor)
276
 
 
277
 
  def AddDecoder(wiretype, is_packed):
278
 
    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279
 
    cls._decoders_by_tag[tag_bytes] = (
280
 
        type_checkers.TYPE_TO_DECODER[field_descriptor.type](
281
 
            field_descriptor.number, is_repeated, is_packed,
282
 
            field_descriptor, field_descriptor._default_constructor))
283
 
 
284
 
  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
285
 
             False)
286
 
 
287
 
  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
288
 
    # To support wire compatibility of adding packed = true, add a decoder for
289
 
    # packed values regardless of the field's options.
290
 
    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
291
 
 
292
 
 
293
 
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
294
 
  extension_dict = descriptor.extensions_by_name
295
 
  for extension_name, extension_field in extension_dict.iteritems():
296
 
    assert extension_name not in dictionary
297
 
    dictionary[extension_name] = extension_field
298
 
 
299
 
 
300
 
def _AddEnumValues(descriptor, cls):
301
 
  """Sets class-level attributes for all enum fields defined in this message.
302
 
 
303
 
  Args:
304
 
    descriptor: Descriptor object for this message type.
305
 
    cls: Class we're constructing for this message type.
306
 
  """
307
 
  for enum_type in descriptor.enum_types:
308
 
    for enum_value in enum_type.values:
309
 
      setattr(cls, enum_value.name, enum_value.number)
310
 
 
311
 
 
312
 
def _DefaultValueConstructorForField(field):
313
 
  """Returns a function which returns a default value for a field.
314
 
 
315
 
  Args:
316
 
    field: FieldDescriptor object for this field.
317
 
 
318
 
  The returned function has one argument:
319
 
    message: Message instance containing this field, or a weakref proxy
320
 
      of same.
321
 
 
322
 
  That function in turn returns a default value for this field.  The default
323
 
    value may refer back to |message| via a weak reference.
324
 
  """
325
 
 
326
 
  if field.label == _FieldDescriptor.LABEL_REPEATED:
327
 
    if field.default_value != []:
328
 
      raise ValueError('Repeated field default value not empty list: %s' % (
329
 
          field.default_value))
330
 
    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
331
 
      # We can't look at _concrete_class yet since it might not have
332
 
      # been set.  (Depends on order in which we initialize the classes).
333
 
      message_type = field.message_type
334
 
      def MakeRepeatedMessageDefault(message):
335
 
        return containers.RepeatedCompositeFieldContainer(
336
 
            message._listener_for_children, field.message_type)
337
 
      return MakeRepeatedMessageDefault
338
 
    else:
339
 
      type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
340
 
      def MakeRepeatedScalarDefault(message):
341
 
        return containers.RepeatedScalarFieldContainer(
342
 
            message._listener_for_children, type_checker)
343
 
      return MakeRepeatedScalarDefault
344
 
 
345
 
  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
346
 
    # _concrete_class may not yet be initialized.
347
 
    message_type = field.message_type
348
 
    def MakeSubMessageDefault(message):
349
 
      result = message_type._concrete_class()
350
 
      result._SetListener(message._listener_for_children)
351
 
      return result
352
 
    return MakeSubMessageDefault
353
 
 
354
 
  def MakeScalarDefault(message):
355
 
    return field.default_value
356
 
  return MakeScalarDefault
357
 
 
358
 
 
359
 
def _AddInitMethod(message_descriptor, cls):
360
 
  """Adds an __init__ method to cls."""
361
 
  fields = message_descriptor.fields
362
 
  def init(self, **kwargs):
363
 
    self._cached_byte_size = 0
364
 
    self._cached_byte_size_dirty = False
365
 
    self._fields = {}
366
 
    self._is_present_in_parent = False
367
 
    self._listener = message_listener_mod.NullMessageListener()
368
 
    self._listener_for_children = _Listener(self)
369
 
    for field_name, field_value in kwargs.iteritems():
370
 
      field = _GetFieldByName(message_descriptor, field_name)
371
 
      if field is None:
372
 
        raise TypeError("%s() got an unexpected keyword argument '%s'" %
373
 
                        (message_descriptor.name, field_name))
374
 
      if field.label == _FieldDescriptor.LABEL_REPEATED:
375
 
        copy = field._default_constructor(self)
376
 
        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
377
 
          for val in field_value:
378
 
            copy.add().MergeFrom(val)
379
 
        else:  # Scalar
380
 
          copy.extend(field_value)
381
 
        self._fields[field] = copy
382
 
      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
383
 
        copy = field._default_constructor(self)
384
 
        copy.MergeFrom(field_value)
385
 
        self._fields[field] = copy
386
 
      else:
387
 
        self._fields[field] = field_value
388
 
 
389
 
  init.__module__ = None
390
 
  init.__doc__ = None
391
 
  cls.__init__ = init
392
 
 
393
 
 
394
 
def _GetFieldByName(message_descriptor, field_name):
395
 
  """Returns a field descriptor by field name.
396
 
 
397
 
  Args:
398
 
    message_descriptor: A Descriptor describing all fields in message.
399
 
    field_name: The name of the field to retrieve.
400
 
  Returns:
401
 
    The field descriptor associated with the field name.
402
 
  """
403
 
  try:
404
 
    return message_descriptor.fields_by_name[field_name]
405
 
  except KeyError:
406
 
    raise ValueError('Protocol message has no "%s" field.' % field_name)
407
 
 
408
 
 
409
 
def _AddPropertiesForFields(descriptor, cls):
410
 
  """Adds properties for all fields in this protocol message type."""
411
 
  for field in descriptor.fields:
412
 
    _AddPropertiesForField(field, cls)
413
 
 
414
 
  if descriptor.is_extendable:
415
 
    # _ExtensionDict is just an adaptor with no state so we allocate a new one
416
 
    # every time it is accessed.
417
 
    cls.Extensions = property(lambda self: _ExtensionDict(self))
418
 
 
419
 
 
420
 
def _AddPropertiesForField(field, cls):
421
 
  """Adds a public property for a protocol message field.
422
 
  Clients can use this property to get and (in the case
423
 
  of non-repeated scalar fields) directly set the value
424
 
  of a protocol message field.
425
 
 
426
 
  Args:
427
 
    field: A FieldDescriptor for this field.
428
 
    cls: The class we're constructing.
429
 
  """
430
 
  # Catch it if we add other types that we should
431
 
  # handle specially here.
432
 
  assert _FieldDescriptor.MAX_CPPTYPE == 10
433
 
 
434
 
  constant_name = field.name.upper() + "_FIELD_NUMBER"
435
 
  setattr(cls, constant_name, field.number)
436
 
 
437
 
  if field.label == _FieldDescriptor.LABEL_REPEATED:
438
 
    _AddPropertiesForRepeatedField(field, cls)
439
 
  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
440
 
    _AddPropertiesForNonRepeatedCompositeField(field, cls)
441
 
  else:
442
 
    _AddPropertiesForNonRepeatedScalarField(field, cls)
443
 
 
444
 
 
445
 
def _AddPropertiesForRepeatedField(field, cls):
446
 
  """Adds a public property for a "repeated" protocol message field.  Clients
447
 
  can use this property to get the value of the field, which will be either a
448
 
  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
449
 
  below).
450
 
 
451
 
  Note that when clients add values to these containers, we perform
452
 
  type-checking in the case of repeated scalar fields, and we also set any
453
 
  necessary "has" bits as a side-effect.
454
 
 
455
 
  Args:
456
 
    field: A FieldDescriptor for this field.
457
 
    cls: The class we're constructing.
458
 
  """
459
 
  proto_field_name = field.name
460
 
  property_name = _PropertyName(proto_field_name)
461
 
 
462
 
  def getter(self):
463
 
    field_value = self._fields.get(field)
464
 
    if field_value is None:
465
 
      # Construct a new object to represent this field.
466
 
      field_value = field._default_constructor(self)
467
 
 
468
 
      # Atomically check if another thread has preempted us and, if not, swap
469
 
      # in the new object we just created.  If someone has preempted us, we
470
 
      # take that object and discard ours.
471
 
      # WARNING:  We are relying on setdefault() being atomic.  This is true
472
 
      #   in CPython but we haven't investigated others.  This warning appears
473
 
      #   in several other locations in this file.
474
 
      field_value = self._fields.setdefault(field, field_value)
475
 
    return field_value
476
 
  getter.__module__ = None
477
 
  getter.__doc__ = 'Getter for %s.' % proto_field_name
478
 
 
479
 
  # We define a setter just so we can throw an exception with a more
480
 
  # helpful error message.
481
 
  def setter(self, new_value):
482
 
    raise AttributeError('Assignment not allowed to repeated field '
483
 
                         '"%s" in protocol message object.' % proto_field_name)
484
 
 
485
 
  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
486
 
  setattr(cls, property_name, property(getter, setter, doc=doc))
487
 
 
488
 
 
489
 
def _AddPropertiesForNonRepeatedScalarField(field, cls):
490
 
  """Adds a public property for a nonrepeated, scalar protocol message field.
491
 
  Clients can use this property to get and directly set the value of the field.
492
 
  Note that when the client sets the value of a field by using this property,
493
 
  all necessary "has" bits are set as a side-effect, and we also perform
494
 
  type-checking.
495
 
 
496
 
  Args:
497
 
    field: A FieldDescriptor for this field.
498
 
    cls: The class we're constructing.
499
 
  """
500
 
  proto_field_name = field.name
501
 
  property_name = _PropertyName(proto_field_name)
502
 
  type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
503
 
  default_value = field.default_value
504
 
 
505
 
  def getter(self):
506
 
    return self._fields.get(field, default_value)
507
 
  getter.__module__ = None
508
 
  getter.__doc__ = 'Getter for %s.' % proto_field_name
509
 
  def setter(self, new_value):
510
 
    type_checker.CheckValue(new_value)
511
 
    self._fields[field] = new_value
512
 
    # Check _cached_byte_size_dirty inline to improve performance, since scalar
513
 
    # setters are called frequently.
514
 
    if not self._cached_byte_size_dirty:
515
 
      self._Modified()
516
 
  setter.__module__ = None
517
 
  setter.__doc__ = 'Setter for %s.' % proto_field_name
518
 
 
519
 
  # Add a property to encapsulate the getter/setter.
520
 
  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
521
 
  setattr(cls, property_name, property(getter, setter, doc=doc))
522
 
 
523
 
 
524
 
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
525
 
  """Adds a public property for a nonrepeated, composite protocol message field.
526
 
  A composite field is a "group" or "message" field.
527
 
 
528
 
  Clients can use this property to get the value of the field, but cannot
529
 
  assign to the property directly.
530
 
 
531
 
  Args:
532
 
    field: A FieldDescriptor for this field.
533
 
    cls: The class we're constructing.
534
 
  """
535
 
  # TODO(robinson): Remove duplication with similar method
536
 
  # for non-repeated scalars.
537
 
  proto_field_name = field.name
538
 
  property_name = _PropertyName(proto_field_name)
539
 
  message_type = field.message_type
540
 
 
541
 
  def getter(self):
542
 
    field_value = self._fields.get(field)
543
 
    if field_value is None:
544
 
      # Construct a new object to represent this field.
545
 
      field_value = message_type._concrete_class()
546
 
      field_value._SetListener(self._listener_for_children)
547
 
 
548
 
      # Atomically check if another thread has preempted us and, if not, swap
549
 
      # in the new object we just created.  If someone has preempted us, we
550
 
      # take that object and discard ours.
551
 
      # WARNING:  We are relying on setdefault() being atomic.  This is true
552
 
      #   in CPython but we haven't investigated others.  This warning appears
553
 
      #   in several other locations in this file.
554
 
      field_value = self._fields.setdefault(field, field_value)
555
 
    return field_value
556
 
  getter.__module__ = None
557
 
  getter.__doc__ = 'Getter for %s.' % proto_field_name
558
 
 
559
 
  # We define a setter just so we can throw an exception with a more
560
 
  # helpful error message.
561
 
  def setter(self, new_value):
562
 
    raise AttributeError('Assignment not allowed to composite field '
563
 
                         '"%s" in protocol message object.' % proto_field_name)
564
 
 
565
 
  # Add a property to encapsulate the getter.
566
 
  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
567
 
  setattr(cls, property_name, property(getter, setter, doc=doc))
568
 
 
569
 
 
570
 
def _AddPropertiesForExtensions(descriptor, cls):
571
 
  """Adds properties for all fields in this protocol message type."""
572
 
  extension_dict = descriptor.extensions_by_name
573
 
  for extension_name, extension_field in extension_dict.iteritems():
574
 
    constant_name = extension_name.upper() + "_FIELD_NUMBER"
575
 
    setattr(cls, constant_name, extension_field.number)
576
 
 
577
 
 
578
 
def _AddStaticMethods(cls):
579
 
  # TODO(robinson): This probably needs to be thread-safe(?)
580
 
  def RegisterExtension(extension_handle):
581
 
    extension_handle.containing_type = cls.DESCRIPTOR
582
 
    _AttachFieldHelpers(cls, extension_handle)
583
 
 
584
 
    # Try to insert our extension, failing if an extension with the same number
585
 
    # already exists.
586
 
    actual_handle = cls._extensions_by_number.setdefault(
587
 
        extension_handle.number, extension_handle)
588
 
    if actual_handle is not extension_handle:
589
 
      raise AssertionError(
590
 
          'Extensions "%s" and "%s" both try to extend message type "%s" with '
591
 
          'field number %d.' %
592
 
          (extension_handle.full_name, actual_handle.full_name,
593
 
           cls.DESCRIPTOR.full_name, extension_handle.number))
594
 
 
595
 
    cls._extensions_by_name[extension_handle.full_name] = extension_handle
596
 
 
597
 
    handle = extension_handle  # avoid line wrapping
598
 
    if _IsMessageSetExtension(handle):
599
 
      # MessageSet extension.  Also register under type name.
600
 
      cls._extensions_by_name[
601
 
          extension_handle.message_type.full_name] = extension_handle
602
 
 
603
 
  cls.RegisterExtension = staticmethod(RegisterExtension)
604
 
 
605
 
  def FromString(s):
606
 
    message = cls()
607
 
    message.MergeFromString(s)
608
 
    return message
609
 
  cls.FromString = staticmethod(FromString)
610
 
 
611
 
 
612
 
def _IsPresent(item):
613
 
  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
614
 
  value should be included in the list returned by ListFields()."""
615
 
 
616
 
  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
617
 
    return bool(item[1])
618
 
  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
619
 
    return item[1]._is_present_in_parent
620
 
  else:
621
 
    return True
622
 
 
623
 
 
624
 
def _AddListFieldsMethod(message_descriptor, cls):
625
 
  """Helper for _AddMessageMethods()."""
626
 
 
627
 
  def ListFields(self):
628
 
    all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
629
 
    all_fields.sort(key = lambda item: item[0].number)
630
 
    return all_fields
631
 
 
632
 
  cls.ListFields = ListFields
633
 
 
634
 
 
635
 
def _AddHasFieldMethod(message_descriptor, cls):
636
 
  """Helper for _AddMessageMethods()."""
637
 
 
638
 
  singular_fields = {}
639
 
  for field in message_descriptor.fields:
640
 
    if field.label != _FieldDescriptor.LABEL_REPEATED:
641
 
      singular_fields[field.name] = field
642
 
 
643
 
  def HasField(self, field_name):
644
 
    try:
645
 
      field = singular_fields[field_name]
646
 
    except KeyError:
647
 
      raise ValueError(
648
 
          'Protocol message has no singular "%s" field.' % field_name)
649
 
 
650
 
    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
651
 
      value = self._fields.get(field)
652
 
      return value is not None and value._is_present_in_parent
653
 
    else:
654
 
      return field in self._fields
655
 
  cls.HasField = HasField
656
 
 
657
 
 
658
 
def _AddClearFieldMethod(message_descriptor, cls):
659
 
  """Helper for _AddMessageMethods()."""
660
 
  def ClearField(self, field_name):
661
 
    try:
662
 
      field = message_descriptor.fields_by_name[field_name]
663
 
    except KeyError:
664
 
      raise ValueError('Protocol message has no "%s" field.' % field_name)
665
 
 
666
 
    if field in self._fields:
667
 
      # Note:  If the field is a sub-message, its listener will still point
668
 
      #   at us.  That's fine, because the worst than can happen is that it
669
 
      #   will call _Modified() and invalidate our byte size.  Big deal.
670
 
      del self._fields[field]
671
 
 
672
 
    # Always call _Modified() -- even if nothing was changed, this is
673
 
    # a mutating method, and thus calling it should cause the field to become
674
 
    # present in the parent message.
675
 
    self._Modified()
676
 
 
677
 
  cls.ClearField = ClearField
678
 
 
679
 
 
680
 
def _AddClearExtensionMethod(cls):
681
 
  """Helper for _AddMessageMethods()."""
682
 
  def ClearExtension(self, extension_handle):
683
 
    _VerifyExtensionHandle(self, extension_handle)
684
 
 
685
 
    # Similar to ClearField(), above.
686
 
    if extension_handle in self._fields:
687
 
      del self._fields[extension_handle]
688
 
    self._Modified()
689
 
  cls.ClearExtension = ClearExtension
690
 
 
691
 
 
692
 
def _AddClearMethod(message_descriptor, cls):
693
 
  """Helper for _AddMessageMethods()."""
694
 
  def Clear(self):
695
 
    # Clear fields.
696
 
    self._fields = {}
697
 
    self._Modified()
698
 
  cls.Clear = Clear
699
 
 
700
 
 
701
 
def _AddHasExtensionMethod(cls):
702
 
  """Helper for _AddMessageMethods()."""
703
 
  def HasExtension(self, extension_handle):
704
 
    _VerifyExtensionHandle(self, extension_handle)
705
 
    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
706
 
      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
707
 
 
708
 
    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
709
 
      value = self._fields.get(extension_handle)
710
 
      return value is not None and value._is_present_in_parent
711
 
    else:
712
 
      return extension_handle in self._fields
713
 
  cls.HasExtension = HasExtension
714
 
 
715
 
 
716
 
def _AddEqualsMethod(message_descriptor, cls):
717
 
  """Helper for _AddMessageMethods()."""
718
 
  def __eq__(self, other):
719
 
    if (not isinstance(other, message_mod.Message) or
720
 
        other.DESCRIPTOR != self.DESCRIPTOR):
721
 
      return False
722
 
 
723
 
    if self is other:
724
 
      return True
725
 
 
726
 
    return self.ListFields() == other.ListFields()
727
 
 
728
 
  cls.__eq__ = __eq__
729
 
 
730
 
 
731
 
def _AddStrMethod(message_descriptor, cls):
732
 
  """Helper for _AddMessageMethods()."""
733
 
  def __str__(self):
734
 
    return text_format.MessageToString(self)
735
 
  cls.__str__ = __str__
736
 
 
737
 
 
738
 
def _AddSetListenerMethod(cls):
739
 
  """Helper for _AddMessageMethods()."""
740
 
  def SetListener(self, listener):
741
 
    if listener is None:
742
 
      self._listener = message_listener_mod.NullMessageListener()
743
 
    else:
744
 
      self._listener = listener
745
 
  cls._SetListener = SetListener
746
 
 
747
 
 
748
 
def _BytesForNonRepeatedElement(value, field_number, field_type):
749
 
  """Returns the number of bytes needed to serialize a non-repeated element.
750
 
  The returned byte count includes space for tag information and any
751
 
  other additional space associated with serializing value.
752
 
 
753
 
  Args:
754
 
    value: Value we're serializing.
755
 
    field_number: Field number of this value.  (Since the field number
756
 
      is stored as part of a varint-encoded tag, this has an impact
757
 
      on the total bytes required to serialize the value).
758
 
    field_type: The type of the field.  One of the TYPE_* constants
759
 
      within FieldDescriptor.
760
 
  """
761
 
  try:
762
 
    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
763
 
    return fn(field_number, value)
764
 
  except KeyError:
765
 
    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
766
 
 
767
 
 
768
 
def _AddByteSizeMethod(message_descriptor, cls):
769
 
  """Helper for _AddMessageMethods()."""
770
 
 
771
 
  def ByteSize(self):
772
 
    if not self._cached_byte_size_dirty:
773
 
      return self._cached_byte_size
774
 
 
775
 
    size = 0
776
 
    for field_descriptor, field_value in self.ListFields():
777
 
      size += field_descriptor._sizer(field_value)
778
 
 
779
 
    self._cached_byte_size = size
780
 
    self._cached_byte_size_dirty = False
781
 
    self._listener_for_children.dirty = False
782
 
    return size
783
 
 
784
 
  cls.ByteSize = ByteSize
785
 
 
786
 
 
787
 
def _AddSerializeToStringMethod(message_descriptor, cls):
788
 
  """Helper for _AddMessageMethods()."""
789
 
 
790
 
  def SerializeToString(self):
791
 
    # Check if the message has all of its required fields set.
792
 
    errors = []
793
 
    if not self.IsInitialized():
794
 
      raise message_mod.EncodeError(
795
 
          'Message is missing required fields: ' +
796
 
          ','.join(self.FindInitializationErrors()))
797
 
    return self.SerializePartialToString()
798
 
  cls.SerializeToString = SerializeToString
799
 
 
800
 
 
801
 
def _AddSerializePartialToStringMethod(message_descriptor, cls):
802
 
  """Helper for _AddMessageMethods()."""
803
 
 
804
 
  def SerializePartialToString(self):
805
 
    out = StringIO()
806
 
    self._InternalSerialize(out.write)
807
 
    return out.getvalue()
808
 
  cls.SerializePartialToString = SerializePartialToString
809
 
 
810
 
  def InternalSerialize(self, write_bytes):
811
 
    for field_descriptor, field_value in self.ListFields():
812
 
      field_descriptor._encoder(write_bytes, field_value)
813
 
  cls._InternalSerialize = InternalSerialize
814
 
 
815
 
 
816
 
def _AddMergeFromStringMethod(message_descriptor, cls):
817
 
  """Helper for _AddMessageMethods()."""
818
 
  def MergeFromString(self, serialized):
819
 
    length = len(serialized)
820
 
    try:
821
 
      if self._InternalParse(serialized, 0, length) != length:
822
 
        # The only reason _InternalParse would return early is if it
823
 
        # encountered an end-group tag.
824
 
        raise message_mod.DecodeError('Unexpected end-group tag.')
825
 
    except IndexError:
826
 
      raise message_mod.DecodeError('Truncated message.')
827
 
    except struct.error, e:
828
 
      raise message_mod.DecodeError(e)
829
 
    return length   # Return this for legacy reasons.
830
 
  cls.MergeFromString = MergeFromString
831
 
 
832
 
  local_ReadTag = decoder.ReadTag
833
 
  local_SkipField = decoder.SkipField
834
 
  decoders_by_tag = cls._decoders_by_tag
835
 
 
836
 
  def InternalParse(self, buffer, pos, end):
837
 
    self._Modified()
838
 
    field_dict = self._fields
839
 
    while pos != end:
840
 
      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
841
 
      field_decoder = decoders_by_tag.get(tag_bytes)
842
 
      if field_decoder is None:
843
 
        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
844
 
        if new_pos == -1:
845
 
          return pos
846
 
        pos = new_pos
847
 
      else:
848
 
        pos = field_decoder(buffer, new_pos, end, self, field_dict)
849
 
    return pos
850
 
  cls._InternalParse = InternalParse
851
 
 
852
 
 
853
 
def _AddIsInitializedMethod(message_descriptor, cls):
854
 
  """Adds the IsInitialized and FindInitializationError methods to the
855
 
  protocol message class."""
856
 
 
857
 
  required_fields = [field for field in message_descriptor.fields
858
 
                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
859
 
 
860
 
  def IsInitialized(self, errors=None):
861
 
    """Checks if all required fields of a message are set.
862
 
 
863
 
    Args:
864
 
      errors:  A list which, if provided, will be populated with the field
865
 
               paths of all missing required fields.
866
 
 
867
 
    Returns:
868
 
      True iff the specified message has all required fields set.
869
 
    """
870
 
 
871
 
    # Performance is critical so we avoid HasField() and ListFields().
872
 
 
873
 
    for field in required_fields:
874
 
      if (field not in self._fields or
875
 
          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
876
 
           not self._fields[field]._is_present_in_parent)):
877
 
        if errors is not None:
878
 
          errors.extend(self.FindInitializationErrors())
879
 
        return False
880
 
 
881
 
    for field, value in self._fields.iteritems():
882
 
      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
883
 
        if field.label == _FieldDescriptor.LABEL_REPEATED:
884
 
          for element in value:
885
 
            if not element.IsInitialized():
886
 
              if errors is not None:
887
 
                errors.extend(self.FindInitializationErrors())
888
 
              return False
889
 
        elif value._is_present_in_parent and not value.IsInitialized():
890
 
          if errors is not None:
891
 
            errors.extend(self.FindInitializationErrors())
892
 
          return False
893
 
 
894
 
    return True
895
 
 
896
 
  cls.IsInitialized = IsInitialized
897
 
 
898
 
  def FindInitializationErrors(self):
899
 
    """Finds required fields which are not initialized.
900
 
 
901
 
    Returns:
902
 
      A list of strings.  Each string is a path to an uninitialized field from
903
 
      the top-level message, e.g. "foo.bar[5].baz".
904
 
    """
905
 
 
906
 
    errors = []  # simplify things
907
 
 
908
 
    for field in required_fields:
909
 
      if not self.HasField(field.name):
910
 
        errors.append(field.name)
911
 
 
912
 
    for field, value in self.ListFields():
913
 
      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
914
 
        if field.is_extension:
915
 
          name = "(%s)" % field.full_name
916
 
        else:
917
 
          name = field.name
918
 
 
919
 
        if field.label == _FieldDescriptor.LABEL_REPEATED:
920
 
          for i in xrange(len(value)):
921
 
            element = value[i]
922
 
            prefix = "%s[%d]." % (name, i)
923
 
            sub_errors = element.FindInitializationErrors()
924
 
            errors += [ prefix + error for error in sub_errors ]
925
 
        else:
926
 
          prefix = name + "."
927
 
          sub_errors = value.FindInitializationErrors()
928
 
          errors += [ prefix + error for error in sub_errors ]
929
 
 
930
 
    return errors
931
 
 
932
 
  cls.FindInitializationErrors = FindInitializationErrors
933
 
 
934
 
 
935
 
def _AddMergeFromMethod(cls):
936
 
  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
937
 
  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
938
 
 
939
 
  def MergeFrom(self, msg):
940
 
    assert msg is not self
941
 
    self._Modified()
942
 
 
943
 
    fields = self._fields
944
 
 
945
 
    for field, value in msg._fields.iteritems():
946
 
      if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
947
 
        field_value = fields.get(field)
948
 
        if field_value is None:
949
 
          # Construct a new object to represent this field.
950
 
          field_value = field._default_constructor(self)
951
 
          fields[field] = field_value
952
 
        field_value.MergeFrom(value)
953
 
      else:
954
 
        self._fields[field] = value
955
 
  cls.MergeFrom = MergeFrom
956
 
 
957
 
 
958
 
def _AddMessageMethods(message_descriptor, cls):
959
 
  """Adds implementations of all Message methods to cls."""
960
 
  _AddListFieldsMethod(message_descriptor, cls)
961
 
  _AddHasFieldMethod(message_descriptor, cls)
962
 
  _AddClearFieldMethod(message_descriptor, cls)
963
 
  if message_descriptor.is_extendable:
964
 
    _AddClearExtensionMethod(cls)
965
 
    _AddHasExtensionMethod(cls)
966
 
  _AddClearMethod(message_descriptor, cls)
967
 
  _AddEqualsMethod(message_descriptor, cls)
968
 
  _AddStrMethod(message_descriptor, cls)
969
 
  _AddSetListenerMethod(cls)
970
 
  _AddByteSizeMethod(message_descriptor, cls)
971
 
  _AddSerializeToStringMethod(message_descriptor, cls)
972
 
  _AddSerializePartialToStringMethod(message_descriptor, cls)
973
 
  _AddMergeFromStringMethod(message_descriptor, cls)
974
 
  _AddIsInitializedMethod(message_descriptor, cls)
975
 
  _AddMergeFromMethod(cls)
976
 
 
977
 
 
978
 
def _AddPrivateHelperMethods(cls):
979
 
  """Adds implementation of private helper methods to cls."""
980
 
 
981
 
  def Modified(self):
982
 
    """Sets the _cached_byte_size_dirty bit to true,
983
 
    and propagates this to our listener iff this was a state change.
984
 
    """
985
 
 
986
 
    # Note:  Some callers check _cached_byte_size_dirty before calling
987
 
    #   _Modified() as an extra optimization.  So, if this method is ever
988
 
    #   changed such that it does stuff even when _cached_byte_size_dirty is
989
 
    #   already true, the callers need to be updated.
990
 
    if not self._cached_byte_size_dirty:
991
 
      self._cached_byte_size_dirty = True
992
 
      self._listener_for_children.dirty = True
993
 
      self._is_present_in_parent = True
994
 
      self._listener.Modified()
995
 
 
996
 
  cls._Modified = Modified
997
 
  cls.SetInParent = Modified
998
 
 
999
 
 
1000
 
class _Listener(object):
1001
 
 
1002
 
  """MessageListener implementation that a parent message registers with its
1003
 
  child message.
1004
 
 
1005
 
  In order to support semantics like:
1006
 
 
1007
 
    foo.bar.baz.qux = 23
1008
 
    assert foo.HasField('bar')
1009
 
 
1010
 
  ...child objects must have back references to their parents.
1011
 
  This helper class is at the heart of this support.
1012
 
  """
1013
 
 
1014
 
  def __init__(self, parent_message):
1015
 
    """Args:
1016
 
      parent_message: The message whose _Modified() method we should call when
1017
 
        we receive Modified() messages.
1018
 
    """
1019
 
    # This listener establishes a back reference from a child (contained) object
1020
 
    # to its parent (containing) object.  We make this a weak reference to avoid
1021
 
    # creating cyclic garbage when the client finishes with the 'parent' object
1022
 
    # in the tree.
1023
 
    if isinstance(parent_message, weakref.ProxyType):
1024
 
      self._parent_message_weakref = parent_message
1025
 
    else:
1026
 
      self._parent_message_weakref = weakref.proxy(parent_message)
1027
 
 
1028
 
    # As an optimization, we also indicate directly on the listener whether
1029
 
    # or not the parent message is dirty.  This way we can avoid traversing
1030
 
    # up the tree in the common case.
1031
 
    self.dirty = False
1032
 
 
1033
 
  def Modified(self):
1034
 
    if self.dirty:
1035
 
      return
1036
 
    try:
1037
 
      # Propagate the signal to our parents iff this is the first field set.
1038
 
      self._parent_message_weakref._Modified()
1039
 
    except ReferenceError:
1040
 
      # We can get here if a client has kept a reference to a child object,
1041
 
      # and is now setting a field on it, but the child's parent has been
1042
 
      # garbage-collected.  This is not an error.
1043
 
      pass
1044
 
 
1045
 
 
1046
 
# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1047
 
# TODO(robinson): Unify error handling of "unknown extension" crap.
1048
 
# TODO(robinson): Support iteritems()-style iteration over all
1049
 
# extensions with the "has" bits turned on?
1050
 
class _ExtensionDict(object):
1051
 
 
1052
 
  """Dict-like container for supporting an indexable "Extensions"
1053
 
  field on proto instances.
1054
 
 
1055
 
  Note that in all cases we expect extension handles to be
1056
 
  FieldDescriptors.
1057
 
  """
1058
 
 
1059
 
  def __init__(self, extended_message):
1060
 
    """extended_message: Message instance for which we are the Extensions dict.
1061
 
    """
1062
 
 
1063
 
    self._extended_message = extended_message
1064
 
 
1065
 
  def __getitem__(self, extension_handle):
1066
 
    """Returns the current value of the given extension handle."""
1067
 
 
1068
 
    _VerifyExtensionHandle(self._extended_message, extension_handle)
1069
 
 
1070
 
    result = self._extended_message._fields.get(extension_handle)
1071
 
    if result is not None:
1072
 
      return result
1073
 
 
1074
 
    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1075
 
      result = extension_handle._default_constructor(self._extended_message)
1076
 
    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1077
 
      result = extension_handle.message_type._concrete_class()
1078
 
      try:
1079
 
        result._SetListener(self._extended_message._listener_for_children)
1080
 
      except ReferenceError:
1081
 
        pass
1082
 
    else:
1083
 
      # Singular scalar -- just return the default without inserting into the
1084
 
      # dict.
1085
 
      return extension_handle.default_value
1086
 
 
1087
 
    # Atomically check if another thread has preempted us and, if not, swap
1088
 
    # in the new object we just created.  If someone has preempted us, we
1089
 
    # take that object and discard ours.
1090
 
    # WARNING:  We are relying on setdefault() being atomic.  This is true
1091
 
    #   in CPython but we haven't investigated others.  This warning appears
1092
 
    #   in several other locations in this file.
1093
 
    result = self._extended_message._fields.setdefault(
1094
 
        extension_handle, result)
1095
 
 
1096
 
    return result
1097
 
 
1098
 
  def __eq__(self, other):
1099
 
    if not isinstance(other, self.__class__):
1100
 
      return False
1101
 
 
1102
 
    my_fields = self._extended_message.ListFields()
1103
 
    other_fields = other._extended_message.ListFields()
1104
 
 
1105
 
    # Get rid of non-extension fields.
1106
 
    my_fields    = [ field for field in my_fields    if field.is_extension ]
1107
 
    other_fields = [ field for field in other_fields if field.is_extension ]
1108
 
 
1109
 
    return my_fields == other_fields
1110
 
 
1111
 
  def __ne__(self, other):
1112
 
    return not self == other
1113
 
 
1114
 
  # Note that this is only meaningful for non-repeated, scalar extension
1115
 
  # fields.  Note also that we may have to call _Modified() when we do
1116
 
  # successfully set a field this way, to set any necssary "has" bits in the
1117
 
  # ancestors of the extended message.
1118
 
  def __setitem__(self, extension_handle, value):
1119
 
    """If extension_handle specifies a non-repeated, scalar extension
1120
 
    field, sets the value of that field.
1121
 
    """
1122
 
 
1123
 
    _VerifyExtensionHandle(self._extended_message, extension_handle)
1124
 
 
1125
 
    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1126
 
        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1127
 
      raise TypeError(
1128
 
          'Cannot assign to extension "%s" because it is a repeated or '
1129
 
          'composite type.' % extension_handle.full_name)
1130
 
 
1131
 
    # It's slightly wasteful to lookup the type checker each time,
1132
 
    # but we expect this to be a vanishingly uncommon case anyway.
1133
 
    type_checker = type_checkers.GetTypeChecker(
1134
 
        extension_handle.cpp_type, extension_handle.type)
1135
 
    type_checker.CheckValue(value)
1136
 
    self._extended_message._fields[extension_handle] = value
1137
 
    self._extended_message._Modified()
1138
 
 
1139
 
  def _FindExtensionByName(self, name):
1140
 
    """Tries to find a known extension with the specified name.
1141
 
 
1142
 
    Args:
1143
 
      name: Extension full name.
1144
 
 
1145
 
    Returns:
1146
 
      Extension field descriptor.
1147
 
    """
1148
 
    return self._extended_message._extensions_by_name.get(name, None)