3
# Drizzle Client & Protocol Library
5
# Copyright (C) 2008 Eric Day (eday@oddments.org)
8
# Use and distribution licensed under the BSD license. See
9
# the COPYING file in this directory for full text.
12
MySQL Protocol Objects
14
Objects in this module can be initialized by passing in either a raw
15
buffer to parse or keywords. This allows you to use these objects
16
for either sending or receiving.
23
class Capabilities(bitfield.BitField):
55
'SSL_VERIFY_SERVER_CERT',
59
class Status(bitfield.BitField):
63
'MORE_RESULTS_EXISTS',
64
'QUERY_NO_GOOD_INDEX_USED',
65
'QUERY_NO_INDEX_USED',
69
'NO_BACKSLASH_ESCAPES',
73
class PacketException(Exception):
77
'''This class represents a packet header.'''
79
def __init__(self, packed=None, size=0, sequence=0):
82
self.sequence = sequence
84
data = struct.unpack('4B', packed)
85
self.size = data[0] | (data[1] << 8) | (data[2] << 16)
86
self.sequence = data[3]
92
return struct.pack('4B',
94
(self.size >> 8) & 0xFF,
95
(self.size >> 16) & 0xFF,
99
if self.size >= 16777216:
100
raise PacketException('Packet size cannot exceed 16777215 bytes (%d)' %
107
''' % (self.size, self.sequence)
109
class TestPacket(unittest.TestCase):
111
def testDefaultInit(self):
113
self.assertEqual(packet.size, 0)
114
self.assertEqual(packet.sequence, 0)
116
def testKeywordInit(self):
117
packet = Packet(size=1234, sequence=5)
118
self.assertEqual(packet.size, 1234)
119
self.assertEqual(packet.sequence, 5)
121
def testUnpackInit(self):
122
packet = Packet(struct.pack('4B', 210, 4, 0, 5))
123
self.assertEqual(packet.size, 1234)
124
self.assertEqual(packet.sequence, 5)
127
packet = Packet(Packet(size=1234, sequence=5).pack())
128
self.assertEqual(packet.size, 1234)
130
def testPackRange(self):
131
for x in range(0, 300):
132
packet = Packet(Packet(size=x, sequence=x).pack())
133
self.assertEqual(packet.size, x)
134
self.assertEqual(packet.sequence, x % 256)
136
# 997 is a random prime number so we hit various increments
137
for x in range(300, 16777216, 997):
138
packet = Packet(Packet(size=x, sequence=x).pack())
139
self.assertEqual(packet.size, x)
140
self.assertEqual(packet.sequence, x % 256)
142
packet = Packet(Packet(size=16777215).pack())
143
self.assertEqual(packet.size, 16777215)
144
self.assertEqual(packet.sequence, 0)
146
self.assertRaises(PacketException, Packet, size=16777216)
147
self.assertRaises(PacketException, Packet, size=16777217)
148
self.assertRaises(PacketException, Packet, size=4294967295)
149
self.assertRaises(PacketException, Packet, size=4294967296)
150
self.assertRaises(PacketException, Packet, size=4294967297)
152
class ServerHandshake:
153
'''This class represents the initial handshake sent from server to client.'''
155
def __init__(self, packed=None, protocol_version=10, server_version='',
156
thread_id=0, scramble=tuple([0] * 20), null1=0, capabilities=0,
157
charset=0, status=0, unused=tuple([0] * 13), null2=0):
159
self.protocol_version = protocol_version
160
self.server_version = server_version
161
self.thread_id = thread_id
162
self.scramble = scramble
164
self.capabilities = Capabilities(capabilities)
165
self.charset = charset
166
self.status = Status(status)
170
self.protocol_version = struct.unpack('B', packed[:1])[0]
171
server_version_length = packed[1:].index('\x00')
172
self.server_version = packed[1:1+server_version_length]
173
data = struct.unpack('<I8BB2BB2B13B12BB', packed[2+server_version_length:])
174
self.thread_id = data[0]
175
self.scramble = data[1:9] + data[28:40]
177
self.capabilities = Capabilities(data[10] | (data[11] << 8))
178
self.charset = data[12]
179
self.status = Status(data[13] | (data[14] << 8))
180
self.unused = data[15:28]
181
self.null2 = data[40]
184
data = struct.pack('B', self.protocol_version)
185
data += self.server_version + '\x00'
186
data += struct.pack('<I', self.thread_id)
187
data += ''.join(map(chr, self.scramble[:8]))
188
data += struct.pack('B2BB2B',
190
self.capabilities.value() & 0xFF,
191
(self.capabilities.value() >> 8) & 0xFF,
193
self.status.value() & 0xFF,
194
(self.status.value() >> 8) & 0xFF)
195
data += ''.join(map(chr, self.unused))
196
data += ''.join(map(chr, self.scramble[8:]))
197
data += struct.pack('B', self.null2)
201
return '''ServerHandshake
202
protocol_version = %s
212
''' % (self.protocol_version, self.server_version, self.thread_id,
213
self.scramble, self.null1, self.capabilities, self.charset,
214
self.status, self.unused, self.null2)
216
class TestServerHandshake(unittest.TestCase):
218
def testDefaultInit(self):
219
handshake = ServerHandshake()
220
self.verifyDefault(handshake)
222
def testKeywordInit(self):
223
handshake = ServerHandshake(protocol_version=11,
224
server_version='test',
226
scramble=tuple([5] * 20),
231
unused=tuple([6] * 13),
233
self.verifyCustom(handshake)
235
def testUnpackInit(self):
236
data = struct.pack('B', 11)
238
data += struct.pack('<I', 1234)
239
data += ''.join([chr(5)] * 8)
240
data += struct.pack('B2BB2B', 1, 255, 254, 253, 252, 251)
241
data += ''.join([chr(6)] * 13)
242
data += ''.join([chr(5)] * 12)
243
data += struct.pack('B', 2)
245
handshake = ServerHandshake(data)
246
self.verifyCustom(handshake)
249
handshake = ServerHandshake(ServerHandshake().pack())
250
self.verifyDefault(handshake)
252
def verifyDefault(self, handshake):
253
self.assertEqual(handshake.protocol_version, 10)
254
self.assertEqual(handshake.server_version, '')
255
self.assertEqual(handshake.thread_id, 0)
256
self.assertEqual(handshake.scramble, tuple([0] * 20))
257
self.assertEqual(handshake.null1, 0)
258
self.assertEqual(handshake.capabilities.value(), 0)
259
self.assertEqual(handshake.charset, 0)
260
self.assertEqual(handshake.status.value(), 0)
261
self.assertEqual(handshake.unused, tuple([0] * 13))
262
self.assertEqual(handshake.null2, 0)
264
def verifyCustom(self, handshake):
265
self.assertEqual(handshake.protocol_version, 11)
266
self.assertEqual(handshake.server_version, 'test')
267
self.assertEqual(handshake.thread_id, 1234)
268
self.assertEqual(handshake.scramble, tuple([5] * 20))
269
self.assertEqual(handshake.null1, 1)
270
self.assertEqual(handshake.capabilities.value(), 65279)
271
self.assertEqual(handshake.charset, 253)
272
self.assertEqual(handshake.status.value(), 64508)
273
self.assertEqual(handshake.unused, tuple([6] * 13))
274
self.assertEqual(handshake.null2, 2)
276
class ClientHandshake:
277
'''This class represents the client handshake sent back to the server.'''
279
def __init__(self, packed=None, capabilities=0, max_packet_size=0, charset=0,
280
unused=tuple([0] * 23), user='', scramble_size=0,
281
scramble=None, db=''):
283
self.capabilities = Capabilities(capabilities)
284
self.max_packet_size = max_packet_size
285
self.charset = charset
288
self.scramble_size = scramble_size
289
self.scramble = scramble
292
data = struct.unpack('<IIB23B', packed[:32])
293
self.capabilities = Capabilities(data[0])
294
self.max_packet_size = data[1]
295
self.charset = data[2]
296
self.unused = data[3:]
298
user_length = packed.index('\x00')
299
self.user = packed[:user_length]
300
packed = packed[1+user_length:]
301
self.scramble_size = ord(packed[0])
302
if self.scramble_size == 0:
305
self.scramble = tuple(map(ord, packed[1:21]))
306
if packed[-1:] == '\x00':
307
self.db = packed[21:-1]
309
self.db = packed[21:]
312
data = struct.pack('<IIB',
313
self.capabilities.value(),
314
self.max_packet_size,
316
data += ''.join(map(chr, self.unused))
317
data += self.user + '\x00'
318
data += chr(self.scramble_size)
319
if self.scramble_size != 0:
320
data += ''.join(map(chr, self.scramble))
321
data += self.db + '\x00'
325
return '''ClientHandshake
334
''' % (self.capabilities, self.max_packet_size, self.charset, self.unused,
335
self.user, self.scramble_size, self.scramble, self.db)
337
class TestClientHandshake(unittest.TestCase):
339
def testDefaultInit(self):
340
handshake = ClientHandshake()
341
self.verifyDefault(handshake)
343
def testKeywordInit(self):
344
handshake = ClientHandshake(capabilities=65279,
345
max_packet_size=64508,
347
unused=tuple([6] * 23),
350
scramble=tuple([5] * 20),
352
self.verifyCustom(handshake)
354
def testUnpackInit(self):
355
data = struct.pack('<IIB', 65279, 64508, 253)
356
data += ''.join([chr(6)] * 23)
359
data += ''.join([chr(5)] * 20)
362
handshake = ClientHandshake(data)
363
self.verifyCustom(handshake)
366
handshake = ClientHandshake(ClientHandshake().pack())
367
self.verifyDefault(handshake)
369
def verifyDefault(self, handshake):
370
self.assertEqual(handshake.capabilities.value(), 0)
371
self.assertEqual(handshake.max_packet_size, 0)
372
self.assertEqual(handshake.charset, 0)
373
self.assertEqual(handshake.unused, tuple([0] * 23))
374
self.assertEqual(handshake.user, '')
375
self.assertEqual(handshake.scramble_size, 0)
376
self.assertEqual(handshake.scramble, None)
377
self.assertEqual(handshake.db, '')
379
def verifyCustom(self, handshake):
380
self.assertEqual(handshake.capabilities.value(), 65279)
381
self.assertEqual(handshake.max_packet_size, 64508)
382
self.assertEqual(handshake.charset, 253)
383
self.assertEqual(handshake.unused, tuple([6] * 23))
384
self.assertEqual(handshake.user, 'user')
385
self.assertEqual(handshake.scramble_size, 20)
386
self.assertEqual(handshake.scramble, tuple([5] * 20))
387
self.assertEqual(handshake.db, 'db')
390
'''This class represents a result packet sent from the server.'''
392
def __init__(self, packed=None, field_count=0, affected_rows=0, insert_id=0,
393
status=0, warning_count=0, message='', version_40=False):
395
self.field_count = field_count
396
self.affected_rows = affected_rows
397
self.insert_id = insert_id
399
self.warning_count = warning_count
400
self.message = message
401
self.version_40 = version_40
403
if version_40 is True:
404
self.field_count = ord(packed[0])
405
if self.field_count == 0:
406
self.affected_rows = ord(packed[1])
407
self.insert_id = ord(packed[2])
411
data = struct.unpack('<H', packed[2:])
412
self.status = data[0]
413
elif self.field_count == 255:
414
data = struct.unpack('<H', packed[1:3])
415
self.error_code = data[0]
416
self.message = packed[3:]
418
self.affected_rows = ord(packed[1])
419
self.insert_id = ord(packed[2])
420
data = struct.unpack('<HH', packed[3:7])
421
self.status = data[0]
422
self.warning_count = data[1]
423
self.message = packed[7:]
425
self.field_count = ord(packed[0])
426
if self.field_count == 255:
427
data = struct.unpack('<H', packed[1:3])
428
self.error_code = data[0]
429
self.sqlstate_marker = packed[3]
430
self.sqlstate = packed[4:9]
431
self.message = packed[9:]
433
self.affected_rows = ord(packed[1])
434
self.insert_id = ord(packed[2])
435
data = struct.unpack('<HH', packed[3:7])
436
self.status = data[0]
437
self.warning_count = data[1]
438
self.message = packed[7:]
440
self.version_40 = version_40
443
if self.version_40 is True:
444
if self.field_count == 255:
450
''' % (self.field_count, self.error_code, self.message, self.version_40)
458
''' % (self.field_count, self.affected_rows, self.insert_id, self.status,
461
if self.field_count == 255:
469
''' % (self.field_count, self.error_code, self.sqlstate_marker, sqlstate,
470
self.message, self.version_40)
480
''' % (self.field_count, self.affected_rows, self.insert_id, self.status,
481
self.warning_count, self.message, self.version_40)
483
if __name__ == '__main__':