3
# Drizzle Client & Protocol Library
5
# Copyright (C) 2008 Eric Day (eday@oddments.org)
8
# Use and distribution licensed under the BSD license. See
9
# the COPYING.BSD file in the root source directory for full text.
12
MySQL Protocol Result Objects
18
class BadFieldCount(Exception):
21
class OkResult(object):
22
'''This class represents an OK result packet sent from the server.'''
24
def __init__(self, packed=None, affected_rows=0, insert_id=0, status=0,
25
warning_count=0, message='', version_40=False):
27
self.affected_rows = affected_rows
28
self.insert_id = insert_id
30
self.message = message
31
self.version_40 = version_40
32
if version_40 is False:
33
self.warning_count = warning_count
35
self.version_40 = version_40
36
if ord(packed[0]) != 0:
37
raise BadFieldCount('Expected 0, received ' + str(ord(packed[0])))
38
self.affected_rows = ord(packed[1])
39
self.insert_id = ord(packed[2])
40
if version_40 is True:
45
data = struct.unpack('<H', packed[3:5])
47
self.message = packed[5:]
49
data = struct.unpack('<HH', packed[3:7])
51
self.warning_count = data[1]
52
self.message = packed[7:]
55
if self.version_40 is True:
62
''' % (self.affected_rows, self.insert_id, self.status, self.message,
72
''' % (self.affected_rows, self.insert_id, self.status, self.warning_count,
73
self.message, self.version_40)
75
class TestOkResult(unittest.TestCase):
77
def testDefaultInit(self):
79
self.assertEqual(result.affected_rows, 0)
80
self.assertEqual(result.insert_id, 0)
81
self.assertEqual(result.status, 0)
82
self.assertEqual(result.warning_count, 0)
83
self.assertEqual(result.message, '')
84
self.assertEqual(result.version_40, False)
87
def testDefaultInit40(self):
88
result = OkResult(version_40=True)
89
self.assertEqual(result.affected_rows, 0)
90
self.assertEqual(result.insert_id, 0)
91
self.assertEqual(result.status, 0)
92
self.assertEqual(result.message, '')
93
self.assertEqual(result.version_40, True)
96
def testKeywordInit(self):
97
result = OkResult(affected_rows=3, insert_id=5, status=2,
98
warning_count=7, message='test', version_40=False)
99
self.assertEqual(result.affected_rows, 3)
100
self.assertEqual(result.insert_id, 5)
101
self.assertEqual(result.status, 2)
102
self.assertEqual(result.warning_count, 7)
103
self.assertEqual(result.message, 'test')
104
self.assertEqual(result.version_40, False)
106
def testUnpackInit(self):
107
data = struct.pack('BBB', 0, 3, 5)
108
data += struct.pack('<HH', 2, 7)
111
result = OkResult(data)
112
self.assertEqual(result.affected_rows, 3)
113
self.assertEqual(result.insert_id, 5)
114
self.assertEqual(result.status, 2)
115
self.assertEqual(result.warning_count, 7)
116
self.assertEqual(result.message, 'test')
117
self.assertEqual(result.version_40, False)
120
def testUnpackInit40(self):
121
data = struct.pack('BBB', 0, 3, 5)
122
data += struct.pack('<H', 2)
125
result = OkResult(data, version_40=True)
126
self.assertEqual(result.affected_rows, 3)
127
self.assertEqual(result.insert_id, 5)
128
self.assertEqual(result.status, 2)
129
self.assertEqual(result.message, 'test')
130
self.assertEqual(result.version_40, True)
133
class ErrorResult(object):
134
'''This class represents an error result packet sent from the server.'''
136
def __init__(self, packed=None, error_code=0, sqlstate_marker='#',
137
sqlstate='XXXXX', message='', version_40=False):
139
self.error_code = error_code
140
self.message = message
141
self.version_40 = version_40
142
if version_40 is False:
143
self.sqlstate_marker = sqlstate_marker
144
self.sqlstate = sqlstate
146
self.version_40 = version_40
147
if ord(packed[0]) != 255:
148
raise BadFieldCount('Expected 255, received ' + str(ord(packed[0])))
149
data = struct.unpack('<H', packed[1:3])
150
self.error_code = data[0]
151
if version_40 is True:
152
self.message = packed[3:]
154
self.sqlstate_marker = packed[3]
155
self.sqlstate = packed[4:9]
156
self.message = packed[9:]
159
if self.version_40 is True:
160
return '''ErrorResult
164
''' % (self.error_code, self.message, self.version_40)
166
return '''ErrorResult
172
''' % (self.error_code, self.sqlstate_marker, self.sqlstate, self.message,
175
class TestErrorResult(unittest.TestCase):
177
def testDefaultInit(self):
178
result = ErrorResult()
179
self.assertEqual(result.error_code, 0)
180
self.assertEqual(result.sqlstate_marker, '#')
181
self.assertEqual(result.sqlstate, 'XXXXX')
182
self.assertEqual(result.message, '')
183
self.assertEqual(result.version_40, False)
186
def testDefaultInit40(self):
187
result = ErrorResult(version_40=True)
188
self.assertEqual(result.error_code, 0)
189
self.assertEqual(result.message, '')
190
self.assertEqual(result.version_40, True)
193
def testKeywordInit(self):
194
result = ErrorResult(error_code=3, sqlstate_marker='@', sqlstate='ABCDE',
195
message='test', version_40=False)
196
self.assertEqual(result.error_code, 3)
197
self.assertEqual(result.sqlstate_marker, '@')
198
self.assertEqual(result.sqlstate, 'ABCDE')
199
self.assertEqual(result.message, 'test')
200
self.assertEqual(result.version_40, False)
203
def testUnpackInit(self):
205
data += struct.pack('<H', 1234)
209
result = ErrorResult(data)
210
self.assertEqual(result.error_code, 1234)
211
self.assertEqual(result.sqlstate_marker, '#')
212
self.assertEqual(result.sqlstate, 'ABCDE')
213
self.assertEqual(result.message, 'test')
214
self.assertEqual(result.version_40, False)
217
def testUnpackInit40(self):
219
data += struct.pack('<H', 1234)
222
result = ErrorResult(data, version_40=True)
223
self.assertEqual(result.error_code, 1234)
224
self.assertEqual(result.message, 'test')
225
self.assertEqual(result.version_40, True)
228
class EofResult(object):
229
'''This class represents an EOF result packet sent from the server.'''
231
def __init__(self, packed=None, warning_count=0, status=0, version_40=False):
233
self.version_40 = version_40
234
if self.version_40 is False:
235
self.warning_count = warning_count
238
self.version_40 = version_40
239
if ord(packed[0]) != 254:
240
raise BadFieldCount('Expected 254, received ' + str(ord(packed[0])))
241
if version_40 is False:
242
data = struct.unpack('<HH', packed[1:])
243
self.warning_count = data[0]
244
self.status = data[1]
247
if self.version_40 is True:
250
''' % self.version_40
256
''' % (self.warning_count, self.status, self.version_40)
258
class TestEofResult(unittest.TestCase):
260
def testDefaultInit(self):
262
self.assertEqual(result.warning_count, 0)
263
self.assertEqual(result.status, 0)
264
self.assertEqual(result.version_40, False)
267
def testDefaultInit40(self):
268
result = EofResult(version_40=True)
269
self.assertEqual(result.version_40, True)
272
def testKeywordInit(self):
273
result = EofResult(warning_count=3, status=5, version_40=False)
274
self.assertEqual(result.warning_count, 3)
275
self.assertEqual(result.status, 5)
276
self.assertEqual(result.version_40, False)
279
def testUnpackInit(self):
281
data += struct.pack('<HH', 3, 5)
283
result = EofResult(data)
284
self.assertEqual(result.warning_count, 3)
285
self.assertEqual(result.status, 5)
286
self.assertEqual(result.version_40, False)
289
def testUnpackInit40(self):
290
result = EofResult(chr(254), version_40=True)
291
self.assertEqual(result.version_40, True)
294
class CountResult(object):
295
'''This class represents an count result packet sent from the server.'''
297
def __init__(self, packed=None, count=0):
301
self.count = ord(packed[0])
302
if self.count == 0 or self.count > 253:
303
raise BadFieldCount('Expected 1-253, received ' + str(ord(packed[0])))
306
return '''CountResult
310
class TestCountResult(unittest.TestCase):
312
def testDefaultInit(self):
313
result = CountResult()
314
self.assertEqual(result.count, 0)
317
def testKeywordInit(self):
318
result = CountResult(count=3)
319
self.assertEqual(result.count, 3)
322
def testUnpackInit(self):
323
result = CountResult("\x03")
324
self.assertEqual(result.count, 3)
327
def create_result(packed, version_40=False):
328
'''This function creates the appropriate result object instance depending on
330
count = ord(packed[0])
332
return OkResult(packed, version_40=version_40)
334
return EofResult(packed, version_40=version_40)
336
return ErrorResult(packed, version_40=version_40)
337
return CountResult(packed)
339
if __name__ == '__main__':