~fallenpegasus/libdrizzle/ping

« back to all changes in this revision

Viewing changes to prototest/mysql_test

  • Committer: Eric Day
  • Date: 2010-02-16 07:17:11 UTC
  • Revision ID: eday@oddments.org-20100216071711-lj5t8g40iq4fcp16
Added protocol testing tool. This includes just initial handshake testing for now.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
#
 
3
# Drizzle Client & Protocol Library
 
4
 
5
# Copyright (C) 2008 Eric Day (eday@oddments.org)
 
6
# All rights reserved.
 
7
 
8
# Use and distribution licensed under the BSD license.  See
 
9
# the COPYING file in this directory for full text.
 
10
#
 
11
'''
 
12
Drizzle and MySQL Protocol Test Suite
 
13
 
 
14
This tool requires for anonymous authentication to be open on the
 
15
MySQL server.
 
16
'''
 
17
 
 
18
import optparse
 
19
import socket
 
20
import unittest
 
21
from prototest import mysql
 
22
 
 
23
parser = optparse.OptionParser(add_help_option=False)
 
24
 
 
25
parser.add_option("--help", action="help", help="Print out help details")
 
26
parser.add_option("-h", "--host", dest="host", default="localhost",
 
27
                  help="Host or IP to test", metavar="HOST")
 
28
parser.add_option("-p", "--port", dest="port", default=3306,
 
29
                  help="Port to test", metavar="PORT")
 
30
 
 
31
(options, args) = parser.parse_args()
 
32
 
 
33
class TestHandshake(unittest.TestCase):
 
34
  def setUp(self):
 
35
    self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
36
    self.s.connect((options.host, int(options.port)))
 
37
 
 
38
  def tearDown(self):
 
39
    self.s.close()
 
40
 
 
41
  def testNullClientHandshake(self):
 
42
    self.verifyServerHandshake()
 
43
    self.s.send(mysql.Packet().pack())
 
44
    (packet, data) = self.verifyPacket(1)
 
45
    result = mysql.Result(data, version_40=True)
 
46
    # Got packets out of order
 
47
    self.assertEqual(result.field_count, 255)
 
48
    #self.assertEqual(result.error_code, 1156)
 
49
 
 
50
  def testEmptyRangeClientHandshake(self):
 
51
    for x in range(0, 1024) + range(1024, 1024*1024, 997):
 
52
      self.verifyServerHandshake()
 
53
      self.s.send(mysql.Packet(size=x, sequence=1).pack())
 
54
 
 
55
      if x > 0:
 
56
        self.s.send('\x00' * x)
 
57
 
 
58
      (packet, data) = self.verifyPacket(2)
 
59
 
 
60
      if x < 6:
 
61
        result = mysql.Result(data, version_40=True)
 
62
        # Bad handshake
 
63
        self.assertEqual(result.field_count, 255)
 
64
        self.assertEqual(result.error_code, 1043)
 
65
      else:
 
66
        result = mysql.Result(data, version_40=True)
 
67
        if result.field_count == 0:
 
68
          self.assertEqual(result.affected_rows, 0)
 
69
          self.assertEqual(result.insert_id, 0)
 
70
          self.assertEqual(result.status, 0)
 
71
        else:
 
72
          # Got a packet bigger than 'max_allowed_packet' bytes
 
73
          self.assertEqual(result.field_count, 255)
 
74
          #self.assertEqual(result.error_code, 1153)
 
75
 
 
76
      self.s.close()
 
77
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
78
      self.s.connect((options.host, int(options.port)))
 
79
 
 
80
  def testMaxRangeClientHandshake(self):
 
81
    for x in range(0, 1024) + range(1024, 1024*1024, 997):
 
82
      self.verifyServerHandshake()
 
83
      self.s.send(mysql.Packet(size=x, sequence=1).pack())
 
84
 
 
85
      if x > 0:
 
86
        self.s.send('\xff' * x)
 
87
 
 
88
      (packet, data) = self.verifyPacket(2)
 
89
 
 
90
      result = mysql.Result(data, version_40=True)
 
91
      self.assertEqual(result.field_count, 255)
 
92
      # Bad handshake?
 
93
      if result.error_code != 1043:
 
94
        # Got a packet bigger than 'max_allowed_packet' bytes
 
95
        self.assertEqual(result.error_code, 1153)
 
96
 
 
97
      self.s.close()
 
98
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
99
      self.s.connect((options.host, int(options.port)))
 
100
 
 
101
  def testUserOverrun(self):
 
102
    for x in range(0, 1024) + range(1024, 1024*1024, 997):
 
103
      server_handshake = self.verifyServerHandshake()
 
104
      client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
 
105
 
 
106
      self.s.send(mysql.Packet(size=32+x, sequence=1).pack())
 
107
      self.s.send(client_handshake.pack()[:32])
 
108
      if x > 0:
 
109
        self.s.send('\xff' * x)
 
110
 
 
111
      (packet, data) = self.verifyPacket(2)
 
112
      result = mysql.Result(data, version_40=True)
 
113
      # Bad handshake?
 
114
      if result.error_code != 1043:
 
115
        # Got a packet bigger than 'max_allowed_packet' bytes
 
116
        self.assertEqual(result.error_code, 1153)
 
117
 
 
118
      self.s.close()
 
119
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
120
      self.s.connect((options.host, int(options.port)))
 
121
 
 
122
  def testScrambleOverrun(self):
 
123
    for x in range(0, 256):
 
124
      server_handshake = self.verifyServerHandshake()
 
125
      client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
 
126
 
 
127
      self.s.send(mysql.Packet(size=34+x, sequence=1).pack())
 
128
      self.s.send(client_handshake.pack()[:33])
 
129
      self.s.send(chr(x))
 
130
      if x > 0:
 
131
        self.s.send('\xff' * x)
 
132
 
 
133
      (packet, data) = self.verifyPacket(2)
 
134
      result = mysql.Result(data)
 
135
      if result.field_count == 0:
 
136
        self.assertEqual(result.affected_rows, 0)
 
137
        self.assertEqual(result.insert_id, 0)
 
138
        self.assertEqual(result.warning_count, 0)
 
139
      elif result.error_code != 1045:
 
140
        # Not access denied, Bad handshake
 
141
        self.assertEqual(result.field_count, 255)
 
142
        self.assertEqual(result.error_code, 1043)
 
143
 
 
144
      self.s.close()
 
145
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
146
      self.s.connect((options.host, int(options.port)))
 
147
 
 
148
  def testDBOverrun(self):
 
149
    for x in range(0, 256):
 
150
      server_handshake = self.verifyServerHandshake()
 
151
      client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
 
152
 
 
153
      self.s.send(mysql.Packet(size=34+x, sequence=1).pack())
 
154
      self.s.send(client_handshake.pack()[:34])
 
155
      if x > 0:
 
156
        self.s.send('\xff' * x)
 
157
 
 
158
      (packet, data) = self.verifyPacket(2)
 
159
      result = mysql.Result(data)
 
160
      if result.field_count == 0:
 
161
        self.assertEqual(result.affected_rows, 0)
 
162
        self.assertEqual(result.insert_id, 0)
 
163
        self.assertEqual(result.warning_count, 0)
 
164
      elif result.error_code != 1044:
 
165
        # Not access denied, Bad database name
 
166
        self.assertEqual(result.field_count, 255)
 
167
        self.assertEqual(result.error_code, 1102)
 
168
 
 
169
      self.s.close()
 
170
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
171
      self.s.connect((options.host, int(options.port)))
 
172
 
 
173
  def testSimple(self):
 
174
    server_handshake = self.verifyServerHandshake()
 
175
 
 
176
    client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
 
177
    client_handshake.capabilities.compress = False
 
178
    data = client_handshake.pack()
 
179
    self.s.send(mysql.Packet(size=len(data), sequence=1).pack())
 
180
    self.s.send(data)
 
181
 
 
182
    (packet, data) = self.verifyPacket(2)
 
183
    result = mysql.Result(data)
 
184
    self.assertEqual(result.field_count, 0)
 
185
    self.assertEqual(result.affected_rows, 0)
 
186
    self.assertEqual(result.insert_id, 0)
 
187
    self.assertEqual(result.warning_count, 0)
 
188
 
 
189
  def verifyPacket(self, sequence):
 
190
    data = self.s.recv(4)
 
191
    self.assertEqual(len(data), 4)
 
192
 
 
193
    packet = mysql.Packet(data)
 
194
    self.assertTrue(packet.size > 0)
 
195
    self.assertEqual(packet.sequence, sequence)
 
196
 
 
197
    data = self.s.recv(packet.size)
 
198
    self.assertEqual(len(data), packet.size)
 
199
 
 
200
    return (packet, data)
 
201
    
 
202
  def verifyServerHandshake(self):
 
203
    (packet, data) = self.verifyPacket(0)
 
204
 
 
205
    server_handshake = mysql.ServerHandshake(data)
 
206
    self.assertEqual(server_handshake.protocol_version, 10)
 
207
    self.assertEqual(server_handshake.null1, 0)
 
208
    self.assertEqual(server_handshake.status.value(), 2)
 
209
    self.assertEqual(server_handshake.unused, tuple([0] * 13))
 
210
    self.assertEqual(server_handshake.null2, 0)
 
211
    return server_handshake
 
212
 
 
213
class TestCommand(unittest.TestCase):
 
214
  def setUp(self):
 
215
    # Read server handshake
 
216
    self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
217
    self.s.connect((options.host, int(options.port)))
 
218
    data = self.s.recv(4)
 
219
    packet = mysql.Packet(data)
 
220
    data = self.s.recv(packet.size)
 
221
    server_handshake = mysql.ServerHandshake(data)
 
222
 
 
223
    # Send client handshake
 
224
    client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
 
225
    client_handshake.capabilities.compress = False
 
226
    data = client_handshake.pack()
 
227
    self.s.send(mysql.Packet(size=len(data), sequence=1).pack())
 
228
    self.s.send(data)
 
229
 
 
230
    # Read server response
 
231
    data = self.s.recv(4)
 
232
    packet = mysql.Packet(data)
 
233
    data = self.s.recv(packet.size)
 
234
    result = mysql.Result(data)
 
235
    self.assertEqual(result.field_count, 0)
 
236
 
 
237
  def tearDown(self):
 
238
    self.s.close()
 
239
 
 
240
  def testSimple(self):
 
241
    self.s.send(mysql.Packet(size=11, sequence=0).pack())
 
242
    self.s.send('\x03SELECT 1+1')
 
243
    data = self.s.recv(1024)
 
244
    #print list(data)
 
245
 
 
246
    #self.s.send(mysql.Packet().pack())
 
247
    #data = self.s.recv(1024)
 
248
    #print list(data)
 
249
 
 
250
    #self.s.send(mysql.Packet(sequence=1).pack())
 
251
    #data = self.s.recv(1024)
 
252
    #print list(data)
 
253
 
 
254
if __name__ == '__main__':
 
255
  suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
 
256
  unittest.TextTestRunner(verbosity=2).run(suite)