3
# Drizzle Client & Protocol Library
5
# Copyright (C) 2008 Eric Day (eday@oddments.org)
8
# Use and distribution licensed under the BSD license. See
9
# the COPYING file in this directory for full text.
12
Drizzle and MySQL Protocol Test Suite
14
This tool requires for anonymous authentication to be open on the
21
from prototest import mysql
23
parser = optparse.OptionParser(add_help_option=False)
25
parser.add_option("--help", action="help", help="Print out help details")
26
parser.add_option("-h", "--host", dest="host", default="localhost",
27
help="Host or IP to test", metavar="HOST")
28
parser.add_option("-p", "--port", dest="port", default=3306,
29
help="Port to test", metavar="PORT")
31
(options, args) = parser.parse_args()
33
class TestHandshake(unittest.TestCase):
35
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
36
self.s.connect((options.host, int(options.port)))
41
def testNullClientHandshake(self):
42
self.verifyServerHandshake()
43
self.s.send(mysql.Packet().pack())
44
(packet, data) = self.verifyPacket(1)
45
result = mysql.Result(data, version_40=True)
46
# Got packets out of order
47
self.assertEqual(result.field_count, 255)
48
#self.assertEqual(result.error_code, 1156)
50
def testEmptyRangeClientHandshake(self):
51
for x in range(0, 1024) + range(1024, 1024*1024, 997):
52
self.verifyServerHandshake()
53
self.s.send(mysql.Packet(size=x, sequence=1).pack())
56
self.s.send('\x00' * x)
58
(packet, data) = self.verifyPacket(2)
61
result = mysql.Result(data, version_40=True)
63
self.assertEqual(result.field_count, 255)
64
self.assertEqual(result.error_code, 1043)
66
result = mysql.Result(data, version_40=True)
67
if result.field_count == 0:
68
self.assertEqual(result.affected_rows, 0)
69
self.assertEqual(result.insert_id, 0)
70
self.assertEqual(result.status, 0)
72
# Got a packet bigger than 'max_allowed_packet' bytes
73
self.assertEqual(result.field_count, 255)
74
#self.assertEqual(result.error_code, 1153)
77
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
78
self.s.connect((options.host, int(options.port)))
80
def testMaxRangeClientHandshake(self):
81
for x in range(0, 1024) + range(1024, 1024*1024, 997):
82
self.verifyServerHandshake()
83
self.s.send(mysql.Packet(size=x, sequence=1).pack())
86
self.s.send('\xff' * x)
88
(packet, data) = self.verifyPacket(2)
90
result = mysql.Result(data, version_40=True)
91
self.assertEqual(result.field_count, 255)
93
if result.error_code != 1043:
94
# Got a packet bigger than 'max_allowed_packet' bytes
95
self.assertEqual(result.error_code, 1153)
98
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
99
self.s.connect((options.host, int(options.port)))
101
def testUserOverrun(self):
102
for x in range(0, 1024) + range(1024, 1024*1024, 997):
103
server_handshake = self.verifyServerHandshake()
104
client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
106
self.s.send(mysql.Packet(size=32+x, sequence=1).pack())
107
self.s.send(client_handshake.pack()[:32])
109
self.s.send('\xff' * x)
111
(packet, data) = self.verifyPacket(2)
112
result = mysql.Result(data, version_40=True)
114
if result.error_code != 1043:
115
# Got a packet bigger than 'max_allowed_packet' bytes
116
self.assertEqual(result.error_code, 1153)
119
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
120
self.s.connect((options.host, int(options.port)))
122
def testScrambleOverrun(self):
123
for x in range(0, 256):
124
server_handshake = self.verifyServerHandshake()
125
client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
127
self.s.send(mysql.Packet(size=34+x, sequence=1).pack())
128
self.s.send(client_handshake.pack()[:33])
131
self.s.send('\xff' * x)
133
(packet, data) = self.verifyPacket(2)
134
result = mysql.Result(data)
135
if result.field_count == 0:
136
self.assertEqual(result.affected_rows, 0)
137
self.assertEqual(result.insert_id, 0)
138
self.assertEqual(result.warning_count, 0)
139
elif result.error_code != 1045:
140
# Not access denied, Bad handshake
141
self.assertEqual(result.field_count, 255)
142
self.assertEqual(result.error_code, 1043)
145
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
146
self.s.connect((options.host, int(options.port)))
148
def testDBOverrun(self):
149
for x in range(0, 256):
150
server_handshake = self.verifyServerHandshake()
151
client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
153
self.s.send(mysql.Packet(size=34+x, sequence=1).pack())
154
self.s.send(client_handshake.pack()[:34])
156
self.s.send('\xff' * x)
158
(packet, data) = self.verifyPacket(2)
159
result = mysql.Result(data)
160
if result.field_count == 0:
161
self.assertEqual(result.affected_rows, 0)
162
self.assertEqual(result.insert_id, 0)
163
self.assertEqual(result.warning_count, 0)
164
elif result.error_code != 1044:
165
# Not access denied, Bad database name
166
self.assertEqual(result.field_count, 255)
167
self.assertEqual(result.error_code, 1102)
170
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
171
self.s.connect((options.host, int(options.port)))
173
def testSimple(self):
174
server_handshake = self.verifyServerHandshake()
176
client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
177
client_handshake.capabilities.compress = False
178
data = client_handshake.pack()
179
self.s.send(mysql.Packet(size=len(data), sequence=1).pack())
182
(packet, data) = self.verifyPacket(2)
183
result = mysql.Result(data)
184
self.assertEqual(result.field_count, 0)
185
self.assertEqual(result.affected_rows, 0)
186
self.assertEqual(result.insert_id, 0)
187
self.assertEqual(result.warning_count, 0)
189
def verifyPacket(self, sequence):
190
data = self.s.recv(4)
191
self.assertEqual(len(data), 4)
193
packet = mysql.Packet(data)
194
self.assertTrue(packet.size > 0)
195
self.assertEqual(packet.sequence, sequence)
197
data = self.s.recv(packet.size)
198
self.assertEqual(len(data), packet.size)
200
return (packet, data)
202
def verifyServerHandshake(self):
203
(packet, data) = self.verifyPacket(0)
205
server_handshake = mysql.ServerHandshake(data)
206
self.assertEqual(server_handshake.protocol_version, 10)
207
self.assertEqual(server_handshake.null1, 0)
208
self.assertEqual(server_handshake.status.value(), 2)
209
self.assertEqual(server_handshake.unused, tuple([0] * 13))
210
self.assertEqual(server_handshake.null2, 0)
211
return server_handshake
213
class TestCommand(unittest.TestCase):
215
# Read server handshake
216
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
217
self.s.connect((options.host, int(options.port)))
218
data = self.s.recv(4)
219
packet = mysql.Packet(data)
220
data = self.s.recv(packet.size)
221
server_handshake = mysql.ServerHandshake(data)
223
# Send client handshake
224
client_handshake = mysql.ClientHandshake(capabilities=server_handshake.capabilities.value())
225
client_handshake.capabilities.compress = False
226
data = client_handshake.pack()
227
self.s.send(mysql.Packet(size=len(data), sequence=1).pack())
230
# Read server response
231
data = self.s.recv(4)
232
packet = mysql.Packet(data)
233
data = self.s.recv(packet.size)
234
result = mysql.Result(data)
235
self.assertEqual(result.field_count, 0)
240
def testSimple(self):
241
self.s.send(mysql.Packet(size=11, sequence=0).pack())
242
self.s.send('\x03SELECT 1+1')
243
data = self.s.recv(1024)
246
#self.s.send(mysql.Packet().pack())
247
#data = self.s.recv(1024)
250
#self.s.send(mysql.Packet(sequence=1).pack())
251
#data = self.s.recv(1024)
254
if __name__ == '__main__':
255
suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
256
unittest.TextTestRunner(verbosity=2).run(suite)