~fallenpegasus/libdrizzle/ping

« back to all changes in this revision

Viewing changes to prototest/prototest/mysql.py

  • Committer: Eric Day
  • Date: 2010-03-17 00:20:06 UTC
  • Revision ID: eday@oddments.org-20100317002006-8vtwqrn9lvpkmq40
Updated prototest.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#!/usr/bin/env python
2
 
#
3
 
# Drizzle Client & Protocol Library
4
 
5
 
# Copyright (C) 2008 Eric Day (eday@oddments.org)
6
 
# All rights reserved.
7
 
8
 
# Use and distribution licensed under the BSD license.  See
9
 
# the COPYING file in this directory for full text.
10
 
#
11
 
'''
12
 
MySQL Protocol Objects
13
 
 
14
 
Objects in this module can be initialized by passing in either a raw
15
 
buffer to parse or keywords. This allows you to use these objects
16
 
for either sending or receiving.
17
 
'''
18
 
 
19
 
import struct
20
 
import unittest
21
 
import bitfield
22
 
 
23
 
class Capabilities(bitfield.BitField):
24
 
  _fields = [
25
 
    'LONG_PASSWORD',
26
 
    'FOUND_ROWS',
27
 
    'LONG_FLAG',
28
 
    'CONNECT_WITH_DB',
29
 
    'NO_SCHEMA',
30
 
    'COMPRESS',
31
 
    'ODBC',
32
 
    'LOCAL_FILES',
33
 
    'IGNORE_SPACE',
34
 
    'PROTOCOL_41',
35
 
    'INTERACTIVE',
36
 
    'SSL',
37
 
    'IGNORE_SIGPIPE',
38
 
    'TRANSACTIONS',
39
 
    'RESERVED',
40
 
    'SECURE_CONNECTION',
41
 
    'MULTI_STATEMENTS',
42
 
    'MULTI_RESULTS',
43
 
    None,
44
 
    None,
45
 
    None,
46
 
    None,
47
 
    None,
48
 
    None,
49
 
    None,
50
 
    None,
51
 
    None,
52
 
    None,
53
 
    None,
54
 
    None,
55
 
    'SSL_VERIFY_SERVER_CERT',
56
 
    'REMEMBER_OPTIONS'
57
 
  ]
58
 
 
59
 
class Status(bitfield.BitField):
60
 
  _fields = [
61
 
    'IN_TRANS',
62
 
    'AUTOCOMMIT',
63
 
    'MORE_RESULTS_EXISTS',
64
 
    'QUERY_NO_GOOD_INDEX_USED',
65
 
    'QUERY_NO_INDEX_USED',
66
 
    'CURSOR_EXISTS',
67
 
    'LAST_ROW_SENT',
68
 
    'DB_DROPPED',
69
 
    'NO_BACKSLASH_ESCAPES',
70
 
    'QUERY_WAS_SLOW'
71
 
  ]
72
 
 
73
 
class PacketException(Exception):
74
 
  pass
75
 
 
76
 
class Packet:
77
 
  '''This class represents a packet header.'''
78
 
 
79
 
  def __init__(self, packed=None, size=0, sequence=0):
80
 
    if packed is None:
81
 
      self.size = size
82
 
      self.sequence = sequence
83
 
    else:
84
 
      data = struct.unpack('4B', packed)
85
 
      self.size = data[0] | (data[1] << 8) | (data[2] << 16)
86
 
      self.sequence = data[3]
87
 
 
88
 
    self.verify()
89
 
 
90
 
  def pack(self):
91
 
    self.verify()
92
 
    return struct.pack('4B',
93
 
                       self.size & 0xFF,
94
 
                       (self.size >> 8) & 0xFF,
95
 
                       (self.size >> 16) & 0xFF,
96
 
                       self.sequence % 256)
97
 
 
98
 
  def verify(self):
99
 
    if self.size >= 16777216:
100
 
      raise PacketException('Packet size cannot exceed 16777215 bytes (%d)' %
101
 
                            self.size)
102
 
 
103
 
  def __str__(self):
104
 
    return '''Packet
105
 
  size = %s
106
 
  sequence = %s
107
 
''' % (self.size, self.sequence)
108
 
 
109
 
class TestPacket(unittest.TestCase):
110
 
 
111
 
  def testDefaultInit(self):
112
 
    packet = Packet()
113
 
    self.assertEqual(packet.size, 0)
114
 
    self.assertEqual(packet.sequence, 0)
115
 
 
116
 
  def testKeywordInit(self):
117
 
    packet = Packet(size=1234, sequence=5)
118
 
    self.assertEqual(packet.size, 1234)
119
 
    self.assertEqual(packet.sequence, 5)
120
 
 
121
 
  def testUnpackInit(self):
122
 
    packet = Packet(struct.pack('4B', 210, 4, 0, 5))
123
 
    self.assertEqual(packet.size, 1234)
124
 
    self.assertEqual(packet.sequence, 5)
125
 
 
126
 
  def testPack(self):
127
 
    packet = Packet(Packet(size=1234, sequence=5).pack())
128
 
    self.assertEqual(packet.size, 1234)
129
 
 
130
 
  def testPackRange(self):
131
 
    for x in range(0, 300):
132
 
      packet = Packet(Packet(size=x, sequence=x).pack())
133
 
      self.assertEqual(packet.size, x)
134
 
      self.assertEqual(packet.sequence, x % 256)
135
 
 
136
 
    # 997 is a random prime number so we hit various increments
137
 
    for x in range(300, 16777216, 997):
138
 
      packet = Packet(Packet(size=x, sequence=x).pack())
139
 
      self.assertEqual(packet.size, x)
140
 
      self.assertEqual(packet.sequence, x % 256)
141
 
 
142
 
    packet = Packet(Packet(size=16777215).pack())
143
 
    self.assertEqual(packet.size, 16777215)
144
 
    self.assertEqual(packet.sequence, 0)
145
 
 
146
 
    self.assertRaises(PacketException, Packet, size=16777216)
147
 
    self.assertRaises(PacketException, Packet, size=16777217)
148
 
    self.assertRaises(PacketException, Packet, size=4294967295)
149
 
    self.assertRaises(PacketException, Packet, size=4294967296)
150
 
    self.assertRaises(PacketException, Packet, size=4294967297)
151
 
 
152
 
class ServerHandshake:
153
 
  '''This class represents the initial handshake sent from server to client.'''
154
 
 
155
 
  def __init__(self, packed=None, protocol_version=10, server_version='',
156
 
               thread_id=0, scramble=tuple([0] * 20), null1=0, capabilities=0,
157
 
               charset=0, status=0, unused=tuple([0] * 13), null2=0):
158
 
    if packed is None:
159
 
      self.protocol_version = protocol_version
160
 
      self.server_version = server_version
161
 
      self.thread_id = thread_id
162
 
      self.scramble = scramble
163
 
      self.null1 = null1
164
 
      self.capabilities = Capabilities(capabilities)
165
 
      self.charset = charset
166
 
      self.status = Status(status)
167
 
      self.unused = unused
168
 
      self.null2 = null2
169
 
    else:
170
 
      self.protocol_version = struct.unpack('B', packed[:1])[0]
171
 
      server_version_length = packed[1:].index('\x00')
172
 
      self.server_version = packed[1:1+server_version_length]
173
 
      data = struct.unpack('<I8BB2BB2B13B12BB', packed[2+server_version_length:])
174
 
      self.thread_id = data[0]
175
 
      self.scramble = data[1:9] + data[28:40]
176
 
      self.null1 = data[9]
177
 
      self.capabilities = Capabilities(data[10] | (data[11] << 8))
178
 
      self.charset = data[12]
179
 
      self.status = Status(data[13] | (data[14] << 8))
180
 
      self.unused = data[15:28]
181
 
      self.null2 = data[40]
182
 
 
183
 
  def pack(self):
184
 
    data = struct.pack('B', self.protocol_version)
185
 
    data += self.server_version + '\x00'
186
 
    data += struct.pack('<I', self.thread_id)
187
 
    data += ''.join(map(chr, self.scramble[:8]))
188
 
    data += struct.pack('B2BB2B',
189
 
                       self.null1,
190
 
                       self.capabilities.value() & 0xFF,
191
 
                       (self.capabilities.value() >> 8) & 0xFF,
192
 
                       self.charset,
193
 
                       self.status.value() & 0xFF,
194
 
                       (self.status.value() >> 8) & 0xFF)
195
 
    data += ''.join(map(chr, self.unused))
196
 
    data += ''.join(map(chr, self.scramble[8:]))
197
 
    data += struct.pack('B', self.null2)
198
 
    return data
199
 
 
200
 
  def __str__(self):
201
 
    return '''ServerHandshake
202
 
  protocol_version = %s
203
 
  server_version = %s
204
 
  thread_id = %s
205
 
  scramble = %s
206
 
  null1 = %s
207
 
  capabilities = %s
208
 
  charset = %s
209
 
  status = %s
210
 
  unused = %s
211
 
  null2 = %s
212
 
''' % (self.protocol_version, self.server_version, self.thread_id,
213
 
       self.scramble, self.null1, self.capabilities, self.charset,
214
 
       self.status, self.unused, self.null2)
215
 
 
216
 
class TestServerHandshake(unittest.TestCase):
217
 
 
218
 
  def testDefaultInit(self):
219
 
    handshake = ServerHandshake()
220
 
    self.verifyDefault(handshake)
221
 
 
222
 
  def testKeywordInit(self):
223
 
    handshake = ServerHandshake(protocol_version=11,
224
 
                                server_version='test',
225
 
                                thread_id=1234,
226
 
                                scramble=tuple([5] * 20),
227
 
                                null1=1,
228
 
                                capabilities=65279,
229
 
                                charset=253,
230
 
                                status=64508,
231
 
                                unused=tuple([6] * 13),
232
 
                                null2=2)
233
 
    self.verifyCustom(handshake)
234
 
 
235
 
  def testUnpackInit(self):
236
 
    data = struct.pack('B', 11)
237
 
    data += 'test\x00'
238
 
    data += struct.pack('<I', 1234)
239
 
    data += ''.join([chr(5)] * 8)
240
 
    data += struct.pack('B2BB2B', 1, 255, 254, 253, 252, 251)
241
 
    data += ''.join([chr(6)] * 13)
242
 
    data += ''.join([chr(5)] * 12)
243
 
    data += struct.pack('B', 2)
244
 
 
245
 
    handshake = ServerHandshake(data)
246
 
    self.verifyCustom(handshake)
247
 
 
248
 
  def testPack(self):
249
 
    handshake = ServerHandshake(ServerHandshake().pack())
250
 
    self.verifyDefault(handshake)
251
 
 
252
 
  def verifyDefault(self, handshake):
253
 
    self.assertEqual(handshake.protocol_version, 10)
254
 
    self.assertEqual(handshake.server_version, '')
255
 
    self.assertEqual(handshake.thread_id, 0)
256
 
    self.assertEqual(handshake.scramble, tuple([0] * 20))
257
 
    self.assertEqual(handshake.null1, 0)
258
 
    self.assertEqual(handshake.capabilities.value(), 0)
259
 
    self.assertEqual(handshake.charset, 0)
260
 
    self.assertEqual(handshake.status.value(), 0)
261
 
    self.assertEqual(handshake.unused, tuple([0] * 13))
262
 
    self.assertEqual(handshake.null2, 0)
263
 
 
264
 
  def verifyCustom(self, handshake):
265
 
    self.assertEqual(handshake.protocol_version, 11)
266
 
    self.assertEqual(handshake.server_version, 'test')
267
 
    self.assertEqual(handshake.thread_id, 1234)
268
 
    self.assertEqual(handshake.scramble, tuple([5] * 20))
269
 
    self.assertEqual(handshake.null1, 1)
270
 
    self.assertEqual(handshake.capabilities.value(), 65279)
271
 
    self.assertEqual(handshake.charset, 253)
272
 
    self.assertEqual(handshake.status.value(), 64508)
273
 
    self.assertEqual(handshake.unused, tuple([6] * 13))
274
 
    self.assertEqual(handshake.null2, 2)
275
 
 
276
 
class ClientHandshake:
277
 
  '''This class represents the client handshake sent back to the server.'''
278
 
 
279
 
  def __init__(self, packed=None, capabilities=0, max_packet_size=0, charset=0,
280
 
               unused=tuple([0] * 23), user='', scramble_size=0,
281
 
               scramble=None, db=''):
282
 
    if packed is None:
283
 
      self.capabilities = Capabilities(capabilities)
284
 
      self.max_packet_size = max_packet_size
285
 
      self.charset = charset
286
 
      self.unused = unused
287
 
      self.user = user
288
 
      self.scramble_size = scramble_size
289
 
      self.scramble = scramble
290
 
      self.db = db
291
 
    else:
292
 
      data = struct.unpack('<IIB23B', packed[:32])
293
 
      self.capabilities = Capabilities(data[0])
294
 
      self.max_packet_size = data[1]
295
 
      self.charset = data[2]
296
 
      self.unused = data[3:]
297
 
      packed = packed[32:]
298
 
      user_length = packed.index('\x00')
299
 
      self.user = packed[:user_length]
300
 
      packed = packed[1+user_length:]
301
 
      self.scramble_size = ord(packed[0])
302
 
      if self.scramble_size == 0:
303
 
        self.scramble = None
304
 
      else:
305
 
        self.scramble = tuple(map(ord, packed[1:21]))
306
 
      if packed[-1:] == '\x00':
307
 
        self.db = packed[21:-1]
308
 
      else:
309
 
        self.db = packed[21:]
310
 
 
311
 
  def pack(self):
312
 
    data = struct.pack('<IIB', 
313
 
                       self.capabilities.value(),
314
 
                       self.max_packet_size,
315
 
                       self.charset)
316
 
    data += ''.join(map(chr, self.unused))
317
 
    data += self.user + '\x00'
318
 
    data += chr(self.scramble_size)
319
 
    if self.scramble_size != 0:
320
 
      data += ''.join(map(chr, self.scramble))
321
 
    data += self.db + '\x00'
322
 
    return data
323
 
 
324
 
  def __str__(self):
325
 
    return '''ClientHandshake
326
 
  capabilities = %s
327
 
  max_packet_size = %s
328
 
  charset = %s
329
 
  unused = %s
330
 
  user = %s
331
 
  scramble_size = %s
332
 
  scramble = %s
333
 
  db = %s
334
 
''' % (self.capabilities, self.max_packet_size, self.charset, self.unused,
335
 
       self.user, self.scramble_size, self.scramble, self.db)
336
 
 
337
 
class TestClientHandshake(unittest.TestCase):
338
 
 
339
 
  def testDefaultInit(self):
340
 
    handshake = ClientHandshake()
341
 
    self.verifyDefault(handshake)
342
 
 
343
 
  def testKeywordInit(self):
344
 
    handshake = ClientHandshake(capabilities=65279,
345
 
                                max_packet_size=64508,
346
 
                                charset=253,
347
 
                                unused=tuple([6] * 23),
348
 
                                user='user',
349
 
                                scramble_size=20,
350
 
                                scramble=tuple([5] * 20),
351
 
                                db='db')
352
 
    self.verifyCustom(handshake)
353
 
 
354
 
  def testUnpackInit(self):
355
 
    data = struct.pack('<IIB', 65279, 64508, 253)
356
 
    data += ''.join([chr(6)] * 23)
357
 
    data += 'user\x00'
358
 
    data += chr(20)
359
 
    data += ''.join([chr(5)] * 20)
360
 
    data += 'db\x00'
361
 
 
362
 
    handshake = ClientHandshake(data)
363
 
    self.verifyCustom(handshake)
364
 
 
365
 
  def testPack(self):
366
 
    handshake = ClientHandshake(ClientHandshake().pack())
367
 
    self.verifyDefault(handshake)
368
 
 
369
 
  def verifyDefault(self, handshake):
370
 
    self.assertEqual(handshake.capabilities.value(), 0)
371
 
    self.assertEqual(handshake.max_packet_size, 0)
372
 
    self.assertEqual(handshake.charset, 0)
373
 
    self.assertEqual(handshake.unused, tuple([0] * 23))
374
 
    self.assertEqual(handshake.user, '')
375
 
    self.assertEqual(handshake.scramble_size, 0)
376
 
    self.assertEqual(handshake.scramble, None)
377
 
    self.assertEqual(handshake.db, '')
378
 
 
379
 
  def verifyCustom(self, handshake):
380
 
    self.assertEqual(handshake.capabilities.value(), 65279)
381
 
    self.assertEqual(handshake.max_packet_size, 64508)
382
 
    self.assertEqual(handshake.charset, 253)
383
 
    self.assertEqual(handshake.unused, tuple([6] * 23))
384
 
    self.assertEqual(handshake.user, 'user')
385
 
    self.assertEqual(handshake.scramble_size, 20)
386
 
    self.assertEqual(handshake.scramble, tuple([5] * 20))
387
 
    self.assertEqual(handshake.db, 'db')
388
 
 
389
 
class Result:
390
 
  '''This class represents a result packet sent from the server.'''
391
 
 
392
 
  def __init__(self, packed=None, field_count=0, affected_rows=0, insert_id=0,
393
 
               status=0, warning_count=0, message='', version_40=False):
394
 
    if packed is None:
395
 
      self.field_count = field_count
396
 
      self.affected_rows = affected_rows
397
 
      self.insert_id = insert_id
398
 
      self.status = status
399
 
      self.warning_count = warning_count
400
 
      self.message = message
401
 
      self.version_40 = version_40
402
 
    else:
403
 
      if version_40 is True:
404
 
        self.field_count = ord(packed[0])
405
 
        if self.field_count == 0:
406
 
          self.affected_rows = ord(packed[1])
407
 
          self.insert_id = ord(packed[2])
408
 
          if len(packed) == 3:
409
 
            self.status = 0
410
 
          else:
411
 
            data = struct.unpack('<H', packed[2:])
412
 
            self.status = data[0]
413
 
        elif self.field_count == 255:
414
 
          data = struct.unpack('<H', packed[1:3])
415
 
          self.error_code = data[0]
416
 
          self.message = packed[3:]
417
 
        else:
418
 
          self.affected_rows = ord(packed[1])
419
 
          self.insert_id = ord(packed[2])
420
 
          data = struct.unpack('<HH', packed[3:7])
421
 
          self.status = data[0]
422
 
          self.warning_count = data[1]
423
 
          self.message = packed[7:]
424
 
      else:
425
 
        self.field_count = ord(packed[0])
426
 
        if self.field_count == 255:
427
 
          data = struct.unpack('<H', packed[1:3])
428
 
          self.error_code = data[0]
429
 
          self.sqlstate_marker = packed[3]
430
 
          self.sqlstate = packed[4:9]
431
 
          self.message = packed[9:]
432
 
        else:
433
 
          self.affected_rows = ord(packed[1])
434
 
          self.insert_id = ord(packed[2])
435
 
          data = struct.unpack('<HH', packed[3:7])
436
 
          self.status = data[0]
437
 
          self.warning_count = data[1]
438
 
          self.message = packed[7:]
439
 
 
440
 
      self.version_40 = version_40
441
 
 
442
 
  def __str__(self):
443
 
    if self.version_40 is True:
444
 
      if self.field_count == 255:
445
 
        return '''Result
446
 
  field_count = %s
447
 
  error_code = %s
448
 
  message = %s
449
 
  version_40 = %s
450
 
''' % (self.field_count, self.error_code, self.message, self.version_40)
451
 
      else:
452
 
        return '''Result
453
 
  field_count = %s
454
 
  affected_rows = %s
455
 
  insert_id = %s
456
 
  status = %s
457
 
  version_40 = %s
458
 
''' % (self.field_count, self.affected_rows, self.insert_id, self.status,
459
 
       self.version_40)
460
 
    else:
461
 
      if self.field_count == 255:
462
 
        return '''Result
463
 
  field_count = %s
464
 
  error_code = %s
465
 
  sqlstate_marker = %s
466
 
  sqlstate = %s
467
 
  message = %s
468
 
  version_40 = %s
469
 
''' % (self.field_count, self.error_code, self.sqlstate_marker, sqlstate,
470
 
       self.message, self.version_40)
471
 
      else:
472
 
        return '''Result
473
 
  field_count = %s
474
 
  affected_rows = %s
475
 
  insert_id = %s
476
 
  status = %s
477
 
  warning_count = %s
478
 
  message = %s
479
 
  version_40 = %s
480
 
''' % (self.field_count, self.affected_rows, self.insert_id, self.status,
481
 
       self.warning_count, self.message, self.version_40)
482
 
 
483
 
if __name__ == '__main__':
484
 
  unittest.main()