~ubuntu-branches/ubuntu/trusty/drizzle/trusty

« back to all changes in this revision

Viewing changes to plugin/mysql_protocol/prototest/prototest/mysql/result.py

  • Committer: Bazaar Package Importer
  • Author(s): Monty Taylor
  • Date: 2010-10-02 14:17:48 UTC
  • mfrom: (1.1.1 upstream)
  • mto: (2.1.17 sid)
  • mto: This revision was merged to the branch mainline in revision 3.
  • Revision ID: james.westby@ubuntu.com-20101002141748-m6vbfbfjhrw1153e
Tags: 2010.09.1802-1
* New upstream release.
* Removed pid-file argument hack.
* Updated GPL-2 address to be new address.
* Directly copy in drizzledump.1 since debian doesn't have sphinx 1.0 yet.
* Link to jquery from libjs-jquery. Add it as a depend.
* Add drizzled.8 symlink to the install files.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
#
 
3
# Drizzle Client & Protocol Library
 
4
 
5
# Copyright (C) 2008 Eric Day (eday@oddments.org)
 
6
# All rights reserved.
 
7
 
8
# Use and distribution licensed under the BSD license.  See
 
9
# the COPYING.BSD file in the root source directory for full text.
 
10
#
 
11
'''
 
12
MySQL Protocol Result Objects
 
13
'''
 
14
 
 
15
import struct
 
16
import unittest
 
17
 
 
18
class BadFieldCount(Exception):
 
19
  pass
 
20
 
 
21
class OkResult(object):
 
22
  '''This class represents an OK result packet sent from the server.'''
 
23
 
 
24
  def __init__(self, packed=None, affected_rows=0, insert_id=0, status=0,
 
25
               warning_count=0, message='', version_40=False):
 
26
    if packed is None:
 
27
      self.affected_rows = affected_rows
 
28
      self.insert_id = insert_id
 
29
      self.status = status
 
30
      self.message = message
 
31
      self.version_40 = version_40
 
32
      if version_40 is False:
 
33
        self.warning_count = warning_count
 
34
    else:
 
35
      self.version_40 = version_40
 
36
      if ord(packed[0]) != 0:
 
37
        raise BadFieldCount('Expected 0, received ' + str(ord(packed[0])))
 
38
      self.affected_rows = ord(packed[1])
 
39
      self.insert_id = ord(packed[2])
 
40
      if version_40 is True:
 
41
        if len(packed) == 3:
 
42
          self.status = 0
 
43
          self.message = ''
 
44
        else:
 
45
          data = struct.unpack('<H', packed[3:5])
 
46
          self.status = data[0]
 
47
          self.message = packed[5:]
 
48
      else:
 
49
        data = struct.unpack('<HH', packed[3:7])
 
50
        self.status = data[0]
 
51
        self.warning_count = data[1]
 
52
        self.message = packed[7:]
 
53
 
 
54
  def __str__(self):
 
55
    if self.version_40 is True:
 
56
      return '''OkResult
 
57
  affected_rows = %s
 
58
  insert_id = %s
 
59
  status = %s
 
60
  message = %s
 
61
  version_40 = %s
 
62
''' % (self.affected_rows, self.insert_id, self.status, self.message,
 
63
       self.version_40)
 
64
    else:
 
65
      return '''OkResult
 
66
  affected_rows = %s
 
67
  insert_id = %s
 
68
  status = %s
 
69
  warning_count = %s
 
70
  message = %s
 
71
  version_40 = %s
 
72
''' % (self.affected_rows, self.insert_id, self.status, self.warning_count,
 
73
       self.message, self.version_40)
 
74
 
 
75
class TestOkResult(unittest.TestCase):
 
76
 
 
77
  def testDefaultInit(self):
 
78
    result = OkResult()
 
79
    self.assertEqual(result.affected_rows, 0)
 
80
    self.assertEqual(result.insert_id, 0)
 
81
    self.assertEqual(result.status, 0)
 
82
    self.assertEqual(result.warning_count, 0)
 
83
    self.assertEqual(result.message, '')
 
84
    self.assertEqual(result.version_40, False)
 
85
    result.__str__()
 
86
 
 
87
  def testDefaultInit40(self):
 
88
    result = OkResult(version_40=True)
 
89
    self.assertEqual(result.affected_rows, 0)
 
90
    self.assertEqual(result.insert_id, 0)
 
91
    self.assertEqual(result.status, 0)
 
92
    self.assertEqual(result.message, '')
 
93
    self.assertEqual(result.version_40, True)
 
94
    result.__str__()
 
95
 
 
96
  def testKeywordInit(self):
 
97
    result = OkResult(affected_rows=3, insert_id=5, status=2,
 
98
                      warning_count=7, message='test', version_40=False)
 
99
    self.assertEqual(result.affected_rows, 3)
 
100
    self.assertEqual(result.insert_id, 5)
 
101
    self.assertEqual(result.status, 2)
 
102
    self.assertEqual(result.warning_count, 7)
 
103
    self.assertEqual(result.message, 'test')
 
104
    self.assertEqual(result.version_40, False)
 
105
 
 
106
  def testUnpackInit(self):
 
107
    data = struct.pack('BBB', 0, 3, 5)
 
108
    data += struct.pack('<HH', 2, 7)
 
109
    data += 'test'
 
110
 
 
111
    result = OkResult(data)
 
112
    self.assertEqual(result.affected_rows, 3)
 
113
    self.assertEqual(result.insert_id, 5)
 
114
    self.assertEqual(result.status, 2)
 
115
    self.assertEqual(result.warning_count, 7)
 
116
    self.assertEqual(result.message, 'test')
 
117
    self.assertEqual(result.version_40, False)
 
118
    result.__str__()
 
119
 
 
120
  def testUnpackInit40(self):
 
121
    data = struct.pack('BBB', 0, 3, 5)
 
122
    data += struct.pack('<H', 2)
 
123
    data += 'test'
 
124
 
 
125
    result = OkResult(data, version_40=True)
 
126
    self.assertEqual(result.affected_rows, 3)
 
127
    self.assertEqual(result.insert_id, 5)
 
128
    self.assertEqual(result.status, 2)
 
129
    self.assertEqual(result.message, 'test')
 
130
    self.assertEqual(result.version_40, True)
 
131
    result.__str__()
 
132
 
 
133
class ErrorResult(object):
 
134
  '''This class represents an error result packet sent from the server.'''
 
135
 
 
136
  def __init__(self, packed=None, error_code=0, sqlstate_marker='#',
 
137
               sqlstate='XXXXX', message='', version_40=False):
 
138
    if packed is None:
 
139
      self.error_code = error_code
 
140
      self.message = message
 
141
      self.version_40 = version_40
 
142
      if version_40 is False:
 
143
        self.sqlstate_marker = sqlstate_marker
 
144
        self.sqlstate = sqlstate
 
145
    else:
 
146
      self.version_40 = version_40
 
147
      if ord(packed[0]) != 255:
 
148
        raise BadFieldCount('Expected 255, received ' + str(ord(packed[0])))
 
149
      data = struct.unpack('<H', packed[1:3])
 
150
      self.error_code = data[0]
 
151
      if version_40 is True:
 
152
        self.message = packed[3:]
 
153
      else:
 
154
        self.sqlstate_marker = packed[3]
 
155
        self.sqlstate = packed[4:9]
 
156
        self.message = packed[9:]
 
157
 
 
158
  def __str__(self):
 
159
    if self.version_40 is True:
 
160
      return '''ErrorResult
 
161
  error_code = %s
 
162
  message = %s
 
163
  version_40 = %s
 
164
''' % (self.error_code, self.message, self.version_40)
 
165
    else:
 
166
      return '''ErrorResult
 
167
  error_code = %s
 
168
  sqlstate_marker = %s
 
169
  sqlstate = %s
 
170
  message = %s
 
171
  version_40 = %s
 
172
''' % (self.error_code, self.sqlstate_marker, self.sqlstate, self.message,
 
173
       self.version_40)
 
174
 
 
175
class TestErrorResult(unittest.TestCase):
 
176
 
 
177
  def testDefaultInit(self):
 
178
    result = ErrorResult()
 
179
    self.assertEqual(result.error_code, 0)
 
180
    self.assertEqual(result.sqlstate_marker, '#')
 
181
    self.assertEqual(result.sqlstate, 'XXXXX')
 
182
    self.assertEqual(result.message, '')
 
183
    self.assertEqual(result.version_40, False)
 
184
    result.__str__()
 
185
 
 
186
  def testDefaultInit40(self):
 
187
    result = ErrorResult(version_40=True)
 
188
    self.assertEqual(result.error_code, 0)
 
189
    self.assertEqual(result.message, '')
 
190
    self.assertEqual(result.version_40, True)
 
191
    result.__str__()
 
192
 
 
193
  def testKeywordInit(self):
 
194
    result = ErrorResult(error_code=3, sqlstate_marker='@', sqlstate='ABCDE',
 
195
                         message='test', version_40=False)
 
196
    self.assertEqual(result.error_code, 3)
 
197
    self.assertEqual(result.sqlstate_marker, '@')
 
198
    self.assertEqual(result.sqlstate, 'ABCDE')
 
199
    self.assertEqual(result.message, 'test')
 
200
    self.assertEqual(result.version_40, False)
 
201
    result.__str__()
 
202
 
 
203
  def testUnpackInit(self):
 
204
    data = chr(255)
 
205
    data += struct.pack('<H', 1234)
 
206
    data += '#ABCDE'
 
207
    data += 'test'
 
208
 
 
209
    result = ErrorResult(data)
 
210
    self.assertEqual(result.error_code, 1234)
 
211
    self.assertEqual(result.sqlstate_marker, '#')
 
212
    self.assertEqual(result.sqlstate, 'ABCDE')
 
213
    self.assertEqual(result.message, 'test')
 
214
    self.assertEqual(result.version_40, False)
 
215
    result.__str__()
 
216
 
 
217
  def testUnpackInit40(self):
 
218
    data = chr(255)
 
219
    data += struct.pack('<H', 1234)
 
220
    data += 'test'
 
221
 
 
222
    result = ErrorResult(data, version_40=True)
 
223
    self.assertEqual(result.error_code, 1234)
 
224
    self.assertEqual(result.message, 'test')
 
225
    self.assertEqual(result.version_40, True)
 
226
    result.__str__()
 
227
 
 
228
class EofResult(object):
 
229
  '''This class represents an EOF result packet sent from the server.'''
 
230
 
 
231
  def __init__(self, packed=None, warning_count=0, status=0, version_40=False):
 
232
    if packed is None:
 
233
      self.version_40 = version_40
 
234
      if self.version_40 is False:
 
235
        self.warning_count = warning_count
 
236
        self.status = status
 
237
    else:
 
238
      self.version_40 = version_40
 
239
      if ord(packed[0]) != 254:
 
240
        raise BadFieldCount('Expected 254, received ' + str(ord(packed[0])))
 
241
      if version_40 is False:
 
242
        data = struct.unpack('<HH', packed[1:])
 
243
        self.warning_count = data[0]
 
244
        self.status = data[1]
 
245
 
 
246
  def __str__(self):
 
247
    if self.version_40 is True:
 
248
      return '''EofResult
 
249
  version_40 = %s
 
250
''' % self.version_40
 
251
    else:
 
252
      return '''EofResult
 
253
  warning_count = %s
 
254
  status = %s
 
255
  version_40 = %s
 
256
''' % (self.warning_count, self.status, self.version_40)
 
257
 
 
258
class TestEofResult(unittest.TestCase):
 
259
 
 
260
  def testDefaultInit(self):
 
261
    result = EofResult()
 
262
    self.assertEqual(result.warning_count, 0)
 
263
    self.assertEqual(result.status, 0)
 
264
    self.assertEqual(result.version_40, False)
 
265
    result.__str__()
 
266
 
 
267
  def testDefaultInit40(self):
 
268
    result = EofResult(version_40=True)
 
269
    self.assertEqual(result.version_40, True)
 
270
    result.__str__()
 
271
 
 
272
  def testKeywordInit(self):
 
273
    result = EofResult(warning_count=3, status=5, version_40=False)
 
274
    self.assertEqual(result.warning_count, 3)
 
275
    self.assertEqual(result.status, 5)
 
276
    self.assertEqual(result.version_40, False)
 
277
    result.__str__()
 
278
 
 
279
  def testUnpackInit(self):
 
280
    data = chr(254)
 
281
    data += struct.pack('<HH', 3, 5)
 
282
 
 
283
    result = EofResult(data)
 
284
    self.assertEqual(result.warning_count, 3)
 
285
    self.assertEqual(result.status, 5)
 
286
    self.assertEqual(result.version_40, False)
 
287
    result.__str__()
 
288
 
 
289
  def testUnpackInit40(self):
 
290
    result = EofResult(chr(254), version_40=True)
 
291
    self.assertEqual(result.version_40, True)
 
292
    result.__str__()
 
293
 
 
294
class CountResult(object):
 
295
  '''This class represents an count result packet sent from the server.'''
 
296
 
 
297
  def __init__(self, packed=None, count=0):
 
298
    if packed is None:
 
299
      self.count = count
 
300
    else:
 
301
      self.count = ord(packed[0])
 
302
      if self.count == 0 or self.count > 253:
 
303
        raise BadFieldCount('Expected 1-253, received ' + str(ord(packed[0])))
 
304
 
 
305
  def __str__(self):
 
306
    return '''CountResult
 
307
  count = %s
 
308
''' % self.count
 
309
 
 
310
class TestCountResult(unittest.TestCase):
 
311
 
 
312
  def testDefaultInit(self):
 
313
    result = CountResult()
 
314
    self.assertEqual(result.count, 0)
 
315
    result.__str__()
 
316
 
 
317
  def testKeywordInit(self):
 
318
    result = CountResult(count=3)
 
319
    self.assertEqual(result.count, 3)
 
320
    result.__str__()
 
321
 
 
322
  def testUnpackInit(self):
 
323
    result = CountResult("\x03")
 
324
    self.assertEqual(result.count, 3)
 
325
    result.__str__()
 
326
 
 
327
def create_result(packed, version_40=False):
 
328
  '''This function creates the appropriate result object instance depending on
 
329
     first byte.'''
 
330
  count = ord(packed[0])
 
331
  if count == 0:
 
332
    return OkResult(packed, version_40=version_40)
 
333
  if count == 254:
 
334
    return EofResult(packed, version_40=version_40)
 
335
  if count == 255:
 
336
    return ErrorResult(packed, version_40=version_40)
 
337
  return CountResult(packed)
 
338
 
 
339
if __name__ == '__main__':
 
340
  unittest.main()