~ubuntu-branches/debian/sid/pytds/sid

« back to all changes in this revision

Viewing changes to pytds/tds.py

  • Committer: Package Import Robot
  • Author(s): Christopher Hoskin
  • Date: 2017-03-11 20:12:33 UTC
  • Revision ID: package-import@ubuntu.com-20170311201233-voewiramv2n5i4uj
Tags: upstream-1.8.2
ImportĀ upstreamĀ versionĀ 1.8.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
import struct
 
2
import codecs
 
3
from contextlib import contextmanager
 
4
import logging
 
5
import socket
 
6
import sys
 
7
from datetime import datetime, date, time, timedelta
 
8
from decimal import Decimal, localcontext
 
9
from . import tz
 
10
import re
 
11
import uuid
 
12
import six
 
13
import types
 
14
from six.moves import reduce
 
15
from six.moves import xrange
 
16
try:
 
17
    import ssl
 
18
except:
 
19
    encryption_supported = False
 
20
else:
 
21
    encryption_supported = True
 
22
from .collate import ucs2_codec, Collation, lcid2charset, raw_collation
 
23
 
 
24
logger = logging.getLogger()
 
25
 
 
26
ENCRYPTION_ENABLED = False
 
27
 
 
28
 
 
29
# tds protocol versions
 
30
TDS70 = 0x70000000
 
31
TDS71 = 0x71000000
 
32
TDS71rev1 = 0x71000001
 
33
TDS72 = 0x72090002
 
34
TDS73A = 0x730A0003
 
35
TDS73 = TDS73A
 
36
TDS73B = 0x730B0003
 
37
TDS74 = 0x74000004
 
38
 
 
39
IS_TDS7_PLUS = lambda x: x.tds_version >= TDS70
 
40
IS_TDS71_PLUS = lambda x: x.tds_version >= TDS71
 
41
IS_TDS72_PLUS = lambda x: x.tds_version >= TDS72
 
42
IS_TDS73_PLUS = lambda x: x.tds_version >= TDS73A
 
43
 
 
44
# packet types
 
45
TDS_QUERY = 1
 
46
TDS_LOGIN = 2
 
47
TDS_RPC = 3
 
48
TDS_REPLY = 4
 
49
TDS_CANCEL = 6
 
50
TDS_BULK = 7
 
51
TDS7_TRANS = 14  # transaction management
 
52
TDS_NORMAL = 15
 
53
TDS7_LOGIN = 16
 
54
TDS7_AUTH = 17
 
55
TDS71_PRELOGIN = 18
 
56
 
 
57
# mssql login options flags
 
58
# option_flag1_values
 
59
TDS_BYTE_ORDER_X86 = 0
 
60
TDS_CHARSET_ASCII = 0
 
61
TDS_DUMPLOAD_ON = 0
 
62
TDS_FLOAT_IEEE_754 = 0
 
63
TDS_INIT_DB_WARN = 0
 
64
TDS_SET_LANG_OFF = 0
 
65
TDS_USE_DB_SILENT = 0
 
66
TDS_BYTE_ORDER_68000 = 0x01
 
67
TDS_CHARSET_EBDDIC = 0x02
 
68
TDS_FLOAT_VAX = 0x04
 
69
TDS_FLOAT_ND5000 = 0x08
 
70
TDS_DUMPLOAD_OFF = 0x10  # prevent BCP
 
71
TDS_USE_DB_NOTIFY = 0x20
 
72
TDS_INIT_DB_FATAL = 0x40
 
73
TDS_SET_LANG_ON = 0x80
 
74
 
 
75
#enum option_flag2_values {
 
76
TDS_INIT_LANG_WARN = 0
 
77
TDS_INTEGRATED_SECURTY_OFF = 0
 
78
TDS_ODBC_OFF = 0
 
79
TDS_USER_NORMAL = 0  # SQL Server login
 
80
TDS_INIT_LANG_REQUIRED = 0x01
 
81
TDS_ODBC_ON = 0x02
 
82
TDS_TRANSACTION_BOUNDARY71 = 0x04  # removed in TDS 7.2
 
83
TDS_CACHE_CONNECT71 = 0x08  # removed in TDS 7.2
 
84
TDS_USER_SERVER = 0x10  # reserved
 
85
TDS_USER_REMUSER = 0x20  # DQ login
 
86
TDS_USER_SQLREPL = 0x40  # replication login
 
87
TDS_INTEGRATED_SECURITY_ON = 0x80
 
88
 
 
89
#enum option_flag3_values TDS 7.3+
 
90
TDS_RESTRICTED_COLLATION = 0
 
91
TDS_CHANGE_PASSWORD = 0x01
 
92
TDS_SEND_YUKON_BINARY_XML = 0x02
 
93
TDS_REQUEST_USER_INSTANCE = 0x04
 
94
TDS_UNKNOWN_COLLATION_HANDLING = 0x08
 
95
TDS_ANY_COLLATION = 0x10
 
96
 
 
97
TDS5_PARAMFMT2_TOKEN = 32  # 0x20
 
98
TDS_LANGUAGE_TOKEN = 33  # 0x20    TDS 5.0 only
 
99
TDS_ORDERBY2_TOKEN = 34  # 0x22
 
100
TDS_ROWFMT2_TOKEN = 97  # 0x61    TDS 5.0 only
 
101
TDS_LOGOUT_TOKEN = 113  # 0x71    TDS 5.0 only?
 
102
TDS_RETURNSTATUS_TOKEN = 121  # 0x79
 
103
TDS_PROCID_TOKEN = 124  # 0x7C    TDS 4.2 only
 
104
TDS7_RESULT_TOKEN = 129  # 0x81    TDS 7.0 only
 
105
TDS7_COMPUTE_RESULT_TOKEN = 136  # 0x88    TDS 7.0 only
 
106
TDS_COLNAME_TOKEN = 160  # 0xA0    TDS 4.2 only
 
107
TDS_COLFMT_TOKEN = 161  # 0xA1    TDS 4.2 only
 
108
TDS_DYNAMIC2_TOKEN = 163  # 0xA3
 
109
TDS_TABNAME_TOKEN = 164  # 0xA4
 
110
TDS_COLINFO_TOKEN = 165  # 0xA5
 
111
TDS_OPTIONCMD_TOKEN = 166  # 0xA6
 
112
TDS_COMPUTE_NAMES_TOKEN = 167  # 0xA7
 
113
TDS_COMPUTE_RESULT_TOKEN = 168  # 0xA8
 
114
TDS_ORDERBY_TOKEN = 169  # 0xA9
 
115
TDS_ERROR_TOKEN = 170  # 0xAA
 
116
TDS_INFO_TOKEN = 171  # 0xAB
 
117
TDS_PARAM_TOKEN = 172  # 0xAC
 
118
TDS_LOGINACK_TOKEN = 173  # 0xAD
 
119
TDS_CONTROL_TOKEN = 174  # 0xAE
 
120
TDS_ROW_TOKEN = 209  # 0xD1
 
121
TDS_NBC_ROW_TOKEN = 210  # 0xD2    as of TDS 7.3.B
 
122
TDS_CMP_ROW_TOKEN = 211  # 0xD3
 
123
TDS5_PARAMS_TOKEN = 215  # 0xD7    TDS 5.0 only
 
124
TDS_CAPABILITY_TOKEN = 226  # 0xE2
 
125
TDS_ENVCHANGE_TOKEN = 227  # 0xE3
 
126
TDS_EED_TOKEN = 229  # 0xE5
 
127
TDS_DBRPC_TOKEN = 230  # 0xE6
 
128
TDS5_DYNAMIC_TOKEN = 231  # 0xE7    TDS 5.0 only
 
129
TDS5_PARAMFMT_TOKEN = 236  # 0xEC    TDS 5.0 only
 
130
TDS_AUTH_TOKEN = 237  # 0xED    TDS 7.0 only
 
131
TDS_RESULT_TOKEN = 238  # 0xEE
 
132
TDS_DONE_TOKEN = 253  # 0xFD
 
133
TDS_DONEPROC_TOKEN = 254  # 0xFE
 
134
TDS_DONEINPROC_TOKEN = 255  # 0xFF
 
135
 
 
136
# CURSOR support: TDS 5.0 only
 
137
TDS_CURCLOSE_TOKEN = 128  # 0x80    TDS 5.0 only
 
138
TDS_CURDELETE_TOKEN = 129  # 0x81    TDS 5.0 only
 
139
TDS_CURFETCH_TOKEN = 130  # 0x82    TDS 5.0 only
 
140
TDS_CURINFO_TOKEN = 131  # 0x83    TDS 5.0 only
 
141
TDS_CUROPEN_TOKEN = 132  # 0x84    TDS 5.0 only
 
142
TDS_CURDECLARE_TOKEN = 134  # 0x86    TDS 5.0 only
 
143
 
 
144
# environment type field
 
145
TDS_ENV_DATABASE = 1
 
146
TDS_ENV_LANG = 2
 
147
TDS_ENV_CHARSET = 3
 
148
TDS_ENV_PACKSIZE = 4
 
149
TDS_ENV_LCID = 5
 
150
TDS_ENV_SQLCOLLATION = 7
 
151
TDS_ENV_BEGINTRANS = 8
 
152
TDS_ENV_COMMITTRANS = 9
 
153
TDS_ENV_ROLLBACKTRANS = 10
 
154
TDS_ENV_ENLIST_DTC_TRANS = 11
 
155
TDS_ENV_DEFECT_TRANS = 12
 
156
TDS_ENV_DB_MIRRORING_PARTNER = 13
 
157
TDS_ENV_PROMOTE_TRANS = 15
 
158
TDS_ENV_TRANS_MANAGER_ADDR = 16
 
159
TDS_ENV_TRANS_ENDED = 17
 
160
TDS_ENV_RESET_COMPLETION_ACK = 18
 
161
TDS_ENV_INSTANCE_INFO = 19
 
162
TDS_ENV_ROUTING = 20
 
163
 
 
164
# Microsoft internal stored procedure id's
 
165
TDS_SP_CURSOR = 1
 
166
TDS_SP_CURSOROPEN = 2
 
167
TDS_SP_CURSORPREPARE = 3
 
168
TDS_SP_CURSOREXECUTE = 4
 
169
TDS_SP_CURSORPREPEXEC = 5
 
170
TDS_SP_CURSORUNPREPARE = 6
 
171
TDS_SP_CURSORFETCH = 7
 
172
TDS_SP_CURSOROPTION = 8
 
173
TDS_SP_CURSORCLOSE = 9
 
174
TDS_SP_EXECUTESQL = 10
 
175
TDS_SP_PREPARE = 11
 
176
TDS_SP_EXECUTE = 12
 
177
TDS_SP_PREPEXEC = 13
 
178
TDS_SP_PREPEXECRPC = 14
 
179
TDS_SP_UNPREPARE = 15
 
180
 
 
181
# Flags returned in TDS_DONE token
 
182
TDS_DONE_FINAL = 0
 
183
TDS_DONE_MORE_RESULTS = 0x01  # more results follow
 
184
TDS_DONE_ERROR = 0x02  # error occurred
 
185
TDS_DONE_INXACT = 0x04  # transaction in progress
 
186
TDS_DONE_PROC = 0x08  # results are from a stored procedure
 
187
TDS_DONE_COUNT = 0x10  # count field in packet is valid
 
188
TDS_DONE_CANCELLED = 0x20  # acknowledging an attention command (usually a cancel)
 
189
TDS_DONE_EVENT = 0x40  # part of an event notification.
 
190
TDS_DONE_SRVERROR = 0x100  # SQL server server error
 
191
 
 
192
 
 
193
SYBVOID = 31  # 0x1F
 
194
IMAGETYPE = SYBIMAGE = 34  # 0x22
 
195
TEXTTYPE = SYBTEXT = 35  # 0x23
 
196
SYBVARBINARY = 37  # 0x25
 
197
INTNTYPE = SYBINTN = 38  # 0x26
 
198
SYBVARCHAR = 39         # 0x27
 
199
BINARYTYPE = SYBBINARY = 45  # 0x2D
 
200
SYBCHAR = 47  # 0x2F
 
201
INT1TYPE = SYBINT1 = 48  # 0x30
 
202
BITTYPE = SYBBIT = 50  # 0x32
 
203
INT2TYPE = SYBINT2 = 52  # 0x34
 
204
INT4TYPE = SYBINT4 = 56  # 0x38
 
205
DATETIM4TYPE = SYBDATETIME4 = 58  # 0x3A
 
206
FLT4TYPE = SYBREAL = 59  # 0x3B
 
207
MONEYTYPE = SYBMONEY = 60  # 0x3C
 
208
DATETIMETYPE = SYBDATETIME = 61  # 0x3D
 
209
FLT8TYPE = SYBFLT8 = 62  # 0x3E
 
210
NTEXTTYPE = SYBNTEXT = 99  # 0x63
 
211
SYBNVARCHAR = 103  # 0x67
 
212
BITNTYPE = SYBBITN = 104  # 0x68
 
213
NUMERICNTYPE = SYBNUMERIC = 108  # 0x6C
 
214
DECIMALNTYPE = SYBDECIMAL = 106  # 0x6A
 
215
FLTNTYPE = SYBFLTN = 109  # 0x6D
 
216
MONEYNTYPE = SYBMONEYN = 110  # 0x6E
 
217
DATETIMNTYPE = SYBDATETIMN = 111  # 0x6F
 
218
MONEY4TYPE = SYBMONEY4 = 122  # 0x7A
 
219
 
 
220
INT8TYPE = SYBINT8 = 127  # 0x7F
 
221
BIGCHARTYPE = XSYBCHAR = 175  # 0xAF
 
222
BIGVARCHRTYPE = XSYBVARCHAR = 167  # 0xA7
 
223
NVARCHARTYPE = XSYBNVARCHAR = 231  # 0xE7
 
224
NCHARTYPE = XSYBNCHAR = 239  # 0xEF
 
225
BIGVARBINTYPE = XSYBVARBINARY = 165  # 0xA5
 
226
BIGBINARYTYPE = XSYBBINARY = 173  # 0xAD
 
227
GUIDTYPE = SYBUNIQUE = 36  # 0x24
 
228
SSVARIANTTYPE = SYBVARIANT = 98  # 0x62
 
229
UDTTYPE = SYBMSUDT = 240  # 0xF0
 
230
XMLTYPE = SYBMSXML = 241  # 0xF1
 
231
DATENTYPE = SYBMSDATE = 40  # 0x28
 
232
TIMENTYPE = SYBMSTIME = 41  # 0x29
 
233
DATETIME2NTYPE = SYBMSDATETIME2 = 42  # 0x2a
 
234
DATETIMEOFFSETNTYPE = SYBMSDATETIMEOFFSET = 43  # 0x2b
 
235
 
 
236
#
 
237
# Sybase only types
 
238
#
 
239
SYBLONGBINARY = 225  # 0xE1
 
240
SYBUINT1 = 64  # 0x40
 
241
SYBUINT2 = 65  # 0x41
 
242
SYBUINT4 = 66  # 0x42
 
243
SYBUINT8 = 67  # 0x43
 
244
SYBBLOB = 36  # 0x24
 
245
SYBBOUNDARY = 104  # 0x68
 
246
SYBDATE = 49  # 0x31
 
247
SYBDATEN = 123  # 0x7B
 
248
SYB5INT8 = 191  # 0xBF
 
249
SYBINTERVAL = 46  # 0x2E
 
250
SYBLONGCHAR = 175  # 0xAF
 
251
SYBSENSITIVITY = 103  # 0x67
 
252
SYBSINT1 = 176  # 0xB0
 
253
SYBTIME = 51  # 0x33
 
254
SYBTIMEN = 147  # 0x93
 
255
SYBUINTN = 68  # 0x44
 
256
SYBUNITEXT = 174  # 0xAE
 
257
SYBXML = 163  # 0xA3
 
258
 
 
259
TDS_UT_TIMESTAMP = 80
 
260
 
 
261
# compute operator
 
262
SYBAOPCNT = 0x4b
 
263
SYBAOPCNTU = 0x4c
 
264
SYBAOPSUM = 0x4d
 
265
SYBAOPSUMU = 0x4e
 
266
SYBAOPAVG = 0x4f
 
267
SYBAOPAVGU = 0x50
 
268
SYBAOPMIN = 0x51
 
269
SYBAOPMAX = 0x52
 
270
 
 
271
# mssql2k compute operator
 
272
SYBAOPCNT_BIG = 0x09
 
273
SYBAOPSTDEV = 0x30
 
274
SYBAOPSTDEVP = 0x31
 
275
SYBAOPVAR = 0x32
 
276
SYBAOPVARP = 0x33
 
277
SYBAOPCHECKSUM_AGG = 0x72
 
278
 
 
279
# param flags
 
280
fByRefValue = 1
 
281
fDefaultValue = 2
 
282
 
 
283
TDS_IDLE = 0
 
284
TDS_QUERYING = 1
 
285
TDS_PENDING = 2
 
286
TDS_READING = 3
 
287
TDS_DEAD = 4
 
288
state_names = ['IDLE', 'QUERYING', 'PENDING', 'READING', 'DEAD']
 
289
 
 
290
TDS_ENCRYPTION_OFF = 0
 
291
TDS_ENCRYPTION_REQUEST = 1
 
292
TDS_ENCRYPTION_REQUIRE = 2
 
293
 
 
294
USE_CORK = hasattr(socket, 'TCP_CORK')
 
295
 
 
296
TDS_NO_COUNT = -1
 
297
 
 
298
_utc = tz.utc
 
299
 
 
300
_header = struct.Struct('>BBHHBx')
 
301
_byte = struct.Struct('B')
 
302
_smallint_le = struct.Struct('<h')
 
303
_smallint_be = struct.Struct('>h')
 
304
_usmallint_le = struct.Struct('<H')
 
305
_usmallint_be = struct.Struct('>H')
 
306
_int_le = struct.Struct('<l')
 
307
_int_be = struct.Struct('>l')
 
308
_uint_le = struct.Struct('<L')
 
309
_uint_be = struct.Struct('>L')
 
310
_int8_le = struct.Struct('<q')
 
311
_int8_be = struct.Struct('>q')
 
312
_uint8_le = struct.Struct('<Q')
 
313
_uint8_be = struct.Struct('>Q')
 
314
_flt8_struct = struct.Struct('d')
 
315
_flt4_struct = struct.Struct('f')
 
316
 
 
317
 
 
318
PLP_MARKER = 0xffff
 
319
PLP_NULL = 0xffffffffffffffff
 
320
PLP_UNKNOWN = 0xfffffffffffffffe
 
321
 
 
322
 
 
323
class PlpReader(object):
 
324
    """ Partially length prefixed reader
 
325
 
 
326
    Spec: http://msdn.microsoft.com/en-us/library/dd340469.aspx
 
327
    """
 
328
    def __init__(self, r):
 
329
        """
 
330
        :param r: An instance of :class:`_TdsReader`
 
331
        """
 
332
        self._rdr = r
 
333
        size = r.get_uint8()
 
334
        self._size = size
 
335
 
 
336
    def is_null(self):
 
337
        """
 
338
        :return: True if stored value is NULL
 
339
        """
 
340
        return self._size == PLP_NULL
 
341
 
 
342
    def is_unknown_len(self):
 
343
        """
 
344
        :return: True if total size is unknown upfront
 
345
        """
 
346
        return self._size == PLP_UNKNOWN
 
347
 
 
348
    def size(self):
 
349
        """
 
350
        :return: Total size in bytes if is_uknown_len and is_null are both False
 
351
        """
 
352
        return self._size
 
353
 
 
354
    def chunks(self):
 
355
        """ Generates chunks from stream, each chunk is an instace of bytes.
 
356
        """
 
357
        if self.is_null():
 
358
            return
 
359
        total = 0
 
360
        while True:
 
361
            chunk_len = self._rdr.get_uint()
 
362
            if chunk_len == 0:
 
363
                if not self.is_unknown_len() and total != self._size:
 
364
                    msg = "PLP actual length (%d) doesn't match reported length (%d)" % (total, self._size)
 
365
                    self._rdr.session.bad_stream(msg)
 
366
 
 
367
                return
 
368
 
 
369
            total += chunk_len
 
370
            left = chunk_len
 
371
            while left:
 
372
                buf = self._rdr.read(left)
 
373
                yield buf
 
374
                left -= len(buf)
 
375
 
 
376
 
 
377
def iterdecode(iterable, codec):
 
378
    """ Uses an incremental decoder to decode each chunk in iterable.
 
379
    This function is a generator.
 
380
 
 
381
    :param codec: An instance of codec
 
382
    """
 
383
    decoder = codec.incrementaldecoder()
 
384
    for chunk in iterable:
 
385
        yield decoder.decode(chunk)
 
386
    yield decoder.decode(b'', True)
 
387
 
 
388
 
 
389
class SimpleLoadBalancer(object):
 
390
    def __init__(self, hosts):
 
391
        self._hosts = hosts
 
392
 
 
393
    def choose(self):
 
394
        for host in self._hosts:
 
395
            yield host
 
396
 
 
397
 
 
398
def force_unicode(s):
 
399
    if isinstance(s, bytes):
 
400
        try:
 
401
            return s.decode('utf8')
 
402
        except UnicodeDecodeError as e:
 
403
            raise DatabaseError(e)
 
404
    else:
 
405
        return s
 
406
 
 
407
 
 
408
def tds_quote_id(id):
 
409
    """ Quote an identifier
 
410
 
 
411
    :param id: id to quote
 
412
    :returns: Quoted identifier
 
413
    """
 
414
    return '[{0}]'.format(id.replace(']', ']]'))
 
415
 
 
416
 
 
417
def tds7_crypt_pass(password):
 
418
    """ Mangle password according to tds rules
 
419
 
 
420
    :param password: Password str
 
421
    :returns: Byte-string with encoded password
 
422
    """
 
423
    encoded = bytearray(ucs2_codec.encode(password)[0])
 
424
    for i, ch in enumerate(encoded):
 
425
        encoded[i] = ((ch << 4) & 0xff | (ch >> 4)) ^ 0xA5
 
426
    return encoded
 
427
 
 
428
 
 
429
def total_seconds(td):
 
430
    """ Total number of seconds in timedelta object
 
431
 
 
432
    Python 2.6 doesn't have total_seconds method, this function
 
433
    provides a backport
 
434
    """
 
435
    return td.days * 24 * 60 * 60 + td.seconds
 
436
 
 
437
 
 
438
# store a tuple of programming error codes
 
439
prog_errors = (
 
440
    102,    # syntax error
 
441
    207,    # invalid column name
 
442
    208,    # invalid object name
 
443
    2812,   # unknown procedure
 
444
    4104    # multi-part identifier could not be bound
 
445
)
 
446
 
 
447
# store a tuple of integrity error codes
 
448
integrity_errors = (
 
449
    515,    # NULL insert
 
450
    547,    # FK related
 
451
    2601,   # violate unique index
 
452
    2627,   # violate UNIQUE KEY constraint
 
453
)
 
454
 
 
455
 
 
456
if sys.version_info[0] >= 3:
 
457
    exc_base_class = Exception
 
458
 
 
459
    def _ord(val):
 
460
        return val
 
461
 
 
462
else:
 
463
    exc_base_class = StandardError
 
464
 
 
465
    def _ord(val):
 
466
        return ord(val)
 
467
 
 
468
 
 
469
def _decode_num(buf):
 
470
    """ Decodes little-endian integer from buffer
 
471
 
 
472
    Buffer can be of any size
 
473
    """
 
474
    return reduce(lambda acc, val: acc * 256 + _ord(val), reversed(buf), 0)
 
475
 
 
476
 
 
477
# exception hierarchy
 
478
class Warning(exc_base_class):
 
479
    pass
 
480
 
 
481
 
 
482
class Error(exc_base_class):
 
483
    pass
 
484
 
 
485
 
 
486
TimeoutError = socket.timeout
 
487
 
 
488
 
 
489
class InterfaceError(Error):
 
490
    pass
 
491
 
 
492
 
 
493
class DatabaseError(Error):
 
494
    @property
 
495
    def message(self):
 
496
        if self.procname:
 
497
            return 'SQL Server message %d, severity %d, state %d, ' \
 
498
                'procedure %s, line %d:\n%s' % (self.number,
 
499
                self.severity, self.state, self.procname,
 
500
                self.line, self.text)
 
501
        else:
 
502
            return 'SQL Server message %d, severity %d, state %d, ' \
 
503
                'line %d:\n%s' % (self.number, self.severity,
 
504
                self.state, self.line, self.text)
 
505
 
 
506
 
 
507
class ClosedConnectionError(InterfaceError):
 
508
    def __init__(self):
 
509
        super(ClosedConnectionError, self).__init__('Server closed connection')
 
510
 
 
511
 
 
512
class DataError(Error):
 
513
    pass
 
514
 
 
515
 
 
516
class OperationalError(DatabaseError):
 
517
    pass
 
518
 
 
519
 
 
520
class LoginError(OperationalError):
 
521
    pass
 
522
 
 
523
 
 
524
class IntegrityError(DatabaseError):
 
525
    pass
 
526
 
 
527
 
 
528
class InternalError(DatabaseError):
 
529
    pass
 
530
 
 
531
 
 
532
class ProgrammingError(DatabaseError):
 
533
    pass
 
534
 
 
535
 
 
536
class NotSupportedError(DatabaseError):
 
537
    pass
 
538
 
 
539
 
 
540
#############################
 
541
## DB-API type definitions ##
 
542
#############################
 
543
class DBAPITypeObject:
 
544
    def __init__(self, *values):
 
545
        self.values = set(values)
 
546
 
 
547
    def __eq__(self, other):
 
548
        return other in self.values
 
549
 
 
550
    def __cmp__(self, other):
 
551
        if other in self.values:
 
552
            return 0
 
553
        if other < self.values:
 
554
            return 1
 
555
        else:
 
556
            return -1
 
557
 
 
558
# standard dbapi type objects
 
559
STRING = DBAPITypeObject(SYBVARCHAR, SYBCHAR, SYBTEXT,
 
560
                         XSYBNVARCHAR, XSYBNCHAR, SYBNTEXT,
 
561
                         XSYBVARCHAR, XSYBCHAR, SYBMSXML)
 
562
BINARY = DBAPITypeObject(SYBIMAGE, SYBBINARY, SYBVARBINARY, XSYBVARBINARY, XSYBBINARY)
 
563
NUMBER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN,
 
564
                         SYBREAL, SYBFLT8, SYBFLTN)
 
565
DATETIME = DBAPITypeObject(SYBDATETIME, SYBDATETIME4, SYBDATETIMN)
 
566
DECIMAL = DBAPITypeObject(SYBMONEY, SYBMONEY4, SYBMONEYN, SYBNUMERIC,
 
567
                          SYBDECIMAL)
 
568
ROWID = DBAPITypeObject()
 
569
 
 
570
# non-standard, but useful type objects
 
571
INTEGER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN)
 
572
REAL = DBAPITypeObject(SYBREAL, SYBFLT8, SYBFLTN)
 
573
XML = DBAPITypeObject(SYBMSXML)
 
574
 
 
575
 
 
576
# stored procedure output parameter
 
577
class output(object):
 
578
    @property
 
579
    def type(self):
 
580
        """
 
581
        This is either the sql type declaration or python type instance
 
582
        of the parameter.
 
583
        """
 
584
        return self._type
 
585
 
 
586
    @property
 
587
    def value(self):
 
588
        """
 
589
        This is the value of the parameter.
 
590
        """
 
591
        return self._value
 
592
 
 
593
    def __init__(self, value=None, param_type=None):
 
594
        """ Creates procedure output parameter.
 
595
        
 
596
        :param param_type: either sql type declaration or python type
 
597
        :param value: value to pass into procedure
 
598
        """
 
599
        if param_type is None:
 
600
            if value is None or value is default:
 
601
                raise ValueError('Output type cannot be autodetected')
 
602
        elif isinstance(param_type, type) and value is not None:
 
603
            if value is not default and not isinstance(value, param_type):
 
604
                raise ValueError('value should match param_type', value, param_type)
 
605
        self._type = param_type
 
606
        self._value = value
 
607
 
 
608
 
 
609
class Binary(bytes):
 
610
    def __repr__(self):
 
611
        return 'Binary({0})'.format(super(Binary, self).__repr__())
 
612
 
 
613
 
 
614
class _Default(object):
 
615
    pass
 
616
 
 
617
default = _Default()
 
618
 
 
619
 
 
620
class InternalProc(object):
 
621
    def __init__(self, proc_id, name):
 
622
        self.proc_id = proc_id
 
623
        self.name = name
 
624
 
 
625
    def __unicode__(self):
 
626
        return self.name
 
627
 
 
628
SP_EXECUTESQL = InternalProc(TDS_SP_EXECUTESQL, 'sp_executesql')
 
629
 
 
630
 
 
631
class _TdsEnv:
 
632
    pass
 
633
 
 
634
 
 
635
def skipall(stm, size):
 
636
    """ Skips exactly size bytes in stm
 
637
 
 
638
    If EOF is reached before size bytes are skipped
 
639
    will raise :class:`ClosedConnectionError`
 
640
 
 
641
    :param stm: Stream to skip bytes in, should have read method
 
642
                this read method can return less than requested
 
643
                number of bytes.
 
644
    :param size: Number of bytes to skip.
 
645
    """
 
646
    res = stm.read(size)
 
647
    if len(res) == size:
 
648
        return
 
649
    elif len(res) == 0:
 
650
        raise ClosedConnectionError()
 
651
    left = size - len(res)
 
652
    while left:
 
653
        buf = stm.read(left)
 
654
        if len(buf) == 0:
 
655
            raise ClosedConnectionError()
 
656
        left -= len(buf)
 
657
 
 
658
 
 
659
def read_chunks(stm, size):
 
660
    """ Reads exactly size bytes from stm and produces chunks
 
661
 
 
662
    May call stm.read multiple times until required
 
663
    number of bytes is read.
 
664
    If EOF is reached before size bytes are read
 
665
    will raise :class:`ClosedConnectionError`
 
666
 
 
667
    :param stm: Stream to read bytes from, should have read method,
 
668
                this read method can return less than requested
 
669
                number of bytes.
 
670
    :param size: Number of bytes to read.
 
671
    """
 
672
    if size == 0:
 
673
        yield b''
 
674
        return
 
675
 
 
676
    res = stm.read(size)
 
677
    if len(res) == 0:
 
678
        raise ClosedConnectionError()
 
679
    yield res
 
680
    left = size - len(res)
 
681
    while left:
 
682
        buf = stm.read(left)
 
683
        if len(buf) == 0:
 
684
            raise ClosedConnectionError()
 
685
        yield buf
 
686
        left -= len(buf)
 
687
 
 
688
 
 
689
def readall(stm, size):
 
690
    """ Reads exactly size bytes from stm
 
691
 
 
692
    May call stm.read multiple times until required
 
693
    number of bytes read.
 
694
    If EOF is reached before size bytes are read
 
695
    will raise :class:`ClosedConnectionError`
 
696
 
 
697
    :param stm: Stream to read bytes from, should have read method
 
698
                this read method can return less than requested
 
699
                number of bytes.
 
700
    :param size: Number of bytes to read.
 
701
    :returns: Bytes buffer of exactly given size.
 
702
    """
 
703
    return b''.join(read_chunks(stm, size))
 
704
 
 
705
 
 
706
def readall_fast(stm, size):
 
707
    """
 
708
    Slightly faster version of readall, it reads no more than two chunks.
 
709
    Meaning that it can only be used to read small data that doesn't span
 
710
    more that two packets.
 
711
 
 
712
    :param stm: Stream to read from, should have read method.
 
713
    :param size: Number of bytes to read.
 
714
    :return:
 
715
    """
 
716
    buf, offset = stm.read_fast(size)
 
717
    if len(buf) - offset < size:
 
718
        # slow case
 
719
        buf = buf[offset:]
 
720
        buf += stm.read(size - len(buf))
 
721
        return buf, 0
 
722
    return buf, offset
 
723
 
 
724
 
 
725
class _TdsReader(object):
 
726
    """ TDS stream reader
 
727
 
 
728
    Provides stream-like interface for TDS packeted stream.
 
729
    Also provides convinience methods to decode primitive data like
 
730
    different kinds of integers etc.
 
731
    """
 
732
    def __init__(self, session):
 
733
        self._buf = ''
 
734
        self._pos = 0  # position in the buffer
 
735
        self._have = 0  # number of bytes read from packet
 
736
        self._size = 0  # size of current packet
 
737
        self._session = session
 
738
        self._transport = session._transport
 
739
        self._type = None
 
740
        self._status = None
 
741
 
 
742
    @property
 
743
    def session(self):
 
744
        """ Link to :class:`_TdsSession` object
 
745
        """
 
746
        return self._session
 
747
 
 
748
    @property
 
749
    def packet_type(self):
 
750
        """ Type of current packet
 
751
 
 
752
        Possible values are TDS_QUERY, TDS_LOGIN, etc.
 
753
        """
 
754
        return self._type
 
755
 
 
756
    def read_fast(self, size):
 
757
        """ Faster version of read
 
758
 
 
759
        Instead of returning sliced buffer it returns reference to internal
 
760
        buffer and the offset to this buffer.
 
761
 
 
762
        :param size: Number of bytes to read
 
763
        :returns: Tuple of bytes buffer, and offset in this buffer
 
764
        """
 
765
        if self._pos >= len(self._buf):
 
766
            if self._have >= self._size:
 
767
                self._read_packet()
 
768
            else:
 
769
                self._buf = self._transport.read(self._size - self._have)
 
770
                self._pos = 0
 
771
                self._have += len(self._buf)
 
772
        offset = self._pos
 
773
        self._pos += size
 
774
        return self._buf, offset
 
775
 
 
776
    def unpack(self, struct):
 
777
        """ Unpacks given structure from stream
 
778
 
 
779
        :param struct: A struct.Struct instance
 
780
        :returns: Result of unpacking
 
781
        """
 
782
        buf, offset = readall_fast(self, struct.size)
 
783
        return struct.unpack_from(buf, offset)
 
784
 
 
785
    def get_byte(self):
 
786
        """ Reads one byte from stream """
 
787
        return self.unpack(_byte)[0]
 
788
 
 
789
    def get_smallint(self):
 
790
        """ Reads 16bit signed integer from the stream """
 
791
        return self.unpack(_smallint_le)[0]
 
792
 
 
793
    def get_usmallint(self):
 
794
        """ Reads 16bit unsigned integer from the stream """
 
795
        return self.unpack(_usmallint_le)[0]
 
796
 
 
797
    def get_int(self):
 
798
        """ Reads 32bit signed integer from the stream """
 
799
        return self.unpack(_int_le)[0]
 
800
 
 
801
    def get_uint(self):
 
802
        """ Reads 32bit unsigned integer from the stream """
 
803
        return self.unpack(_uint_le)[0]
 
804
 
 
805
    def get_uint_be(self):
 
806
        """ Reads 32bit unsigned big-endian integer from the stream """
 
807
        return self.unpack(_uint_be)[0]
 
808
 
 
809
    def get_uint8(self):
 
810
        """ Reads 64bit unsigned integer from the stream """
 
811
        return self.unpack(_uint8_le)[0]
 
812
 
 
813
    def get_int8(self):
 
814
        """ Reads 64bit signed integer from the stream """
 
815
        return self.unpack(_int8_le)[0]
 
816
 
 
817
    def read_ucs2(self, num_chars):
 
818
        """ Reads num_chars UCS2 string from the stream """
 
819
        buf = readall(self, num_chars * 2)
 
820
        return ucs2_codec.decode(buf)[0]
 
821
 
 
822
    def read_str(self, size, codec):
 
823
        """ Reads byte string from the stream and decodes it
 
824
 
 
825
        :param size: Size of string in bytes
 
826
        :param codec: Instance of codec to decode string
 
827
        :returns: Unicode string
 
828
        """
 
829
        return codec.decode(readall(self, size))[0]
 
830
 
 
831
    def get_collation(self):
 
832
        """ Reads :class:`Collation` object from stream """
 
833
        buf = readall(self, Collation.wire_size)
 
834
        return Collation.unpack(buf)
 
835
 
 
836
    def unget_byte(self):
 
837
        """ Returns one last read byte to stream
 
838
 
 
839
        Can only be called once per read byte.
 
840
        """
 
841
        # this is a one trick pony...don't call it twice
 
842
        assert self._pos > 0
 
843
        self._pos -= 1
 
844
 
 
845
    def peek(self):
 
846
        """ Returns next byte from stream without consuming it
 
847
        """
 
848
        res = self.get_byte()
 
849
        self.unget_byte()
 
850
        return res
 
851
 
 
852
    def read(self, size):
 
853
        """ Reads size bytes from buffer
 
854
 
 
855
        May return fewer bytes than requested
 
856
        :param size: Number of bytes to read
 
857
        :returns: Bytes buffer, possibly shorter than requested,
 
858
                  returns empty buffer in case of EOF
 
859
        """
 
860
        buf, offset = self.read_fast(size)
 
861
        return buf[offset:offset + size]
 
862
 
 
863
    def _read_packet(self):
 
864
        """ Reads next TDS packet from the underlying transport
 
865
 
 
866
        If timeout is happened during reading of packet's header will
 
867
        cancel current request.
 
868
        Can only be called when transport's read pointer is at the begining
 
869
        of the packet.
 
870
        """
 
871
        try:
 
872
            header = readall(self._transport, _header.size)
 
873
        except TimeoutError:
 
874
            self._session._put_cancel()
 
875
            raise
 
876
        self._pos = 0
 
877
        self._type, self._status, self._size, self._session._spid, _ = _header.unpack(header)
 
878
        self._have = _header.size
 
879
        assert self._size > self._have, 'Empty packet doesn make any sense'
 
880
        self._buf = self._transport.read(self._size - self._have)
 
881
        self._have += len(self._buf)
 
882
 
 
883
    def read_whole_packet(self):
 
884
        """ Reads single packet and returns bytes payload of the packet
 
885
 
 
886
        Can only be called when transport's read pointer is at the beginning
 
887
        of the packet.
 
888
        """
 
889
        self._read_packet()
 
890
        return readall(self, self._size - _header.size)
 
891
 
 
892
 
 
893
class _TdsWriter(object):
 
894
    """ TDS stream writer
 
895
 
 
896
    Handles splitting of incoming data into TDS packets according to TDS protocol.
 
897
    Provides convinience methods for writing primitive data types.
 
898
    """
 
899
    def __init__(self, session, bufsize):
 
900
        self._session = session
 
901
        self._tds = session
 
902
        self._transport = session
 
903
        self._pos = 0
 
904
        self._buf = bytearray(bufsize)
 
905
        self._packet_no = 0
 
906
 
 
907
    @property
 
908
    def session(self):
 
909
        """ Back reference to parent :class:`_TdsSession` object """
 
910
        return self._session
 
911
 
 
912
    @property
 
913
    def bufsize(self):
 
914
        """ Size of the buffer """
 
915
        return len(self._buf)
 
916
 
 
917
    @bufsize.setter
 
918
    def bufsize(self, bufsize):
 
919
        if len(self._buf) == bufsize:
 
920
            return
 
921
 
 
922
        if bufsize > len(self._buf):
 
923
            self._buf.extend(b'\0' * (bufsize - len(self._buf)))
 
924
        else:
 
925
            self._buf = self._buf[0:bufsize]
 
926
 
 
927
    def begin_packet(self, packet_type):
 
928
        """ Starts new packet stream
 
929
 
 
930
        :param packet_type: Type of TDS stream, e.g. TDS_PRELOGIN, TDS_QUERY etc.
 
931
        """
 
932
        self._type = packet_type
 
933
        self._pos = 8
 
934
 
 
935
    def pack(self, struct, *args):
 
936
        """ Packs and writes structure into stream """
 
937
        self.write(struct.pack(*args))
 
938
 
 
939
    def put_byte(self, value):
 
940
        """ Writes single byte into stream """
 
941
        self.pack(_byte, value)
 
942
 
 
943
    def put_smallint(self, value):
 
944
        """ Writes 16-bit signed integer into the stream """
 
945
        self.pack(_smallint_le, value)
 
946
 
 
947
    def put_usmallint(self, value):
 
948
        """ Writes 16-bit unsigned integer into the stream """
 
949
        self.pack(_usmallint_le, value)
 
950
 
 
951
    def put_smallint_be(self, value):
 
952
        """ Writes 16-bit signed big-endian integer into the stream """
 
953
        self.pack(_smallint_be, value)
 
954
 
 
955
    def put_usmallint_be(self, value):
 
956
        """ Writes 16-bit unsigned big-endian integer into the stream """
 
957
        self.pack(_usmallint_be, value)
 
958
 
 
959
    def put_int(self, value):
 
960
        """ Writes 32-bit signed integer into the stream """
 
961
        self.pack(_int_le, value)
 
962
 
 
963
    def put_uint(self, value):
 
964
        """ Writes 32-bit unsigned integer into the stream """
 
965
        self.pack(_uint_le, value)
 
966
 
 
967
    def put_int_be(self, value):
 
968
        """ Writes 32-bit signed big-endian integer into the stream """
 
969
        self.pack(_int_be, value)
 
970
 
 
971
    def put_uint_be(self, value):
 
972
        """ Writes 32-bit unsigned big-endian integer into the stream """
 
973
        self.pack(_uint_be, value)
 
974
 
 
975
    def put_int8(self, value):
 
976
        """ Writes 64-bit signed integer into the stream """
 
977
        self.pack(_int8_le, value)
 
978
 
 
979
    def put_uint8(self, value):
 
980
        """ Writes 64-bit unsigned integer into the stream """
 
981
        self.pack(_uint8_le, value)
 
982
 
 
983
    def put_collation(self, collation):
 
984
        """ Writes :class:`Collation` structure into the stream """
 
985
        self.write(collation.pack())
 
986
 
 
987
    def write(self, data):
 
988
        """ Writes given bytes buffer into the stream
 
989
 
 
990
        Function returns only when entire buffer is written
 
991
        """
 
992
        data_off = 0
 
993
        while data_off < len(data):
 
994
            left = len(self._buf) - self._pos
 
995
            if left <= 0:
 
996
                self._write_packet(final=False)
 
997
            else:
 
998
                to_write = min(left, len(data) - data_off)
 
999
                self._buf[self._pos:self._pos + to_write] = data[data_off:data_off + to_write]
 
1000
                self._pos += to_write
 
1001
                data_off += to_write
 
1002
 
 
1003
    def write_ucs2(self, s):
 
1004
        """ Write string encoding it in UCS2 into stream """
 
1005
        self.write_string(s, ucs2_codec)
 
1006
 
 
1007
    def write_string(self, s, codec):
 
1008
        """ Write string encoding it with codec into stream """
 
1009
        for i in xrange(0, len(s), self.bufsize):
 
1010
            chunk = s[i:i + self.bufsize]
 
1011
            buf, consumed = codec.encode(chunk)
 
1012
            assert consumed == len(chunk)
 
1013
            self.write(buf)
 
1014
 
 
1015
    def flush(self):
 
1016
        """ Closes current packet stream """
 
1017
        return self._write_packet(final=True)
 
1018
 
 
1019
    def _write_packet(self, final):
 
1020
        """ Writes single TDS packet into underlying transport.
 
1021
 
 
1022
        Data for the packet is taken from internal buffer.
 
1023
 
 
1024
        :param final: True means this is the final packet in substream.
 
1025
        """
 
1026
        status = 1 if final else 0
 
1027
        _header.pack_into(self._buf, 0, self._type, status, self._pos, 0, self._packet_no)
 
1028
        self._packet_no = (self._packet_no + 1) % 256
 
1029
        self._transport.send(self._buf[:self._pos], final)
 
1030
        self._pos = 8
 
1031
 
 
1032
 
 
1033
class MemoryChunkedHandler(object):
 
1034
    def begin(self, column, size):
 
1035
        self.size = size
 
1036
        self._chunks = []
 
1037
 
 
1038
    def new_chunk(self, val):
 
1039
        #logger.debug('MemoryChunkedHandler.new_chunk(sz=%d)', len(val))
 
1040
        self._chunks.append(val)
 
1041
 
 
1042
    def end(self):
 
1043
        return b''.join(self._chunks)
 
1044
 
 
1045
 
 
1046
class MemoryStrChunkedHandler(object):
 
1047
    def begin(self, column, size):
 
1048
        self.size = size
 
1049
        self._chunks = []
 
1050
 
 
1051
    def new_chunk(self, val):
 
1052
        #logger.debug('MemoryChunkedHandler.new_chunk(sz=%d)', len(val))
 
1053
        self._chunks.append(val)
 
1054
 
 
1055
    def end(self):
 
1056
        return ''.join(self._chunks)
 
1057
 
 
1058
 
 
1059
class BaseType(object):
 
1060
    """ Base type for TDS data types.
 
1061
 
 
1062
    All TDS types should derive from it.
 
1063
    In addition actual types should provide the following:
 
1064
 
 
1065
    - type - class variable storing type identifier
 
1066
    """
 
1067
    def get_typeid(self):
 
1068
        """ Returns type identifier of type. """
 
1069
        return self.type
 
1070
 
 
1071
    def get_declaration(self):
 
1072
        """ Returns SQL declaration for this type.
 
1073
        
 
1074
        Examples are: NVARCHAR(10), TEXT, TINYINT
 
1075
        Should be implemented in actual types.
 
1076
        """
 
1077
        raise NotImplementedError
 
1078
 
 
1079
    @classmethod
 
1080
    def from_declaration(cls, declaration, nullable, connection):
 
1081
        """ Class method that parses declaration and returns a type instance.
 
1082
 
 
1083
        :param declaration: type declaration string
 
1084
        :param nullable: true if type have to be nullable, false otherwise
 
1085
        :param connection: instance of :class:`_TdsSocket`
 
1086
        :return: If declaration is parsed, returns type instance,
 
1087
                 otherwise returns None.
 
1088
 
 
1089
        Should be implemented in actual types.
 
1090
        """
 
1091
        raise NotImplementedError
 
1092
 
 
1093
    @classmethod
 
1094
    def from_stream(cls, r):
 
1095
        """ Class method that reads and returns a type instance.
 
1096
 
 
1097
        :param r: An instance of :class:`_TdsReader` to read type from.
 
1098
 
 
1099
        Should be implemented in actual types.
 
1100
        """
 
1101
        raise NotImplementedError
 
1102
 
 
1103
    def write_info(self, w):
 
1104
        """ Writes type info into w stream.
 
1105
 
 
1106
        :param w: An instance of :class:`_TdsWriter` to write into.
 
1107
 
 
1108
        Should be symmetrical to from_stream method.
 
1109
        Should be implemented in actual types.
 
1110
        """
 
1111
        raise NotImplementedError
 
1112
 
 
1113
    def write(self, w, value):
 
1114
        """ Writes type's value into stream
 
1115
 
 
1116
        :param w: An instance of :class:`_TdsWriter` to write into.
 
1117
        :param value: A value to be stored, should be compatible with the type
 
1118
 
 
1119
        Should be implemented in actual types.
 
1120
        """
 
1121
        raise NotImplementedError
 
1122
 
 
1123
    def read(self, r):
 
1124
        """ Reads value from the stream.
 
1125
 
 
1126
        :param r: An instance of :class:`_TdsReader` to read value from.
 
1127
        :return: A read value.
 
1128
 
 
1129
        Should be implemented in actual types.
 
1130
        """
 
1131
        raise NotImplementedError
 
1132
 
 
1133
    
 
1134
class BasePrimitiveType(BaseType):
 
1135
    """ Base type for primitive TDS data types.
 
1136
 
 
1137
    Primitive type is a fixed size type with no type arguments.
 
1138
    All primitive TDS types should derive from it.
 
1139
    In addition actual types should provide the following:
 
1140
 
 
1141
    - type - class variable storing type identifier
 
1142
    - declaration - class variable storing name of sql type
 
1143
    - isntance - class variable storing instance of class
 
1144
    """
 
1145
 
 
1146
    def get_declaration(self):
 
1147
        return self.declaration
 
1148
 
 
1149
    @classmethod
 
1150
    def from_declaration(cls, declaration, nullable, connection):
 
1151
        if not nullable and declaration == cls.declaration:
 
1152
            return cls.instance
 
1153
 
 
1154
    @classmethod
 
1155
    def from_stream(cls, r):
 
1156
        return cls.instance
 
1157
 
 
1158
    def write_info(self, w):
 
1159
        pass
 
1160
 
 
1161
 
 
1162
class BaseTypeN(BaseType):
 
1163
    """ Base type for nullable TDS data types.
 
1164
 
 
1165
    All nullable TDS types should derive from it.
 
1166
    In addition actual types should provide the following:
 
1167
 
 
1168
    - type - class variable storing type identifier
 
1169
    - subtypes - class variable storing dict {subtype_size: subtype_instance}
 
1170
    """
 
1171
 
 
1172
    def __init__(self, size):
 
1173
        assert size in self.subtypes
 
1174
        self._size = size
 
1175
        self._current_subtype = self.subtypes[size]
 
1176
 
 
1177
    def get_typeid(self):
 
1178
        return self._current_subtype.get_typeid()
 
1179
 
 
1180
    def get_declaration(self):
 
1181
        return self._current_subtype.get_declaration()
 
1182
 
 
1183
    @classmethod
 
1184
    def from_declaration(cls, declaration, nullable, connection):
 
1185
        if nullable:
 
1186
            for size, subtype in cls.subtypes.items():
 
1187
                inst = subtype.from_declaration(declaration, False, connection)
 
1188
                if inst:
 
1189
                    return cls(size)
 
1190
    
 
1191
    @classmethod
 
1192
    def from_stream(cls, r):
 
1193
        size = r.get_byte()
 
1194
        if size not in cls.subtypes:
 
1195
            raise InterfaceError('Invalid %s size' % cls.type, size)
 
1196
        return cls(size)
 
1197
 
 
1198
    def write_info(self, w):
 
1199
        w.put_byte(self._size)
 
1200
 
 
1201
    def read(self, r):
 
1202
        size = r.get_byte()
 
1203
        if size == 0:
 
1204
            return None
 
1205
        if size not in self.subtypes:
 
1206
            raise r.session.bad_stream('Invalid %s size' % self.type, size)
 
1207
        return self.subtypes[size].read(r)
 
1208
 
 
1209
    def write(self, w, val):
 
1210
        if val is None:
 
1211
            w.put_byte(0)
 
1212
            return
 
1213
        w.put_byte(self._size)
 
1214
        self._current_subtype.write(w, val)
 
1215
 
 
1216
class Bit(BasePrimitiveType):
 
1217
    type = SYBBIT
 
1218
    declaration = 'BIT'
 
1219
 
 
1220
    def write(self, w, value):
 
1221
        w.put_byte(1 if value else 0)
 
1222
 
 
1223
    def read(self, r):
 
1224
        return bool(r.get_byte())
 
1225
 
 
1226
Bit.instance = Bit()
 
1227
 
 
1228
 
 
1229
class BitN(BaseTypeN):
 
1230
    type = SYBBITN
 
1231
    subtypes = {1 : Bit.instance}
 
1232
    
 
1233
BitN.instance = BitN(1)
 
1234
 
 
1235
 
 
1236
class TinyInt(BasePrimitiveType):
 
1237
    type = SYBINT1
 
1238
    declaration = 'TINYINT'
 
1239
 
 
1240
    def write(self, w, val):
 
1241
        w.put_byte(val)
 
1242
 
 
1243
    def read(self, r):
 
1244
        return r.get_byte()
 
1245
    
 
1246
TinyInt.instance = TinyInt()
 
1247
 
 
1248
 
 
1249
class SmallInt(BasePrimitiveType):
 
1250
    type = SYBINT2
 
1251
    declaration = 'SMALLINT'
 
1252
 
 
1253
    def write(self, w, val):
 
1254
        w.put_smallint(val)
 
1255
 
 
1256
    def read(self, r):
 
1257
        return r.get_smallint()
 
1258
    
 
1259
SmallInt.instance = SmallInt()
 
1260
 
 
1261
 
 
1262
class Int(BasePrimitiveType):
 
1263
    type = SYBINT4
 
1264
    declaration = 'INT'
 
1265
 
 
1266
    def write(self, w, val):
 
1267
        w.put_int(val)
 
1268
 
 
1269
    def read(self, r):
 
1270
        return r.get_int()
 
1271
    
 
1272
Int.instance = Int()
 
1273
 
 
1274
 
 
1275
class BigInt(BasePrimitiveType):
 
1276
    type = SYBINT8
 
1277
    declaration = 'BIGINT'
 
1278
 
 
1279
    def write(self, w, val):
 
1280
        w.put_int8(val)
 
1281
 
 
1282
    def read(self, r):
 
1283
        return r.get_int8()
 
1284
 
 
1285
BigInt.instance = BigInt()
 
1286
 
 
1287
 
 
1288
class IntN(BaseTypeN):
 
1289
    type = SYBINTN
 
1290
    
 
1291
    subtypes = {
 
1292
        1: TinyInt.instance,
 
1293
        2: SmallInt.instance,
 
1294
        4: Int.instance,
 
1295
        8: BigInt.instance,
 
1296
        }
 
1297
 
 
1298
    
 
1299
class Real(BasePrimitiveType):
 
1300
    type = SYBREAL
 
1301
    declaration = 'REAL'
 
1302
 
 
1303
    def write(self, w, val):
 
1304
        w.pack(_flt4_struct, val)
 
1305
 
 
1306
    def read(self, r):
 
1307
        return r.unpack(_flt4_struct)[0]
 
1308
    
 
1309
Real.instance = Real()
 
1310
 
 
1311
 
 
1312
class Float(BasePrimitiveType):
 
1313
    type = SYBFLT8
 
1314
    declaration = 'FLOAT'
 
1315
 
 
1316
    def write(self, w, val):
 
1317
        w.pack(_flt8_struct, val)
 
1318
 
 
1319
    def read(self, r):
 
1320
        return r.unpack(_flt8_struct)[0]
 
1321
    
 
1322
Float.instance = Float()
 
1323
 
 
1324
 
 
1325
class FloatN(BaseTypeN):
 
1326
    type = SYBFLTN
 
1327
    
 
1328
    subtypes = {
 
1329
        4: Real.instance,
 
1330
        8: Float.instance,
 
1331
        }
 
1332
 
 
1333
    
 
1334
class VarChar70(BaseType):
 
1335
    type = XSYBVARCHAR
 
1336
 
 
1337
    def __init__(self, size, codec):
 
1338
        #if size <= 0 or size > 8000:
 
1339
        #    raise DataError('Invalid size for VARCHAR field')
 
1340
        self._size = size
 
1341
        self._codec = codec
 
1342
 
 
1343
    @classmethod
 
1344
    def from_stream(cls, r):
 
1345
        size = r.get_smallint()
 
1346
        return cls(size, codec=r._session.conn.server_codec)
 
1347
 
 
1348
    @classmethod
 
1349
    def from_declaration(cls, declaration, nullable, connection):
 
1350
        m = re.match(r'VARCHAR\((\d+)\)', declaration)
 
1351
        if m:
 
1352
            return cls(int(m.group(1)), connection.server_codec)
 
1353
 
 
1354
    def get_declaration(self):
 
1355
        return 'VARCHAR({0})'.format(self._size)
 
1356
 
 
1357
    def write_info(self, w):
 
1358
        w.put_smallint(self._size)
 
1359
        #w.put_smallint(self._size)
 
1360
 
 
1361
    def write(self, w, val):
 
1362
        if val is None:
 
1363
            w.put_smallint(-1)
 
1364
        else:
 
1365
            val = force_unicode(val)
 
1366
            val, _ = self._codec.encode(val)
 
1367
            w.put_smallint(len(val))
 
1368
            #w.put_smallint(len(val))
 
1369
            w.write(val)
 
1370
 
 
1371
    def read(self, r):
 
1372
        size = r.get_smallint()
 
1373
        if size < 0:
 
1374
            return None
 
1375
        return r.read_str(size, self._codec)
 
1376
 
 
1377
 
 
1378
class VarChar71(VarChar70):
 
1379
    def __init__(self, size, collation):
 
1380
        super(VarChar71, self).__init__(size, codec=collation.get_codec())
 
1381
        self._collation = collation
 
1382
 
 
1383
    @classmethod
 
1384
    def from_stream(cls, r):
 
1385
        size = r.get_smallint()
 
1386
        collation = r.get_collation()
 
1387
        return cls(size, collation)
 
1388
 
 
1389
    @classmethod
 
1390
    def from_declaration(cls, declaration, nullable, connection):
 
1391
        m = re.match(r'VARCHAR\((\d+)\)', declaration)
 
1392
        if m:
 
1393
            return cls(int(m.group(1)), connection.collation)
 
1394
 
 
1395
    def write_info(self, w):
 
1396
        super(VarChar71, self).write_info(w)
 
1397
        w.put_collation(self._collation)
 
1398
 
 
1399
 
 
1400
class VarChar72(VarChar71):
 
1401
    @classmethod
 
1402
    def from_stream(cls, r):
 
1403
        size = r.get_usmallint()
 
1404
        collation = r.get_collation()
 
1405
        if size == 0xffff:
 
1406
            return VarCharMax(collation)
 
1407
        return cls(size, collation)
 
1408
 
 
1409
    @classmethod
 
1410
    def from_declaration(cls, declaration, nullable, connection):
 
1411
        if declaration == 'VARCHAR(MAX)':
 
1412
            return VarCharMax(connection.collation)
 
1413
        m = re.match(r'VARCHAR\((\d+)\)', declaration)
 
1414
        if m:
 
1415
            return cls(int(m.group(1)), connection.collation)
 
1416
 
 
1417
 
 
1418
class VarCharMax(VarChar72):
 
1419
    def __init__(self, collation):
 
1420
        super(VarChar72, self).__init__(0, collation)
 
1421
 
 
1422
    def get_declaration(self):
 
1423
        return 'VARCHAR(MAX)'
 
1424
 
 
1425
    def write_info(self, w):
 
1426
        w.put_usmallint(PLP_MARKER)
 
1427
        w.put_collation(self._collation)
 
1428
 
 
1429
    def write(self, w, val):
 
1430
        if val is None:
 
1431
            w.put_uint8(PLP_NULL)
 
1432
        else:
 
1433
            val = force_unicode(val)
 
1434
            val, _ = self._codec.encode(val)
 
1435
            w.put_int8(len(val))
 
1436
            if len(val) > 0:
 
1437
                w.put_int(len(val))
 
1438
                w.write(val)
 
1439
            w.put_int(0)
 
1440
 
 
1441
    def read(self, r):
 
1442
        r = PlpReader(r)
 
1443
        if r.is_null():
 
1444
            return None
 
1445
        return ''.join(iterdecode(r.chunks(), self._codec))
 
1446
 
 
1447
 
 
1448
class NVarChar70(BaseType):
 
1449
    type = XSYBNVARCHAR
 
1450
 
 
1451
    def __init__(self, size):
 
1452
        #if size <= 0 or size > 4000:
 
1453
        #    raise DataError('Invalid size for NVARCHAR field')
 
1454
        self._size = size
 
1455
 
 
1456
    @classmethod
 
1457
    def from_stream(cls, r):
 
1458
        size = r.get_usmallint()
 
1459
        return cls(size / 2)
 
1460
 
 
1461
    @classmethod
 
1462
    def from_declaration(cls, declaration, nullable, connection):
 
1463
        m = re.match(r'NVARCHAR\((\d+)\)', declaration)
 
1464
        if m:
 
1465
            return cls(int(m.group(1)))
 
1466
 
 
1467
    def get_declaration(self):
 
1468
        return 'NVARCHAR({0})'.format(self._size)
 
1469
 
 
1470
    def write_info(self, w):
 
1471
        w.put_usmallint(self._size * 2)
 
1472
        #w.put_smallint(self._size)
 
1473
 
 
1474
    def write(self, w, val):
 
1475
        if val is None:
 
1476
            w.put_usmallint(0xffff)
 
1477
        else:
 
1478
            if isinstance(val, bytes):
 
1479
                val = force_unicode(val)
 
1480
            buf, _ = ucs2_codec.encode(val)
 
1481
            l = len(buf)
 
1482
            w.put_usmallint(l)
 
1483
            w.write(buf)
 
1484
 
 
1485
    def read(self, r):
 
1486
        size = r.get_usmallint()
 
1487
        if size == 0xffff:
 
1488
            return None
 
1489
        return r.read_str(size, ucs2_codec)
 
1490
 
 
1491
 
 
1492
class NVarChar71(NVarChar70):
 
1493
    def __init__(self, size, collation=raw_collation):
 
1494
        super(NVarChar71, self).__init__(size)
 
1495
        self._collation = collation
 
1496
 
 
1497
    @classmethod
 
1498
    def from_stream(cls, r):
 
1499
        size = r.get_usmallint()
 
1500
        collation = r.get_collation()
 
1501
        return cls(size / 2, collation)
 
1502
 
 
1503
    @classmethod
 
1504
    def from_declaration(cls, declaration, nullable, connection):
 
1505
        m = re.match(r'NVARCHAR\((\d+)\)', declaration)
 
1506
        if m:
 
1507
            return cls(int(m.group(1)), connection.collation)
 
1508
 
 
1509
    def write_info(self, w):
 
1510
        super(NVarChar71, self).write_info(w)
 
1511
        w.put_collation(self._collation)
 
1512
 
 
1513
 
 
1514
class NVarChar72(NVarChar71):
 
1515
    @classmethod
 
1516
    def from_stream(cls, r):
 
1517
        size = r.get_usmallint()
 
1518
        collation = r.get_collation()
 
1519
        if size == 0xffff:
 
1520
            return NVarCharMax(size, collation)
 
1521
        return cls(size / 2, collation=collation)
 
1522
 
 
1523
    @classmethod
 
1524
    def from_declaration(cls, declaration, nullable, connection):
 
1525
        if declaration == 'NVARCHAR(MAX)':
 
1526
            return VarCharMax(connection.collation)
 
1527
        m = re.match(r'NVARCHAR\((\d+)\)', declaration)
 
1528
        if m:
 
1529
            return cls(int(m.group(1)), connection.collation)
 
1530
 
 
1531
 
 
1532
class NVarCharMax(NVarChar72):
 
1533
    def get_typeid(self):
 
1534
        return SYBNTEXT
 
1535
 
 
1536
    def get_declaration(self):
 
1537
        return 'NVARCHAR(MAX)'
 
1538
 
 
1539
    def write_info(self, w):
 
1540
        w.put_usmallint(PLP_MARKER)
 
1541
        w.put_collation(self._collation)
 
1542
 
 
1543
    def write(self, w, val):
 
1544
        if val is None:
 
1545
            w.put_uint8(PLP_NULL)
 
1546
        else:
 
1547
            if isinstance(val, bytes):
 
1548
                val = force_unicode(val)
 
1549
            val, _ = ucs2_codec.encode(val)
 
1550
            w.put_uint8(len(val))
 
1551
            if len(val) > 0:
 
1552
                w.put_uint(len(val))
 
1553
                w.write(val)
 
1554
            w.put_uint(0)
 
1555
 
 
1556
    def read(self, r):
 
1557
        r = PlpReader(r)
 
1558
        if r.is_null():
 
1559
            return None
 
1560
        res = ''.join(iterdecode(r.chunks(), ucs2_codec))
 
1561
        return res
 
1562
 
 
1563
 
 
1564
class Xml(NVarCharMax):
 
1565
    type = SYBMSXML
 
1566
    declaration = 'XML'
 
1567
 
 
1568
    def __init__(self, schema={}):
 
1569
        super(Xml, self).__init__(0)
 
1570
        self._schema = schema
 
1571
 
 
1572
    def get_typeid(self):
 
1573
        return self.type
 
1574
 
 
1575
    def get_declaration(self):
 
1576
        return self.declaration
 
1577
 
 
1578
    @classmethod
 
1579
    def from_stream(cls, r):
 
1580
        has_schema = r.get_byte()
 
1581
        schema = {}
 
1582
        if has_schema:
 
1583
            schema['dbname'] = r.read_ucs2(r.get_byte())
 
1584
            schema['owner'] = r.read_ucs2(r.get_byte())
 
1585
            schema['collection'] = r.read_ucs2(r.get_smallint())
 
1586
        return cls(schema)
 
1587
 
 
1588
    @classmethod
 
1589
    def from_declaration(cls, declaration, nullable, connection):
 
1590
        if declaration == cls.declaration:
 
1591
            return cls()
 
1592
 
 
1593
    def write_info(self, w):
 
1594
        if self._schema:
 
1595
            w.put_byte(1)
 
1596
            w.put_byte(len(self._schema['dbname']))
 
1597
            w.write_ucs2(self._schema['dbname'])
 
1598
            w.put_byte(len(self._schema['owner']))
 
1599
            w.write_ucs2(self._schema['owner'])
 
1600
            w.put_usmallint(len(self._schema['collection']))
 
1601
            w.write_ucs2(self._schema['collection'])
 
1602
        else:
 
1603
            w.put_byte(0)
 
1604
 
 
1605
 
 
1606
class Text70(BaseType):
 
1607
    type = SYBTEXT
 
1608
    declaration = 'TEXT'
 
1609
 
 
1610
    def __init__(self, size=0, table_name='', codec=None):
 
1611
        self._size = size
 
1612
        self._table_name = table_name
 
1613
        self._codec = codec
 
1614
 
 
1615
    @classmethod
 
1616
    def from_stream(cls, r):
 
1617
        size = r.get_int()
 
1618
        table_name = r.read_ucs2(r.get_smallint())
 
1619
        return cls(size, table_name, codec=r.session.conn.server_codec)
 
1620
 
 
1621
    @classmethod
 
1622
    def from_declaration(cls, declaration, nullable, connection):
 
1623
        if declaration == cls.declaration:
 
1624
            return cls()
 
1625
    
 
1626
    def get_declaration(self):
 
1627
        return self.declaration
 
1628
 
 
1629
    def write_info(self, w):
 
1630
        w.put_int(self._size)
 
1631
 
 
1632
    def write(self, w, val):
 
1633
        if val is None:
 
1634
            w.put_int(-1)
 
1635
        else:
 
1636
            val = force_unicode(val)
 
1637
            val, _ = self._codec.encode(val)
 
1638
            w.put_int(len(val))
 
1639
            w.write(val)
 
1640
 
 
1641
    def read(self, r):
 
1642
        size = r.get_byte()
 
1643
        if size == 0:
 
1644
            return None
 
1645
        readall(r, size)  # textptr
 
1646
        readall(r, 8)  # timestamp
 
1647
        colsize = r.get_int()
 
1648
        return r.read_str(colsize, self._codec)
 
1649
 
 
1650
 
 
1651
class Text71(Text70):
 
1652
    def __init__(self, size=0, table_name='', collation=raw_collation):
 
1653
        self._size = size
 
1654
        self._collation = collation
 
1655
        self._codec = collation.get_codec()
 
1656
        self._table_name = table_name
 
1657
 
 
1658
    @classmethod
 
1659
    def from_stream(cls, r):
 
1660
        size = r.get_int()
 
1661
        collation = r.get_collation()
 
1662
        table_name = r.read_ucs2(r.get_smallint())
 
1663
        return cls(size, table_name, collation)
 
1664
 
 
1665
    def write_info(self, w):
 
1666
        w.put_int(self._size)
 
1667
        w.put_collation(self._collation)
 
1668
 
 
1669
 
 
1670
class Text72(Text71):
 
1671
    def __init__(self, size=0, table_name_parts=[], collation=raw_collation):
 
1672
        super(Text72, self).__init__(size, '.'.join(table_name_parts), collation)
 
1673
        self._table_name_parts = table_name_parts
 
1674
 
 
1675
    @classmethod
 
1676
    def from_stream(cls, r):
 
1677
        size = r.get_int()
 
1678
        collation = r.get_collation()
 
1679
        num_parts = r.get_byte()
 
1680
        parts = []
 
1681
        for _ in range(num_parts):
 
1682
            parts.append(r.read_ucs2(r.get_smallint()))
 
1683
        return cls(size, parts, collation)
 
1684
 
 
1685
 
 
1686
class NText70(BaseType):
 
1687
    type = SYBNTEXT
 
1688
    declaration = 'NTEXT'
 
1689
 
 
1690
    def __init__(self, size=0, table_name=''):
 
1691
        self._size = size
 
1692
        self._table_name = table_name
 
1693
 
 
1694
    @classmethod
 
1695
    def from_stream(cls, r):
 
1696
        size = r.get_int()
 
1697
        table_name = r.read_ucs2(r.get_smallint())
 
1698
        return cls(size, table_name)
 
1699
 
 
1700
    @classmethod
 
1701
    def from_declaration(cls, declaration, nullable, connection):
 
1702
        if declaration == cls.declaration:
 
1703
            return cls()
 
1704
    
 
1705
    def get_declaration(self):
 
1706
        return self.declaration
 
1707
 
 
1708
    def read(self, r):
 
1709
        textptr_size = r.get_byte()
 
1710
        if textptr_size == 0:
 
1711
            return None
 
1712
        readall(r, textptr_size)  # textptr
 
1713
        readall(r, 8)  # timestamp
 
1714
        colsize = r.get_int()
 
1715
        return r.read_str(colsize, ucs2_codec)
 
1716
 
 
1717
    def write_info(self, w):
 
1718
        w.put_int(self._size * 2)
 
1719
 
 
1720
    def write(self, w, val):
 
1721
        if val is None:
 
1722
            w.put_int(-1)
 
1723
        else:
 
1724
            w.put_int(len(val) * 2)
 
1725
            w.write_ucs2(val)
 
1726
 
 
1727
 
 
1728
class NText71(NText70):
 
1729
    def __init__(self, size=0, table_name='', collation=raw_collation):
 
1730
        self._size = size
 
1731
        self._collation = collation
 
1732
        self._table_name = table_name
 
1733
 
 
1734
    @classmethod
 
1735
    def from_stream(cls, r):
 
1736
        size = r.get_int()
 
1737
        collation = r.get_collation()
 
1738
        table_name = r.read_ucs2(r.get_smallint())
 
1739
        return cls(size, table_name, collation)
 
1740
 
 
1741
    def write_info(self, w):
 
1742
        w.put_int(self._size)
 
1743
        w.put_collation(self._collation)
 
1744
 
 
1745
    def read(self, r):
 
1746
        textptr_size = r.get_byte()
 
1747
        if textptr_size == 0:
 
1748
            return None
 
1749
        readall(r, textptr_size)  # textptr
 
1750
        readall(r, 8)  # timestamp
 
1751
        colsize = r.get_int()
 
1752
        return r.read_str(colsize, ucs2_codec)
 
1753
 
 
1754
 
 
1755
class NText72(NText71):
 
1756
    def __init__(self, size=0, table_name_parts=[], collation=raw_collation):
 
1757
        self._size = size
 
1758
        self._collation = collation
 
1759
        self._table_name_parts = table_name_parts
 
1760
 
 
1761
    @classmethod
 
1762
    def from_stream(cls, r):
 
1763
        size = r.get_int()
 
1764
        collation = r.get_collation()
 
1765
        num_parts = r.get_byte()
 
1766
        parts = []
 
1767
        for _ in range(num_parts):
 
1768
            parts.append(r.read_ucs2(r.get_smallint()))
 
1769
        return cls(size, parts, collation)
 
1770
 
 
1771
 
 
1772
class VarBinary(BaseType):
 
1773
    type = XSYBVARBINARY
 
1774
 
 
1775
    def __init__(self, size):
 
1776
        self._size = size
 
1777
 
 
1778
    @classmethod
 
1779
    def from_stream(cls, r):
 
1780
        size = r.get_usmallint()
 
1781
        return cls(size)
 
1782
 
 
1783
    @classmethod
 
1784
    def from_declaration(cls, declaration, nullable, connection):
 
1785
        m = re.match(r'VARBINARY\((\d+)\)', declaration)
 
1786
        if m:
 
1787
            return cls(int(m.group(1)))
 
1788
    
 
1789
    def get_declaration(self):
 
1790
        return 'VARBINARY({0})'.format(self._size)
 
1791
 
 
1792
    def write_info(self, w):
 
1793
        w.put_usmallint(self._size)
 
1794
 
 
1795
    def write(self, w, val):
 
1796
        if val is None:
 
1797
            w.put_usmallint(0xffff)
 
1798
        else:
 
1799
            w.put_usmallint(len(val))
 
1800
            w.write(val)
 
1801
 
 
1802
    def read(self, r):
 
1803
        size = r.get_usmallint()
 
1804
        if size == 0xffff:
 
1805
            return None
 
1806
        return readall(r, size)
 
1807
 
 
1808
 
 
1809
class VarBinary72(VarBinary):
 
1810
    @classmethod
 
1811
    def from_stream(cls, r):
 
1812
        size = r.get_usmallint()
 
1813
        if size == 0xffff:
 
1814
            return VarBinaryMax()
 
1815
        return cls(size)
 
1816
 
 
1817
    @classmethod
 
1818
    def from_declaration(cls, declaration, nullable, connection):
 
1819
        if declaration == 'VARBINARY(MAX)':
 
1820
            return VarBinaryMax()
 
1821
        m = re.match(r'VARBINARY\((\d+)\)', declaration)
 
1822
        if m:
 
1823
            return cls(int(m.group(1)))
 
1824
 
 
1825
 
 
1826
class VarBinaryMax(VarBinary):
 
1827
    def __init__(self):
 
1828
        super(VarBinaryMax, self).__init__(0)
 
1829
 
 
1830
    def get_declaration(self):
 
1831
        return 'VARBINARY(MAX)'
 
1832
 
 
1833
    def write_info(self, w):
 
1834
        w.put_usmallint(PLP_MARKER)
 
1835
 
 
1836
    def write(self, w, val):
 
1837
        if val is None:
 
1838
            w.put_uint8(PLP_NULL)
 
1839
        else:
 
1840
            w.put_uint8(len(val))
 
1841
            if val:
 
1842
                w.put_uint(len(val))
 
1843
                w.write(val)
 
1844
            w.put_uint(0)
 
1845
 
 
1846
    def read(self, r):
 
1847
        r = PlpReader(r)
 
1848
        if r.is_null():
 
1849
            return None
 
1850
        return b''.join(r.chunks())
 
1851
 
 
1852
 
 
1853
class Image70(BaseType):
 
1854
    type = SYBIMAGE
 
1855
    declaration = 'IMAGE'
 
1856
 
 
1857
    def __init__(self, size=0, table_name=''):
 
1858
        self._table_name = table_name
 
1859
        self._size = size
 
1860
 
 
1861
    def get_declaration(self):
 
1862
        return self.declaration
 
1863
 
 
1864
    @classmethod
 
1865
    def from_stream(cls, r):
 
1866
        size = r.get_int()
 
1867
        table_name = r.read_ucs2(r.get_smallint())
 
1868
        return cls(size, table_name)
 
1869
 
 
1870
    @classmethod
 
1871
    def from_declaration(cls, declaration, nullable, connection):
 
1872
        if declaration == cls.declaration:
 
1873
            return cls()
 
1874
 
 
1875
    def read(self, r):
 
1876
        size = r.get_byte()
 
1877
        if size == 16:  # Jeff's hack
 
1878
            readall(r, 16)  # textptr
 
1879
            readall(r, 8)  # timestamp
 
1880
            colsize = r.get_int()
 
1881
            return readall(r, colsize)
 
1882
        else:
 
1883
            return None
 
1884
 
 
1885
    def write(self, w, val):
 
1886
        if val is None:
 
1887
            w.put_int(-1)
 
1888
            return
 
1889
        w.put_int(len(val))
 
1890
        w.write(val)
 
1891
 
 
1892
    def write_info(self, w):
 
1893
        w.put_int(self._size)
 
1894
 
 
1895
 
 
1896
class Image72(Image70):
 
1897
    def __init__(self, size=0, parts=[]):
 
1898
        self._parts = parts
 
1899
        self._size = size
 
1900
 
 
1901
    @classmethod
 
1902
    def from_stream(cls, r):
 
1903
        size = r.get_int()
 
1904
        num_parts = r.get_byte()
 
1905
        parts = []
 
1906
        for _ in range(num_parts):
 
1907
            parts.append(r.read_ucs2(r.get_usmallint()))
 
1908
        return Image72(size, parts)
 
1909
 
 
1910
 
 
1911
class BaseDateTime(BaseType):
 
1912
    _base_date = datetime(1900, 1, 1)
 
1913
    _min_date = datetime(1753, 1, 1, 0, 0, 0)
 
1914
    _max_date = datetime(9999, 12, 31, 23, 59, 59, 997000)
 
1915
 
 
1916
 
 
1917
class SmallDateTime(BasePrimitiveType, BaseDateTime):
 
1918
    type = SYBDATETIME4
 
1919
    declaration = 'SMALLDATETIME'
 
1920
 
 
1921
    _max_date = datetime(2079, 6, 6, 23, 59, 0)
 
1922
    _struct = struct.Struct('<HH')
 
1923
 
 
1924
    def write(self, w, val):
 
1925
        if val.tzinfo:
 
1926
            if not w.session.use_tz:
 
1927
                raise DataError('Timezone-aware datetime is used without specifying use_tz')
 
1928
            val = val.astimezone(w.session.use_tz).replace(tzinfo=None)
 
1929
        days = (val - self._base_date).days
 
1930
        minutes = val.hour * 60 + val.minute
 
1931
        w.pack(self._struct, days, minutes)
 
1932
 
 
1933
    def read(self, r):
 
1934
        days, minutes = r.unpack(self._struct)
 
1935
        tzinfo = None
 
1936
        if r.session.tzinfo_factory is not None:
 
1937
            tzinfo = r.session.tzinfo_factory(0)
 
1938
        return (self._base_date + timedelta(days=days, minutes=minutes)).replace(tzinfo=tzinfo)
 
1939
    
 
1940
SmallDateTime.instance = SmallDateTime()
 
1941
 
 
1942
 
 
1943
class DateTime(BasePrimitiveType, BaseDateTime):
 
1944
    type = SYBDATETIME
 
1945
    declaration = 'DATETIME'
 
1946
 
 
1947
    _struct = struct.Struct('<ll')
 
1948
    
 
1949
    def write(self, w, val):
 
1950
        if val.tzinfo:
 
1951
            if not w.session.use_tz:
 
1952
                raise DataError('Timezone-aware datetime is used without specifying use_tz')
 
1953
            val = val.astimezone(w.session.use_tz).replace(tzinfo=None)
 
1954
        w.write(self.encode(val))
 
1955
 
 
1956
    def read(self, r):
 
1957
        days, t = r.unpack(self._struct)
 
1958
        tzinfo = None
 
1959
        if r.session.tzinfo_factory is not None:
 
1960
            tzinfo = r.session.tzinfo_factory(0)
 
1961
        return _applytz(self.decode(days, t), tzinfo)
 
1962
 
 
1963
    @classmethod
 
1964
    def validate(cls, value):
 
1965
        if not (cls._min_date <= value <= cls._max_date):
 
1966
            raise DataError('Date is out of range')
 
1967
 
 
1968
    @classmethod
 
1969
    def encode(cls, value):
 
1970
        #cls.validate(value)
 
1971
        if type(value) == date:
 
1972
            value = datetime.combine(value, time(0, 0, 0))
 
1973
        days = (value - cls._base_date).days
 
1974
        ms = value.microsecond // 1000
 
1975
        tm = (value.hour * 60 * 60 + value.minute * 60 + value.second) * 300 + int(round(ms * 3 / 10.0))
 
1976
        return cls._struct.pack(days, tm)
 
1977
 
 
1978
    @classmethod
 
1979
    def decode(cls, days, time):
 
1980
        ms = int(round(time % 300 * 10 / 3.0))
 
1981
        secs = time // 300
 
1982
        return cls._base_date + timedelta(days=days, seconds=secs, milliseconds=ms)
 
1983
    
 
1984
DateTime.instance = DateTime()
 
1985
 
 
1986
 
 
1987
class DateTimeN(BaseTypeN, BaseDateTime):
 
1988
    type = SYBDATETIMN
 
1989
    subtypes = {
 
1990
        4: SmallDateTime.instance,
 
1991
        8: DateTime.instance,
 
1992
        }
 
1993
 
 
1994
 
 
1995
class BaseDateTime73(BaseType):
 
1996
    _precision_to_len = {
 
1997
        0: 3,
 
1998
        1: 3,
 
1999
        2: 3,
 
2000
        3: 4,
 
2001
        4: 4,
 
2002
        5: 5,
 
2003
        6: 5,
 
2004
        7: 5,
 
2005
        }
 
2006
 
 
2007
    _base_date = datetime(1, 1, 1)
 
2008
 
 
2009
    def _write_time(self, w, t, prec):
 
2010
        secs = t.hour * 60 * 60 + t.minute * 60 + t.second
 
2011
        val = (secs * 10 ** 7 + t.microsecond * 10) // (10 ** (7 - prec))
 
2012
        w.write(struct.pack('<Q', val)[:self._precision_to_len[prec]])
 
2013
 
 
2014
    def _read_time(self, r, size, prec, use_tz):
 
2015
        time_buf = readall(r, size)
 
2016
        val = _decode_num(time_buf)
 
2017
        val *= 10 ** (7 - prec)
 
2018
        nanoseconds = val * 100
 
2019
        hours = nanoseconds // 1000000000 // 60 // 60
 
2020
        nanoseconds -= hours * 60 * 60 * 1000000000
 
2021
        minutes = nanoseconds // 1000000000 // 60
 
2022
        nanoseconds -= minutes * 60 * 1000000000
 
2023
        seconds = nanoseconds // 1000000000
 
2024
        nanoseconds -= seconds * 1000000000
 
2025
        return time(hours, minutes, seconds, nanoseconds // 1000, tzinfo=use_tz)
 
2026
 
 
2027
    def _write_date(self, w, value):
 
2028
        if type(value) == date:
 
2029
            value = datetime.combine(value, time(0, 0, 0))
 
2030
        days = (value - self._base_date).days
 
2031
        buf = struct.pack('<l', days)[:3]
 
2032
        w.write(buf)
 
2033
 
 
2034
    def _read_date(self, r):
 
2035
        days = _decode_num(readall(r, 3))
 
2036
        return (self._base_date + timedelta(days=days)).date()
 
2037
 
 
2038
 
 
2039
class MsDate(BasePrimitiveType, BaseDateTime73):
 
2040
    type = SYBMSDATE
 
2041
    declaration = 'DATE'
 
2042
 
 
2043
    MIN = date(1, 1, 1)
 
2044
    MAX = date(9999, 12, 31)
 
2045
 
 
2046
    def write(self, w, value):
 
2047
        if value is None:
 
2048
            w.put_byte(0)
 
2049
        else:
 
2050
            w.put_byte(3)
 
2051
            self._write_date(w, value)
 
2052
 
 
2053
    def read_fixed(self, r):
 
2054
        return self._read_date(r)
 
2055
 
 
2056
    def read(self, r):
 
2057
        size = r.get_byte()
 
2058
        if size == 0:
 
2059
            return None
 
2060
        return self._read_date(r)
 
2061
    
 
2062
MsDate.instance = MsDate()
 
2063
 
 
2064
 
 
2065
class MsTime(BaseDateTime73):
 
2066
    type = SYBMSTIME
 
2067
 
 
2068
    def __init__(self, prec):
 
2069
        self._prec = prec
 
2070
        self._size = self._precision_to_len[prec]
 
2071
 
 
2072
    @classmethod
 
2073
    def from_stream(cls, r):
 
2074
        prec = r.get_byte()
 
2075
        return cls(prec)
 
2076
 
 
2077
    @classmethod
 
2078
    def from_declaration(cls, declaration, nullable, connection):
 
2079
        m = re.match(r'TIME\((\d+)\)', declaration)
 
2080
        if m:
 
2081
            return cls(int(m.group(1)))
 
2082
 
 
2083
    def get_declaration(self):
 
2084
        return 'TIME({0})'.format(self._prec)
 
2085
 
 
2086
    def write_info(self, w):
 
2087
        w.put_byte(self._prec)
 
2088
 
 
2089
    def write(self, w, value):
 
2090
        if value is None:
 
2091
            w.put_byte(0)
 
2092
        else:
 
2093
            if value.tzinfo:
 
2094
                if not w.session.use_tz:
 
2095
                    raise DataError('Timezone-aware datetime is used without specifying use_tz')
 
2096
                value = value.astimezone(w.session.use_tz).replace(tzinfo=None)
 
2097
            w.put_byte(self._size)
 
2098
            self._write_time(w, value, self._prec)
 
2099
 
 
2100
    def read_fixed(self, r, size):
 
2101
        tzinfo = None
 
2102
        if r.session.tzinfo_factory is not None:
 
2103
            tzinfo = r.session.tzinfo_factory(0)
 
2104
        return self._read_time(r, size, self._prec, tzinfo)
 
2105
 
 
2106
    def read(self, r):
 
2107
        size = r.get_byte()
 
2108
        if size == 0:
 
2109
            return None
 
2110
        return self.read_fixed(r, size)
 
2111
 
 
2112
 
 
2113
class DateTime2(BaseDateTime73):
 
2114
    type = SYBMSDATETIME2
 
2115
 
 
2116
    def __init__(self, prec=7):
 
2117
        self._prec = prec
 
2118
        self._size = self._precision_to_len[prec] + 3
 
2119
 
 
2120
    @classmethod
 
2121
    def from_stream(cls, r):
 
2122
        prec = r.get_byte()
 
2123
        return cls(prec)
 
2124
 
 
2125
    def get_declaration(self):
 
2126
        return 'DATETIME2({0})'.format(self._prec)
 
2127
 
 
2128
    @classmethod
 
2129
    def from_declaration(cls, declaration, nullable, connection):
 
2130
        if declaration == 'DATETIME2':
 
2131
            return cls()
 
2132
        m = re.match(r'DATETIME2\((\d+)\)', declaration)
 
2133
        if m:
 
2134
            return cls(int(m.group(1)))
 
2135
 
 
2136
    def write_info(self, w):
 
2137
        w.put_byte(self._prec)
 
2138
 
 
2139
    def write(self, w, value):
 
2140
        if value is None:
 
2141
            w.put_byte(0)
 
2142
        else:
 
2143
            if value.tzinfo:
 
2144
                if not w.session.use_tz:
 
2145
                    raise DataError('Timezone-aware datetime is used without specifying use_tz')
 
2146
                value = value.astimezone(w.session.use_tz).replace(tzinfo=None)
 
2147
            w.put_byte(self._size)
 
2148
            self._write_time(w, value, self._prec)
 
2149
            self._write_date(w, value)
 
2150
 
 
2151
    def read_fixed(self, r, size):
 
2152
        tzinfo = None
 
2153
        if r.session.tzinfo_factory is not None:
 
2154
            tzinfo = r.session.tzinfo_factory(0)
 
2155
        time = self._read_time(r, size - 3, self._prec, tzinfo)
 
2156
        date = self._read_date(r)
 
2157
        return datetime.combine(date, time)
 
2158
 
 
2159
    def read(self, r):
 
2160
        size = r.get_byte()
 
2161
        if size == 0:
 
2162
            return None
 
2163
        return self.read_fixed(r, size)
 
2164
 
 
2165
 
 
2166
class DateTimeOffset(BaseDateTime73):
 
2167
    type = SYBMSDATETIMEOFFSET
 
2168
 
 
2169
    def __init__(self, prec=7):
 
2170
        self._prec = prec
 
2171
        self._size = self._precision_to_len[prec] + 5
 
2172
 
 
2173
    @classmethod
 
2174
    def from_stream(cls, r):
 
2175
        prec = r.get_byte()
 
2176
        return cls(prec)
 
2177
 
 
2178
    @classmethod
 
2179
    def from_declaration(cls, declaration, nullable, connection):
 
2180
        if declaration == 'DATETIMEOFFSET':
 
2181
            return cls()
 
2182
        m = re.match(r'DATETIMEOFFSET\((\d+)\)', declaration)
 
2183
        if m:
 
2184
            return cls(int(m.group(1)))
 
2185
    
 
2186
    def get_declaration(self):
 
2187
        return 'DATETIMEOFFSET({0})'.format(self._prec)
 
2188
 
 
2189
    def write_info(self, w):
 
2190
        w.put_byte(self._prec)
 
2191
 
 
2192
    def write(self, w, value):
 
2193
        if value is None:
 
2194
            w.put_byte(0)
 
2195
        else:
 
2196
            utcoffset = value.utcoffset()
 
2197
            value = value.astimezone(_utc).replace(tzinfo=None)
 
2198
 
 
2199
            w.put_byte(self._size)
 
2200
            self._write_time(w, value, self._prec)
 
2201
            self._write_date(w, value)
 
2202
            w.put_smallint(int(total_seconds(utcoffset)) // 60)
 
2203
 
 
2204
    def read_fixed(self, r, size):
 
2205
        time = self._read_time(r, size - 5, self._prec, _utc)
 
2206
        date = self._read_date(r)
 
2207
        offset = r.get_smallint()
 
2208
        tzinfo_factory = r._session.tzinfo_factory
 
2209
        if tzinfo_factory is None:
 
2210
            from .tz import FixedOffsetTimezone
 
2211
            tzinfo_factory = FixedOffsetTimezone
 
2212
        tz = tzinfo_factory(offset)
 
2213
        return datetime.combine(date, time).astimezone(tz)
 
2214
 
 
2215
    def read(self, r):
 
2216
        size = r.get_byte()
 
2217
        if size == 0:
 
2218
            return None
 
2219
        return self.read_fixed(r, size)
 
2220
 
 
2221
 
 
2222
class MsDecimal(BaseType):
 
2223
    type = SYBDECIMAL
 
2224
 
 
2225
    _max_size = 17
 
2226
 
 
2227
    _bytes_per_prec = [
 
2228
        #
 
2229
        # precision can't be 0 but using a value > 0 assure no
 
2230
        # core if for some bug it's 0...
 
2231
        #
 
2232
        1,
 
2233
        5, 5, 5, 5, 5, 5, 5, 5, 5,
 
2234
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
 
2235
        13, 13, 13, 13, 13, 13, 13, 13, 13,
 
2236
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
 
2237
        ]
 
2238
 
 
2239
    _info_struct = struct.Struct('BBB')
 
2240
 
 
2241
    @property
 
2242
    def scale(self):
 
2243
        return self._scale
 
2244
 
 
2245
    @property
 
2246
    def precision(self):
 
2247
        return self._prec
 
2248
 
 
2249
    def __init__(self, scale=0, prec=18):
 
2250
        if prec > 38:
 
2251
            raise DataError('Precision of decimal value is out of range')
 
2252
        self._scale = scale
 
2253
        self._prec = prec
 
2254
        self._size = self._bytes_per_prec[prec]
 
2255
 
 
2256
    @classmethod
 
2257
    def from_value(cls, value):
 
2258
        if not (-10 ** 38 + 1 <= value <= 10 ** 38 - 1):
 
2259
            raise DataError('Decimal value is out of range')
 
2260
        value = value.normalize()
 
2261
        _, digits, exp = value.as_tuple()
 
2262
        if exp > 0:
 
2263
            scale = 0
 
2264
            prec = len(digits) + exp
 
2265
        else:
 
2266
            scale = -exp
 
2267
            prec = max(len(digits), scale)
 
2268
        return cls(scale=scale, prec=prec)
 
2269
 
 
2270
    @classmethod
 
2271
    def from_stream(cls, r):
 
2272
        size, prec, scale = r.unpack(cls._info_struct)
 
2273
        return cls(scale=scale, prec=prec)
 
2274
 
 
2275
    @classmethod
 
2276
    def from_declaration(cls, declaration, nullable, connection):
 
2277
        if declaration == 'DECIMAL':
 
2278
            return cls()
 
2279
        m = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', declaration)
 
2280
        if m:
 
2281
            return cls(int(m.group(2)), int(m.group(1)))
 
2282
 
 
2283
    def get_declaration(self):
 
2284
        return 'DECIMAL({0},{1})'.format(self._prec, self._scale)
 
2285
 
 
2286
    def write_info(self, w):
 
2287
        w.pack(self._info_struct, self._size, self._prec, self._scale)
 
2288
 
 
2289
    def write(self, w, value):
 
2290
        if value is None:
 
2291
            w.put_byte(0)
 
2292
            return
 
2293
        if not isinstance(value, Decimal):
 
2294
            value = Decimal(value)
 
2295
        value = value.normalize()
 
2296
        scale = self._scale
 
2297
        size = self._size
 
2298
        w.put_byte(size)
 
2299
        val = value
 
2300
        positive = 1 if val > 0 else 0
 
2301
        w.put_byte(positive)  # sign
 
2302
        with localcontext() as ctx:
 
2303
            ctx.prec = 38
 
2304
            if not positive:
 
2305
                val *= -1
 
2306
            size -= 1
 
2307
            val = val * (10 ** scale)
 
2308
        for i in range(size):
 
2309
            w.put_byte(int(val % 256))
 
2310
            val //= 256
 
2311
        assert val == 0
 
2312
 
 
2313
    def _decode(self, positive, buf):
 
2314
        val = _decode_num(buf)
 
2315
        val = Decimal(val)
 
2316
        with localcontext() as ctx:
 
2317
            ctx.prec = 38
 
2318
            if not positive:
 
2319
                val *= -1
 
2320
            val /= 10 ** self._scale
 
2321
        return val
 
2322
 
 
2323
    def read_fixed(self, r, size):
 
2324
        positive = r.get_byte()
 
2325
        buf = readall(r, size - 1)
 
2326
        return self._decode(positive, buf)
 
2327
 
 
2328
    def read(self, r):
 
2329
        size = r.get_byte()
 
2330
        if size <= 0:
 
2331
            return None
 
2332
        return self.read_fixed(r, size)
 
2333
 
 
2334
 
 
2335
class Money4(BasePrimitiveType):
 
2336
    type = SYBMONEY4
 
2337
    declaration = 'SMALLMONEY'
 
2338
 
 
2339
    def read(self, r):
 
2340
        return Decimal(r.get_int()) / 10000
 
2341
 
 
2342
    def write(self, w, val):
 
2343
        val = int(val * 10000)
 
2344
        w.put_int(val)
 
2345
 
 
2346
Money4.instance = Money4()
 
2347
 
 
2348
 
 
2349
class Money8(BasePrimitiveType):
 
2350
    type = SYBMONEY
 
2351
    declaration = 'MONEY'
 
2352
    
 
2353
    _struct = struct.Struct('<lL')
 
2354
 
 
2355
    def read(self, r):
 
2356
        hi, lo = r.unpack(self._struct)
 
2357
        val = hi * (2 ** 32) + lo
 
2358
        return Decimal(val) / 10000
 
2359
 
 
2360
    def write(self, w, val):
 
2361
        val = val * 10000
 
2362
        hi = int(val // (2 ** 32))
 
2363
        lo = int(val % (2 ** 32))
 
2364
        w.pack(self._struct, hi, lo)
 
2365
 
 
2366
Money8.instance = Money8()
 
2367
 
 
2368
 
 
2369
class MoneyN(BaseTypeN):
 
2370
    type = SYBMONEYN
 
2371
    
 
2372
    subtypes = {
 
2373
        4: Money4.instance,
 
2374
        8: Money8.instance,
 
2375
        }
 
2376
 
 
2377
class MsUnique(BaseType):
 
2378
    type = SYBUNIQUE
 
2379
    declaration = 'UNIQUEIDENTIFIER'
 
2380
 
 
2381
    @classmethod
 
2382
    def from_stream(cls, r):
 
2383
        size = r.get_byte()
 
2384
        if size != 16:
 
2385
            raise InterfaceError('Invalid size of UNIQUEIDENTIFIER field')
 
2386
        return cls.instance
 
2387
 
 
2388
    @classmethod
 
2389
    def from_declaration(cls, declaration, nullable, connection):
 
2390
        if declaration == cls.declaration:
 
2391
            return cls.instance
 
2392
 
 
2393
    def get_declaration(self):
 
2394
        return self.declaration
 
2395
 
 
2396
    def write_info(self, w):
 
2397
        w.put_byte(16)
 
2398
 
 
2399
    def write(self, w, value):
 
2400
        if value is None:
 
2401
            w.put_byte(0)
 
2402
        else:
 
2403
            w.put_byte(16)
 
2404
            w.write(value.bytes_le)
 
2405
 
 
2406
    def read_fixed(self, r, size):
 
2407
        return uuid.UUID(bytes_le=readall(r, size))
 
2408
 
 
2409
    def read(self, r):
 
2410
        size = r.get_byte()
 
2411
        if size == 0:
 
2412
            return None
 
2413
        if size != 16:
 
2414
            raise InterfaceError('Invalid size of UNIQUEIDENTIFIER field')
 
2415
        return self.read_fixed(r, size)
 
2416
MsUnique.instance = MsUnique()
 
2417
 
 
2418
 
 
2419
def _variant_read_str(r, size):
 
2420
    collation = r.get_collation()
 
2421
    r.get_usmallint()
 
2422
    return r.read_str(size, collation.get_codec())
 
2423
 
 
2424
 
 
2425
def _variant_read_nstr(r, size):
 
2426
    r.get_collation()
 
2427
    r.get_usmallint()
 
2428
    return r.read_str(size, ucs2_codec)
 
2429
 
 
2430
 
 
2431
def _variant_read_decimal(r, size):
 
2432
    prec, scale = r.unpack(Variant._decimal_info_struct)
 
2433
    return MsDecimal(prec=prec, scale=scale).read_fixed(r, size)
 
2434
 
 
2435
 
 
2436
def _variant_read_binary(r, size):
 
2437
    r.get_usmallint()
 
2438
    return readall(r, size)
 
2439
 
 
2440
 
 
2441
class Variant(BaseType):
 
2442
    type = SYBVARIANT
 
2443
    declaration = 'SQL_VARIANT'
 
2444
 
 
2445
    _decimal_info_struct = struct.Struct('BB')
 
2446
 
 
2447
    _type_map = {
 
2448
        GUIDTYPE: lambda r, size: MsUnique.instance.read_fixed(r, size),
 
2449
        BITTYPE: lambda r, size: Bit.instance.read(r),
 
2450
        INT1TYPE: lambda r, size: TinyInt.instance.read(r),
 
2451
        INT2TYPE: lambda r, size: SmallInt.instance.read(r),
 
2452
        INT4TYPE: lambda r, size: Int.instance.read(r),
 
2453
        INT8TYPE: lambda r, size: BigInt.instance.read(r),
 
2454
        DATETIMETYPE: lambda r, size: DateTime.instance.read(r),
 
2455
        DATETIM4TYPE: lambda r, size: SmallDateTime.instance.read(r),
 
2456
        FLT4TYPE: lambda r, size: Real.instance.read(r),
 
2457
        FLT8TYPE: lambda r, size: Float.instance.read(r),
 
2458
        MONEYTYPE: lambda r, size: Money8.instance.read(r),
 
2459
        MONEY4TYPE: lambda r, size: Money4.instance.read(r),
 
2460
        DATENTYPE: lambda r, size: MsDate.instance.read_fixed(r),
 
2461
 
 
2462
        TIMENTYPE: lambda r, size: MsTime(prec=r.get_byte()).read_fixed(r, size),
 
2463
        DATETIME2NTYPE: lambda r, size: DateTime2(prec=r.get_byte()).read_fixed(r, size),
 
2464
        DATETIMEOFFSETNTYPE: lambda r, size: DateTimeOffset(prec=r.get_byte()).read_fixed(r, size),
 
2465
 
 
2466
        BIGVARBINTYPE: _variant_read_binary,
 
2467
        BIGBINARYTYPE: _variant_read_binary,
 
2468
 
 
2469
        NUMERICNTYPE: _variant_read_decimal,
 
2470
        DECIMALNTYPE: _variant_read_decimal,
 
2471
 
 
2472
        BIGVARCHRTYPE: _variant_read_str,
 
2473
        BIGCHARTYPE: _variant_read_str,
 
2474
        NVARCHARTYPE: _variant_read_nstr,
 
2475
        NCHARTYPE: _variant_read_nstr,
 
2476
 
 
2477
        }
 
2478
 
 
2479
    def __init__(self, size):
 
2480
        self._size = size
 
2481
 
 
2482
    def get_declaration(self):
 
2483
        return self.declaration
 
2484
 
 
2485
    @classmethod
 
2486
    def from_stream(cls, r):
 
2487
        size = r.get_int()
 
2488
        return Variant(size)
 
2489
 
 
2490
    @classmethod
 
2491
    def from_declaration(cls, declaration, nullable, connection):
 
2492
        if declaration == cls.declaration:
 
2493
            return cls(0)
 
2494
 
 
2495
    def write_info(self, w):
 
2496
        w.put_int(self._size)
 
2497
 
 
2498
    def read(self, r):
 
2499
        size = r.get_int()
 
2500
        if size == 0:
 
2501
            return None
 
2502
 
 
2503
        type_id = r.get_byte()
 
2504
        prop_bytes = r.get_byte()
 
2505
        type_factory = self._type_map.get(type_id)
 
2506
        if not type_factory:
 
2507
            r.session.bad_stream('Variant type invalid', type_id)
 
2508
        return type_factory(r, size - prop_bytes - 2)
 
2509
 
 
2510
    def write(self, w, val):
 
2511
        if val is None:
 
2512
            w.put_int(0)
 
2513
            return
 
2514
        raise NotImplementedError
 
2515
 
 
2516
 
 
2517
_type_map = {
 
2518
    SYBINT1: TinyInt,
 
2519
    SYBINT2: SmallInt,
 
2520
    SYBINT4: Int,
 
2521
    SYBINT8: BigInt,
 
2522
    SYBINTN: IntN,
 
2523
    SYBBIT: Bit,
 
2524
    SYBBITN: BitN,
 
2525
    SYBREAL: Real,
 
2526
    SYBFLT8: Float,
 
2527
    SYBFLTN: FloatN,
 
2528
    SYBMONEY4: Money4,
 
2529
    SYBMONEY: Money8,
 
2530
    SYBMONEYN: MoneyN,
 
2531
    XSYBCHAR: VarChar70,
 
2532
    XSYBVARCHAR: VarChar70,
 
2533
    XSYBNCHAR: NVarChar70,
 
2534
    XSYBNVARCHAR: NVarChar70,
 
2535
    SYBTEXT: Text70,
 
2536
    SYBNTEXT: NText70,
 
2537
    SYBMSXML: Xml,
 
2538
    XSYBBINARY: VarBinary,
 
2539
    XSYBVARBINARY: VarBinary,
 
2540
    SYBIMAGE: Image70,
 
2541
    SYBNUMERIC: MsDecimal,
 
2542
    SYBDECIMAL: MsDecimal,
 
2543
    SYBVARIANT: Variant,
 
2544
    SYBMSDATE: MsDate,
 
2545
    SYBMSTIME: MsTime,
 
2546
    SYBMSDATETIME2: DateTime2,
 
2547
    SYBMSDATETIMEOFFSET: DateTimeOffset,
 
2548
    SYBDATETIME4: SmallDateTime,
 
2549
    SYBDATETIME: DateTime,
 
2550
    SYBDATETIMN: DateTimeN,
 
2551
    SYBUNIQUE: MsUnique,
 
2552
    }
 
2553
 
 
2554
_type_map71 = _type_map.copy()
 
2555
_type_map71.update({
 
2556
    XSYBCHAR: VarChar71,
 
2557
    XSYBNCHAR: NVarChar71,
 
2558
    XSYBVARCHAR: VarChar71,
 
2559
    XSYBNVARCHAR: NVarChar71,
 
2560
    SYBTEXT: Text71,
 
2561
    SYBNTEXT: NText71,
 
2562
    })
 
2563
 
 
2564
_type_map72 = _type_map.copy()
 
2565
_type_map72.update({
 
2566
    XSYBCHAR: VarChar72,
 
2567
    XSYBNCHAR: NVarChar72,
 
2568
    XSYBVARCHAR: VarChar72,
 
2569
    XSYBNVARCHAR: NVarChar72,
 
2570
    SYBTEXT: Text72,
 
2571
    SYBNTEXT: NText72,
 
2572
    XSYBBINARY: VarBinary72,
 
2573
    XSYBVARBINARY: VarBinary72,
 
2574
    SYBIMAGE: Image72,
 
2575
    })
 
2576
 
 
2577
 
 
2578
def _create_exception_by_message(msg, custom_error_msg = None):
 
2579
    msg_no = msg['msgno']
 
2580
    if custom_error_msg is not None:
 
2581
        error_msg = custom_error_msg
 
2582
    else:
 
2583
        error_msg = msg['message']
 
2584
    if msg_no in prog_errors:
 
2585
        ex = ProgrammingError(error_msg)
 
2586
    elif msg_no in integrity_errors:
 
2587
        ex = IntegrityError(error_msg)
 
2588
    else:
 
2589
        ex = OperationalError(error_msg)
 
2590
    ex.msg_no = msg['msgno']
 
2591
    ex.text = msg['message']
 
2592
    ex.srvname = msg['server']
 
2593
    ex.procname = msg['proc_name']
 
2594
    ex.number = msg['msgno']
 
2595
    ex.severity = msg['severity']
 
2596
    ex.state = msg['state']
 
2597
    ex.line = msg['line_number']
 
2598
    return ex
 
2599
 
 
2600
 
 
2601
class _TdsSession(object):
 
2602
    """ TDS session
 
2603
 
 
2604
    Represents a single TDS session within MARS connection, when MARS enabled there could be multiple TDS sessions
 
2605
    within one connection.
 
2606
    """
 
2607
    def __init__(self, tds, transport, tzinfo_factory):
 
2608
        self.out_pos = 8
 
2609
        self.res_info = None
 
2610
        self.in_cancel = False
 
2611
        self.wire_mtx = None
 
2612
        self.param_info = None
 
2613
        self.has_status = False
 
2614
        self.ret_status = None
 
2615
        self.skipped_to_status = False
 
2616
        self._transport = transport
 
2617
        self._reader = _TdsReader(self)
 
2618
        self._reader._transport = transport
 
2619
        self._writer = _TdsWriter(self, tds._bufsize)
 
2620
        self._writer._transport = transport
 
2621
        self.in_buf_max = 0
 
2622
        self.state = TDS_IDLE
 
2623
        self._tds = tds
 
2624
        self.messages = []
 
2625
        self.chunk_handler = tds.chunk_handler
 
2626
        self.rows_affected = -1
 
2627
        self.use_tz = tds.use_tz
 
2628
        self._spid = 0
 
2629
        self.tzinfo_factory = tzinfo_factory
 
2630
 
 
2631
    def __repr__(self):
 
2632
        fmt = "<_TdsSession state={} tds={} messages={} rows_affected={} use_tz={} spid={} in_cancel={}>"
 
2633
        res = fmt.format(repr(self.state), repr(self._tds), repr(self.messages),
 
2634
                         repr(self.rows_affected), repr(self.use_tz), repr(self._spid),
 
2635
                         self.in_cancel)
 
2636
        return res
 
2637
 
 
2638
    def raise_db_exception(self):
 
2639
        """ Raises exception from last server message
 
2640
 
 
2641
        This function will skip messages: The statement has been terminated
 
2642
        """
 
2643
        if not self.messages:
 
2644
            raise Error("Request failed, server didn't send error message")
 
2645
        while True:
 
2646
            msg = self.messages[-1]
 
2647
            if msg['msgno'] == 3621:  # the statement has been terminated
 
2648
                self.messages = self.messages[:-1]
 
2649
            else:
 
2650
                break
 
2651
 
 
2652
        error_msg = ' '.join(msg['message'] for msg in self.messages)
 
2653
        ex = _create_exception_by_message(msg, error_msg)
 
2654
        raise ex
 
2655
 
 
2656
    def get_type_info(self, curcol):
 
2657
        """ Reads TYPE_INFO structure (http://msdn.microsoft.com/en-us/library/dd358284.aspx)
 
2658
 
 
2659
        :param curcol: An instance of :class:`Column` that will receive read information
 
2660
        """
 
2661
        r = self._reader
 
2662
        # User defined data type of the column
 
2663
        curcol.column_usertype = r.get_uint() if IS_TDS72_PLUS(self) else r.get_usmallint()
 
2664
        curcol.flags = r.get_usmallint()  # Flags
 
2665
        curcol.column_nullable = curcol.flags & Column.fNullable
 
2666
        curcol.column_writeable = (curcol.flags & Column.fReadWrite) > 0
 
2667
        curcol.column_identity = (curcol.flags & Column.fIdentity) > 0
 
2668
        type_id = r.get_byte()
 
2669
        type_class = self._tds._type_map.get(type_id)
 
2670
        if not type_class:
 
2671
            raise InterfaceError('Invalid type id', type_id)
 
2672
        curcol.type = type_class.from_stream(r)
 
2673
 
 
2674
    def tds7_process_result(self):
 
2675
        """ Reads and processes COLMETADATA stream
 
2676
 
 
2677
        This stream contains a list of returned columns.
 
2678
        Stream format link: http://msdn.microsoft.com/en-us/library/dd357363.aspx
 
2679
        """
 
2680
        r = self._reader
 
2681
        #logger.debug("processing TDS7 result metadata.")
 
2682
 
 
2683
        # read number of columns and allocate the columns structure
 
2684
 
 
2685
        num_cols = r.get_smallint()
 
2686
 
 
2687
        # This can be a DUMMY results token from a cursor fetch
 
2688
 
 
2689
        if num_cols == -1:
 
2690
            #logger.debug("no meta data")
 
2691
            return
 
2692
 
 
2693
        self.param_info = None
 
2694
        self.has_status = False
 
2695
        self.ret_status = None
 
2696
        self.skipped_to_status = False
 
2697
        self.rows_affected = TDS_NO_COUNT
 
2698
        self.more_rows = True
 
2699
        self.row = [None] * num_cols
 
2700
        self.res_info = info = _Results()
 
2701
 
 
2702
        #
 
2703
        # loop through the columns populating COLINFO struct from
 
2704
        # server response
 
2705
        #
 
2706
        #logger.debug("setting up {0} columns".format(num_cols))
 
2707
        header_tuple = []
 
2708
        for col in range(num_cols):
 
2709
            curcol = Column()
 
2710
            info.columns.append(curcol)
 
2711
            self.get_type_info(curcol)
 
2712
 
 
2713
            #
 
2714
            # under 7.0 lengths are number of characters not
 
2715
            # number of bytes... read_ucs2 handles this
 
2716
            #
 
2717
            curcol.column_name = r.read_ucs2(r.get_byte())
 
2718
            precision = curcol.type.precision if hasattr(curcol.type, 'precision') else None
 
2719
            scale = curcol.type.scale if hasattr(curcol.type, 'scale') else None
 
2720
            size = curcol.type._size if hasattr(curcol.type, '_size') else None
 
2721
            header_tuple.append((curcol.column_name, curcol.type.get_typeid(), None, size, precision, scale, curcol.column_nullable))
 
2722
        info.description = tuple(header_tuple)
 
2723
        return info
 
2724
 
 
2725
    def process_param(self):
 
2726
        """ Reads and processes RETURNVALUE stream.
 
2727
 
 
2728
        This stream is used to send OUTPUT parameters from RPC to client.
 
2729
        Stream format url: http://msdn.microsoft.com/en-us/library/dd303881.aspx
 
2730
        """
 
2731
        r = self._reader
 
2732
        if IS_TDS72_PLUS(self):
 
2733
            ordinal = r.get_usmallint()
 
2734
        else:
 
2735
            r.get_usmallint()  # ignore size
 
2736
            ordinal = self._out_params_indexes[self.return_value_index]
 
2737
        name = r.read_ucs2(r.get_byte())
 
2738
        r.get_byte()  # 1 - OUTPUT of sp, 2 - result of udf
 
2739
        param = Column()
 
2740
        param.column_name = name
 
2741
        self.get_type_info(param)
 
2742
        param.value = param.type.read(r)
 
2743
        self.output_params[ordinal] = param
 
2744
        self.return_value_index += 1
 
2745
 
 
2746
    def process_cancel(self):
 
2747
        """
 
2748
        Process the incoming token stream until it finds
 
2749
        an end token DONE with the cancel flag set.
 
2750
        At that point the connection should be ready to handle a new query.
 
2751
 
 
2752
        In case when no cancel request is pending this function does nothing.
 
2753
        """
 
2754
        # silly cases, nothing to do
 
2755
        if not self.in_cancel:
 
2756
            return
 
2757
 
 
2758
        while True:
 
2759
            token_id = self.get_token_id()
 
2760
            self.process_token(token_id)
 
2761
            if not self.in_cancel:
 
2762
                return
 
2763
 
 
2764
    def process_msg(self, marker):
 
2765
        """ Reads and processes ERROR/INFO streams
 
2766
 
 
2767
        Stream formats:
 
2768
 
 
2769
        - ERROR: http://msdn.microsoft.com/en-us/library/dd304156.aspx
 
2770
        - INFO: http://msdn.microsoft.com/en-us/library/dd303398.aspx
 
2771
 
 
2772
        :param marker: TDS_ERROR_TOKEN or TDS_INFO_TOKEN
 
2773
        """
 
2774
        r = self._reader
 
2775
        r.get_smallint()  # size
 
2776
        msg = {}
 
2777
        msg['marker'] = marker
 
2778
        msg['msgno'] = r.get_int()
 
2779
        msg['state'] = r.get_byte()
 
2780
        msg['severity'] = r.get_byte()
 
2781
        msg['sql_state'] = None
 
2782
        has_eed = False
 
2783
        if marker == TDS_EED_TOKEN:
 
2784
            if msg['severity'] <= 10:
 
2785
                msg['priv_msg_type'] = 0
 
2786
            else:
 
2787
                msg['priv_msg_type'] = 1
 
2788
            len_sqlstate = r.get_byte()
 
2789
            msg['sql_state'] = readall(r, len_sqlstate)
 
2790
            has_eed = r.get_byte()
 
2791
            # junk status and transaction state
 
2792
            r.get_smallint()
 
2793
        elif marker == TDS_INFO_TOKEN:
 
2794
            msg['priv_msg_type'] = 0
 
2795
        elif marker == TDS_ERROR_TOKEN:
 
2796
            msg['priv_msg_type'] = 1
 
2797
        else:
 
2798
            logger.error('tds_process_msg() called with unknown marker "{0}"'.format(marker))
 
2799
        #logger.debug('tds_process_msg() reading message {0} from server'.format(msg['msgno']))
 
2800
        msg['message'] = r.read_ucs2(r.get_smallint())
 
2801
        # server name
 
2802
        msg['server'] = r.read_ucs2(r.get_byte())
 
2803
        # stored proc name if available
 
2804
        msg['proc_name'] = r.read_ucs2(r.get_byte())
 
2805
        msg['line_number'] = r.get_int() if IS_TDS72_PLUS(self) else r.get_smallint()
 
2806
        if not msg['sql_state']:
 
2807
            #msg['sql_state'] = tds_alloc_lookup_sqlstate(self, msg['msgno'])
 
2808
            pass
 
2809
        # in case extended error data is sent, we just try to discard it
 
2810
        if has_eed:
 
2811
            while True:
 
2812
                next_marker = r.get_byte()
 
2813
                if next_marker in (TDS5_PARAMFMT_TOKEN, TDS5_PARAMFMT2_TOKEN, TDS5_PARAMS_TOKEN):
 
2814
                    self.process_token(next_marker)
 
2815
                else:
 
2816
                    break
 
2817
            r.unget_byte()
 
2818
 
 
2819
        # special case
 
2820
        self.messages.append(msg)
 
2821
 
 
2822
    def process_row(self):
 
2823
        """ Reads and handles ROW stream.
 
2824
 
 
2825
        This stream contains list of values of one returned row.
 
2826
        Stream format url: http://msdn.microsoft.com/en-us/library/dd357254.aspx
 
2827
        """
 
2828
        r = self._reader
 
2829
        info = self.res_info
 
2830
        info.row_count += 1
 
2831
        for i, curcol in enumerate(info.columns):
 
2832
            curcol.value = self.row[i] = curcol.type.read(r)
 
2833
 
 
2834
    def process_nbcrow(self):
 
2835
        """ Reads and handles NBCROW stream.
 
2836
 
 
2837
        This stream contains list of values of one returned row in a compressed way,
 
2838
        introduced in TDS 7.3.B
 
2839
        Stream format url: http://msdn.microsoft.com/en-us/library/dd304783.aspx
 
2840
        """
 
2841
        r = self._reader
 
2842
        info = self.res_info
 
2843
        if not info:
 
2844
            self.bad_stream('got row without info')
 
2845
        assert len(info.columns) > 0
 
2846
        info.row_count += 1
 
2847
 
 
2848
        # reading bitarray for nulls, 1 represent null values for
 
2849
        # corresponding fields
 
2850
        nbc = readall(r, (len(info.columns) + 7) // 8)
 
2851
        for i, curcol in enumerate(info.columns):
 
2852
            if _ord(nbc[i // 8]) & (1 << (i % 8)):
 
2853
                value = None
 
2854
            else:
 
2855
                value = curcol.type.read(r)
 
2856
            self.row[i] = value
 
2857
 
 
2858
    def process_orderby(self):
 
2859
        """ Reads and processes ORDER stream
 
2860
 
 
2861
        Used to inform client by which column dataset is ordered.
 
2862
        Stream format url: http://msdn.microsoft.com/en-us/library/dd303317.aspx
 
2863
        """
 
2864
        r = self._reader
 
2865
        skipall(r, r.get_smallint())
 
2866
 
 
2867
    def process_orderby2(self):
 
2868
        r = self._reader
 
2869
        skipall(r, r.get_int())
 
2870
 
 
2871
    def process_end(self, marker):
 
2872
        """ Reads and processes DONE/DONEINPROC/DONEPROC streams
 
2873
 
 
2874
        Stream format urls:
 
2875
 
 
2876
        - DONE: http://msdn.microsoft.com/en-us/library/dd340421.aspx
 
2877
        - DONEINPROC: http://msdn.microsoft.com/en-us/library/dd340553.aspx
 
2878
        - DONEPROC: http://msdn.microsoft.com/en-us/library/dd340753.aspx
 
2879
 
 
2880
        :param marker: Can be TDS_DONE_TOKEN or TDS_DONEINPROC_TOKEN or TDS_DONEPROC_TOKEN
 
2881
        """
 
2882
        self.more_rows = False
 
2883
        r = self._reader
 
2884
        status = r.get_usmallint()
 
2885
        r.get_usmallint()  # cur_cmd
 
2886
        more_results = status & TDS_DONE_MORE_RESULTS != 0
 
2887
        was_cancelled = status & TDS_DONE_CANCELLED != 0
 
2888
        #error = status & TDS_DONE_ERROR != 0
 
2889
        done_count_valid = status & TDS_DONE_COUNT != 0
 
2890
        #logger.debug(
 
2891
        #    'process_end: more_results = {0}\n'
 
2892
        #    '\t\twas_cancelled = {1}\n'
 
2893
        #    '\t\terror = {2}\n'
 
2894
        #    '\t\tdone_count_valid = {3}'.format(more_results, was_cancelled, error, done_count_valid))
 
2895
        if self.res_info:
 
2896
            self.res_info.more_results = more_results
 
2897
        rows_affected = r.get_int8() if IS_TDS72_PLUS(self) else r.get_int()
 
2898
        #logger.debug('\t\trows_affected = {0}'.format(rows_affected))
 
2899
        if was_cancelled or (not more_results and not self.in_cancel):
 
2900
            #logger.debug('process_end() state set to TDS_IDLE')
 
2901
            self.in_cancel = False
 
2902
            self.set_state(TDS_IDLE)
 
2903
        if done_count_valid:
 
2904
            self.rows_affected = rows_affected
 
2905
        else:
 
2906
            self.rows_affected = -1
 
2907
        self.done_flags = status
 
2908
        if self.done_flags & TDS_DONE_ERROR and not was_cancelled and not self.in_cancel:
 
2909
            self.raise_db_exception()
 
2910
 
 
2911
    def process_env_chg(self):
 
2912
        """ Reads and processes ENVCHANGE stream.
 
2913
 
 
2914
        Stream info url: http://msdn.microsoft.com/en-us/library/dd303449.aspx
 
2915
        """
 
2916
        r = self._reader
 
2917
        size = r.get_smallint()
 
2918
        type = r.get_byte()
 
2919
        #logger.debug("process_env_chg: type: {0}".format(type))
 
2920
        if type == TDS_ENV_SQLCOLLATION:
 
2921
            size = r.get_byte()
 
2922
            #logger.debug("process_env_chg(): {0} bytes of collation data received".format(size))
 
2923
            #logger.debug("self.collation was {0}".format(self.conn.collation))
 
2924
            self.conn.collation = r.get_collation()
 
2925
            skipall(r, size - 5)
 
2926
            #tds7_srv_charset_changed(tds, tds.conn.collation)
 
2927
            #logger.debug("self.collation now {0}".format(self.conn.collation))
 
2928
            # discard old one
 
2929
            skipall(r, r.get_byte())
 
2930
        elif type == TDS_ENV_BEGINTRANS:
 
2931
            size = r.get_byte()
 
2932
            # TODO: parse transaction
 
2933
            self.conn.tds72_transaction = r.get_uint8()
 
2934
            skipall(r, r.get_byte())
 
2935
        elif type == TDS_ENV_COMMITTRANS or type == TDS_ENV_ROLLBACKTRANS:
 
2936
            self.conn.tds72_transaction = 0
 
2937
            skipall(r, r.get_byte())
 
2938
            skipall(r, r.get_byte())
 
2939
        elif type == TDS_ENV_PACKSIZE:
 
2940
            newval = r.read_ucs2(r.get_byte())
 
2941
            r.read_ucs2(r.get_byte())
 
2942
            new_block_size = int(newval)
 
2943
            if new_block_size >= 512:
 
2944
                #logger.info("changing block size from {0} to {1}".format(oldval, new_block_size))
 
2945
                #
 
2946
                # Is possible to have a shrink if server limits packet
 
2947
                # size more than what we specified
 
2948
                #
 
2949
                # Reallocate buffer if possible (strange values from server or out of memory) use older buffer */
 
2950
                self._writer.bufsize = new_block_size
 
2951
        elif type == TDS_ENV_DATABASE:
 
2952
            newval = r.read_ucs2(r.get_byte())
 
2953
            r.read_ucs2(r.get_byte())
 
2954
            self.conn.env.database = newval
 
2955
        elif type == TDS_ENV_LANG:
 
2956
            newval = r.read_ucs2(r.get_byte())
 
2957
            r.read_ucs2(r.get_byte())
 
2958
            self.conn.env.language = newval
 
2959
        elif type == TDS_ENV_CHARSET:
 
2960
            newval = r.read_ucs2(r.get_byte())
 
2961
            r.read_ucs2(r.get_byte())
 
2962
            #logger.debug("server indicated charset change to \"{0}\"\n".format(newval))
 
2963
            self.conn.env.charset = newval
 
2964
            remap = {'iso_1': 'iso8859-1'}
 
2965
            self.conn.server_codec = codecs.lookup(remap.get(newval, newval))
 
2966
            #tds_srv_charset_changed(self, newval)
 
2967
        elif type == TDS_ENV_DB_MIRRORING_PARTNER:
 
2968
            r.read_ucs2(r.get_byte())
 
2969
            r.read_ucs2(r.get_byte())
 
2970
        elif type == TDS_ENV_LCID:
 
2971
            lcid = int(r.read_ucs2(r.get_byte()))
 
2972
            self.conn.server_codec = codecs.lookup(lcid2charset(lcid))
 
2973
            r.read_ucs2(r.get_byte())
 
2974
        else:
 
2975
            logger.warning("unknown env type: {0}, skipping".format(type))
 
2976
            # discard byte values, not still supported
 
2977
            skipall(r, size - 1)
 
2978
 
 
2979
    def process_auth(self):
 
2980
        """ Reads and processes SSPI stream.
 
2981
 
 
2982
        Stream info: http://msdn.microsoft.com/en-us/library/dd302844.aspx
 
2983
        """
 
2984
        r = self._reader
 
2985
        w = self._writer
 
2986
        pdu_size = r.get_smallint()
 
2987
        if not self.authentication:
 
2988
            raise Error('Got unexpected token')
 
2989
        packet = self.authentication.handle_next(readall(r, pdu_size))
 
2990
        if packet:
 
2991
            w.write(packet)
 
2992
            w.flush()
 
2993
 
 
2994
    def is_connected(self):
 
2995
        """
 
2996
        :return: True if transport is connected
 
2997
        """
 
2998
        return self._transport.is_connected()
 
2999
 
 
3000
    def bad_stream(self, msg):
 
3001
        """ Called when input stream contains unexpected data.
 
3002
 
 
3003
        Will close stream and raise :class:`InterfaceError`
 
3004
        :param msg: Message for InterfaceError exception.
 
3005
        :return: Never returns, always raises exception.
 
3006
        """
 
3007
        self.close()
 
3008
        raise InterfaceError(msg)
 
3009
 
 
3010
    @property
 
3011
    def tds_version(self):
 
3012
        """ Returns integer encoded current TDS protocol version
 
3013
        """
 
3014
        return self._tds.tds_version
 
3015
 
 
3016
    @property
 
3017
    def conn(self):
 
3018
        """ Reference to owning :class:`_TdsSocket`
 
3019
        """
 
3020
        return self._tds
 
3021
 
 
3022
    def close(self):
 
3023
        self._transport.close()
 
3024
 
 
3025
    def set_state(self, state):
 
3026
        """ Switches state of the TDS session.
 
3027
 
 
3028
        It also does state transitions checks.
 
3029
        :param state: New state, one of TDS_PENDING/TDS_READING/TDS_IDLE/TDS_DEAD/TDS_QUERING
 
3030
        """
 
3031
        prior_state = self.state
 
3032
        if state == prior_state:
 
3033
            return state
 
3034
        if state == TDS_PENDING:
 
3035
            if prior_state in (TDS_READING, TDS_QUERYING):
 
3036
                self.state = TDS_PENDING
 
3037
            else:
 
3038
                raise InterfaceError('logic error: cannot chage query state from {0} to {1}'.
 
3039
                                     format(state_names[prior_state], state_names[state]))
 
3040
        elif state == TDS_READING:
 
3041
            # transition to READING are valid only from PENDING
 
3042
            if self.state != TDS_PENDING:
 
3043
                raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
 
3044
                                     format(state_names[prior_state], state_names[state]))
 
3045
            else:
 
3046
                self.state = state
 
3047
        elif state == TDS_IDLE:
 
3048
            if prior_state == TDS_DEAD:
 
3049
                raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
 
3050
                                     format(state_names[prior_state], state_names[state]))
 
3051
            self.state = state
 
3052
        elif state == TDS_DEAD:
 
3053
            self.state = state
 
3054
        elif state == TDS_QUERYING:
 
3055
            if self.state == TDS_DEAD:
 
3056
                raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
 
3057
                                     format(state_names[prior_state], state_names[state]))
 
3058
            elif self.state != TDS_IDLE:
 
3059
                raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
 
3060
                                     format(state_names[prior_state], state_names[state]))
 
3061
            else:
 
3062
                self.rows_affected = TDS_NO_COUNT
 
3063
                self.internal_sp_called = 0
 
3064
                self.state = state
 
3065
        else:
 
3066
            assert False
 
3067
        return self.state
 
3068
 
 
3069
    @contextmanager
 
3070
    def querying_context(self, packet_type):
 
3071
        """ Context manager for querying.
 
3072
 
 
3073
        Sets state to TDS_QUERYING, and reverts it to TDS_IDLE if exception happens inside managed block,
 
3074
        and to TDS_PENDING if managed block succeeds and flushes buffer.
 
3075
        """
 
3076
        if self.set_state(TDS_QUERYING) != TDS_QUERYING:
 
3077
            raise Error("Couldn't switch to state")
 
3078
        self._writer.begin_packet(packet_type)
 
3079
        try:
 
3080
            yield
 
3081
        except:
 
3082
            if self.state != TDS_DEAD:
 
3083
                self.set_state(TDS_IDLE)
 
3084
            raise
 
3085
        else:
 
3086
            self.set_state(TDS_PENDING)
 
3087
            self._writer.flush()
 
3088
 
 
3089
    def _autodetect_column_type(self, value, value_type):
 
3090
        """ Function guesses type of the parameter from the type of value.
 
3091
 
 
3092
        :param value: value to be passed to db, can be None
 
3093
        :param value_type: value type, if value is None, type is used instead of it
 
3094
        :return: An instance of subclass of :class:`BaseType`
 
3095
        """
 
3096
        if value is None and value_type is None:
 
3097
            return self.conn.NVarChar(1, collation=self.conn.collation)
 
3098
        assert value_type is not None
 
3099
        assert value is None or isinstance(value, value_type)
 
3100
        
 
3101
        if issubclass(value_type, bool):
 
3102
            return BitN.instance
 
3103
        elif issubclass(value_type, six.integer_types):
 
3104
            if value == None:
 
3105
                return IntN(8)
 
3106
            if -2 ** 31 <= value <= 2 ** 31 - 1:
 
3107
                return IntN(4)
 
3108
            elif -2 ** 63 <= value <= 2 ** 63 - 1:
 
3109
                return IntN(8)
 
3110
            elif -10 ** 38 + 1 <= value <= 10 ** 38 - 1:
 
3111
                return MsDecimal(0, 38)
 
3112
            else:
 
3113
                raise DataError('Numeric value out of range')
 
3114
        elif issubclass(value_type, float):
 
3115
            return FloatN(8)
 
3116
        elif issubclass(value_type, Binary):
 
3117
            return self.conn.long_binary_type()
 
3118
        elif issubclass(value_type, six.binary_type):
 
3119
            if self._tds.login.bytes_to_unicode:
 
3120
                return self.conn.long_string_type(collation=self.conn.collation)
 
3121
            else:
 
3122
                return self.conn.long_varchar_type(collation=self.conn.collation)
 
3123
        elif issubclass(value_type, six.string_types):
 
3124
            return self.conn.long_string_type(collation=self.conn.collation)
 
3125
        elif issubclass(value_type, datetime):
 
3126
            if IS_TDS73_PLUS(self):
 
3127
                if value != None and value.tzinfo and not self.use_tz:
 
3128
                    return DateTimeOffset()
 
3129
                else:
 
3130
                    return DateTime2()
 
3131
            else:
 
3132
                return DateTimeN(8)
 
3133
        elif issubclass(value_type, date):
 
3134
            if IS_TDS73_PLUS(self):
 
3135
                return MsDate.instance
 
3136
            else:
 
3137
                return DateTimeN(8)
 
3138
        elif issubclass(value_type, time):
 
3139
            if not IS_TDS73_PLUS(self):
 
3140
                raise DataError('Time type is not supported on MSSQL 2005 and lower')
 
3141
            return MsTime(6)
 
3142
        elif issubclass(value_type, Decimal):
 
3143
            if value != None:
 
3144
                return MsDecimal.from_value(value)
 
3145
            else:
 
3146
                return MsDecimal()
 
3147
        elif issubclass(value_type, uuid.UUID):
 
3148
            return MsUnique.instance
 
3149
        else:
 
3150
            raise DataError('Parameter type is not supported: {!r} {!r}'.format(value, value_type))
 
3151
 
 
3152
    def make_param(self, name, value):
 
3153
        """ Generates instance of :class:`Column` from value and name
 
3154
 
 
3155
        Value can also be of a special types:
 
3156
 
 
3157
        - An instance of :class:`Column`, in which case it is just returned.
 
3158
        - An instance of :class:`output`, in which case parameter will become
 
3159
          an output parameter.
 
3160
        - A singleton :var:`default`, in which case default value will be passed
 
3161
          into a stored proc.
 
3162
 
 
3163
        :param name: Name of the parameter, will populate column_name property of returned column.
 
3164
        :param value: Value of the parameter, also used to guess the type of parameter.
 
3165
        :return: An instance of :class:`Column`
 
3166
        """
 
3167
        if isinstance(value, Column):
 
3168
            value.column_name = name
 
3169
            return value
 
3170
        column = Column()
 
3171
        column.column_name = name
 
3172
        column.flags = 0
 
3173
        
 
3174
        if isinstance(value, output):
 
3175
            column.flags |= fByRefValue
 
3176
            if isinstance(value.type, six.string_types):
 
3177
                column.type = self._tds.type_by_declaration(value.type, True)
 
3178
            value_type = value.type or type(value.value)
 
3179
            value = value.value
 
3180
        else:
 
3181
            value_type = type(value)
 
3182
 
 
3183
        if value_type is type(None):
 
3184
            value_type = None
 
3185
            
 
3186
        if value is default:
 
3187
            column.flags |= fDefaultValue
 
3188
            value = None
 
3189
            if value_type is _Default:
 
3190
                value_type = None
 
3191
 
 
3192
        column.value = value
 
3193
        if column.type is None:
 
3194
            column.type = self._autodetect_column_type(value, value_type)
 
3195
        return column
 
3196
 
 
3197
    def _convert_params(self, parameters):
 
3198
        """ Converts a dict of list of parameters into a list of :class:`Column` instances.
 
3199
 
 
3200
        :param parameters: Can be a list of parameter values, or a dict of parameter names to values.
 
3201
        :return: A list of :class:`Column` instances.
 
3202
        """
 
3203
        if isinstance(parameters, dict):
 
3204
            return [self.make_param(name, value)
 
3205
                    for name, value in parameters.items()]
 
3206
        else:
 
3207
            params = []
 
3208
            for parameter in parameters:
 
3209
                params.append(self.make_param('', parameter))
 
3210
            return params
 
3211
 
 
3212
    def cancel_if_pending(self):
 
3213
        """ Cancels current pending request.
 
3214
 
 
3215
        Does nothing if no request is pending, otherwise sends cancel request,
 
3216
        and waits for response.
 
3217
        """
 
3218
        if self.state == TDS_IDLE:
 
3219
            return
 
3220
        if not self.in_cancel:
 
3221
            self._put_cancel()
 
3222
        self.process_cancel()
 
3223
 
 
3224
    def submit_rpc(self, rpc_name, params, flags):
 
3225
        """ Sends an RPC request.
 
3226
 
 
3227
        This call will transition session into pending state.
 
3228
        If some operation is currently pending on the session, it will be
 
3229
        cancelled before sending this request.
 
3230
 
 
3231
        Spec: http://msdn.microsoft.com/en-us/library/dd357576.aspx
 
3232
 
 
3233
        :param rpc_name: Name of the RPC to call, can be an instance of :class:`InternalProc`
 
3234
        :param params: Stored proc parameters, should be a list of :class:`Column` instances.
 
3235
        :param flags: See spec for possible flags.
 
3236
        """
 
3237
        self.messages = []
 
3238
        self.output_params = {}
 
3239
        self.cancel_if_pending()
 
3240
        self.res_info = None
 
3241
        w = self._writer
 
3242
        with self.querying_context(TDS_RPC):
 
3243
            self._START_QUERY()
 
3244
            if IS_TDS71_PLUS(self) and isinstance(rpc_name, InternalProc):
 
3245
                w.put_smallint(-1)
 
3246
                w.put_smallint(rpc_name.proc_id)
 
3247
            else:
 
3248
                if isinstance(rpc_name, InternalProc):
 
3249
                    rpc_name = rpc_name.name
 
3250
                w.put_smallint(len(rpc_name))
 
3251
                w.write_ucs2(rpc_name)
 
3252
            #
 
3253
            # TODO support flags
 
3254
            # bit 0 (1 as flag) in TDS7/TDS5 is "recompile"
 
3255
            # bit 1 (2 as flag) in TDS7+ is "no metadata" bit this will prevent sending of column infos
 
3256
            #
 
3257
            w.put_usmallint(flags)
 
3258
            self._out_params_indexes = []
 
3259
            for i, param in enumerate(params):
 
3260
                if param.flags & fByRefValue:
 
3261
                    self._out_params_indexes.append(i)
 
3262
                w.put_byte(len(param.column_name))
 
3263
                w.write_ucs2(param.column_name)
 
3264
                #
 
3265
                # TODO support other flags (use defaul null/no metadata)
 
3266
                # bit 1 (2 as flag) in TDS7+ is "default value" bit
 
3267
                # (what's the meaning of "default value" ?)
 
3268
                #
 
3269
                w.put_byte(param.flags)
 
3270
                # FIXME: column_type is wider than one byte.  Do something sensible, not just lop off the high byte.
 
3271
                w.put_byte(param.type.type)
 
3272
                param.type.write_info(w)
 
3273
                param.type.write(w, param.value)
 
3274
 
 
3275
    def submit_plain_query(self, operation):
 
3276
        """ Sends a plain query to server.
 
3277
 
 
3278
        This call will transition session into pending state.
 
3279
        If some operation is currently pending on the session, it will be
 
3280
        cancelled before sending this request.
 
3281
 
 
3282
        Spec: http://msdn.microsoft.com/en-us/library/dd358575.aspx
 
3283
 
 
3284
        :param operation: A string representing sql statement.
 
3285
        """
 
3286
        #logger.debug('submit_plain_query(%s)', operation)
 
3287
        self.messages = []
 
3288
        self.cancel_if_pending()
 
3289
        self.res_info = None
 
3290
        w = self._writer
 
3291
        with self.querying_context(TDS_QUERY):
 
3292
            self._START_QUERY()
 
3293
            w.write_ucs2(operation)
 
3294
 
 
3295
    def submit_bulk(self, metadata, rows):
 
3296
        """ Sends insert bulk command.
 
3297
 
 
3298
        Spec: http://msdn.microsoft.com/en-us/library/dd358082.aspx
 
3299
 
 
3300
        :param metadata: A list of :class:`Column` instances.
 
3301
        :param rows: A collection of rows, each row is a collection of values.
 
3302
        :return:
 
3303
        """
 
3304
        num_cols = len(metadata)
 
3305
        w = self._writer
 
3306
        with self.querying_context(TDS_BULK):
 
3307
            w.put_byte(TDS7_RESULT_TOKEN)
 
3308
            w.put_usmallint(num_cols)
 
3309
            for col in metadata:
 
3310
                if IS_TDS72_PLUS(self):
 
3311
                    w.put_uint(col.column_usertype)
 
3312
                else:
 
3313
                    w.put_usmallint(col.column_usertype)
 
3314
                w.put_usmallint(col.flags)
 
3315
                w.put_byte(col.type.type)
 
3316
                col.type.write_info(w)
 
3317
                w.put_byte(len(col.column_name))
 
3318
                w.write_ucs2(col.column_name)
 
3319
            for row in rows:
 
3320
                w.put_byte(TDS_ROW_TOKEN)
 
3321
                for i, col in enumerate(metadata):
 
3322
                    col.type.write(w, row[i])
 
3323
 
 
3324
            w.put_byte(TDS_DONE_TOKEN)
 
3325
            w.put_usmallint(TDS_DONE_FINAL)
 
3326
            w.put_usmallint(0)  # curcmd
 
3327
            if IS_TDS72_PLUS(self):
 
3328
                w.put_int8(0)
 
3329
            else:
 
3330
                w.put_int(0)
 
3331
 
 
3332
    def _put_cancel(self):
 
3333
        """ Sends a cancel request to the server.
 
3334
 
 
3335
        Switches connection to IN_CANCEL state.
 
3336
        """
 
3337
        self._writer.begin_packet(TDS_CANCEL)
 
3338
        self._writer.flush()
 
3339
        self.in_cancel = 1
 
3340
 
 
3341
    _begin_tran_struct_72 = struct.Struct('<HBB')
 
3342
 
 
3343
    def begin_tran(self, isolation_level=0):
 
3344
        self.submit_begin_tran(isolation_level=isolation_level)
 
3345
        self.process_simple_request()
 
3346
 
 
3347
    def submit_begin_tran(self, isolation_level=0):
 
3348
        #logger.debug('submit_begin_tran()')
 
3349
        if IS_TDS72_PLUS(self):
 
3350
            self.messages = []
 
3351
            self.cancel_if_pending()
 
3352
            w = self._writer
 
3353
            with self.querying_context(TDS7_TRANS):
 
3354
                self._start_query()
 
3355
                w.pack(self._begin_tran_struct_72,
 
3356
                    5,  # TM_BEGIN_XACT
 
3357
                    isolation_level,
 
3358
                    0,  # new transaction name
 
3359
                    )
 
3360
        else:
 
3361
            self.submit_plain_query("BEGIN TRANSACTION")
 
3362
            self.conn.tds72_transaction = 1
 
3363
 
 
3364
    _commit_rollback_tran_struct72_hdr = struct.Struct('<HBB')
 
3365
    _continue_tran_struct72 = struct.Struct('<BB')
 
3366
 
 
3367
    def rollback(self, cont, isolation_level=0):
 
3368
        self.submit_rollback(cont, isolation_level=isolation_level)
 
3369
        prev_timeout = self._tds._sock.gettimeout()
 
3370
        self._tds._sock.settimeout(None)
 
3371
        try:
 
3372
            self.process_simple_request()
 
3373
        finally:
 
3374
            self._tds._sock.settimeout(prev_timeout)
 
3375
 
 
3376
    def submit_rollback(self, cont, isolation_level=0):
 
3377
        #logger.debug('submit_rollback(%s, %s)', id(self), cont)
 
3378
        if IS_TDS72_PLUS(self):
 
3379
            self.messages = []
 
3380
            self.cancel_if_pending()
 
3381
            w = self._writer
 
3382
            with self.querying_context(TDS7_TRANS):
 
3383
                self._start_query()
 
3384
                flags = 0
 
3385
                if cont:
 
3386
                    flags |= 1
 
3387
                w.pack(self._commit_rollback_tran_struct72_hdr,
 
3388
                    8,  # TM_ROLLBACK_XACT
 
3389
                    0,  # transaction name
 
3390
                    flags,
 
3391
                    )
 
3392
                if cont:
 
3393
                    w.pack(self._continue_tran_struct72,
 
3394
                        isolation_level,
 
3395
                        0,  # new transaction name
 
3396
                        )
 
3397
        else:
 
3398
            self.submit_plain_query("IF @@TRANCOUNT > 0 ROLLBACK BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 ROLLBACK")
 
3399
            self.conn.tds72_transaction = 1 if cont else 0
 
3400
 
 
3401
    def commit(self, cont, isolation_level=0):
 
3402
        self.submit_commit(cont, isolation_level=isolation_level)
 
3403
        prev_timeout = self._tds._sock.gettimeout()
 
3404
        self._tds._sock.settimeout(None)
 
3405
        try:
 
3406
            self.process_simple_request()
 
3407
        finally:
 
3408
            self._tds._sock.settimeout(prev_timeout)
 
3409
 
 
3410
    def submit_commit(self, cont, isolation_level=0):
 
3411
        #logger.debug('submit_commit(%s)', cont)
 
3412
        if IS_TDS72_PLUS(self):
 
3413
            self.messages = []
 
3414
            self.cancel_if_pending()
 
3415
            w = self._writer
 
3416
            with self.querying_context(TDS7_TRANS):
 
3417
                self._start_query()
 
3418
                flags = 0
 
3419
                if cont:
 
3420
                    flags |= 1
 
3421
                w.pack(self._commit_rollback_tran_struct72_hdr,
 
3422
                    7,  # TM_COMMIT_XACT
 
3423
                    0,  # transaction name
 
3424
                    flags,
 
3425
                    )
 
3426
                if cont:
 
3427
                    w.pack(self._continue_tran_struct72,
 
3428
                        isolation_level,
 
3429
                        0,  # new transaction name
 
3430
                        )
 
3431
        else:
 
3432
            self.submit_plain_query("IF @@TRANCOUNT > 0 COMMIT BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 COMMIT")
 
3433
            self.conn.tds72_transaction = 1 if cont else 0
 
3434
 
 
3435
    def _START_QUERY(self):
 
3436
        if IS_TDS72_PLUS(self):
 
3437
            self._start_query()
 
3438
 
 
3439
    _tds72_query_start = struct.Struct('<IIHQI')
 
3440
 
 
3441
    def _start_query(self):
 
3442
        w = self._writer
 
3443
        w.pack(_TdsSession._tds72_query_start,
 
3444
               0x16,  # total length
 
3445
               0x12,  # length
 
3446
               2,  # type
 
3447
               self.conn.tds72_transaction,
 
3448
               1,  # request count
 
3449
               )
 
3450
 
 
3451
    VERSION = 0
 
3452
    ENCRYPTION = 1
 
3453
    INSTOPT = 2
 
3454
    THREADID = 3
 
3455
    MARS = 4
 
3456
    TRACEID = 5
 
3457
    TERMINATOR = 0xff
 
3458
 
 
3459
    def _send_prelogin(self, login):
 
3460
        instance_name = login.instance_name or 'MSSQLServer'
 
3461
        instance_name = instance_name.encode('ascii')
 
3462
        encryption_level = login.encryption_level
 
3463
        if len(instance_name) > 65490:
 
3464
            raise ValueError('Instance name is too long')
 
3465
        if encryption_level >= TDS_ENCRYPTION_REQUIRE:
 
3466
            raise NotSupportedError('Client requested encryption but it is not supported')
 
3467
        if IS_TDS72_PLUS(self):
 
3468
            START_POS = 26
 
3469
            buf = struct.pack(
 
3470
                b'>BHHBHHBHHBHHBHHB',
 
3471
                #netlib version
 
3472
                self.VERSION, START_POS, 6,
 
3473
                #encryption
 
3474
                self.ENCRYPTION, START_POS + 6, 1,
 
3475
                #instance
 
3476
                self.INSTOPT, START_POS + 6 + 1, len(instance_name) + 1,
 
3477
                # thread id
 
3478
                self.THREADID, START_POS + 6 + 1 + len(instance_name) + 1, 4,
 
3479
                # MARS enabled
 
3480
                self.MARS, START_POS + 6 + 1 + len(instance_name) + 1 + 4, 1,
 
3481
                # end
 
3482
                self.TERMINATOR
 
3483
                )
 
3484
        else:
 
3485
            START_POS = 21
 
3486
            buf = struct.pack(
 
3487
                b'>BHHBHHBHHBHHB',
 
3488
                #netlib version
 
3489
                self.VERSION, START_POS, 6,
 
3490
                #encryption
 
3491
                self.ENCRYPTION, START_POS + 6, 1,
 
3492
                #instance
 
3493
                self.INSTOPT, START_POS + 6 + 1, len(instance_name) + 1,
 
3494
                # thread id
 
3495
                self.THREADID, START_POS + 6 + 1 + len(instance_name) + 1, 4,
 
3496
                # end
 
3497
                self.TERMINATOR
 
3498
                )
 
3499
        assert START_POS == len(buf)
 
3500
        w = self._writer
 
3501
        w.begin_packet(TDS71_PRELOGIN)
 
3502
        w.write(buf)
 
3503
        from . import intversion
 
3504
        w.put_uint_be(intversion)
 
3505
        w.put_usmallint_be(0)  # build number
 
3506
        # encryption
 
3507
        if ENCRYPTION_ENABLED and encryption_supported:
 
3508
            w.put_byte(1 if encryption_level >= TDS_ENCRYPTION_REQUIRE else 0)
 
3509
        else:
 
3510
            # not supported
 
3511
            w.put_byte(2)
 
3512
        w.write(instance_name)
 
3513
        w.put_byte(0)  # zero terminate instance_name
 
3514
        w.put_int(0)  # TODO: change this to thread id
 
3515
        if IS_TDS72_PLUS(self):
 
3516
            # MARS (1 enabled)
 
3517
            w.put_byte(1 if login.use_mars else 0)
 
3518
        w.flush()
 
3519
 
 
3520
    def _process_prelogin(self, login):
 
3521
        p = self._reader.read_whole_packet()
 
3522
        size = len(p)
 
3523
        if size <= 0 or self._reader.packet_type != 4:
 
3524
            self.bad_stream('Invalid packet type: {0}, expected PRELOGIN(4)'.format(self._reader.packet_type))
 
3525
        # default 2, no certificate, no encryptption
 
3526
        crypt_flag = 2
 
3527
        i = 0
 
3528
        byte_struct = struct.Struct('B')
 
3529
        off_len_struct = struct.Struct('>HH')
 
3530
        prod_version_struct = struct.Struct('>LH')
 
3531
        while True:
 
3532
            if i >= size:
 
3533
                self.bad_stream('Invalid size of PRELOGIN structure')
 
3534
            type, = byte_struct.unpack_from(p, i)
 
3535
            if type == 0xff:
 
3536
                break
 
3537
            if i + 4 > size:
 
3538
                self.bad_stream('Invalid size of PRELOGIN structure')
 
3539
            off, l = off_len_struct.unpack_from(p, i + 1)
 
3540
            if off > size or off + l > size:
 
3541
                self.bad_stream('Invalid offset in PRELOGIN structure')
 
3542
            if type == self.VERSION:
 
3543
                self.conn.server_library_version = prod_version_struct.unpack_from(p, off)
 
3544
            elif type == self.ENCRYPTION and l >= 1:
 
3545
                crypt_flag, = byte_struct.unpack_from(p, off)
 
3546
            elif type == self.MARS:
 
3547
                self.conn._mars_enabled = bool(byte_struct.unpack_from(p, off)[0])
 
3548
            elif type == self.INSTOPT:
 
3549
                # ignore instance name mismatch
 
3550
                pass
 
3551
            i += 5
 
3552
        # if server do not has certificate do normal login
 
3553
        if crypt_flag == 2:
 
3554
            if login.encryption_level >= TDS_ENCRYPTION_REQUIRE:
 
3555
                raise Error('Server required encryption but it is not supported')
 
3556
            return
 
3557
        self._sock = ssl.wrap_socket(self._sock, ssl_version=ssl.PROTOCOL_SSLv3)
 
3558
 
 
3559
    def tds7_send_login(self, login):
 
3560
        option_flag2 = login.option_flag2
 
3561
        user_name = login.user_name
 
3562
        if len(user_name) > 128:
 
3563
            raise ValueError('User name should be no longer that 128 characters')
 
3564
        if len(login.password) > 128:
 
3565
            raise ValueError('Password should be not longer than 128 characters')
 
3566
        if len(login.change_password) > 128:
 
3567
            raise ValueError('Password should be not longer than 128 characters')
 
3568
        if len(login.client_host_name) > 128:
 
3569
            raise ValueError('Host name should be not longer than 128 characters')
 
3570
        if len(login.app_name) > 128:
 
3571
            raise ValueError('App name should be not longer than 128 characters')
 
3572
        if len(login.server_name) > 128:
 
3573
            raise ValueError('Server name should be not longer than 128 characters')
 
3574
        if len(login.database) > 128:
 
3575
            raise ValueError('Database name should be not longer than 128 characters')
 
3576
        if len(login.language) > 128:
 
3577
            raise ValueError('Language should be not longer than 128 characters')
 
3578
        if len(login.attach_db_file) > 260:
 
3579
            raise ValueError('File path should be not longer than 260 characters')
 
3580
        w = self._writer
 
3581
        w.begin_packet(TDS7_LOGIN)
 
3582
        self.authentication = None
 
3583
        current_pos = 86 + 8 if IS_TDS72_PLUS(self) else 86
 
3584
        client_host_name = login.client_host_name
 
3585
        login.client_host_name = client_host_name
 
3586
        packet_size = current_pos + (len(client_host_name) + len(login.app_name) + len(login.server_name) + len(login.library) + len(login.language) + len(login.database)) * 2
 
3587
        if login.auth:
 
3588
            self.authentication = login.auth
 
3589
            auth_packet = login.auth.create_packet()
 
3590
            packet_size += len(auth_packet)
 
3591
        else:
 
3592
            auth_packet = ''
 
3593
            packet_size += (len(user_name) + len(login.password)) * 2
 
3594
        w.put_int(packet_size)
 
3595
        w.put_uint(login.tds_version)
 
3596
        w.put_int(w.bufsize)
 
3597
        from . import intversion
 
3598
        w.put_uint(intversion)
 
3599
        w.put_int(login.pid)
 
3600
        w.put_uint(0)  # connection id
 
3601
        option_flag1 = TDS_SET_LANG_ON | TDS_USE_DB_NOTIFY | TDS_INIT_DB_FATAL
 
3602
        if not login.bulk_copy:
 
3603
            option_flag1 |= TDS_DUMPLOAD_OFF
 
3604
        w.put_byte(option_flag1)
 
3605
        if self.authentication:
 
3606
            option_flag2 |= TDS_INTEGRATED_SECURITY_ON
 
3607
        w.put_byte(option_flag2)
 
3608
        type_flags = 0
 
3609
        if login.readonly:
 
3610
            type_flags |= (2 << 5)
 
3611
        w.put_byte(type_flags)
 
3612
        option_flag3 = TDS_UNKNOWN_COLLATION_HANDLING
 
3613
        w.put_byte(option_flag3 if IS_TDS73_PLUS(self) else 0)
 
3614
        mins_fix = int(total_seconds(login.client_tz.utcoffset(datetime.now()))) // 60
 
3615
        w.put_int(mins_fix)
 
3616
        w.put_int(login.client_lcid)
 
3617
        w.put_smallint(current_pos)
 
3618
        w.put_smallint(len(client_host_name))
 
3619
        current_pos += len(client_host_name) * 2
 
3620
        if self.authentication:
 
3621
            w.put_smallint(0)
 
3622
            w.put_smallint(0)
 
3623
            w.put_smallint(0)
 
3624
            w.put_smallint(0)
 
3625
        else:
 
3626
            w.put_smallint(current_pos)
 
3627
            w.put_smallint(len(user_name))
 
3628
            current_pos += len(user_name) * 2
 
3629
            w.put_smallint(current_pos)
 
3630
            w.put_smallint(len(login.password))
 
3631
            current_pos += len(login.password) * 2
 
3632
        w.put_smallint(current_pos)
 
3633
        w.put_smallint(len(login.app_name))
 
3634
        current_pos += len(login.app_name) * 2
 
3635
        # server name
 
3636
        w.put_smallint(current_pos)
 
3637
        w.put_smallint(len(login.server_name))
 
3638
        current_pos += len(login.server_name) * 2
 
3639
        # reserved
 
3640
        w.put_smallint(0)
 
3641
        w.put_smallint(0)
 
3642
        # library name
 
3643
        w.put_smallint(current_pos)
 
3644
        w.put_smallint(len(login.library))
 
3645
        current_pos += len(login.library) * 2
 
3646
        # language
 
3647
        w.put_smallint(current_pos)
 
3648
        w.put_smallint(len(login.language))
 
3649
        current_pos += len(login.language) * 2
 
3650
        # database name
 
3651
        w.put_smallint(current_pos)
 
3652
        w.put_smallint(len(login.database))
 
3653
        current_pos += len(login.database) * 2
 
3654
        # ClientID
 
3655
        client_id = struct.pack('>Q', login.client_id)[2:]
 
3656
        w.write(client_id)
 
3657
        # authentication
 
3658
        w.put_smallint(current_pos)
 
3659
        w.put_smallint(len(auth_packet))
 
3660
        current_pos += len(auth_packet)
 
3661
        # db file
 
3662
        w.put_smallint(current_pos)
 
3663
        w.put_smallint(len(login.attach_db_file))
 
3664
        current_pos += len(login.attach_db_file) * 2
 
3665
        if IS_TDS72_PLUS(self):
 
3666
            # new password
 
3667
            w.put_smallint(current_pos)
 
3668
            w.put_smallint(len(login.change_password))
 
3669
            # sspi long
 
3670
            w.put_int(0)
 
3671
        w.write_ucs2(client_host_name)
 
3672
        if not self.authentication:
 
3673
            w.write_ucs2(user_name)
 
3674
            w.write(tds7_crypt_pass(login.password))
 
3675
        w.write_ucs2(login.app_name)
 
3676
        w.write_ucs2(login.server_name)
 
3677
        w.write_ucs2(login.library)
 
3678
        w.write_ucs2(login.language)
 
3679
        w.write_ucs2(login.database)
 
3680
        if self.authentication:
 
3681
            w.write(auth_packet)
 
3682
        w.write_ucs2(login.attach_db_file)
 
3683
        w.write_ucs2(login.change_password)
 
3684
        w.flush()
 
3685
 
 
3686
    _SERVER_TO_CLIENT_MAPPING = {
 
3687
        0x07000000: TDS70,
 
3688
        0x07010000: TDS71,
 
3689
        0x71000001: TDS71rev1,
 
3690
        TDS72: TDS72,
 
3691
        TDS73A: TDS73A,
 
3692
        TDS73B: TDS73B,
 
3693
        TDS74: TDS74,
 
3694
        }
 
3695
 
 
3696
    def process_login_tokens(self):
 
3697
        r = self._reader
 
3698
        succeed = False
 
3699
        #logger.debug('process_login_tokens()')
 
3700
        while True:
 
3701
            marker = r.get_byte()
 
3702
            #logger.debug('looking for login token, got  {0:x}({1})'.format(marker, tds_token_name(marker)))
 
3703
            if marker == TDS_LOGINACK_TOKEN:
 
3704
                succeed = True
 
3705
                size = r.get_smallint()
 
3706
                r.get_byte()  # interface
 
3707
                version = r.get_uint_be()
 
3708
                self.conn.tds_version = self._SERVER_TO_CLIENT_MAPPING.get(version, version)
 
3709
                #logger.debug('server reports TDS version {0:x}'.format(version))
 
3710
                if not IS_TDS7_PLUS(self):
 
3711
                    self.bad_stream('Only TDS 7.0 and higher are supported')
 
3712
                # get server product name
 
3713
                # ignore product name length, some servers seem to set it incorrectly
 
3714
                r.get_byte()
 
3715
                size -= 10
 
3716
                self.conn.product_name = r.read_ucs2(size // 2)
 
3717
                product_version = r.get_uint_be()
 
3718
                # MSSQL 6.5 and 7.0 seem to return strange values for this
 
3719
                # using TDS 4.2, something like 5F 06 32 FF for 6.50
 
3720
                self.conn.product_version = product_version
 
3721
                #logger.debug('Product version {0:x}'.format(product_version))
 
3722
                if self.conn.authentication:
 
3723
                    self.conn.authentication.close()
 
3724
                    self.conn.authentication = None
 
3725
            else:
 
3726
                self.process_token(marker)
 
3727
                if marker == TDS_DONE_TOKEN:
 
3728
                    break
 
3729
        return succeed
 
3730
 
 
3731
    def process_returnstatus(self):
 
3732
        self.ret_status = self._reader.get_int()
 
3733
        self.has_status = True
 
3734
 
 
3735
    def process_token(self, marker):
 
3736
        handler = _token_map.get(marker)
 
3737
        if not handler:
 
3738
            self.bad_stream('Invalid TDS marker: {0}({0:x})'.format(marker))
 
3739
        return handler(self)
 
3740
 
 
3741
    def get_token_id(self):
 
3742
        self.set_state(TDS_READING)
 
3743
        try:
 
3744
            marker = self._reader.get_byte()
 
3745
        except TimeoutError:
 
3746
            self.set_state(TDS_PENDING)
 
3747
            raise
 
3748
        except:
 
3749
            self._tds.close()
 
3750
            raise
 
3751
        return marker
 
3752
 
 
3753
    def process_simple_request(self):
 
3754
        while True:
 
3755
            marker = self.get_token_id()
 
3756
            if marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
 
3757
                self.process_end(marker)
 
3758
                if self.done_flags & TDS_DONE_MORE_RESULTS:
 
3759
                    # skip results that don't event have rowcount
 
3760
                    continue
 
3761
                return
 
3762
            else:
 
3763
                self.process_token(marker)
 
3764
 
 
3765
    def next_set(self):
 
3766
        while self.more_rows:
 
3767
            self.next_row()
 
3768
        if self.state == TDS_IDLE:
 
3769
            return False
 
3770
        if self.find_result_or_done():
 
3771
            return True
 
3772
 
 
3773
    def fetchone(self):
 
3774
        if self.res_info is None:
 
3775
            raise Error("Previous statement didn't produce any results")
 
3776
 
 
3777
        if self.skipped_to_status:
 
3778
            raise Error("Unable to fetch any rows after accessing return_status")
 
3779
 
 
3780
        if not self.next_row():
 
3781
            return None
 
3782
 
 
3783
        return self.row
 
3784
 
 
3785
    def next_row(self):
 
3786
        if not self.more_rows:
 
3787
            return False
 
3788
        while True:
 
3789
            marker = self.get_token_id()
 
3790
            if marker in (TDS_ROW_TOKEN, TDS_NBC_ROW_TOKEN):
 
3791
                self.process_token(marker)
 
3792
                return True
 
3793
            elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
 
3794
                self.process_end(marker)
 
3795
                return False
 
3796
            else:
 
3797
                self.process_token(marker)
 
3798
 
 
3799
    def find_result_or_done(self):
 
3800
        self.done_flags = 0
 
3801
        while True:
 
3802
            marker = self.get_token_id()
 
3803
            if marker == TDS7_RESULT_TOKEN:
 
3804
                self.process_token(marker)
 
3805
                return True
 
3806
            elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
 
3807
                self.process_end(marker)
 
3808
                if self.done_flags & TDS_DONE_MORE_RESULTS:
 
3809
                    if self.done_flags & TDS_DONE_COUNT:
 
3810
                        return True
 
3811
                    else:
 
3812
                        # skip results without rowcount
 
3813
                        continue
 
3814
                else:
 
3815
                    return False
 
3816
            else:
 
3817
                self.process_token(marker)
 
3818
 
 
3819
    def process_rpc(self):
 
3820
        self.done_flags = 0
 
3821
        self.return_value_index = 0
 
3822
        while True:
 
3823
            marker = self.get_token_id()
 
3824
            if marker == TDS7_RESULT_TOKEN:
 
3825
                self.process_token(marker)
 
3826
                return True
 
3827
            elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN):
 
3828
                self.process_end(marker)
 
3829
                if self.done_flags & TDS_DONE_MORE_RESULTS and not self.done_flags & TDS_DONE_COUNT:
 
3830
                    # skip results that don't event have rowcount
 
3831
                    continue
 
3832
                return False
 
3833
            else:
 
3834
                self.process_token(marker)
 
3835
 
 
3836
    def find_return_status(self):
 
3837
        self.skipped_to_status = True
 
3838
        while True:
 
3839
            marker = self.get_token_id()
 
3840
            self.process_token(marker)
 
3841
            if marker == TDS_RETURNSTATUS_TOKEN:
 
3842
                return
 
3843
 
 
3844
 
 
3845
_token_map = {
 
3846
    TDS_AUTH_TOKEN: _TdsSession.process_auth,
 
3847
    TDS_ENVCHANGE_TOKEN: _TdsSession.process_env_chg,
 
3848
    TDS_DONE_TOKEN: lambda self: self.process_end(TDS_DONE_TOKEN),
 
3849
    TDS_DONEPROC_TOKEN: lambda self: self.process_end(TDS_DONEPROC_TOKEN),
 
3850
    TDS_DONEINPROC_TOKEN: lambda self: self.process_end(TDS_DONEINPROC_TOKEN),
 
3851
    TDS_ERROR_TOKEN: lambda self: self.process_msg(TDS_ERROR_TOKEN),
 
3852
    TDS_INFO_TOKEN: lambda self: self.process_msg(TDS_INFO_TOKEN),
 
3853
    TDS_EED_TOKEN: lambda self: self.process_msg(TDS_EED_TOKEN),
 
3854
    TDS_CAPABILITY_TOKEN: lambda self: self.process_msg(TDS_CAPABILITY_TOKEN),
 
3855
    TDS_PARAM_TOKEN: lambda self: self.process_param(),
 
3856
    TDS7_RESULT_TOKEN: lambda self: self.tds7_process_result(),
 
3857
    TDS_ROW_TOKEN: lambda self: self.process_row(),
 
3858
    TDS_NBC_ROW_TOKEN: lambda self: self.process_nbcrow(),
 
3859
    TDS_ORDERBY2_TOKEN: lambda self: self.process_orderby2(),
 
3860
    TDS_ORDERBY_TOKEN: lambda self: self.process_orderby(),
 
3861
    TDS_RETURNSTATUS_TOKEN: lambda self: self.process_returnstatus(),
 
3862
    }
 
3863
 
 
3864
 
 
3865
class _TdsSocket(object):
 
3866
    def __init__(self, use_tz=None):
 
3867
        self._is_connected = False
 
3868
        self.env = _TdsEnv()
 
3869
        self.collation = None
 
3870
        self.tds72_transaction = 0
 
3871
        self.authentication = None
 
3872
        self._mars_enabled = False
 
3873
        self.chunk_handler = MemoryChunkedHandler()
 
3874
        self._sock = None
 
3875
        self._bufsize = 4096
 
3876
        self.tds_version = TDS74
 
3877
        self.use_tz = use_tz
 
3878
 
 
3879
    def __repr__(self):
 
3880
        fmt = "<_TdsSocket tran={} mars={} tds_version={} use_tz={}>"
 
3881
        return fmt.format(self.tds72_transaction, self._mars_enabled,
 
3882
                          self.tds_version, self.use_tz)
 
3883
 
 
3884
    def login(self, login, sock, tzinfo_factory):
 
3885
        self.login = login
 
3886
        self._bufsize = login.blocksize
 
3887
        self.query_timeout = login.query_timeout
 
3888
        self._main_session = _TdsSession(self, self, tzinfo_factory)
 
3889
        self._sock = sock
 
3890
        self.tds_version = login.tds_version
 
3891
        if IS_TDS71_PLUS(self):
 
3892
            self._main_session._send_prelogin(login)
 
3893
            self._main_session._process_prelogin(login)
 
3894
        if IS_TDS7_PLUS(self):
 
3895
            self._main_session.tds7_send_login(login)
 
3896
        else:
 
3897
            raise ValueError('This TDS version is not supported')
 
3898
        if not self._main_session.process_login_tokens():
 
3899
            self._main_session.raise_db_exception()
 
3900
        if IS_TDS72_PLUS(self):
 
3901
            self._type_map = _type_map72
 
3902
        elif IS_TDS71_PLUS(self):
 
3903
            self._type_map = _type_map71
 
3904
        else:
 
3905
            self._type_map = _type_map
 
3906
        text_size = login.text_size
 
3907
        if self._mars_enabled:
 
3908
            from .smp import SmpManager
 
3909
            self._smp_manager = SmpManager(self)
 
3910
            self._main_session = _TdsSession(
 
3911
                self,
 
3912
                self._smp_manager.create_session(),
 
3913
                tzinfo_factory)
 
3914
        self._is_connected = True
 
3915
        q = []
 
3916
        if text_size:
 
3917
            q.append('set textsize {0}'.format(int(text_size)))
 
3918
        if login.database and self.env.database != login.database:
 
3919
            q.append('use ' + tds_quote_id(login.database))
 
3920
        if q:
 
3921
            self._main_session.submit_plain_query(''.join(q))
 
3922
            self._main_session.process_simple_request()
 
3923
 
 
3924
    @property
 
3925
    def mars_enabled(self):
 
3926
        return self._mars_enabled
 
3927
 
 
3928
    @property
 
3929
    def main_session(self):
 
3930
        return self._main_session
 
3931
 
 
3932
    def create_session(self, tzinfo_factory):
 
3933
        return _TdsSession(
 
3934
            self, self._smp_manager.create_session(),
 
3935
            tzinfo_factory)
 
3936
 
 
3937
    def read(self, size):
 
3938
        buf = self._sock.recv(size)
 
3939
        if len(buf) == 0:
 
3940
            self.close()
 
3941
            raise ClosedConnectionError()
 
3942
        return buf
 
3943
 
 
3944
    def _write(self, data, final):
 
3945
        try:
 
3946
            flags = 0
 
3947
            if hasattr(socket, 'MSG_NOSIGNAL'):
 
3948
                flags |= socket.MSG_NOSIGNAL
 
3949
            if not final:
 
3950
                if hasattr(socket, 'MSG_MORE'):
 
3951
                    flags |= socket.MSG_MORE
 
3952
            self._sock.sendall(data, flags)
 
3953
            if final and USE_CORK:
 
3954
                self._sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 0)
 
3955
                self._sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 1)
 
3956
        except:
 
3957
            self.close()
 
3958
            raise
 
3959
 
 
3960
    send = _write
 
3961
 
 
3962
    def is_connected(self):
 
3963
        return self._is_connected
 
3964
 
 
3965
    def close(self):
 
3966
        self._is_connected = False
 
3967
        if self._sock is not None:
 
3968
            self._sock.close()
 
3969
        if hasattr(self, '_smp_manager'):
 
3970
            self._smp_manager._transport_closed()
 
3971
        self._main_session.state = TDS_DEAD
 
3972
        if self.authentication:
 
3973
            self.authentication.close()
 
3974
            self.authentication = None
 
3975
 
 
3976
    def NVarChar(self, size, collation=raw_collation):
 
3977
        if IS_TDS72_PLUS(self):
 
3978
            return NVarChar72(size, collation)
 
3979
        elif IS_TDS71_PLUS(self):
 
3980
            return NVarChar71(size, collation)
 
3981
        else:
 
3982
            return NVarChar70(size)
 
3983
 
 
3984
    def VarChar(self, size, collation=raw_collation):
 
3985
        if IS_TDS72_PLUS(self):
 
3986
            return VarChar72(size, collation)
 
3987
        elif IS_TDS71_PLUS(self):
 
3988
            return VarChar71(size, collation)
 
3989
        else:
 
3990
            return VarChar70(size, codec=self.server_codec)
 
3991
 
 
3992
    def Text(self, size=0, collation=raw_collation):
 
3993
        if IS_TDS72_PLUS(self):
 
3994
            return Text72(size, collation=collation)
 
3995
        elif IS_TDS71_PLUS(self):
 
3996
            return Text71(size, collation=collation)
 
3997
        else:
 
3998
            return Text70(size, codec=self.server_codec)
 
3999
 
 
4000
    def NText(self, size=0, collation=raw_collation):
 
4001
        if IS_TDS72_PLUS(self):
 
4002
            return NText72(size, collation=collation)
 
4003
        elif IS_TDS71_PLUS(self):
 
4004
            return NText71(size, collation=collation)
 
4005
        else:
 
4006
            return NText70(size)
 
4007
 
 
4008
    def VarBinary(self, size):
 
4009
        if IS_TDS72_PLUS(self):
 
4010
            return VarBinary72(size)
 
4011
        else:
 
4012
            return VarBinary(size)
 
4013
 
 
4014
    def Image(self, size=0):
 
4015
        if IS_TDS72_PLUS(self):
 
4016
            return Image72(size)
 
4017
        else:
 
4018
            return Image70(size)
 
4019
 
 
4020
    Bit = Bit.instance
 
4021
    BitN = BitN.instance
 
4022
    TinyInt = TinyInt.instance
 
4023
    SmallInt = SmallInt.instance
 
4024
    Int = Int.instance
 
4025
    BigInt = BigInt.instance
 
4026
    IntN = IntN
 
4027
    Real = Real.instance
 
4028
    Float = Float.instance
 
4029
    FloatN = FloatN
 
4030
    SmallDateTime = SmallDateTime.instance
 
4031
    DateTime = DateTime.instance
 
4032
    DateTimeN = DateTimeN
 
4033
    Date = MsDate.instance
 
4034
    Time = MsTime
 
4035
    DateTime2 = DateTime2
 
4036
    DateTimeOffset = DateTimeOffset
 
4037
    Decimal = MsDecimal
 
4038
    SmallMoney = Money4.instance
 
4039
    Money = Money8.instance
 
4040
    MoneyN = MoneyN
 
4041
    UniqueIdentifier = MsUnique.instance
 
4042
    SqlVariant = Variant
 
4043
    Xml = Xml
 
4044
 
 
4045
    def long_binary_type(self):
 
4046
        if IS_TDS72_PLUS(self):
 
4047
            return VarBinaryMax()
 
4048
        else:
 
4049
            return Image70()
 
4050
 
 
4051
    def long_varchar_type(self, collation=raw_collation):
 
4052
        if IS_TDS72_PLUS(self):
 
4053
            return VarCharMax(collation)
 
4054
        elif IS_TDS71_PLUS(self):
 
4055
            return Text71(-1, '', collation)
 
4056
        else:
 
4057
            return Text70(codec=self.server_codec)
 
4058
 
 
4059
    def long_string_type(self, collation=raw_collation):
 
4060
        if IS_TDS72_PLUS(self):
 
4061
            return NVarCharMax(0, collation)
 
4062
        elif IS_TDS71_PLUS(self):
 
4063
            return NText71(-1, '', collation)
 
4064
        else:
 
4065
            return NText70()
 
4066
 
 
4067
    def type_by_declaration(self, declaration, nullable):
 
4068
        declaration = declaration.strip().upper()
 
4069
        for type_class in self._type_map.values():
 
4070
            type_inst = type_class.from_declaration(declaration, nullable, self)
 
4071
            if type_inst:
 
4072
                return type_inst 
 
4073
        raise ValueError('Unable to parse type declaration', declaration)
 
4074
 
 
4075
 
 
4076
class Column(object):
 
4077
    fNullable = 1
 
4078
    fCaseSen = 2
 
4079
    fReadWrite = 8
 
4080
    fIdentity = 0x10
 
4081
    fComputed = 0x20
 
4082
 
 
4083
    def __init__(self, name='', type=None, flags=0, value=None):
 
4084
        self.char_codec = None
 
4085
        self.column_name = name
 
4086
        self.column_usertype = 0
 
4087
        self.flags = flags
 
4088
        self.type = type
 
4089
        self.value = value
 
4090
 
 
4091
    def __repr__(self):
 
4092
        return '<Column(name={0}, value={1}, type={2})>'.format(repr(self.column_name), repr(self.value), repr(self.type))
 
4093
 
 
4094
 
 
4095
class _Results(object):
 
4096
    def __init__(self):
 
4097
        self.columns = []
 
4098
        self.row_count = 0
 
4099
 
 
4100
 
 
4101
def _parse_instances(msg):
 
4102
    name = None
 
4103
    if len(msg) > 3 and _ord(msg[0]) == 5:
 
4104
        tokens = msg[3:].decode('ascii').split(';')
 
4105
        results = {}
 
4106
        instdict = {}
 
4107
        got_name = False
 
4108
        for token in tokens:
 
4109
            if got_name:
 
4110
                instdict[name] = token
 
4111
                got_name = False
 
4112
            else:
 
4113
                name = token
 
4114
                if not name:
 
4115
                    if not instdict:
 
4116
                        break
 
4117
                    results[instdict['InstanceName'].upper()] = instdict
 
4118
                    instdict = {}
 
4119
                    continue
 
4120
                got_name = True
 
4121
        return results
 
4122
 
 
4123
 
 
4124
#
 
4125
# Get port of all instances
 
4126
# @return default port number or 0 if error
 
4127
# @remark experimental, cf. MC-SQLR.pdf.
 
4128
#
 
4129
def tds7_get_instances(ip_addr, timeout=5):
 
4130
    s = socket.socket(type=socket.SOCK_DGRAM)
 
4131
    s.settimeout(timeout)
 
4132
    try:
 
4133
        # send the request
 
4134
        s.sendto(b'\x03', (ip_addr, 1434))
 
4135
        msg = s.recv(16 * 1024 - 1)
 
4136
        # got data, read and parse
 
4137
        return _parse_instances(msg)
 
4138
    finally:
 
4139
        s.close()
 
4140
 
 
4141
 
 
4142
def _applytz(dt, tz):
 
4143
    if not tz:
 
4144
        return dt
 
4145
    dt = dt.replace(tzinfo=tz)
 
4146
    return dt