566
568
proto = unittest_pb2.TestAllTypes()
567
569
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
571
def testClearRemovesChildren(self):
572
# Make sure there aren't any implementation bugs that are only partially
573
# clearing the message (which can happen in the more complex C++
574
# implementation which has parallel message lists).
575
proto = unittest_pb2.TestRequiredForeign()
577
proto.repeated_message.add()
578
proto2 = unittest_pb2.TestRequiredForeign()
579
proto.CopyFrom(proto2)
580
self.assertRaises(IndexError, lambda: proto.repeated_message[5])
569
582
def testDisallowedAssignments(self):
570
583
# It's illegal to assign values directly to repeated fields
571
584
# or to nonrepeated composite fields. Ensure that this fails.
594
607
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
595
608
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
610
def testIntegerTypes(self):
611
def TestGetAndDeserialize(field_name, value, expected_type):
612
proto = unittest_pb2.TestAllTypes()
613
setattr(proto, field_name, value)
614
self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
615
proto2 = unittest_pb2.TestAllTypes()
616
proto2.ParseFromString(proto.SerializeToString())
617
self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
619
TestGetAndDeserialize('optional_int32', 1, int)
620
TestGetAndDeserialize('optional_int32', 1 << 30, int)
621
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
622
if struct.calcsize('L') == 4:
623
# Python only has signed ints, so 32-bit python can't fit an uint32
625
TestGetAndDeserialize('optional_uint32', 1 << 31, long)
627
# 64-bit python can fit uint32 inside an int
628
TestGetAndDeserialize('optional_uint32', 1 << 31, int)
629
TestGetAndDeserialize('optional_int64', 1 << 30, long)
630
TestGetAndDeserialize('optional_int64', 1 << 60, long)
631
TestGetAndDeserialize('optional_uint64', 1 << 30, long)
632
TestGetAndDeserialize('optional_uint64', 1 << 60, long)
597
634
def testSingleScalarBoundsChecking(self):
598
635
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
599
636
pb = unittest_pb2.TestAllTypes()
613
650
pb.optional_nested_enum = 1
614
651
self.assertEqual(1, pb.optional_nested_enum)
616
# Invalid enum values.
617
pb.optional_nested_enum = 0
618
self.assertEqual(0, pb.optional_nested_enum)
620
bytes_size_before = pb.ByteSize()
622
pb.optional_nested_enum = 4
623
self.assertEqual(4, pb.optional_nested_enum)
625
pb.optional_nested_enum = 0
626
self.assertEqual(0, pb.optional_nested_enum)
628
# Make sure that setting the same enum field doesn't just add unknown
629
# fields (but overwrites them).
630
self.assertEqual(bytes_size_before, pb.ByteSize())
632
# Is the invalid value preserved after serialization?
633
serialized = pb.SerializeToString()
634
pb2 = unittest_pb2.TestAllTypes()
635
pb2.ParseFromString(serialized)
636
self.assertEqual(0, pb2.optional_nested_enum)
637
self.assertEqual(pb, pb2)
639
653
def testRepeatedScalarTypeSafety(self):
640
654
proto = unittest_pb2.TestAllTypes()
641
655
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
749
763
unittest_pb2.ForeignEnum.items())
751
765
proto = unittest_pb2.TestAllTypes()
752
self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys())
753
self.assertEqual([1, 2, 3], proto.NestedEnum.values())
754
self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)],
766
self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
767
self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
768
self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
755
769
proto.NestedEnum.items())
757
771
def testRepeatedScalars(self):
1155
1169
self.assertTrue(required is not extendee_proto.Extensions[extension])
1156
1170
self.assertTrue(not extendee_proto.HasExtension(extension))
1172
def testRegisteredExtensions(self):
1173
self.assertTrue('protobuf_unittest.optional_int32_extension' in
1174
unittest_pb2.TestAllExtensions._extensions_by_name)
1175
self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
1176
# Make sure extensions haven't been registered into types that shouldn't
1178
self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
1158
1180
# If message A directly contains message B, and
1159
1181
# a.HasField('b') is currently False, then mutating any
1160
1182
# extension in B should change a.HasField('b') to True
1451
1473
proto2 = unittest_pb2.TestAllExtensions()
1452
1474
self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1476
def testDeepCopy(self):
1477
proto1 = unittest_pb2.TestAllTypes()
1478
proto1.optional_int32 = 1
1479
proto2 = copy.deepcopy(proto1)
1480
self.assertEqual(1, proto2.optional_int32)
1482
proto1.repeated_int32.append(2)
1483
proto1.repeated_int32.append(3)
1484
container = copy.deepcopy(proto1.repeated_int32)
1485
self.assertEqual([2, 3], container)
1487
# TODO(anuraag): Implement deepcopy for repeated composite / extension dict
1454
1489
def testClear(self):
1455
1490
proto = unittest_pb2.TestAllTypes()
1456
1491
# C++ implementation does not support lazy fields right now so leave it
1496
1531
self.assertEqual(6, foreign.c)
1499
self.assertTrue(not proto.HasField('optional_nested_message'))
1534
self.assertFalse(proto.HasField('optional_nested_message'))
1500
1535
self.assertEqual(0, proto.optional_nested_message.bb)
1501
self.assertTrue(not proto.HasField('optional_foreign_message'))
1536
self.assertFalse(proto.HasField('optional_foreign_message'))
1502
1537
self.assertEqual(0, proto.optional_foreign_message.c)
1539
def testOneOf(self):
1540
proto = unittest_pb2.TestAllTypes()
1541
proto.oneof_uint32 = 10
1542
proto.oneof_nested_message.bb = 11
1543
self.assertEqual(11, proto.oneof_nested_message.bb)
1544
self.assertFalse(proto.HasField('oneof_uint32'))
1545
nested = proto.oneof_nested_message
1546
proto.oneof_string = 'abc'
1547
self.assertEqual('abc', proto.oneof_string)
1548
self.assertEqual(11, nested.bb)
1549
self.assertFalse(proto.HasField('oneof_nested_message'))
1504
1551
def assertInitialized(self, proto):
1505
1552
self.assertTrue(proto.IsInitialized())
1506
1553
# Neither method should raise an exception.
1571
1618
self.assertFalse(proto.IsInitialized(errors))
1572
1619
self.assertEqual(errors, ['a', 'b', 'c'])
1621
@basetest.unittest.skipIf(
1622
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1623
'Errors are only available from the most recent C++ implementation.')
1624
def testFileDescriptorErrors(self):
1625
file_name = 'test_file_descriptor_errors.proto'
1626
package_name = 'test_file_descriptor_errors.proto'
1627
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1628
file_descriptor_proto.name = file_name
1629
file_descriptor_proto.package = package_name
1630
m1 = file_descriptor_proto.message_type.add()
1632
# Compiles the proto into the C++ descriptor pool
1633
descriptor.FileDescriptor(
1636
serialized_pb=file_descriptor_proto.SerializeToString())
1637
# Add a FileDescriptorProto that has duplicate symbols
1638
another_file_name = 'another_test_file_descriptor_errors.proto'
1639
file_descriptor_proto.name = another_file_name
1640
m2 = file_descriptor_proto.message_type.add()
1642
with self.assertRaises(TypeError) as cm:
1643
descriptor.FileDescriptor(
1646
serialized_pb=file_descriptor_proto.SerializeToString())
1647
self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1648
getattr(cm.expected, '__name__', cm.expected))
1649
self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1650
# Error message will say something about this definition being a
1651
# duplicate, though we don't check the message exactly to avoid a
1652
# dependency on the C++ logging code.
1653
self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
1574
1655
def testStringUTF8Encoding(self):
1575
1656
proto = unittest_pb2.TestAllTypes()
1588
1669
proto.optional_string = str('Testing')
1589
1670
self.assertEqual(proto.optional_string, unicode('Testing'))
1591
if api_implementation.Type() == 'python':
1592
# Values of type 'str' are also accepted as long as they can be
1594
self.assertEqual(type(proto.optional_string), str)
1596
1672
# Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
1597
1673
self.assertRaises(ValueError,
1598
setattr, proto, 'optional_string', str('a\x80a'))
1599
# Assign a 'str' object which contains a UTF-8 encoded string.
1600
self.assertRaises(ValueError,
1601
setattr, proto, 'optional_string', 'Тест')
1674
setattr, proto, 'optional_string', b'a\x80a')
1675
if str is bytes: # PY2
1676
# Assign a 'str' object which contains a UTF-8 encoded string.
1677
self.assertRaises(ValueError,
1678
setattr, proto, 'optional_string', 'Тест')
1680
proto.optional_string = 'Тест'
1602
1681
# No exception thrown.
1603
1682
proto.optional_string = 'abc'
1643
1724
# MergeFromString and thus has no way to throw the exception.
1645
1726
# The pure Python API always returns objects of type 'unicode' (UTF-8
1646
# encoded), or 'str' (in 7 bit ASCII).
1647
bytes = raw.item[0].message.replace(
1648
test_utf8_bytes, len(test_utf8_bytes) * '\xff')
1727
# encoded), or 'bytes' (in 7 bit ASCII).
1728
badbytes = raw.item[0].message.replace(
1729
test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
1650
1731
unicode_decode_failed = False
1652
message2.MergeFromString(bytes)
1653
except UnicodeDecodeError as e:
1733
message2.MergeFromString(badbytes)
1734
except UnicodeDecodeError:
1654
1735
unicode_decode_failed = True
1655
1736
string_field = message2.str
1656
self.assertTrue(unicode_decode_failed or type(string_field) == str)
1737
self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
1739
def testBytesInTextFormat(self):
1740
proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
1741
self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
1658
1744
def testEmptyNestedMessage(self):
1659
1745
proto = unittest_pb2.TestAllTypes()
1667
1753
self.assertTrue(proto.HasField('optional_nested_message'))
1669
1755
proto = unittest_pb2.TestAllTypes()
1670
proto.optional_nested_message.MergeFromString('')
1756
bytes_read = proto.optional_nested_message.MergeFromString(b'')
1757
self.assertEqual(0, bytes_read)
1671
1758
self.assertTrue(proto.HasField('optional_nested_message'))
1673
1760
proto = unittest_pb2.TestAllTypes()
1674
proto.optional_nested_message.ParseFromString('')
1761
proto.optional_nested_message.ParseFromString(b'')
1675
1762
self.assertTrue(proto.HasField('optional_nested_message'))
1677
1764
serialized = proto.SerializeToString()
1678
1765
proto2 = unittest_pb2.TestAllTypes()
1679
proto2.MergeFromString(serialized)
1768
proto2.MergeFromString(serialized))
1680
1769
self.assertTrue(proto2.HasField('optional_nested_message'))
1682
1771
def testSetInParent(self):
2133
2222
# * Handling of empty submessages (with and without "has"
2136
class SerializationTest(unittest.TestCase):
2225
class SerializationTest(basetest.TestCase):
2138
2227
def testSerializeEmtpyMessage(self):
2139
2228
first_proto = unittest_pb2.TestAllTypes()
2140
2229
second_proto = unittest_pb2.TestAllTypes()
2141
2230
serialized = first_proto.SerializeToString()
2142
2231
self.assertEqual(first_proto.ByteSize(), len(serialized))
2143
second_proto.MergeFromString(serialized)
2234
second_proto.MergeFromString(serialized))
2144
2235
self.assertEqual(first_proto, second_proto)
2146
2237
def testSerializeAllFields(self):
2157
2250
second_proto = unittest_pb2.TestAllExtensions()
2158
2251
test_util.SetAllExtensions(first_proto)
2159
2252
serialized = first_proto.SerializeToString()
2160
second_proto.MergeFromString(serialized)
2255
second_proto.MergeFromString(serialized))
2256
self.assertEqual(first_proto, second_proto)
2258
def testSerializeWithOptionalGroup(self):
2259
first_proto = unittest_pb2.TestAllTypes()
2260
second_proto = unittest_pb2.TestAllTypes()
2261
first_proto.optionalgroup.a = 242
2262
serialized = first_proto.SerializeToString()
2265
second_proto.MergeFromString(serialized))
2161
2266
self.assertEqual(first_proto, second_proto)
2163
2268
def testSerializeNegativeValues(self):
2274
2381
raw = unittest_mset_pb2.RawMessageSet()
2275
2382
self.assertEqual(False,
2276
2383
raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2277
raw.MergeFromString(serialized)
2386
raw.MergeFromString(serialized))
2278
2387
self.assertEqual(2, len(raw.item))
2280
2389
message1 = unittest_mset_pb2.TestMessageSetExtension1()
2281
message1.MergeFromString(raw.item[0].message)
2391
len(raw.item[0].message),
2392
message1.MergeFromString(raw.item[0].message))
2282
2393
self.assertEqual(123, message1.i)
2284
2395
message2 = unittest_mset_pb2.TestMessageSetExtension2()
2285
message2.MergeFromString(raw.item[1].message)
2397
len(raw.item[1].message),
2398
message2.MergeFromString(raw.item[1].message))
2286
2399
self.assertEqual('foo', message2.str)
2288
2401
# Deserialize using the MessageSet wire format.
2289
2402
proto2 = unittest_mset_pb2.TestMessageSet()
2290
proto2.MergeFromString(serialized)
2405
proto2.MergeFromString(serialized))
2291
2406
self.assertEqual(123, proto2.Extensions[extension1].i)
2292
2407
self.assertEqual('foo', proto2.Extensions[extension2].str)
2406
2527
partial = proto.SerializePartialToString()
2408
2529
proto2 = unittest_pb2.TestRequired()
2409
proto2.MergeFromString(serialized)
2532
proto2.MergeFromString(serialized))
2410
2533
self.assertEqual(1, proto2.a)
2411
2534
self.assertEqual(2, proto2.b)
2412
2535
self.assertEqual(3, proto2.c)
2413
proto2.ParseFromString(partial)
2538
proto2.MergeFromString(partial))
2414
2539
self.assertEqual(1, proto2.a)
2415
2540
self.assertEqual(2, proto2.b)
2416
2541
self.assertEqual(3, proto2.c)
2803
class ClassAPITest(basetest.TestCase):
2805
def testMakeClassWithNestedDescriptor(self):
2806
leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
2807
containing_type=None, fields=[],
2808
nested_types=[], enum_types=[],
2810
child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
2811
containing_type=None, fields=[],
2812
nested_types=[leaf_desc], enum_types=[],
2814
sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
2815
'', containing_type=None, fields=[],
2816
nested_types=[], enum_types=[],
2818
parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
2819
containing_type=None, fields=[],
2820
nested_types=[child_desc, sibling_desc],
2821
enum_types=[], extensions=[])
2822
message_class = reflection.MakeClass(parent_desc)
2823
self.assertIn('child', message_class.__dict__)
2824
self.assertIn('sibling', message_class.__dict__)
2825
self.assertIn('leaf', message_class.child.__dict__)
2827
def _GetSerializedFileDescriptor(self, name):
2828
"""Get a serialized representation of a test FileDescriptorProto.
2831
name: All calls to this must use a unique message name, to avoid
2832
collisions in the cpp descriptor pool.
2834
A string containing the serialized form of a test FileDescriptorProto.
2836
file_descriptor_str = (
2838
' name: "' + name + '"'
2842
' label: LABEL_REPEATED'
2843
' type: TYPE_UINT32'
2848
' label: LABEL_OPTIONAL'
2849
' type: TYPE_MESSAGE'
2857
' label: LABEL_OPTIONAL'
2858
' type: TYPE_MESSAGE'
2864
' name: "deep_enum"'
2873
' label: LABEL_OPTIONAL'
2874
' type: TYPE_UINT32'
2879
file_descriptor = descriptor_pb2.FileDescriptorProto()
2880
text_format.Merge(file_descriptor_str, file_descriptor)
2881
return file_descriptor.SerializeToString()
2883
def testParsingFlatClassWithExplicitClassDeclaration(self):
2884
"""Test that the generated class can parse a flat message."""
2885
file_descriptor = descriptor_pb2.FileDescriptorProto()
2886
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
2887
msg_descriptor = descriptor.MakeDescriptor(
2888
file_descriptor.message_type[0])
2890
class MessageClass(message.Message):
2891
__metaclass__ = reflection.GeneratedProtocolMessageType
2892
DESCRIPTOR = msg_descriptor
2893
msg = MessageClass()
2898
text_format.Merge(msg_str, msg)
2899
self.assertEqual(msg.flat, [0, 1, 2])
2901
def testParsingFlatClass(self):
2902
"""Test that the generated class can parse a flat message."""
2903
file_descriptor = descriptor_pb2.FileDescriptorProto()
2904
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
2905
msg_descriptor = descriptor.MakeDescriptor(
2906
file_descriptor.message_type[0])
2907
msg_class = reflection.MakeClass(msg_descriptor)
2913
text_format.Merge(msg_str, msg)
2914
self.assertEqual(msg.flat, [0, 1, 2])
2916
def testParsingNestedClass(self):
2917
"""Test that the generated class can parse a nested message."""
2918
file_descriptor = descriptor_pb2.FileDescriptorProto()
2919
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
2920
msg_descriptor = descriptor.MakeDescriptor(
2921
file_descriptor.message_type[0])
2922
msg_class = reflection.MakeClass(msg_descriptor)
2930
text_format.Merge(msg_str, msg)
2931
self.assertEqual(msg.bar.baz.deep, 4)
2670
2933
if __name__ == '__main__':