3
from contextlib import contextmanager
7
from datetime import datetime, date, time, timedelta
8
from decimal import Decimal, localcontext
14
from six.moves import reduce
15
from six.moves import xrange
19
encryption_supported = False
21
encryption_supported = True
22
from .collate import ucs2_codec, Collation, lcid2charset, raw_collation
24
logger = logging.getLogger()
26
ENCRYPTION_ENABLED = False
29
# tds protocol versions
32
TDS71rev1 = 0x71000001
39
IS_TDS7_PLUS = lambda x: x.tds_version >= TDS70
40
IS_TDS71_PLUS = lambda x: x.tds_version >= TDS71
41
IS_TDS72_PLUS = lambda x: x.tds_version >= TDS72
42
IS_TDS73_PLUS = lambda x: x.tds_version >= TDS73A
51
TDS7_TRANS = 14 # transaction management
57
# mssql login options flags
59
TDS_BYTE_ORDER_X86 = 0
62
TDS_FLOAT_IEEE_754 = 0
66
TDS_BYTE_ORDER_68000 = 0x01
67
TDS_CHARSET_EBDDIC = 0x02
69
TDS_FLOAT_ND5000 = 0x08
70
TDS_DUMPLOAD_OFF = 0x10 # prevent BCP
71
TDS_USE_DB_NOTIFY = 0x20
72
TDS_INIT_DB_FATAL = 0x40
73
TDS_SET_LANG_ON = 0x80
75
#enum option_flag2_values {
76
TDS_INIT_LANG_WARN = 0
77
TDS_INTEGRATED_SECURTY_OFF = 0
79
TDS_USER_NORMAL = 0 # SQL Server login
80
TDS_INIT_LANG_REQUIRED = 0x01
82
TDS_TRANSACTION_BOUNDARY71 = 0x04 # removed in TDS 7.2
83
TDS_CACHE_CONNECT71 = 0x08 # removed in TDS 7.2
84
TDS_USER_SERVER = 0x10 # reserved
85
TDS_USER_REMUSER = 0x20 # DQ login
86
TDS_USER_SQLREPL = 0x40 # replication login
87
TDS_INTEGRATED_SECURITY_ON = 0x80
89
#enum option_flag3_values TDS 7.3+
90
TDS_RESTRICTED_COLLATION = 0
91
TDS_CHANGE_PASSWORD = 0x01
92
TDS_SEND_YUKON_BINARY_XML = 0x02
93
TDS_REQUEST_USER_INSTANCE = 0x04
94
TDS_UNKNOWN_COLLATION_HANDLING = 0x08
95
TDS_ANY_COLLATION = 0x10
97
TDS5_PARAMFMT2_TOKEN = 32 # 0x20
98
TDS_LANGUAGE_TOKEN = 33 # 0x20 TDS 5.0 only
99
TDS_ORDERBY2_TOKEN = 34 # 0x22
100
TDS_ROWFMT2_TOKEN = 97 # 0x61 TDS 5.0 only
101
TDS_LOGOUT_TOKEN = 113 # 0x71 TDS 5.0 only?
102
TDS_RETURNSTATUS_TOKEN = 121 # 0x79
103
TDS_PROCID_TOKEN = 124 # 0x7C TDS 4.2 only
104
TDS7_RESULT_TOKEN = 129 # 0x81 TDS 7.0 only
105
TDS7_COMPUTE_RESULT_TOKEN = 136 # 0x88 TDS 7.0 only
106
TDS_COLNAME_TOKEN = 160 # 0xA0 TDS 4.2 only
107
TDS_COLFMT_TOKEN = 161 # 0xA1 TDS 4.2 only
108
TDS_DYNAMIC2_TOKEN = 163 # 0xA3
109
TDS_TABNAME_TOKEN = 164 # 0xA4
110
TDS_COLINFO_TOKEN = 165 # 0xA5
111
TDS_OPTIONCMD_TOKEN = 166 # 0xA6
112
TDS_COMPUTE_NAMES_TOKEN = 167 # 0xA7
113
TDS_COMPUTE_RESULT_TOKEN = 168 # 0xA8
114
TDS_ORDERBY_TOKEN = 169 # 0xA9
115
TDS_ERROR_TOKEN = 170 # 0xAA
116
TDS_INFO_TOKEN = 171 # 0xAB
117
TDS_PARAM_TOKEN = 172 # 0xAC
118
TDS_LOGINACK_TOKEN = 173 # 0xAD
119
TDS_CONTROL_TOKEN = 174 # 0xAE
120
TDS_ROW_TOKEN = 209 # 0xD1
121
TDS_NBC_ROW_TOKEN = 210 # 0xD2 as of TDS 7.3.B
122
TDS_CMP_ROW_TOKEN = 211 # 0xD3
123
TDS5_PARAMS_TOKEN = 215 # 0xD7 TDS 5.0 only
124
TDS_CAPABILITY_TOKEN = 226 # 0xE2
125
TDS_ENVCHANGE_TOKEN = 227 # 0xE3
126
TDS_EED_TOKEN = 229 # 0xE5
127
TDS_DBRPC_TOKEN = 230 # 0xE6
128
TDS5_DYNAMIC_TOKEN = 231 # 0xE7 TDS 5.0 only
129
TDS5_PARAMFMT_TOKEN = 236 # 0xEC TDS 5.0 only
130
TDS_AUTH_TOKEN = 237 # 0xED TDS 7.0 only
131
TDS_RESULT_TOKEN = 238 # 0xEE
132
TDS_DONE_TOKEN = 253 # 0xFD
133
TDS_DONEPROC_TOKEN = 254 # 0xFE
134
TDS_DONEINPROC_TOKEN = 255 # 0xFF
136
# CURSOR support: TDS 5.0 only
137
TDS_CURCLOSE_TOKEN = 128 # 0x80 TDS 5.0 only
138
TDS_CURDELETE_TOKEN = 129 # 0x81 TDS 5.0 only
139
TDS_CURFETCH_TOKEN = 130 # 0x82 TDS 5.0 only
140
TDS_CURINFO_TOKEN = 131 # 0x83 TDS 5.0 only
141
TDS_CUROPEN_TOKEN = 132 # 0x84 TDS 5.0 only
142
TDS_CURDECLARE_TOKEN = 134 # 0x86 TDS 5.0 only
144
# environment type field
150
TDS_ENV_SQLCOLLATION = 7
151
TDS_ENV_BEGINTRANS = 8
152
TDS_ENV_COMMITTRANS = 9
153
TDS_ENV_ROLLBACKTRANS = 10
154
TDS_ENV_ENLIST_DTC_TRANS = 11
155
TDS_ENV_DEFECT_TRANS = 12
156
TDS_ENV_DB_MIRRORING_PARTNER = 13
157
TDS_ENV_PROMOTE_TRANS = 15
158
TDS_ENV_TRANS_MANAGER_ADDR = 16
159
TDS_ENV_TRANS_ENDED = 17
160
TDS_ENV_RESET_COMPLETION_ACK = 18
161
TDS_ENV_INSTANCE_INFO = 19
164
# Microsoft internal stored procedure id's
166
TDS_SP_CURSOROPEN = 2
167
TDS_SP_CURSORPREPARE = 3
168
TDS_SP_CURSOREXECUTE = 4
169
TDS_SP_CURSORPREPEXEC = 5
170
TDS_SP_CURSORUNPREPARE = 6
171
TDS_SP_CURSORFETCH = 7
172
TDS_SP_CURSOROPTION = 8
173
TDS_SP_CURSORCLOSE = 9
174
TDS_SP_EXECUTESQL = 10
178
TDS_SP_PREPEXECRPC = 14
179
TDS_SP_UNPREPARE = 15
181
# Flags returned in TDS_DONE token
183
TDS_DONE_MORE_RESULTS = 0x01 # more results follow
184
TDS_DONE_ERROR = 0x02 # error occurred
185
TDS_DONE_INXACT = 0x04 # transaction in progress
186
TDS_DONE_PROC = 0x08 # results are from a stored procedure
187
TDS_DONE_COUNT = 0x10 # count field in packet is valid
188
TDS_DONE_CANCELLED = 0x20 # acknowledging an attention command (usually a cancel)
189
TDS_DONE_EVENT = 0x40 # part of an event notification.
190
TDS_DONE_SRVERROR = 0x100 # SQL server server error
194
IMAGETYPE = SYBIMAGE = 34 # 0x22
195
TEXTTYPE = SYBTEXT = 35 # 0x23
196
SYBVARBINARY = 37 # 0x25
197
INTNTYPE = SYBINTN = 38 # 0x26
198
SYBVARCHAR = 39 # 0x27
199
BINARYTYPE = SYBBINARY = 45 # 0x2D
201
INT1TYPE = SYBINT1 = 48 # 0x30
202
BITTYPE = SYBBIT = 50 # 0x32
203
INT2TYPE = SYBINT2 = 52 # 0x34
204
INT4TYPE = SYBINT4 = 56 # 0x38
205
DATETIM4TYPE = SYBDATETIME4 = 58 # 0x3A
206
FLT4TYPE = SYBREAL = 59 # 0x3B
207
MONEYTYPE = SYBMONEY = 60 # 0x3C
208
DATETIMETYPE = SYBDATETIME = 61 # 0x3D
209
FLT8TYPE = SYBFLT8 = 62 # 0x3E
210
NTEXTTYPE = SYBNTEXT = 99 # 0x63
211
SYBNVARCHAR = 103 # 0x67
212
BITNTYPE = SYBBITN = 104 # 0x68
213
NUMERICNTYPE = SYBNUMERIC = 108 # 0x6C
214
DECIMALNTYPE = SYBDECIMAL = 106 # 0x6A
215
FLTNTYPE = SYBFLTN = 109 # 0x6D
216
MONEYNTYPE = SYBMONEYN = 110 # 0x6E
217
DATETIMNTYPE = SYBDATETIMN = 111 # 0x6F
218
MONEY4TYPE = SYBMONEY4 = 122 # 0x7A
220
INT8TYPE = SYBINT8 = 127 # 0x7F
221
BIGCHARTYPE = XSYBCHAR = 175 # 0xAF
222
BIGVARCHRTYPE = XSYBVARCHAR = 167 # 0xA7
223
NVARCHARTYPE = XSYBNVARCHAR = 231 # 0xE7
224
NCHARTYPE = XSYBNCHAR = 239 # 0xEF
225
BIGVARBINTYPE = XSYBVARBINARY = 165 # 0xA5
226
BIGBINARYTYPE = XSYBBINARY = 173 # 0xAD
227
GUIDTYPE = SYBUNIQUE = 36 # 0x24
228
SSVARIANTTYPE = SYBVARIANT = 98 # 0x62
229
UDTTYPE = SYBMSUDT = 240 # 0xF0
230
XMLTYPE = SYBMSXML = 241 # 0xF1
231
DATENTYPE = SYBMSDATE = 40 # 0x28
232
TIMENTYPE = SYBMSTIME = 41 # 0x29
233
DATETIME2NTYPE = SYBMSDATETIME2 = 42 # 0x2a
234
DATETIMEOFFSETNTYPE = SYBMSDATETIMEOFFSET = 43 # 0x2b
239
SYBLONGBINARY = 225 # 0xE1
245
SYBBOUNDARY = 104 # 0x68
247
SYBDATEN = 123 # 0x7B
248
SYB5INT8 = 191 # 0xBF
249
SYBINTERVAL = 46 # 0x2E
250
SYBLONGCHAR = 175 # 0xAF
251
SYBSENSITIVITY = 103 # 0x67
252
SYBSINT1 = 176 # 0xB0
254
SYBTIMEN = 147 # 0x93
256
SYBUNITEXT = 174 # 0xAE
259
TDS_UT_TIMESTAMP = 80
271
# mssql2k compute operator
277
SYBAOPCHECKSUM_AGG = 0x72
288
state_names = ['IDLE', 'QUERYING', 'PENDING', 'READING', 'DEAD']
290
TDS_ENCRYPTION_OFF = 0
291
TDS_ENCRYPTION_REQUEST = 1
292
TDS_ENCRYPTION_REQUIRE = 2
294
USE_CORK = hasattr(socket, 'TCP_CORK')
300
_header = struct.Struct('>BBHHBx')
301
_byte = struct.Struct('B')
302
_smallint_le = struct.Struct('<h')
303
_smallint_be = struct.Struct('>h')
304
_usmallint_le = struct.Struct('<H')
305
_usmallint_be = struct.Struct('>H')
306
_int_le = struct.Struct('<l')
307
_int_be = struct.Struct('>l')
308
_uint_le = struct.Struct('<L')
309
_uint_be = struct.Struct('>L')
310
_int8_le = struct.Struct('<q')
311
_int8_be = struct.Struct('>q')
312
_uint8_le = struct.Struct('<Q')
313
_uint8_be = struct.Struct('>Q')
314
_flt8_struct = struct.Struct('d')
315
_flt4_struct = struct.Struct('f')
319
PLP_NULL = 0xffffffffffffffff
320
PLP_UNKNOWN = 0xfffffffffffffffe
323
class PlpReader(object):
324
""" Partially length prefixed reader
326
Spec: http://msdn.microsoft.com/en-us/library/dd340469.aspx
328
def __init__(self, r):
330
:param r: An instance of :class:`_TdsReader`
338
:return: True if stored value is NULL
340
return self._size == PLP_NULL
342
def is_unknown_len(self):
344
:return: True if total size is unknown upfront
346
return self._size == PLP_UNKNOWN
350
:return: Total size in bytes if is_uknown_len and is_null are both False
355
""" Generates chunks from stream, each chunk is an instace of bytes.
361
chunk_len = self._rdr.get_uint()
363
if not self.is_unknown_len() and total != self._size:
364
msg = "PLP actual length (%d) doesn't match reported length (%d)" % (total, self._size)
365
self._rdr.session.bad_stream(msg)
372
buf = self._rdr.read(left)
377
def iterdecode(iterable, codec):
378
""" Uses an incremental decoder to decode each chunk in iterable.
379
This function is a generator.
381
:param codec: An instance of codec
383
decoder = codec.incrementaldecoder()
384
for chunk in iterable:
385
yield decoder.decode(chunk)
386
yield decoder.decode(b'', True)
389
class SimpleLoadBalancer(object):
390
def __init__(self, hosts):
394
for host in self._hosts:
398
def force_unicode(s):
399
if isinstance(s, bytes):
401
return s.decode('utf8')
402
except UnicodeDecodeError as e:
403
raise DatabaseError(e)
408
def tds_quote_id(id):
409
""" Quote an identifier
411
:param id: id to quote
412
:returns: Quoted identifier
414
return '[{0}]'.format(id.replace(']', ']]'))
417
def tds7_crypt_pass(password):
418
""" Mangle password according to tds rules
420
:param password: Password str
421
:returns: Byte-string with encoded password
423
encoded = bytearray(ucs2_codec.encode(password)[0])
424
for i, ch in enumerate(encoded):
425
encoded[i] = ((ch << 4) & 0xff | (ch >> 4)) ^ 0xA5
429
def total_seconds(td):
430
""" Total number of seconds in timedelta object
432
Python 2.6 doesn't have total_seconds method, this function
435
return td.days * 24 * 60 * 60 + td.seconds
438
# store a tuple of programming error codes
441
207, # invalid column name
442
208, # invalid object name
443
2812, # unknown procedure
444
4104 # multi-part identifier could not be bound
447
# store a tuple of integrity error codes
451
2601, # violate unique index
452
2627, # violate UNIQUE KEY constraint
456
if sys.version_info[0] >= 3:
457
exc_base_class = Exception
463
exc_base_class = StandardError
469
def _decode_num(buf):
470
""" Decodes little-endian integer from buffer
472
Buffer can be of any size
474
return reduce(lambda acc, val: acc * 256 + _ord(val), reversed(buf), 0)
477
# exception hierarchy
478
class Warning(exc_base_class):
482
class Error(exc_base_class):
486
TimeoutError = socket.timeout
489
class InterfaceError(Error):
493
class DatabaseError(Error):
497
return 'SQL Server message %d, severity %d, state %d, ' \
498
'procedure %s, line %d:\n%s' % (self.number,
499
self.severity, self.state, self.procname,
500
self.line, self.text)
502
return 'SQL Server message %d, severity %d, state %d, ' \
503
'line %d:\n%s' % (self.number, self.severity,
504
self.state, self.line, self.text)
507
class ClosedConnectionError(InterfaceError):
509
super(ClosedConnectionError, self).__init__('Server closed connection')
512
class DataError(Error):
516
class OperationalError(DatabaseError):
520
class LoginError(OperationalError):
524
class IntegrityError(DatabaseError):
528
class InternalError(DatabaseError):
532
class ProgrammingError(DatabaseError):
536
class NotSupportedError(DatabaseError):
540
#############################
541
## DB-API type definitions ##
542
#############################
543
class DBAPITypeObject:
544
def __init__(self, *values):
545
self.values = set(values)
547
def __eq__(self, other):
548
return other in self.values
550
def __cmp__(self, other):
551
if other in self.values:
553
if other < self.values:
558
# standard dbapi type objects
559
STRING = DBAPITypeObject(SYBVARCHAR, SYBCHAR, SYBTEXT,
560
XSYBNVARCHAR, XSYBNCHAR, SYBNTEXT,
561
XSYBVARCHAR, XSYBCHAR, SYBMSXML)
562
BINARY = DBAPITypeObject(SYBIMAGE, SYBBINARY, SYBVARBINARY, XSYBVARBINARY, XSYBBINARY)
563
NUMBER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN,
564
SYBREAL, SYBFLT8, SYBFLTN)
565
DATETIME = DBAPITypeObject(SYBDATETIME, SYBDATETIME4, SYBDATETIMN)
566
DECIMAL = DBAPITypeObject(SYBMONEY, SYBMONEY4, SYBMONEYN, SYBNUMERIC,
568
ROWID = DBAPITypeObject()
570
# non-standard, but useful type objects
571
INTEGER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN)
572
REAL = DBAPITypeObject(SYBREAL, SYBFLT8, SYBFLTN)
573
XML = DBAPITypeObject(SYBMSXML)
576
# stored procedure output parameter
577
class output(object):
581
This is either the sql type declaration or python type instance
589
This is the value of the parameter.
593
def __init__(self, value=None, param_type=None):
594
""" Creates procedure output parameter.
596
:param param_type: either sql type declaration or python type
597
:param value: value to pass into procedure
599
if param_type is None:
600
if value is None or value is default:
601
raise ValueError('Output type cannot be autodetected')
602
elif isinstance(param_type, type) and value is not None:
603
if value is not default and not isinstance(value, param_type):
604
raise ValueError('value should match param_type', value, param_type)
605
self._type = param_type
611
return 'Binary({0})'.format(super(Binary, self).__repr__())
614
class _Default(object):
620
class InternalProc(object):
621
def __init__(self, proc_id, name):
622
self.proc_id = proc_id
625
def __unicode__(self):
628
SP_EXECUTESQL = InternalProc(TDS_SP_EXECUTESQL, 'sp_executesql')
635
def skipall(stm, size):
636
""" Skips exactly size bytes in stm
638
If EOF is reached before size bytes are skipped
639
will raise :class:`ClosedConnectionError`
641
:param stm: Stream to skip bytes in, should have read method
642
this read method can return less than requested
644
:param size: Number of bytes to skip.
650
raise ClosedConnectionError()
651
left = size - len(res)
655
raise ClosedConnectionError()
659
def read_chunks(stm, size):
660
""" Reads exactly size bytes from stm and produces chunks
662
May call stm.read multiple times until required
663
number of bytes is read.
664
If EOF is reached before size bytes are read
665
will raise :class:`ClosedConnectionError`
667
:param stm: Stream to read bytes from, should have read method,
668
this read method can return less than requested
670
:param size: Number of bytes to read.
678
raise ClosedConnectionError()
680
left = size - len(res)
684
raise ClosedConnectionError()
689
def readall(stm, size):
690
""" Reads exactly size bytes from stm
692
May call stm.read multiple times until required
693
number of bytes read.
694
If EOF is reached before size bytes are read
695
will raise :class:`ClosedConnectionError`
697
:param stm: Stream to read bytes from, should have read method
698
this read method can return less than requested
700
:param size: Number of bytes to read.
701
:returns: Bytes buffer of exactly given size.
703
return b''.join(read_chunks(stm, size))
706
def readall_fast(stm, size):
708
Slightly faster version of readall, it reads no more than two chunks.
709
Meaning that it can only be used to read small data that doesn't span
710
more that two packets.
712
:param stm: Stream to read from, should have read method.
713
:param size: Number of bytes to read.
716
buf, offset = stm.read_fast(size)
717
if len(buf) - offset < size:
720
buf += stm.read(size - len(buf))
725
class _TdsReader(object):
726
""" TDS stream reader
728
Provides stream-like interface for TDS packeted stream.
729
Also provides convinience methods to decode primitive data like
730
different kinds of integers etc.
732
def __init__(self, session):
734
self._pos = 0 # position in the buffer
735
self._have = 0 # number of bytes read from packet
736
self._size = 0 # size of current packet
737
self._session = session
738
self._transport = session._transport
744
""" Link to :class:`_TdsSession` object
749
def packet_type(self):
750
""" Type of current packet
752
Possible values are TDS_QUERY, TDS_LOGIN, etc.
756
def read_fast(self, size):
757
""" Faster version of read
759
Instead of returning sliced buffer it returns reference to internal
760
buffer and the offset to this buffer.
762
:param size: Number of bytes to read
763
:returns: Tuple of bytes buffer, and offset in this buffer
765
if self._pos >= len(self._buf):
766
if self._have >= self._size:
769
self._buf = self._transport.read(self._size - self._have)
771
self._have += len(self._buf)
774
return self._buf, offset
776
def unpack(self, struct):
777
""" Unpacks given structure from stream
779
:param struct: A struct.Struct instance
780
:returns: Result of unpacking
782
buf, offset = readall_fast(self, struct.size)
783
return struct.unpack_from(buf, offset)
786
""" Reads one byte from stream """
787
return self.unpack(_byte)[0]
789
def get_smallint(self):
790
""" Reads 16bit signed integer from the stream """
791
return self.unpack(_smallint_le)[0]
793
def get_usmallint(self):
794
""" Reads 16bit unsigned integer from the stream """
795
return self.unpack(_usmallint_le)[0]
798
""" Reads 32bit signed integer from the stream """
799
return self.unpack(_int_le)[0]
802
""" Reads 32bit unsigned integer from the stream """
803
return self.unpack(_uint_le)[0]
805
def get_uint_be(self):
806
""" Reads 32bit unsigned big-endian integer from the stream """
807
return self.unpack(_uint_be)[0]
810
""" Reads 64bit unsigned integer from the stream """
811
return self.unpack(_uint8_le)[0]
814
""" Reads 64bit signed integer from the stream """
815
return self.unpack(_int8_le)[0]
817
def read_ucs2(self, num_chars):
818
""" Reads num_chars UCS2 string from the stream """
819
buf = readall(self, num_chars * 2)
820
return ucs2_codec.decode(buf)[0]
822
def read_str(self, size, codec):
823
""" Reads byte string from the stream and decodes it
825
:param size: Size of string in bytes
826
:param codec: Instance of codec to decode string
827
:returns: Unicode string
829
return codec.decode(readall(self, size))[0]
831
def get_collation(self):
832
""" Reads :class:`Collation` object from stream """
833
buf = readall(self, Collation.wire_size)
834
return Collation.unpack(buf)
836
def unget_byte(self):
837
""" Returns one last read byte to stream
839
Can only be called once per read byte.
841
# this is a one trick pony...don't call it twice
846
""" Returns next byte from stream without consuming it
848
res = self.get_byte()
852
def read(self, size):
853
""" Reads size bytes from buffer
855
May return fewer bytes than requested
856
:param size: Number of bytes to read
857
:returns: Bytes buffer, possibly shorter than requested,
858
returns empty buffer in case of EOF
860
buf, offset = self.read_fast(size)
861
return buf[offset:offset + size]
863
def _read_packet(self):
864
""" Reads next TDS packet from the underlying transport
866
If timeout is happened during reading of packet's header will
867
cancel current request.
868
Can only be called when transport's read pointer is at the begining
872
header = readall(self._transport, _header.size)
874
self._session._put_cancel()
877
self._type, self._status, self._size, self._session._spid, _ = _header.unpack(header)
878
self._have = _header.size
879
assert self._size > self._have, 'Empty packet doesn make any sense'
880
self._buf = self._transport.read(self._size - self._have)
881
self._have += len(self._buf)
883
def read_whole_packet(self):
884
""" Reads single packet and returns bytes payload of the packet
886
Can only be called when transport's read pointer is at the beginning
890
return readall(self, self._size - _header.size)
893
class _TdsWriter(object):
894
""" TDS stream writer
896
Handles splitting of incoming data into TDS packets according to TDS protocol.
897
Provides convinience methods for writing primitive data types.
899
def __init__(self, session, bufsize):
900
self._session = session
902
self._transport = session
904
self._buf = bytearray(bufsize)
909
""" Back reference to parent :class:`_TdsSession` object """
914
""" Size of the buffer """
915
return len(self._buf)
918
def bufsize(self, bufsize):
919
if len(self._buf) == bufsize:
922
if bufsize > len(self._buf):
923
self._buf.extend(b'\0' * (bufsize - len(self._buf)))
925
self._buf = self._buf[0:bufsize]
927
def begin_packet(self, packet_type):
928
""" Starts new packet stream
930
:param packet_type: Type of TDS stream, e.g. TDS_PRELOGIN, TDS_QUERY etc.
932
self._type = packet_type
935
def pack(self, struct, *args):
936
""" Packs and writes structure into stream """
937
self.write(struct.pack(*args))
939
def put_byte(self, value):
940
""" Writes single byte into stream """
941
self.pack(_byte, value)
943
def put_smallint(self, value):
944
""" Writes 16-bit signed integer into the stream """
945
self.pack(_smallint_le, value)
947
def put_usmallint(self, value):
948
""" Writes 16-bit unsigned integer into the stream """
949
self.pack(_usmallint_le, value)
951
def put_smallint_be(self, value):
952
""" Writes 16-bit signed big-endian integer into the stream """
953
self.pack(_smallint_be, value)
955
def put_usmallint_be(self, value):
956
""" Writes 16-bit unsigned big-endian integer into the stream """
957
self.pack(_usmallint_be, value)
959
def put_int(self, value):
960
""" Writes 32-bit signed integer into the stream """
961
self.pack(_int_le, value)
963
def put_uint(self, value):
964
""" Writes 32-bit unsigned integer into the stream """
965
self.pack(_uint_le, value)
967
def put_int_be(self, value):
968
""" Writes 32-bit signed big-endian integer into the stream """
969
self.pack(_int_be, value)
971
def put_uint_be(self, value):
972
""" Writes 32-bit unsigned big-endian integer into the stream """
973
self.pack(_uint_be, value)
975
def put_int8(self, value):
976
""" Writes 64-bit signed integer into the stream """
977
self.pack(_int8_le, value)
979
def put_uint8(self, value):
980
""" Writes 64-bit unsigned integer into the stream """
981
self.pack(_uint8_le, value)
983
def put_collation(self, collation):
984
""" Writes :class:`Collation` structure into the stream """
985
self.write(collation.pack())
987
def write(self, data):
988
""" Writes given bytes buffer into the stream
990
Function returns only when entire buffer is written
993
while data_off < len(data):
994
left = len(self._buf) - self._pos
996
self._write_packet(final=False)
998
to_write = min(left, len(data) - data_off)
999
self._buf[self._pos:self._pos + to_write] = data[data_off:data_off + to_write]
1000
self._pos += to_write
1001
data_off += to_write
1003
def write_ucs2(self, s):
1004
""" Write string encoding it in UCS2 into stream """
1005
self.write_string(s, ucs2_codec)
1007
def write_string(self, s, codec):
1008
""" Write string encoding it with codec into stream """
1009
for i in xrange(0, len(s), self.bufsize):
1010
chunk = s[i:i + self.bufsize]
1011
buf, consumed = codec.encode(chunk)
1012
assert consumed == len(chunk)
1016
""" Closes current packet stream """
1017
return self._write_packet(final=True)
1019
def _write_packet(self, final):
1020
""" Writes single TDS packet into underlying transport.
1022
Data for the packet is taken from internal buffer.
1024
:param final: True means this is the final packet in substream.
1026
status = 1 if final else 0
1027
_header.pack_into(self._buf, 0, self._type, status, self._pos, 0, self._packet_no)
1028
self._packet_no = (self._packet_no + 1) % 256
1029
self._transport.send(self._buf[:self._pos], final)
1033
class MemoryChunkedHandler(object):
1034
def begin(self, column, size):
1038
def new_chunk(self, val):
1039
#logger.debug('MemoryChunkedHandler.new_chunk(sz=%d)', len(val))
1040
self._chunks.append(val)
1043
return b''.join(self._chunks)
1046
class MemoryStrChunkedHandler(object):
1047
def begin(self, column, size):
1051
def new_chunk(self, val):
1052
#logger.debug('MemoryChunkedHandler.new_chunk(sz=%d)', len(val))
1053
self._chunks.append(val)
1056
return ''.join(self._chunks)
1059
class BaseType(object):
1060
""" Base type for TDS data types.
1062
All TDS types should derive from it.
1063
In addition actual types should provide the following:
1065
- type - class variable storing type identifier
1067
def get_typeid(self):
1068
""" Returns type identifier of type. """
1071
def get_declaration(self):
1072
""" Returns SQL declaration for this type.
1074
Examples are: NVARCHAR(10), TEXT, TINYINT
1075
Should be implemented in actual types.
1077
raise NotImplementedError
1080
def from_declaration(cls, declaration, nullable, connection):
1081
""" Class method that parses declaration and returns a type instance.
1083
:param declaration: type declaration string
1084
:param nullable: true if type have to be nullable, false otherwise
1085
:param connection: instance of :class:`_TdsSocket`
1086
:return: If declaration is parsed, returns type instance,
1087
otherwise returns None.
1089
Should be implemented in actual types.
1091
raise NotImplementedError
1094
def from_stream(cls, r):
1095
""" Class method that reads and returns a type instance.
1097
:param r: An instance of :class:`_TdsReader` to read type from.
1099
Should be implemented in actual types.
1101
raise NotImplementedError
1103
def write_info(self, w):
1104
""" Writes type info into w stream.
1106
:param w: An instance of :class:`_TdsWriter` to write into.
1108
Should be symmetrical to from_stream method.
1109
Should be implemented in actual types.
1111
raise NotImplementedError
1113
def write(self, w, value):
1114
""" Writes type's value into stream
1116
:param w: An instance of :class:`_TdsWriter` to write into.
1117
:param value: A value to be stored, should be compatible with the type
1119
Should be implemented in actual types.
1121
raise NotImplementedError
1124
""" Reads value from the stream.
1126
:param r: An instance of :class:`_TdsReader` to read value from.
1127
:return: A read value.
1129
Should be implemented in actual types.
1131
raise NotImplementedError
1134
class BasePrimitiveType(BaseType):
1135
""" Base type for primitive TDS data types.
1137
Primitive type is a fixed size type with no type arguments.
1138
All primitive TDS types should derive from it.
1139
In addition actual types should provide the following:
1141
- type - class variable storing type identifier
1142
- declaration - class variable storing name of sql type
1143
- isntance - class variable storing instance of class
1146
def get_declaration(self):
1147
return self.declaration
1150
def from_declaration(cls, declaration, nullable, connection):
1151
if not nullable and declaration == cls.declaration:
1155
def from_stream(cls, r):
1158
def write_info(self, w):
1162
class BaseTypeN(BaseType):
1163
""" Base type for nullable TDS data types.
1165
All nullable TDS types should derive from it.
1166
In addition actual types should provide the following:
1168
- type - class variable storing type identifier
1169
- subtypes - class variable storing dict {subtype_size: subtype_instance}
1172
def __init__(self, size):
1173
assert size in self.subtypes
1175
self._current_subtype = self.subtypes[size]
1177
def get_typeid(self):
1178
return self._current_subtype.get_typeid()
1180
def get_declaration(self):
1181
return self._current_subtype.get_declaration()
1184
def from_declaration(cls, declaration, nullable, connection):
1186
for size, subtype in cls.subtypes.items():
1187
inst = subtype.from_declaration(declaration, False, connection)
1192
def from_stream(cls, r):
1194
if size not in cls.subtypes:
1195
raise InterfaceError('Invalid %s size' % cls.type, size)
1198
def write_info(self, w):
1199
w.put_byte(self._size)
1205
if size not in self.subtypes:
1206
raise r.session.bad_stream('Invalid %s size' % self.type, size)
1207
return self.subtypes[size].read(r)
1209
def write(self, w, val):
1213
w.put_byte(self._size)
1214
self._current_subtype.write(w, val)
1216
class Bit(BasePrimitiveType):
1220
def write(self, w, value):
1221
w.put_byte(1 if value else 0)
1224
return bool(r.get_byte())
1226
Bit.instance = Bit()
1229
class BitN(BaseTypeN):
1231
subtypes = {1 : Bit.instance}
1233
BitN.instance = BitN(1)
1236
class TinyInt(BasePrimitiveType):
1238
declaration = 'TINYINT'
1240
def write(self, w, val):
1246
TinyInt.instance = TinyInt()
1249
class SmallInt(BasePrimitiveType):
1251
declaration = 'SMALLINT'
1253
def write(self, w, val):
1257
return r.get_smallint()
1259
SmallInt.instance = SmallInt()
1262
class Int(BasePrimitiveType):
1266
def write(self, w, val):
1272
Int.instance = Int()
1275
class BigInt(BasePrimitiveType):
1277
declaration = 'BIGINT'
1279
def write(self, w, val):
1285
BigInt.instance = BigInt()
1288
class IntN(BaseTypeN):
1292
1: TinyInt.instance,
1293
2: SmallInt.instance,
1299
class Real(BasePrimitiveType):
1301
declaration = 'REAL'
1303
def write(self, w, val):
1304
w.pack(_flt4_struct, val)
1307
return r.unpack(_flt4_struct)[0]
1309
Real.instance = Real()
1312
class Float(BasePrimitiveType):
1314
declaration = 'FLOAT'
1316
def write(self, w, val):
1317
w.pack(_flt8_struct, val)
1320
return r.unpack(_flt8_struct)[0]
1322
Float.instance = Float()
1325
class FloatN(BaseTypeN):
1334
class VarChar70(BaseType):
1337
def __init__(self, size, codec):
1338
#if size <= 0 or size > 8000:
1339
# raise DataError('Invalid size for VARCHAR field')
1344
def from_stream(cls, r):
1345
size = r.get_smallint()
1346
return cls(size, codec=r._session.conn.server_codec)
1349
def from_declaration(cls, declaration, nullable, connection):
1350
m = re.match(r'VARCHAR\((\d+)\)', declaration)
1352
return cls(int(m.group(1)), connection.server_codec)
1354
def get_declaration(self):
1355
return 'VARCHAR({0})'.format(self._size)
1357
def write_info(self, w):
1358
w.put_smallint(self._size)
1359
#w.put_smallint(self._size)
1361
def write(self, w, val):
1365
val = force_unicode(val)
1366
val, _ = self._codec.encode(val)
1367
w.put_smallint(len(val))
1368
#w.put_smallint(len(val))
1372
size = r.get_smallint()
1375
return r.read_str(size, self._codec)
1378
class VarChar71(VarChar70):
1379
def __init__(self, size, collation):
1380
super(VarChar71, self).__init__(size, codec=collation.get_codec())
1381
self._collation = collation
1384
def from_stream(cls, r):
1385
size = r.get_smallint()
1386
collation = r.get_collation()
1387
return cls(size, collation)
1390
def from_declaration(cls, declaration, nullable, connection):
1391
m = re.match(r'VARCHAR\((\d+)\)', declaration)
1393
return cls(int(m.group(1)), connection.collation)
1395
def write_info(self, w):
1396
super(VarChar71, self).write_info(w)
1397
w.put_collation(self._collation)
1400
class VarChar72(VarChar71):
1402
def from_stream(cls, r):
1403
size = r.get_usmallint()
1404
collation = r.get_collation()
1406
return VarCharMax(collation)
1407
return cls(size, collation)
1410
def from_declaration(cls, declaration, nullable, connection):
1411
if declaration == 'VARCHAR(MAX)':
1412
return VarCharMax(connection.collation)
1413
m = re.match(r'VARCHAR\((\d+)\)', declaration)
1415
return cls(int(m.group(1)), connection.collation)
1418
class VarCharMax(VarChar72):
1419
def __init__(self, collation):
1420
super(VarChar72, self).__init__(0, collation)
1422
def get_declaration(self):
1423
return 'VARCHAR(MAX)'
1425
def write_info(self, w):
1426
w.put_usmallint(PLP_MARKER)
1427
w.put_collation(self._collation)
1429
def write(self, w, val):
1431
w.put_uint8(PLP_NULL)
1433
val = force_unicode(val)
1434
val, _ = self._codec.encode(val)
1435
w.put_int8(len(val))
1445
return ''.join(iterdecode(r.chunks(), self._codec))
1448
class NVarChar70(BaseType):
1451
def __init__(self, size):
1452
#if size <= 0 or size > 4000:
1453
# raise DataError('Invalid size for NVARCHAR field')
1457
def from_stream(cls, r):
1458
size = r.get_usmallint()
1459
return cls(size / 2)
1462
def from_declaration(cls, declaration, nullable, connection):
1463
m = re.match(r'NVARCHAR\((\d+)\)', declaration)
1465
return cls(int(m.group(1)))
1467
def get_declaration(self):
1468
return 'NVARCHAR({0})'.format(self._size)
1470
def write_info(self, w):
1471
w.put_usmallint(self._size * 2)
1472
#w.put_smallint(self._size)
1474
def write(self, w, val):
1476
w.put_usmallint(0xffff)
1478
if isinstance(val, bytes):
1479
val = force_unicode(val)
1480
buf, _ = ucs2_codec.encode(val)
1486
size = r.get_usmallint()
1489
return r.read_str(size, ucs2_codec)
1492
class NVarChar71(NVarChar70):
1493
def __init__(self, size, collation=raw_collation):
1494
super(NVarChar71, self).__init__(size)
1495
self._collation = collation
1498
def from_stream(cls, r):
1499
size = r.get_usmallint()
1500
collation = r.get_collation()
1501
return cls(size / 2, collation)
1504
def from_declaration(cls, declaration, nullable, connection):
1505
m = re.match(r'NVARCHAR\((\d+)\)', declaration)
1507
return cls(int(m.group(1)), connection.collation)
1509
def write_info(self, w):
1510
super(NVarChar71, self).write_info(w)
1511
w.put_collation(self._collation)
1514
class NVarChar72(NVarChar71):
1516
def from_stream(cls, r):
1517
size = r.get_usmallint()
1518
collation = r.get_collation()
1520
return NVarCharMax(size, collation)
1521
return cls(size / 2, collation=collation)
1524
def from_declaration(cls, declaration, nullable, connection):
1525
if declaration == 'NVARCHAR(MAX)':
1526
return VarCharMax(connection.collation)
1527
m = re.match(r'NVARCHAR\((\d+)\)', declaration)
1529
return cls(int(m.group(1)), connection.collation)
1532
class NVarCharMax(NVarChar72):
1533
def get_typeid(self):
1536
def get_declaration(self):
1537
return 'NVARCHAR(MAX)'
1539
def write_info(self, w):
1540
w.put_usmallint(PLP_MARKER)
1541
w.put_collation(self._collation)
1543
def write(self, w, val):
1545
w.put_uint8(PLP_NULL)
1547
if isinstance(val, bytes):
1548
val = force_unicode(val)
1549
val, _ = ucs2_codec.encode(val)
1550
w.put_uint8(len(val))
1552
w.put_uint(len(val))
1560
res = ''.join(iterdecode(r.chunks(), ucs2_codec))
1564
class Xml(NVarCharMax):
1568
def __init__(self, schema={}):
1569
super(Xml, self).__init__(0)
1570
self._schema = schema
1572
def get_typeid(self):
1575
def get_declaration(self):
1576
return self.declaration
1579
def from_stream(cls, r):
1580
has_schema = r.get_byte()
1583
schema['dbname'] = r.read_ucs2(r.get_byte())
1584
schema['owner'] = r.read_ucs2(r.get_byte())
1585
schema['collection'] = r.read_ucs2(r.get_smallint())
1589
def from_declaration(cls, declaration, nullable, connection):
1590
if declaration == cls.declaration:
1593
def write_info(self, w):
1596
w.put_byte(len(self._schema['dbname']))
1597
w.write_ucs2(self._schema['dbname'])
1598
w.put_byte(len(self._schema['owner']))
1599
w.write_ucs2(self._schema['owner'])
1600
w.put_usmallint(len(self._schema['collection']))
1601
w.write_ucs2(self._schema['collection'])
1606
class Text70(BaseType):
1608
declaration = 'TEXT'
1610
def __init__(self, size=0, table_name='', codec=None):
1612
self._table_name = table_name
1616
def from_stream(cls, r):
1618
table_name = r.read_ucs2(r.get_smallint())
1619
return cls(size, table_name, codec=r.session.conn.server_codec)
1622
def from_declaration(cls, declaration, nullable, connection):
1623
if declaration == cls.declaration:
1626
def get_declaration(self):
1627
return self.declaration
1629
def write_info(self, w):
1630
w.put_int(self._size)
1632
def write(self, w, val):
1636
val = force_unicode(val)
1637
val, _ = self._codec.encode(val)
1645
readall(r, size) # textptr
1646
readall(r, 8) # timestamp
1647
colsize = r.get_int()
1648
return r.read_str(colsize, self._codec)
1651
class Text71(Text70):
1652
def __init__(self, size=0, table_name='', collation=raw_collation):
1654
self._collation = collation
1655
self._codec = collation.get_codec()
1656
self._table_name = table_name
1659
def from_stream(cls, r):
1661
collation = r.get_collation()
1662
table_name = r.read_ucs2(r.get_smallint())
1663
return cls(size, table_name, collation)
1665
def write_info(self, w):
1666
w.put_int(self._size)
1667
w.put_collation(self._collation)
1670
class Text72(Text71):
1671
def __init__(self, size=0, table_name_parts=[], collation=raw_collation):
1672
super(Text72, self).__init__(size, '.'.join(table_name_parts), collation)
1673
self._table_name_parts = table_name_parts
1676
def from_stream(cls, r):
1678
collation = r.get_collation()
1679
num_parts = r.get_byte()
1681
for _ in range(num_parts):
1682
parts.append(r.read_ucs2(r.get_smallint()))
1683
return cls(size, parts, collation)
1686
class NText70(BaseType):
1688
declaration = 'NTEXT'
1690
def __init__(self, size=0, table_name=''):
1692
self._table_name = table_name
1695
def from_stream(cls, r):
1697
table_name = r.read_ucs2(r.get_smallint())
1698
return cls(size, table_name)
1701
def from_declaration(cls, declaration, nullable, connection):
1702
if declaration == cls.declaration:
1705
def get_declaration(self):
1706
return self.declaration
1709
textptr_size = r.get_byte()
1710
if textptr_size == 0:
1712
readall(r, textptr_size) # textptr
1713
readall(r, 8) # timestamp
1714
colsize = r.get_int()
1715
return r.read_str(colsize, ucs2_codec)
1717
def write_info(self, w):
1718
w.put_int(self._size * 2)
1720
def write(self, w, val):
1724
w.put_int(len(val) * 2)
1728
class NText71(NText70):
1729
def __init__(self, size=0, table_name='', collation=raw_collation):
1731
self._collation = collation
1732
self._table_name = table_name
1735
def from_stream(cls, r):
1737
collation = r.get_collation()
1738
table_name = r.read_ucs2(r.get_smallint())
1739
return cls(size, table_name, collation)
1741
def write_info(self, w):
1742
w.put_int(self._size)
1743
w.put_collation(self._collation)
1746
textptr_size = r.get_byte()
1747
if textptr_size == 0:
1749
readall(r, textptr_size) # textptr
1750
readall(r, 8) # timestamp
1751
colsize = r.get_int()
1752
return r.read_str(colsize, ucs2_codec)
1755
class NText72(NText71):
1756
def __init__(self, size=0, table_name_parts=[], collation=raw_collation):
1758
self._collation = collation
1759
self._table_name_parts = table_name_parts
1762
def from_stream(cls, r):
1764
collation = r.get_collation()
1765
num_parts = r.get_byte()
1767
for _ in range(num_parts):
1768
parts.append(r.read_ucs2(r.get_smallint()))
1769
return cls(size, parts, collation)
1772
class VarBinary(BaseType):
1773
type = XSYBVARBINARY
1775
def __init__(self, size):
1779
def from_stream(cls, r):
1780
size = r.get_usmallint()
1784
def from_declaration(cls, declaration, nullable, connection):
1785
m = re.match(r'VARBINARY\((\d+)\)', declaration)
1787
return cls(int(m.group(1)))
1789
def get_declaration(self):
1790
return 'VARBINARY({0})'.format(self._size)
1792
def write_info(self, w):
1793
w.put_usmallint(self._size)
1795
def write(self, w, val):
1797
w.put_usmallint(0xffff)
1799
w.put_usmallint(len(val))
1803
size = r.get_usmallint()
1806
return readall(r, size)
1809
class VarBinary72(VarBinary):
1811
def from_stream(cls, r):
1812
size = r.get_usmallint()
1814
return VarBinaryMax()
1818
def from_declaration(cls, declaration, nullable, connection):
1819
if declaration == 'VARBINARY(MAX)':
1820
return VarBinaryMax()
1821
m = re.match(r'VARBINARY\((\d+)\)', declaration)
1823
return cls(int(m.group(1)))
1826
class VarBinaryMax(VarBinary):
1828
super(VarBinaryMax, self).__init__(0)
1830
def get_declaration(self):
1831
return 'VARBINARY(MAX)'
1833
def write_info(self, w):
1834
w.put_usmallint(PLP_MARKER)
1836
def write(self, w, val):
1838
w.put_uint8(PLP_NULL)
1840
w.put_uint8(len(val))
1842
w.put_uint(len(val))
1850
return b''.join(r.chunks())
1853
class Image70(BaseType):
1855
declaration = 'IMAGE'
1857
def __init__(self, size=0, table_name=''):
1858
self._table_name = table_name
1861
def get_declaration(self):
1862
return self.declaration
1865
def from_stream(cls, r):
1867
table_name = r.read_ucs2(r.get_smallint())
1868
return cls(size, table_name)
1871
def from_declaration(cls, declaration, nullable, connection):
1872
if declaration == cls.declaration:
1877
if size == 16: # Jeff's hack
1878
readall(r, 16) # textptr
1879
readall(r, 8) # timestamp
1880
colsize = r.get_int()
1881
return readall(r, colsize)
1885
def write(self, w, val):
1892
def write_info(self, w):
1893
w.put_int(self._size)
1896
class Image72(Image70):
1897
def __init__(self, size=0, parts=[]):
1902
def from_stream(cls, r):
1904
num_parts = r.get_byte()
1906
for _ in range(num_parts):
1907
parts.append(r.read_ucs2(r.get_usmallint()))
1908
return Image72(size, parts)
1911
class BaseDateTime(BaseType):
1912
_base_date = datetime(1900, 1, 1)
1913
_min_date = datetime(1753, 1, 1, 0, 0, 0)
1914
_max_date = datetime(9999, 12, 31, 23, 59, 59, 997000)
1917
class SmallDateTime(BasePrimitiveType, BaseDateTime):
1919
declaration = 'SMALLDATETIME'
1921
_max_date = datetime(2079, 6, 6, 23, 59, 0)
1922
_struct = struct.Struct('<HH')
1924
def write(self, w, val):
1926
if not w.session.use_tz:
1927
raise DataError('Timezone-aware datetime is used without specifying use_tz')
1928
val = val.astimezone(w.session.use_tz).replace(tzinfo=None)
1929
days = (val - self._base_date).days
1930
minutes = val.hour * 60 + val.minute
1931
w.pack(self._struct, days, minutes)
1934
days, minutes = r.unpack(self._struct)
1936
if r.session.tzinfo_factory is not None:
1937
tzinfo = r.session.tzinfo_factory(0)
1938
return (self._base_date + timedelta(days=days, minutes=minutes)).replace(tzinfo=tzinfo)
1940
SmallDateTime.instance = SmallDateTime()
1943
class DateTime(BasePrimitiveType, BaseDateTime):
1945
declaration = 'DATETIME'
1947
_struct = struct.Struct('<ll')
1949
def write(self, w, val):
1951
if not w.session.use_tz:
1952
raise DataError('Timezone-aware datetime is used without specifying use_tz')
1953
val = val.astimezone(w.session.use_tz).replace(tzinfo=None)
1954
w.write(self.encode(val))
1957
days, t = r.unpack(self._struct)
1959
if r.session.tzinfo_factory is not None:
1960
tzinfo = r.session.tzinfo_factory(0)
1961
return _applytz(self.decode(days, t), tzinfo)
1964
def validate(cls, value):
1965
if not (cls._min_date <= value <= cls._max_date):
1966
raise DataError('Date is out of range')
1969
def encode(cls, value):
1970
#cls.validate(value)
1971
if type(value) == date:
1972
value = datetime.combine(value, time(0, 0, 0))
1973
days = (value - cls._base_date).days
1974
ms = value.microsecond // 1000
1975
tm = (value.hour * 60 * 60 + value.minute * 60 + value.second) * 300 + int(round(ms * 3 / 10.0))
1976
return cls._struct.pack(days, tm)
1979
def decode(cls, days, time):
1980
ms = int(round(time % 300 * 10 / 3.0))
1982
return cls._base_date + timedelta(days=days, seconds=secs, milliseconds=ms)
1984
DateTime.instance = DateTime()
1987
class DateTimeN(BaseTypeN, BaseDateTime):
1990
4: SmallDateTime.instance,
1991
8: DateTime.instance,
1995
class BaseDateTime73(BaseType):
1996
_precision_to_len = {
2007
_base_date = datetime(1, 1, 1)
2009
def _write_time(self, w, t, prec):
2010
secs = t.hour * 60 * 60 + t.minute * 60 + t.second
2011
val = (secs * 10 ** 7 + t.microsecond * 10) // (10 ** (7 - prec))
2012
w.write(struct.pack('<Q', val)[:self._precision_to_len[prec]])
2014
def _read_time(self, r, size, prec, use_tz):
2015
time_buf = readall(r, size)
2016
val = _decode_num(time_buf)
2017
val *= 10 ** (7 - prec)
2018
nanoseconds = val * 100
2019
hours = nanoseconds // 1000000000 // 60 // 60
2020
nanoseconds -= hours * 60 * 60 * 1000000000
2021
minutes = nanoseconds // 1000000000 // 60
2022
nanoseconds -= minutes * 60 * 1000000000
2023
seconds = nanoseconds // 1000000000
2024
nanoseconds -= seconds * 1000000000
2025
return time(hours, minutes, seconds, nanoseconds // 1000, tzinfo=use_tz)
2027
def _write_date(self, w, value):
2028
if type(value) == date:
2029
value = datetime.combine(value, time(0, 0, 0))
2030
days = (value - self._base_date).days
2031
buf = struct.pack('<l', days)[:3]
2034
def _read_date(self, r):
2035
days = _decode_num(readall(r, 3))
2036
return (self._base_date + timedelta(days=days)).date()
2039
class MsDate(BasePrimitiveType, BaseDateTime73):
2041
declaration = 'DATE'
2044
MAX = date(9999, 12, 31)
2046
def write(self, w, value):
2051
self._write_date(w, value)
2053
def read_fixed(self, r):
2054
return self._read_date(r)
2060
return self._read_date(r)
2062
MsDate.instance = MsDate()
2065
class MsTime(BaseDateTime73):
2068
def __init__(self, prec):
2070
self._size = self._precision_to_len[prec]
2073
def from_stream(cls, r):
2078
def from_declaration(cls, declaration, nullable, connection):
2079
m = re.match(r'TIME\((\d+)\)', declaration)
2081
return cls(int(m.group(1)))
2083
def get_declaration(self):
2084
return 'TIME({0})'.format(self._prec)
2086
def write_info(self, w):
2087
w.put_byte(self._prec)
2089
def write(self, w, value):
2094
if not w.session.use_tz:
2095
raise DataError('Timezone-aware datetime is used without specifying use_tz')
2096
value = value.astimezone(w.session.use_tz).replace(tzinfo=None)
2097
w.put_byte(self._size)
2098
self._write_time(w, value, self._prec)
2100
def read_fixed(self, r, size):
2102
if r.session.tzinfo_factory is not None:
2103
tzinfo = r.session.tzinfo_factory(0)
2104
return self._read_time(r, size, self._prec, tzinfo)
2110
return self.read_fixed(r, size)
2113
class DateTime2(BaseDateTime73):
2114
type = SYBMSDATETIME2
2116
def __init__(self, prec=7):
2118
self._size = self._precision_to_len[prec] + 3
2121
def from_stream(cls, r):
2125
def get_declaration(self):
2126
return 'DATETIME2({0})'.format(self._prec)
2129
def from_declaration(cls, declaration, nullable, connection):
2130
if declaration == 'DATETIME2':
2132
m = re.match(r'DATETIME2\((\d+)\)', declaration)
2134
return cls(int(m.group(1)))
2136
def write_info(self, w):
2137
w.put_byte(self._prec)
2139
def write(self, w, value):
2144
if not w.session.use_tz:
2145
raise DataError('Timezone-aware datetime is used without specifying use_tz')
2146
value = value.astimezone(w.session.use_tz).replace(tzinfo=None)
2147
w.put_byte(self._size)
2148
self._write_time(w, value, self._prec)
2149
self._write_date(w, value)
2151
def read_fixed(self, r, size):
2153
if r.session.tzinfo_factory is not None:
2154
tzinfo = r.session.tzinfo_factory(0)
2155
time = self._read_time(r, size - 3, self._prec, tzinfo)
2156
date = self._read_date(r)
2157
return datetime.combine(date, time)
2163
return self.read_fixed(r, size)
2166
class DateTimeOffset(BaseDateTime73):
2167
type = SYBMSDATETIMEOFFSET
2169
def __init__(self, prec=7):
2171
self._size = self._precision_to_len[prec] + 5
2174
def from_stream(cls, r):
2179
def from_declaration(cls, declaration, nullable, connection):
2180
if declaration == 'DATETIMEOFFSET':
2182
m = re.match(r'DATETIMEOFFSET\((\d+)\)', declaration)
2184
return cls(int(m.group(1)))
2186
def get_declaration(self):
2187
return 'DATETIMEOFFSET({0})'.format(self._prec)
2189
def write_info(self, w):
2190
w.put_byte(self._prec)
2192
def write(self, w, value):
2196
utcoffset = value.utcoffset()
2197
value = value.astimezone(_utc).replace(tzinfo=None)
2199
w.put_byte(self._size)
2200
self._write_time(w, value, self._prec)
2201
self._write_date(w, value)
2202
w.put_smallint(int(total_seconds(utcoffset)) // 60)
2204
def read_fixed(self, r, size):
2205
time = self._read_time(r, size - 5, self._prec, _utc)
2206
date = self._read_date(r)
2207
offset = r.get_smallint()
2208
tzinfo_factory = r._session.tzinfo_factory
2209
if tzinfo_factory is None:
2210
from .tz import FixedOffsetTimezone
2211
tzinfo_factory = FixedOffsetTimezone
2212
tz = tzinfo_factory(offset)
2213
return datetime.combine(date, time).astimezone(tz)
2219
return self.read_fixed(r, size)
2222
class MsDecimal(BaseType):
2229
# precision can't be 0 but using a value > 0 assure no
2230
# core if for some bug it's 0...
2233
5, 5, 5, 5, 5, 5, 5, 5, 5,
2234
9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
2235
13, 13, 13, 13, 13, 13, 13, 13, 13,
2236
17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
2239
_info_struct = struct.Struct('BBB')
2246
def precision(self):
2249
def __init__(self, scale=0, prec=18):
2251
raise DataError('Precision of decimal value is out of range')
2254
self._size = self._bytes_per_prec[prec]
2257
def from_value(cls, value):
2258
if not (-10 ** 38 + 1 <= value <= 10 ** 38 - 1):
2259
raise DataError('Decimal value is out of range')
2260
value = value.normalize()
2261
_, digits, exp = value.as_tuple()
2264
prec = len(digits) + exp
2267
prec = max(len(digits), scale)
2268
return cls(scale=scale, prec=prec)
2271
def from_stream(cls, r):
2272
size, prec, scale = r.unpack(cls._info_struct)
2273
return cls(scale=scale, prec=prec)
2276
def from_declaration(cls, declaration, nullable, connection):
2277
if declaration == 'DECIMAL':
2279
m = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', declaration)
2281
return cls(int(m.group(2)), int(m.group(1)))
2283
def get_declaration(self):
2284
return 'DECIMAL({0},{1})'.format(self._prec, self._scale)
2286
def write_info(self, w):
2287
w.pack(self._info_struct, self._size, self._prec, self._scale)
2289
def write(self, w, value):
2293
if not isinstance(value, Decimal):
2294
value = Decimal(value)
2295
value = value.normalize()
2300
positive = 1 if val > 0 else 0
2301
w.put_byte(positive) # sign
2302
with localcontext() as ctx:
2307
val = val * (10 ** scale)
2308
for i in range(size):
2309
w.put_byte(int(val % 256))
2313
def _decode(self, positive, buf):
2314
val = _decode_num(buf)
2316
with localcontext() as ctx:
2320
val /= 10 ** self._scale
2323
def read_fixed(self, r, size):
2324
positive = r.get_byte()
2325
buf = readall(r, size - 1)
2326
return self._decode(positive, buf)
2332
return self.read_fixed(r, size)
2335
class Money4(BasePrimitiveType):
2337
declaration = 'SMALLMONEY'
2340
return Decimal(r.get_int()) / 10000
2342
def write(self, w, val):
2343
val = int(val * 10000)
2346
Money4.instance = Money4()
2349
class Money8(BasePrimitiveType):
2351
declaration = 'MONEY'
2353
_struct = struct.Struct('<lL')
2356
hi, lo = r.unpack(self._struct)
2357
val = hi * (2 ** 32) + lo
2358
return Decimal(val) / 10000
2360
def write(self, w, val):
2362
hi = int(val // (2 ** 32))
2363
lo = int(val % (2 ** 32))
2364
w.pack(self._struct, hi, lo)
2366
Money8.instance = Money8()
2369
class MoneyN(BaseTypeN):
2377
class MsUnique(BaseType):
2379
declaration = 'UNIQUEIDENTIFIER'
2382
def from_stream(cls, r):
2385
raise InterfaceError('Invalid size of UNIQUEIDENTIFIER field')
2389
def from_declaration(cls, declaration, nullable, connection):
2390
if declaration == cls.declaration:
2393
def get_declaration(self):
2394
return self.declaration
2396
def write_info(self, w):
2399
def write(self, w, value):
2404
w.write(value.bytes_le)
2406
def read_fixed(self, r, size):
2407
return uuid.UUID(bytes_le=readall(r, size))
2414
raise InterfaceError('Invalid size of UNIQUEIDENTIFIER field')
2415
return self.read_fixed(r, size)
2416
MsUnique.instance = MsUnique()
2419
def _variant_read_str(r, size):
2420
collation = r.get_collation()
2422
return r.read_str(size, collation.get_codec())
2425
def _variant_read_nstr(r, size):
2428
return r.read_str(size, ucs2_codec)
2431
def _variant_read_decimal(r, size):
2432
prec, scale = r.unpack(Variant._decimal_info_struct)
2433
return MsDecimal(prec=prec, scale=scale).read_fixed(r, size)
2436
def _variant_read_binary(r, size):
2438
return readall(r, size)
2441
class Variant(BaseType):
2443
declaration = 'SQL_VARIANT'
2445
_decimal_info_struct = struct.Struct('BB')
2448
GUIDTYPE: lambda r, size: MsUnique.instance.read_fixed(r, size),
2449
BITTYPE: lambda r, size: Bit.instance.read(r),
2450
INT1TYPE: lambda r, size: TinyInt.instance.read(r),
2451
INT2TYPE: lambda r, size: SmallInt.instance.read(r),
2452
INT4TYPE: lambda r, size: Int.instance.read(r),
2453
INT8TYPE: lambda r, size: BigInt.instance.read(r),
2454
DATETIMETYPE: lambda r, size: DateTime.instance.read(r),
2455
DATETIM4TYPE: lambda r, size: SmallDateTime.instance.read(r),
2456
FLT4TYPE: lambda r, size: Real.instance.read(r),
2457
FLT8TYPE: lambda r, size: Float.instance.read(r),
2458
MONEYTYPE: lambda r, size: Money8.instance.read(r),
2459
MONEY4TYPE: lambda r, size: Money4.instance.read(r),
2460
DATENTYPE: lambda r, size: MsDate.instance.read_fixed(r),
2462
TIMENTYPE: lambda r, size: MsTime(prec=r.get_byte()).read_fixed(r, size),
2463
DATETIME2NTYPE: lambda r, size: DateTime2(prec=r.get_byte()).read_fixed(r, size),
2464
DATETIMEOFFSETNTYPE: lambda r, size: DateTimeOffset(prec=r.get_byte()).read_fixed(r, size),
2466
BIGVARBINTYPE: _variant_read_binary,
2467
BIGBINARYTYPE: _variant_read_binary,
2469
NUMERICNTYPE: _variant_read_decimal,
2470
DECIMALNTYPE: _variant_read_decimal,
2472
BIGVARCHRTYPE: _variant_read_str,
2473
BIGCHARTYPE: _variant_read_str,
2474
NVARCHARTYPE: _variant_read_nstr,
2475
NCHARTYPE: _variant_read_nstr,
2479
def __init__(self, size):
2482
def get_declaration(self):
2483
return self.declaration
2486
def from_stream(cls, r):
2488
return Variant(size)
2491
def from_declaration(cls, declaration, nullable, connection):
2492
if declaration == cls.declaration:
2495
def write_info(self, w):
2496
w.put_int(self._size)
2503
type_id = r.get_byte()
2504
prop_bytes = r.get_byte()
2505
type_factory = self._type_map.get(type_id)
2506
if not type_factory:
2507
r.session.bad_stream('Variant type invalid', type_id)
2508
return type_factory(r, size - prop_bytes - 2)
2510
def write(self, w, val):
2514
raise NotImplementedError
2531
XSYBCHAR: VarChar70,
2532
XSYBVARCHAR: VarChar70,
2533
XSYBNCHAR: NVarChar70,
2534
XSYBNVARCHAR: NVarChar70,
2538
XSYBBINARY: VarBinary,
2539
XSYBVARBINARY: VarBinary,
2541
SYBNUMERIC: MsDecimal,
2542
SYBDECIMAL: MsDecimal,
2543
SYBVARIANT: Variant,
2546
SYBMSDATETIME2: DateTime2,
2547
SYBMSDATETIMEOFFSET: DateTimeOffset,
2548
SYBDATETIME4: SmallDateTime,
2549
SYBDATETIME: DateTime,
2550
SYBDATETIMN: DateTimeN,
2551
SYBUNIQUE: MsUnique,
2554
_type_map71 = _type_map.copy()
2555
_type_map71.update({
2556
XSYBCHAR: VarChar71,
2557
XSYBNCHAR: NVarChar71,
2558
XSYBVARCHAR: VarChar71,
2559
XSYBNVARCHAR: NVarChar71,
2564
_type_map72 = _type_map.copy()
2565
_type_map72.update({
2566
XSYBCHAR: VarChar72,
2567
XSYBNCHAR: NVarChar72,
2568
XSYBVARCHAR: VarChar72,
2569
XSYBNVARCHAR: NVarChar72,
2572
XSYBBINARY: VarBinary72,
2573
XSYBVARBINARY: VarBinary72,
2578
def _create_exception_by_message(msg, custom_error_msg = None):
2579
msg_no = msg['msgno']
2580
if custom_error_msg is not None:
2581
error_msg = custom_error_msg
2583
error_msg = msg['message']
2584
if msg_no in prog_errors:
2585
ex = ProgrammingError(error_msg)
2586
elif msg_no in integrity_errors:
2587
ex = IntegrityError(error_msg)
2589
ex = OperationalError(error_msg)
2590
ex.msg_no = msg['msgno']
2591
ex.text = msg['message']
2592
ex.srvname = msg['server']
2593
ex.procname = msg['proc_name']
2594
ex.number = msg['msgno']
2595
ex.severity = msg['severity']
2596
ex.state = msg['state']
2597
ex.line = msg['line_number']
2601
class _TdsSession(object):
2604
Represents a single TDS session within MARS connection, when MARS enabled there could be multiple TDS sessions
2605
within one connection.
2607
def __init__(self, tds, transport, tzinfo_factory):
2609
self.res_info = None
2610
self.in_cancel = False
2611
self.wire_mtx = None
2612
self.param_info = None
2613
self.has_status = False
2614
self.ret_status = None
2615
self.skipped_to_status = False
2616
self._transport = transport
2617
self._reader = _TdsReader(self)
2618
self._reader._transport = transport
2619
self._writer = _TdsWriter(self, tds._bufsize)
2620
self._writer._transport = transport
2622
self.state = TDS_IDLE
2625
self.chunk_handler = tds.chunk_handler
2626
self.rows_affected = -1
2627
self.use_tz = tds.use_tz
2629
self.tzinfo_factory = tzinfo_factory
2632
fmt = "<_TdsSession state={} tds={} messages={} rows_affected={} use_tz={} spid={} in_cancel={}>"
2633
res = fmt.format(repr(self.state), repr(self._tds), repr(self.messages),
2634
repr(self.rows_affected), repr(self.use_tz), repr(self._spid),
2638
def raise_db_exception(self):
2639
""" Raises exception from last server message
2641
This function will skip messages: The statement has been terminated
2643
if not self.messages:
2644
raise Error("Request failed, server didn't send error message")
2646
msg = self.messages[-1]
2647
if msg['msgno'] == 3621: # the statement has been terminated
2648
self.messages = self.messages[:-1]
2652
error_msg = ' '.join(msg['message'] for msg in self.messages)
2653
ex = _create_exception_by_message(msg, error_msg)
2656
def get_type_info(self, curcol):
2657
""" Reads TYPE_INFO structure (http://msdn.microsoft.com/en-us/library/dd358284.aspx)
2659
:param curcol: An instance of :class:`Column` that will receive read information
2662
# User defined data type of the column
2663
curcol.column_usertype = r.get_uint() if IS_TDS72_PLUS(self) else r.get_usmallint()
2664
curcol.flags = r.get_usmallint() # Flags
2665
curcol.column_nullable = curcol.flags & Column.fNullable
2666
curcol.column_writeable = (curcol.flags & Column.fReadWrite) > 0
2667
curcol.column_identity = (curcol.flags & Column.fIdentity) > 0
2668
type_id = r.get_byte()
2669
type_class = self._tds._type_map.get(type_id)
2671
raise InterfaceError('Invalid type id', type_id)
2672
curcol.type = type_class.from_stream(r)
2674
def tds7_process_result(self):
2675
""" Reads and processes COLMETADATA stream
2677
This stream contains a list of returned columns.
2678
Stream format link: http://msdn.microsoft.com/en-us/library/dd357363.aspx
2681
#logger.debug("processing TDS7 result metadata.")
2683
# read number of columns and allocate the columns structure
2685
num_cols = r.get_smallint()
2687
# This can be a DUMMY results token from a cursor fetch
2690
#logger.debug("no meta data")
2693
self.param_info = None
2694
self.has_status = False
2695
self.ret_status = None
2696
self.skipped_to_status = False
2697
self.rows_affected = TDS_NO_COUNT
2698
self.more_rows = True
2699
self.row = [None] * num_cols
2700
self.res_info = info = _Results()
2703
# loop through the columns populating COLINFO struct from
2706
#logger.debug("setting up {0} columns".format(num_cols))
2708
for col in range(num_cols):
2710
info.columns.append(curcol)
2711
self.get_type_info(curcol)
2714
# under 7.0 lengths are number of characters not
2715
# number of bytes... read_ucs2 handles this
2717
curcol.column_name = r.read_ucs2(r.get_byte())
2718
precision = curcol.type.precision if hasattr(curcol.type, 'precision') else None
2719
scale = curcol.type.scale if hasattr(curcol.type, 'scale') else None
2720
size = curcol.type._size if hasattr(curcol.type, '_size') else None
2721
header_tuple.append((curcol.column_name, curcol.type.get_typeid(), None, size, precision, scale, curcol.column_nullable))
2722
info.description = tuple(header_tuple)
2725
def process_param(self):
2726
""" Reads and processes RETURNVALUE stream.
2728
This stream is used to send OUTPUT parameters from RPC to client.
2729
Stream format url: http://msdn.microsoft.com/en-us/library/dd303881.aspx
2732
if IS_TDS72_PLUS(self):
2733
ordinal = r.get_usmallint()
2735
r.get_usmallint() # ignore size
2736
ordinal = self._out_params_indexes[self.return_value_index]
2737
name = r.read_ucs2(r.get_byte())
2738
r.get_byte() # 1 - OUTPUT of sp, 2 - result of udf
2740
param.column_name = name
2741
self.get_type_info(param)
2742
param.value = param.type.read(r)
2743
self.output_params[ordinal] = param
2744
self.return_value_index += 1
2746
def process_cancel(self):
2748
Process the incoming token stream until it finds
2749
an end token DONE with the cancel flag set.
2750
At that point the connection should be ready to handle a new query.
2752
In case when no cancel request is pending this function does nothing.
2754
# silly cases, nothing to do
2755
if not self.in_cancel:
2759
token_id = self.get_token_id()
2760
self.process_token(token_id)
2761
if not self.in_cancel:
2764
def process_msg(self, marker):
2765
""" Reads and processes ERROR/INFO streams
2769
- ERROR: http://msdn.microsoft.com/en-us/library/dd304156.aspx
2770
- INFO: http://msdn.microsoft.com/en-us/library/dd303398.aspx
2772
:param marker: TDS_ERROR_TOKEN or TDS_INFO_TOKEN
2775
r.get_smallint() # size
2777
msg['marker'] = marker
2778
msg['msgno'] = r.get_int()
2779
msg['state'] = r.get_byte()
2780
msg['severity'] = r.get_byte()
2781
msg['sql_state'] = None
2783
if marker == TDS_EED_TOKEN:
2784
if msg['severity'] <= 10:
2785
msg['priv_msg_type'] = 0
2787
msg['priv_msg_type'] = 1
2788
len_sqlstate = r.get_byte()
2789
msg['sql_state'] = readall(r, len_sqlstate)
2790
has_eed = r.get_byte()
2791
# junk status and transaction state
2793
elif marker == TDS_INFO_TOKEN:
2794
msg['priv_msg_type'] = 0
2795
elif marker == TDS_ERROR_TOKEN:
2796
msg['priv_msg_type'] = 1
2798
logger.error('tds_process_msg() called with unknown marker "{0}"'.format(marker))
2799
#logger.debug('tds_process_msg() reading message {0} from server'.format(msg['msgno']))
2800
msg['message'] = r.read_ucs2(r.get_smallint())
2802
msg['server'] = r.read_ucs2(r.get_byte())
2803
# stored proc name if available
2804
msg['proc_name'] = r.read_ucs2(r.get_byte())
2805
msg['line_number'] = r.get_int() if IS_TDS72_PLUS(self) else r.get_smallint()
2806
if not msg['sql_state']:
2807
#msg['sql_state'] = tds_alloc_lookup_sqlstate(self, msg['msgno'])
2809
# in case extended error data is sent, we just try to discard it
2812
next_marker = r.get_byte()
2813
if next_marker in (TDS5_PARAMFMT_TOKEN, TDS5_PARAMFMT2_TOKEN, TDS5_PARAMS_TOKEN):
2814
self.process_token(next_marker)
2820
self.messages.append(msg)
2822
def process_row(self):
2823
""" Reads and handles ROW stream.
2825
This stream contains list of values of one returned row.
2826
Stream format url: http://msdn.microsoft.com/en-us/library/dd357254.aspx
2829
info = self.res_info
2831
for i, curcol in enumerate(info.columns):
2832
curcol.value = self.row[i] = curcol.type.read(r)
2834
def process_nbcrow(self):
2835
""" Reads and handles NBCROW stream.
2837
This stream contains list of values of one returned row in a compressed way,
2838
introduced in TDS 7.3.B
2839
Stream format url: http://msdn.microsoft.com/en-us/library/dd304783.aspx
2842
info = self.res_info
2844
self.bad_stream('got row without info')
2845
assert len(info.columns) > 0
2848
# reading bitarray for nulls, 1 represent null values for
2849
# corresponding fields
2850
nbc = readall(r, (len(info.columns) + 7) // 8)
2851
for i, curcol in enumerate(info.columns):
2852
if _ord(nbc[i // 8]) & (1 << (i % 8)):
2855
value = curcol.type.read(r)
2858
def process_orderby(self):
2859
""" Reads and processes ORDER stream
2861
Used to inform client by which column dataset is ordered.
2862
Stream format url: http://msdn.microsoft.com/en-us/library/dd303317.aspx
2865
skipall(r, r.get_smallint())
2867
def process_orderby2(self):
2869
skipall(r, r.get_int())
2871
def process_end(self, marker):
2872
""" Reads and processes DONE/DONEINPROC/DONEPROC streams
2876
- DONE: http://msdn.microsoft.com/en-us/library/dd340421.aspx
2877
- DONEINPROC: http://msdn.microsoft.com/en-us/library/dd340553.aspx
2878
- DONEPROC: http://msdn.microsoft.com/en-us/library/dd340753.aspx
2880
:param marker: Can be TDS_DONE_TOKEN or TDS_DONEINPROC_TOKEN or TDS_DONEPROC_TOKEN
2882
self.more_rows = False
2884
status = r.get_usmallint()
2885
r.get_usmallint() # cur_cmd
2886
more_results = status & TDS_DONE_MORE_RESULTS != 0
2887
was_cancelled = status & TDS_DONE_CANCELLED != 0
2888
#error = status & TDS_DONE_ERROR != 0
2889
done_count_valid = status & TDS_DONE_COUNT != 0
2891
# 'process_end: more_results = {0}\n'
2892
# '\t\twas_cancelled = {1}\n'
2893
# '\t\terror = {2}\n'
2894
# '\t\tdone_count_valid = {3}'.format(more_results, was_cancelled, error, done_count_valid))
2896
self.res_info.more_results = more_results
2897
rows_affected = r.get_int8() if IS_TDS72_PLUS(self) else r.get_int()
2898
#logger.debug('\t\trows_affected = {0}'.format(rows_affected))
2899
if was_cancelled or (not more_results and not self.in_cancel):
2900
#logger.debug('process_end() state set to TDS_IDLE')
2901
self.in_cancel = False
2902
self.set_state(TDS_IDLE)
2903
if done_count_valid:
2904
self.rows_affected = rows_affected
2906
self.rows_affected = -1
2907
self.done_flags = status
2908
if self.done_flags & TDS_DONE_ERROR and not was_cancelled and not self.in_cancel:
2909
self.raise_db_exception()
2911
def process_env_chg(self):
2912
""" Reads and processes ENVCHANGE stream.
2914
Stream info url: http://msdn.microsoft.com/en-us/library/dd303449.aspx
2917
size = r.get_smallint()
2919
#logger.debug("process_env_chg: type: {0}".format(type))
2920
if type == TDS_ENV_SQLCOLLATION:
2922
#logger.debug("process_env_chg(): {0} bytes of collation data received".format(size))
2923
#logger.debug("self.collation was {0}".format(self.conn.collation))
2924
self.conn.collation = r.get_collation()
2925
skipall(r, size - 5)
2926
#tds7_srv_charset_changed(tds, tds.conn.collation)
2927
#logger.debug("self.collation now {0}".format(self.conn.collation))
2929
skipall(r, r.get_byte())
2930
elif type == TDS_ENV_BEGINTRANS:
2932
# TODO: parse transaction
2933
self.conn.tds72_transaction = r.get_uint8()
2934
skipall(r, r.get_byte())
2935
elif type == TDS_ENV_COMMITTRANS or type == TDS_ENV_ROLLBACKTRANS:
2936
self.conn.tds72_transaction = 0
2937
skipall(r, r.get_byte())
2938
skipall(r, r.get_byte())
2939
elif type == TDS_ENV_PACKSIZE:
2940
newval = r.read_ucs2(r.get_byte())
2941
r.read_ucs2(r.get_byte())
2942
new_block_size = int(newval)
2943
if new_block_size >= 512:
2944
#logger.info("changing block size from {0} to {1}".format(oldval, new_block_size))
2946
# Is possible to have a shrink if server limits packet
2947
# size more than what we specified
2949
# Reallocate buffer if possible (strange values from server or out of memory) use older buffer */
2950
self._writer.bufsize = new_block_size
2951
elif type == TDS_ENV_DATABASE:
2952
newval = r.read_ucs2(r.get_byte())
2953
r.read_ucs2(r.get_byte())
2954
self.conn.env.database = newval
2955
elif type == TDS_ENV_LANG:
2956
newval = r.read_ucs2(r.get_byte())
2957
r.read_ucs2(r.get_byte())
2958
self.conn.env.language = newval
2959
elif type == TDS_ENV_CHARSET:
2960
newval = r.read_ucs2(r.get_byte())
2961
r.read_ucs2(r.get_byte())
2962
#logger.debug("server indicated charset change to \"{0}\"\n".format(newval))
2963
self.conn.env.charset = newval
2964
remap = {'iso_1': 'iso8859-1'}
2965
self.conn.server_codec = codecs.lookup(remap.get(newval, newval))
2966
#tds_srv_charset_changed(self, newval)
2967
elif type == TDS_ENV_DB_MIRRORING_PARTNER:
2968
r.read_ucs2(r.get_byte())
2969
r.read_ucs2(r.get_byte())
2970
elif type == TDS_ENV_LCID:
2971
lcid = int(r.read_ucs2(r.get_byte()))
2972
self.conn.server_codec = codecs.lookup(lcid2charset(lcid))
2973
r.read_ucs2(r.get_byte())
2975
logger.warning("unknown env type: {0}, skipping".format(type))
2976
# discard byte values, not still supported
2977
skipall(r, size - 1)
2979
def process_auth(self):
2980
""" Reads and processes SSPI stream.
2982
Stream info: http://msdn.microsoft.com/en-us/library/dd302844.aspx
2986
pdu_size = r.get_smallint()
2987
if not self.authentication:
2988
raise Error('Got unexpected token')
2989
packet = self.authentication.handle_next(readall(r, pdu_size))
2994
def is_connected(self):
2996
:return: True if transport is connected
2998
return self._transport.is_connected()
3000
def bad_stream(self, msg):
3001
""" Called when input stream contains unexpected data.
3003
Will close stream and raise :class:`InterfaceError`
3004
:param msg: Message for InterfaceError exception.
3005
:return: Never returns, always raises exception.
3008
raise InterfaceError(msg)
3011
def tds_version(self):
3012
""" Returns integer encoded current TDS protocol version
3014
return self._tds.tds_version
3018
""" Reference to owning :class:`_TdsSocket`
3023
self._transport.close()
3025
def set_state(self, state):
3026
""" Switches state of the TDS session.
3028
It also does state transitions checks.
3029
:param state: New state, one of TDS_PENDING/TDS_READING/TDS_IDLE/TDS_DEAD/TDS_QUERING
3031
prior_state = self.state
3032
if state == prior_state:
3034
if state == TDS_PENDING:
3035
if prior_state in (TDS_READING, TDS_QUERYING):
3036
self.state = TDS_PENDING
3038
raise InterfaceError('logic error: cannot chage query state from {0} to {1}'.
3039
format(state_names[prior_state], state_names[state]))
3040
elif state == TDS_READING:
3041
# transition to READING are valid only from PENDING
3042
if self.state != TDS_PENDING:
3043
raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
3044
format(state_names[prior_state], state_names[state]))
3047
elif state == TDS_IDLE:
3048
if prior_state == TDS_DEAD:
3049
raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
3050
format(state_names[prior_state], state_names[state]))
3052
elif state == TDS_DEAD:
3054
elif state == TDS_QUERYING:
3055
if self.state == TDS_DEAD:
3056
raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
3057
format(state_names[prior_state], state_names[state]))
3058
elif self.state != TDS_IDLE:
3059
raise InterfaceError('logic error: cannot change query state from {0} to {1}'.
3060
format(state_names[prior_state], state_names[state]))
3062
self.rows_affected = TDS_NO_COUNT
3063
self.internal_sp_called = 0
3070
def querying_context(self, packet_type):
3071
""" Context manager for querying.
3073
Sets state to TDS_QUERYING, and reverts it to TDS_IDLE if exception happens inside managed block,
3074
and to TDS_PENDING if managed block succeeds and flushes buffer.
3076
if self.set_state(TDS_QUERYING) != TDS_QUERYING:
3077
raise Error("Couldn't switch to state")
3078
self._writer.begin_packet(packet_type)
3082
if self.state != TDS_DEAD:
3083
self.set_state(TDS_IDLE)
3086
self.set_state(TDS_PENDING)
3087
self._writer.flush()
3089
def _autodetect_column_type(self, value, value_type):
3090
""" Function guesses type of the parameter from the type of value.
3092
:param value: value to be passed to db, can be None
3093
:param value_type: value type, if value is None, type is used instead of it
3094
:return: An instance of subclass of :class:`BaseType`
3096
if value is None and value_type is None:
3097
return self.conn.NVarChar(1, collation=self.conn.collation)
3098
assert value_type is not None
3099
assert value is None or isinstance(value, value_type)
3101
if issubclass(value_type, bool):
3102
return BitN.instance
3103
elif issubclass(value_type, six.integer_types):
3106
if -2 ** 31 <= value <= 2 ** 31 - 1:
3108
elif -2 ** 63 <= value <= 2 ** 63 - 1:
3110
elif -10 ** 38 + 1 <= value <= 10 ** 38 - 1:
3111
return MsDecimal(0, 38)
3113
raise DataError('Numeric value out of range')
3114
elif issubclass(value_type, float):
3116
elif issubclass(value_type, Binary):
3117
return self.conn.long_binary_type()
3118
elif issubclass(value_type, six.binary_type):
3119
if self._tds.login.bytes_to_unicode:
3120
return self.conn.long_string_type(collation=self.conn.collation)
3122
return self.conn.long_varchar_type(collation=self.conn.collation)
3123
elif issubclass(value_type, six.string_types):
3124
return self.conn.long_string_type(collation=self.conn.collation)
3125
elif issubclass(value_type, datetime):
3126
if IS_TDS73_PLUS(self):
3127
if value != None and value.tzinfo and not self.use_tz:
3128
return DateTimeOffset()
3133
elif issubclass(value_type, date):
3134
if IS_TDS73_PLUS(self):
3135
return MsDate.instance
3138
elif issubclass(value_type, time):
3139
if not IS_TDS73_PLUS(self):
3140
raise DataError('Time type is not supported on MSSQL 2005 and lower')
3142
elif issubclass(value_type, Decimal):
3144
return MsDecimal.from_value(value)
3147
elif issubclass(value_type, uuid.UUID):
3148
return MsUnique.instance
3150
raise DataError('Parameter type is not supported: {!r} {!r}'.format(value, value_type))
3152
def make_param(self, name, value):
3153
""" Generates instance of :class:`Column` from value and name
3155
Value can also be of a special types:
3157
- An instance of :class:`Column`, in which case it is just returned.
3158
- An instance of :class:`output`, in which case parameter will become
3159
an output parameter.
3160
- A singleton :var:`default`, in which case default value will be passed
3163
:param name: Name of the parameter, will populate column_name property of returned column.
3164
:param value: Value of the parameter, also used to guess the type of parameter.
3165
:return: An instance of :class:`Column`
3167
if isinstance(value, Column):
3168
value.column_name = name
3171
column.column_name = name
3174
if isinstance(value, output):
3175
column.flags |= fByRefValue
3176
if isinstance(value.type, six.string_types):
3177
column.type = self._tds.type_by_declaration(value.type, True)
3178
value_type = value.type or type(value.value)
3181
value_type = type(value)
3183
if value_type is type(None):
3186
if value is default:
3187
column.flags |= fDefaultValue
3189
if value_type is _Default:
3192
column.value = value
3193
if column.type is None:
3194
column.type = self._autodetect_column_type(value, value_type)
3197
def _convert_params(self, parameters):
3198
""" Converts a dict of list of parameters into a list of :class:`Column` instances.
3200
:param parameters: Can be a list of parameter values, or a dict of parameter names to values.
3201
:return: A list of :class:`Column` instances.
3203
if isinstance(parameters, dict):
3204
return [self.make_param(name, value)
3205
for name, value in parameters.items()]
3208
for parameter in parameters:
3209
params.append(self.make_param('', parameter))
3212
def cancel_if_pending(self):
3213
""" Cancels current pending request.
3215
Does nothing if no request is pending, otherwise sends cancel request,
3216
and waits for response.
3218
if self.state == TDS_IDLE:
3220
if not self.in_cancel:
3222
self.process_cancel()
3224
def submit_rpc(self, rpc_name, params, flags):
3225
""" Sends an RPC request.
3227
This call will transition session into pending state.
3228
If some operation is currently pending on the session, it will be
3229
cancelled before sending this request.
3231
Spec: http://msdn.microsoft.com/en-us/library/dd357576.aspx
3233
:param rpc_name: Name of the RPC to call, can be an instance of :class:`InternalProc`
3234
:param params: Stored proc parameters, should be a list of :class:`Column` instances.
3235
:param flags: See spec for possible flags.
3238
self.output_params = {}
3239
self.cancel_if_pending()
3240
self.res_info = None
3242
with self.querying_context(TDS_RPC):
3244
if IS_TDS71_PLUS(self) and isinstance(rpc_name, InternalProc):
3246
w.put_smallint(rpc_name.proc_id)
3248
if isinstance(rpc_name, InternalProc):
3249
rpc_name = rpc_name.name
3250
w.put_smallint(len(rpc_name))
3251
w.write_ucs2(rpc_name)
3253
# TODO support flags
3254
# bit 0 (1 as flag) in TDS7/TDS5 is "recompile"
3255
# bit 1 (2 as flag) in TDS7+ is "no metadata" bit this will prevent sending of column infos
3257
w.put_usmallint(flags)
3258
self._out_params_indexes = []
3259
for i, param in enumerate(params):
3260
if param.flags & fByRefValue:
3261
self._out_params_indexes.append(i)
3262
w.put_byte(len(param.column_name))
3263
w.write_ucs2(param.column_name)
3265
# TODO support other flags (use defaul null/no metadata)
3266
# bit 1 (2 as flag) in TDS7+ is "default value" bit
3267
# (what's the meaning of "default value" ?)
3269
w.put_byte(param.flags)
3270
# FIXME: column_type is wider than one byte. Do something sensible, not just lop off the high byte.
3271
w.put_byte(param.type.type)
3272
param.type.write_info(w)
3273
param.type.write(w, param.value)
3275
def submit_plain_query(self, operation):
3276
""" Sends a plain query to server.
3278
This call will transition session into pending state.
3279
If some operation is currently pending on the session, it will be
3280
cancelled before sending this request.
3282
Spec: http://msdn.microsoft.com/en-us/library/dd358575.aspx
3284
:param operation: A string representing sql statement.
3286
#logger.debug('submit_plain_query(%s)', operation)
3288
self.cancel_if_pending()
3289
self.res_info = None
3291
with self.querying_context(TDS_QUERY):
3293
w.write_ucs2(operation)
3295
def submit_bulk(self, metadata, rows):
3296
""" Sends insert bulk command.
3298
Spec: http://msdn.microsoft.com/en-us/library/dd358082.aspx
3300
:param metadata: A list of :class:`Column` instances.
3301
:param rows: A collection of rows, each row is a collection of values.
3304
num_cols = len(metadata)
3306
with self.querying_context(TDS_BULK):
3307
w.put_byte(TDS7_RESULT_TOKEN)
3308
w.put_usmallint(num_cols)
3309
for col in metadata:
3310
if IS_TDS72_PLUS(self):
3311
w.put_uint(col.column_usertype)
3313
w.put_usmallint(col.column_usertype)
3314
w.put_usmallint(col.flags)
3315
w.put_byte(col.type.type)
3316
col.type.write_info(w)
3317
w.put_byte(len(col.column_name))
3318
w.write_ucs2(col.column_name)
3320
w.put_byte(TDS_ROW_TOKEN)
3321
for i, col in enumerate(metadata):
3322
col.type.write(w, row[i])
3324
w.put_byte(TDS_DONE_TOKEN)
3325
w.put_usmallint(TDS_DONE_FINAL)
3326
w.put_usmallint(0) # curcmd
3327
if IS_TDS72_PLUS(self):
3332
def _put_cancel(self):
3333
""" Sends a cancel request to the server.
3335
Switches connection to IN_CANCEL state.
3337
self._writer.begin_packet(TDS_CANCEL)
3338
self._writer.flush()
3341
_begin_tran_struct_72 = struct.Struct('<HBB')
3343
def begin_tran(self, isolation_level=0):
3344
self.submit_begin_tran(isolation_level=isolation_level)
3345
self.process_simple_request()
3347
def submit_begin_tran(self, isolation_level=0):
3348
#logger.debug('submit_begin_tran()')
3349
if IS_TDS72_PLUS(self):
3351
self.cancel_if_pending()
3353
with self.querying_context(TDS7_TRANS):
3355
w.pack(self._begin_tran_struct_72,
3358
0, # new transaction name
3361
self.submit_plain_query("BEGIN TRANSACTION")
3362
self.conn.tds72_transaction = 1
3364
_commit_rollback_tran_struct72_hdr = struct.Struct('<HBB')
3365
_continue_tran_struct72 = struct.Struct('<BB')
3367
def rollback(self, cont, isolation_level=0):
3368
self.submit_rollback(cont, isolation_level=isolation_level)
3369
prev_timeout = self._tds._sock.gettimeout()
3370
self._tds._sock.settimeout(None)
3372
self.process_simple_request()
3374
self._tds._sock.settimeout(prev_timeout)
3376
def submit_rollback(self, cont, isolation_level=0):
3377
#logger.debug('submit_rollback(%s, %s)', id(self), cont)
3378
if IS_TDS72_PLUS(self):
3380
self.cancel_if_pending()
3382
with self.querying_context(TDS7_TRANS):
3387
w.pack(self._commit_rollback_tran_struct72_hdr,
3388
8, # TM_ROLLBACK_XACT
3389
0, # transaction name
3393
w.pack(self._continue_tran_struct72,
3395
0, # new transaction name
3398
self.submit_plain_query("IF @@TRANCOUNT > 0 ROLLBACK BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 ROLLBACK")
3399
self.conn.tds72_transaction = 1 if cont else 0
3401
def commit(self, cont, isolation_level=0):
3402
self.submit_commit(cont, isolation_level=isolation_level)
3403
prev_timeout = self._tds._sock.gettimeout()
3404
self._tds._sock.settimeout(None)
3406
self.process_simple_request()
3408
self._tds._sock.settimeout(prev_timeout)
3410
def submit_commit(self, cont, isolation_level=0):
3411
#logger.debug('submit_commit(%s)', cont)
3412
if IS_TDS72_PLUS(self):
3414
self.cancel_if_pending()
3416
with self.querying_context(TDS7_TRANS):
3421
w.pack(self._commit_rollback_tran_struct72_hdr,
3423
0, # transaction name
3427
w.pack(self._continue_tran_struct72,
3429
0, # new transaction name
3432
self.submit_plain_query("IF @@TRANCOUNT > 0 COMMIT BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 COMMIT")
3433
self.conn.tds72_transaction = 1 if cont else 0
3435
def _START_QUERY(self):
3436
if IS_TDS72_PLUS(self):
3439
_tds72_query_start = struct.Struct('<IIHQI')
3441
def _start_query(self):
3443
w.pack(_TdsSession._tds72_query_start,
3444
0x16, # total length
3447
self.conn.tds72_transaction,
3459
def _send_prelogin(self, login):
3460
instance_name = login.instance_name or 'MSSQLServer'
3461
instance_name = instance_name.encode('ascii')
3462
encryption_level = login.encryption_level
3463
if len(instance_name) > 65490:
3464
raise ValueError('Instance name is too long')
3465
if encryption_level >= TDS_ENCRYPTION_REQUIRE:
3466
raise NotSupportedError('Client requested encryption but it is not supported')
3467
if IS_TDS72_PLUS(self):
3470
b'>BHHBHHBHHBHHBHHB',
3472
self.VERSION, START_POS, 6,
3474
self.ENCRYPTION, START_POS + 6, 1,
3476
self.INSTOPT, START_POS + 6 + 1, len(instance_name) + 1,
3478
self.THREADID, START_POS + 6 + 1 + len(instance_name) + 1, 4,
3480
self.MARS, START_POS + 6 + 1 + len(instance_name) + 1 + 4, 1,
3489
self.VERSION, START_POS, 6,
3491
self.ENCRYPTION, START_POS + 6, 1,
3493
self.INSTOPT, START_POS + 6 + 1, len(instance_name) + 1,
3495
self.THREADID, START_POS + 6 + 1 + len(instance_name) + 1, 4,
3499
assert START_POS == len(buf)
3501
w.begin_packet(TDS71_PRELOGIN)
3503
from . import intversion
3504
w.put_uint_be(intversion)
3505
w.put_usmallint_be(0) # build number
3507
if ENCRYPTION_ENABLED and encryption_supported:
3508
w.put_byte(1 if encryption_level >= TDS_ENCRYPTION_REQUIRE else 0)
3512
w.write(instance_name)
3513
w.put_byte(0) # zero terminate instance_name
3514
w.put_int(0) # TODO: change this to thread id
3515
if IS_TDS72_PLUS(self):
3517
w.put_byte(1 if login.use_mars else 0)
3520
def _process_prelogin(self, login):
3521
p = self._reader.read_whole_packet()
3523
if size <= 0 or self._reader.packet_type != 4:
3524
self.bad_stream('Invalid packet type: {0}, expected PRELOGIN(4)'.format(self._reader.packet_type))
3525
# default 2, no certificate, no encryptption
3528
byte_struct = struct.Struct('B')
3529
off_len_struct = struct.Struct('>HH')
3530
prod_version_struct = struct.Struct('>LH')
3533
self.bad_stream('Invalid size of PRELOGIN structure')
3534
type, = byte_struct.unpack_from(p, i)
3538
self.bad_stream('Invalid size of PRELOGIN structure')
3539
off, l = off_len_struct.unpack_from(p, i + 1)
3540
if off > size or off + l > size:
3541
self.bad_stream('Invalid offset in PRELOGIN structure')
3542
if type == self.VERSION:
3543
self.conn.server_library_version = prod_version_struct.unpack_from(p, off)
3544
elif type == self.ENCRYPTION and l >= 1:
3545
crypt_flag, = byte_struct.unpack_from(p, off)
3546
elif type == self.MARS:
3547
self.conn._mars_enabled = bool(byte_struct.unpack_from(p, off)[0])
3548
elif type == self.INSTOPT:
3549
# ignore instance name mismatch
3552
# if server do not has certificate do normal login
3554
if login.encryption_level >= TDS_ENCRYPTION_REQUIRE:
3555
raise Error('Server required encryption but it is not supported')
3557
self._sock = ssl.wrap_socket(self._sock, ssl_version=ssl.PROTOCOL_SSLv3)
3559
def tds7_send_login(self, login):
3560
option_flag2 = login.option_flag2
3561
user_name = login.user_name
3562
if len(user_name) > 128:
3563
raise ValueError('User name should be no longer that 128 characters')
3564
if len(login.password) > 128:
3565
raise ValueError('Password should be not longer than 128 characters')
3566
if len(login.change_password) > 128:
3567
raise ValueError('Password should be not longer than 128 characters')
3568
if len(login.client_host_name) > 128:
3569
raise ValueError('Host name should be not longer than 128 characters')
3570
if len(login.app_name) > 128:
3571
raise ValueError('App name should be not longer than 128 characters')
3572
if len(login.server_name) > 128:
3573
raise ValueError('Server name should be not longer than 128 characters')
3574
if len(login.database) > 128:
3575
raise ValueError('Database name should be not longer than 128 characters')
3576
if len(login.language) > 128:
3577
raise ValueError('Language should be not longer than 128 characters')
3578
if len(login.attach_db_file) > 260:
3579
raise ValueError('File path should be not longer than 260 characters')
3581
w.begin_packet(TDS7_LOGIN)
3582
self.authentication = None
3583
current_pos = 86 + 8 if IS_TDS72_PLUS(self) else 86
3584
client_host_name = login.client_host_name
3585
login.client_host_name = client_host_name
3586
packet_size = current_pos + (len(client_host_name) + len(login.app_name) + len(login.server_name) + len(login.library) + len(login.language) + len(login.database)) * 2
3588
self.authentication = login.auth
3589
auth_packet = login.auth.create_packet()
3590
packet_size += len(auth_packet)
3593
packet_size += (len(user_name) + len(login.password)) * 2
3594
w.put_int(packet_size)
3595
w.put_uint(login.tds_version)
3596
w.put_int(w.bufsize)
3597
from . import intversion
3598
w.put_uint(intversion)
3599
w.put_int(login.pid)
3600
w.put_uint(0) # connection id
3601
option_flag1 = TDS_SET_LANG_ON | TDS_USE_DB_NOTIFY | TDS_INIT_DB_FATAL
3602
if not login.bulk_copy:
3603
option_flag1 |= TDS_DUMPLOAD_OFF
3604
w.put_byte(option_flag1)
3605
if self.authentication:
3606
option_flag2 |= TDS_INTEGRATED_SECURITY_ON
3607
w.put_byte(option_flag2)
3610
type_flags |= (2 << 5)
3611
w.put_byte(type_flags)
3612
option_flag3 = TDS_UNKNOWN_COLLATION_HANDLING
3613
w.put_byte(option_flag3 if IS_TDS73_PLUS(self) else 0)
3614
mins_fix = int(total_seconds(login.client_tz.utcoffset(datetime.now()))) // 60
3616
w.put_int(login.client_lcid)
3617
w.put_smallint(current_pos)
3618
w.put_smallint(len(client_host_name))
3619
current_pos += len(client_host_name) * 2
3620
if self.authentication:
3626
w.put_smallint(current_pos)
3627
w.put_smallint(len(user_name))
3628
current_pos += len(user_name) * 2
3629
w.put_smallint(current_pos)
3630
w.put_smallint(len(login.password))
3631
current_pos += len(login.password) * 2
3632
w.put_smallint(current_pos)
3633
w.put_smallint(len(login.app_name))
3634
current_pos += len(login.app_name) * 2
3636
w.put_smallint(current_pos)
3637
w.put_smallint(len(login.server_name))
3638
current_pos += len(login.server_name) * 2
3643
w.put_smallint(current_pos)
3644
w.put_smallint(len(login.library))
3645
current_pos += len(login.library) * 2
3647
w.put_smallint(current_pos)
3648
w.put_smallint(len(login.language))
3649
current_pos += len(login.language) * 2
3651
w.put_smallint(current_pos)
3652
w.put_smallint(len(login.database))
3653
current_pos += len(login.database) * 2
3655
client_id = struct.pack('>Q', login.client_id)[2:]
3658
w.put_smallint(current_pos)
3659
w.put_smallint(len(auth_packet))
3660
current_pos += len(auth_packet)
3662
w.put_smallint(current_pos)
3663
w.put_smallint(len(login.attach_db_file))
3664
current_pos += len(login.attach_db_file) * 2
3665
if IS_TDS72_PLUS(self):
3667
w.put_smallint(current_pos)
3668
w.put_smallint(len(login.change_password))
3671
w.write_ucs2(client_host_name)
3672
if not self.authentication:
3673
w.write_ucs2(user_name)
3674
w.write(tds7_crypt_pass(login.password))
3675
w.write_ucs2(login.app_name)
3676
w.write_ucs2(login.server_name)
3677
w.write_ucs2(login.library)
3678
w.write_ucs2(login.language)
3679
w.write_ucs2(login.database)
3680
if self.authentication:
3681
w.write(auth_packet)
3682
w.write_ucs2(login.attach_db_file)
3683
w.write_ucs2(login.change_password)
3686
_SERVER_TO_CLIENT_MAPPING = {
3689
0x71000001: TDS71rev1,
3696
def process_login_tokens(self):
3699
#logger.debug('process_login_tokens()')
3701
marker = r.get_byte()
3702
#logger.debug('looking for login token, got {0:x}({1})'.format(marker, tds_token_name(marker)))
3703
if marker == TDS_LOGINACK_TOKEN:
3705
size = r.get_smallint()
3706
r.get_byte() # interface
3707
version = r.get_uint_be()
3708
self.conn.tds_version = self._SERVER_TO_CLIENT_MAPPING.get(version, version)
3709
#logger.debug('server reports TDS version {0:x}'.format(version))
3710
if not IS_TDS7_PLUS(self):
3711
self.bad_stream('Only TDS 7.0 and higher are supported')
3712
# get server product name
3713
# ignore product name length, some servers seem to set it incorrectly
3716
self.conn.product_name = r.read_ucs2(size // 2)
3717
product_version = r.get_uint_be()
3718
# MSSQL 6.5 and 7.0 seem to return strange values for this
3719
# using TDS 4.2, something like 5F 06 32 FF for 6.50
3720
self.conn.product_version = product_version
3721
#logger.debug('Product version {0:x}'.format(product_version))
3722
if self.conn.authentication:
3723
self.conn.authentication.close()
3724
self.conn.authentication = None
3726
self.process_token(marker)
3727
if marker == TDS_DONE_TOKEN:
3731
def process_returnstatus(self):
3732
self.ret_status = self._reader.get_int()
3733
self.has_status = True
3735
def process_token(self, marker):
3736
handler = _token_map.get(marker)
3738
self.bad_stream('Invalid TDS marker: {0}({0:x})'.format(marker))
3739
return handler(self)
3741
def get_token_id(self):
3742
self.set_state(TDS_READING)
3744
marker = self._reader.get_byte()
3745
except TimeoutError:
3746
self.set_state(TDS_PENDING)
3753
def process_simple_request(self):
3755
marker = self.get_token_id()
3756
if marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
3757
self.process_end(marker)
3758
if self.done_flags & TDS_DONE_MORE_RESULTS:
3759
# skip results that don't event have rowcount
3763
self.process_token(marker)
3766
while self.more_rows:
3768
if self.state == TDS_IDLE:
3770
if self.find_result_or_done():
3774
if self.res_info is None:
3775
raise Error("Previous statement didn't produce any results")
3777
if self.skipped_to_status:
3778
raise Error("Unable to fetch any rows after accessing return_status")
3780
if not self.next_row():
3786
if not self.more_rows:
3789
marker = self.get_token_id()
3790
if marker in (TDS_ROW_TOKEN, TDS_NBC_ROW_TOKEN):
3791
self.process_token(marker)
3793
elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
3794
self.process_end(marker)
3797
self.process_token(marker)
3799
def find_result_or_done(self):
3802
marker = self.get_token_id()
3803
if marker == TDS7_RESULT_TOKEN:
3804
self.process_token(marker)
3806
elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN, TDS_DONEINPROC_TOKEN):
3807
self.process_end(marker)
3808
if self.done_flags & TDS_DONE_MORE_RESULTS:
3809
if self.done_flags & TDS_DONE_COUNT:
3812
# skip results without rowcount
3817
self.process_token(marker)
3819
def process_rpc(self):
3821
self.return_value_index = 0
3823
marker = self.get_token_id()
3824
if marker == TDS7_RESULT_TOKEN:
3825
self.process_token(marker)
3827
elif marker in (TDS_DONE_TOKEN, TDS_DONEPROC_TOKEN):
3828
self.process_end(marker)
3829
if self.done_flags & TDS_DONE_MORE_RESULTS and not self.done_flags & TDS_DONE_COUNT:
3830
# skip results that don't event have rowcount
3834
self.process_token(marker)
3836
def find_return_status(self):
3837
self.skipped_to_status = True
3839
marker = self.get_token_id()
3840
self.process_token(marker)
3841
if marker == TDS_RETURNSTATUS_TOKEN:
3846
TDS_AUTH_TOKEN: _TdsSession.process_auth,
3847
TDS_ENVCHANGE_TOKEN: _TdsSession.process_env_chg,
3848
TDS_DONE_TOKEN: lambda self: self.process_end(TDS_DONE_TOKEN),
3849
TDS_DONEPROC_TOKEN: lambda self: self.process_end(TDS_DONEPROC_TOKEN),
3850
TDS_DONEINPROC_TOKEN: lambda self: self.process_end(TDS_DONEINPROC_TOKEN),
3851
TDS_ERROR_TOKEN: lambda self: self.process_msg(TDS_ERROR_TOKEN),
3852
TDS_INFO_TOKEN: lambda self: self.process_msg(TDS_INFO_TOKEN),
3853
TDS_EED_TOKEN: lambda self: self.process_msg(TDS_EED_TOKEN),
3854
TDS_CAPABILITY_TOKEN: lambda self: self.process_msg(TDS_CAPABILITY_TOKEN),
3855
TDS_PARAM_TOKEN: lambda self: self.process_param(),
3856
TDS7_RESULT_TOKEN: lambda self: self.tds7_process_result(),
3857
TDS_ROW_TOKEN: lambda self: self.process_row(),
3858
TDS_NBC_ROW_TOKEN: lambda self: self.process_nbcrow(),
3859
TDS_ORDERBY2_TOKEN: lambda self: self.process_orderby2(),
3860
TDS_ORDERBY_TOKEN: lambda self: self.process_orderby(),
3861
TDS_RETURNSTATUS_TOKEN: lambda self: self.process_returnstatus(),
3865
class _TdsSocket(object):
3866
def __init__(self, use_tz=None):
3867
self._is_connected = False
3868
self.env = _TdsEnv()
3869
self.collation = None
3870
self.tds72_transaction = 0
3871
self.authentication = None
3872
self._mars_enabled = False
3873
self.chunk_handler = MemoryChunkedHandler()
3875
self._bufsize = 4096
3876
self.tds_version = TDS74
3877
self.use_tz = use_tz
3880
fmt = "<_TdsSocket tran={} mars={} tds_version={} use_tz={}>"
3881
return fmt.format(self.tds72_transaction, self._mars_enabled,
3882
self.tds_version, self.use_tz)
3884
def login(self, login, sock, tzinfo_factory):
3886
self._bufsize = login.blocksize
3887
self.query_timeout = login.query_timeout
3888
self._main_session = _TdsSession(self, self, tzinfo_factory)
3890
self.tds_version = login.tds_version
3891
if IS_TDS71_PLUS(self):
3892
self._main_session._send_prelogin(login)
3893
self._main_session._process_prelogin(login)
3894
if IS_TDS7_PLUS(self):
3895
self._main_session.tds7_send_login(login)
3897
raise ValueError('This TDS version is not supported')
3898
if not self._main_session.process_login_tokens():
3899
self._main_session.raise_db_exception()
3900
if IS_TDS72_PLUS(self):
3901
self._type_map = _type_map72
3902
elif IS_TDS71_PLUS(self):
3903
self._type_map = _type_map71
3905
self._type_map = _type_map
3906
text_size = login.text_size
3907
if self._mars_enabled:
3908
from .smp import SmpManager
3909
self._smp_manager = SmpManager(self)
3910
self._main_session = _TdsSession(
3912
self._smp_manager.create_session(),
3914
self._is_connected = True
3917
q.append('set textsize {0}'.format(int(text_size)))
3918
if login.database and self.env.database != login.database:
3919
q.append('use ' + tds_quote_id(login.database))
3921
self._main_session.submit_plain_query(''.join(q))
3922
self._main_session.process_simple_request()
3925
def mars_enabled(self):
3926
return self._mars_enabled
3929
def main_session(self):
3930
return self._main_session
3932
def create_session(self, tzinfo_factory):
3934
self, self._smp_manager.create_session(),
3937
def read(self, size):
3938
buf = self._sock.recv(size)
3941
raise ClosedConnectionError()
3944
def _write(self, data, final):
3947
if hasattr(socket, 'MSG_NOSIGNAL'):
3948
flags |= socket.MSG_NOSIGNAL
3950
if hasattr(socket, 'MSG_MORE'):
3951
flags |= socket.MSG_MORE
3952
self._sock.sendall(data, flags)
3953
if final and USE_CORK:
3954
self._sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 0)
3955
self._sock.setsockopt(socket.SOL_TCP, socket.TCP_CORK, 1)
3962
def is_connected(self):
3963
return self._is_connected
3966
self._is_connected = False
3967
if self._sock is not None:
3969
if hasattr(self, '_smp_manager'):
3970
self._smp_manager._transport_closed()
3971
self._main_session.state = TDS_DEAD
3972
if self.authentication:
3973
self.authentication.close()
3974
self.authentication = None
3976
def NVarChar(self, size, collation=raw_collation):
3977
if IS_TDS72_PLUS(self):
3978
return NVarChar72(size, collation)
3979
elif IS_TDS71_PLUS(self):
3980
return NVarChar71(size, collation)
3982
return NVarChar70(size)
3984
def VarChar(self, size, collation=raw_collation):
3985
if IS_TDS72_PLUS(self):
3986
return VarChar72(size, collation)
3987
elif IS_TDS71_PLUS(self):
3988
return VarChar71(size, collation)
3990
return VarChar70(size, codec=self.server_codec)
3992
def Text(self, size=0, collation=raw_collation):
3993
if IS_TDS72_PLUS(self):
3994
return Text72(size, collation=collation)
3995
elif IS_TDS71_PLUS(self):
3996
return Text71(size, collation=collation)
3998
return Text70(size, codec=self.server_codec)
4000
def NText(self, size=0, collation=raw_collation):
4001
if IS_TDS72_PLUS(self):
4002
return NText72(size, collation=collation)
4003
elif IS_TDS71_PLUS(self):
4004
return NText71(size, collation=collation)
4006
return NText70(size)
4008
def VarBinary(self, size):
4009
if IS_TDS72_PLUS(self):
4010
return VarBinary72(size)
4012
return VarBinary(size)
4014
def Image(self, size=0):
4015
if IS_TDS72_PLUS(self):
4016
return Image72(size)
4018
return Image70(size)
4021
BitN = BitN.instance
4022
TinyInt = TinyInt.instance
4023
SmallInt = SmallInt.instance
4025
BigInt = BigInt.instance
4027
Real = Real.instance
4028
Float = Float.instance
4030
SmallDateTime = SmallDateTime.instance
4031
DateTime = DateTime.instance
4032
DateTimeN = DateTimeN
4033
Date = MsDate.instance
4035
DateTime2 = DateTime2
4036
DateTimeOffset = DateTimeOffset
4038
SmallMoney = Money4.instance
4039
Money = Money8.instance
4041
UniqueIdentifier = MsUnique.instance
4042
SqlVariant = Variant
4045
def long_binary_type(self):
4046
if IS_TDS72_PLUS(self):
4047
return VarBinaryMax()
4051
def long_varchar_type(self, collation=raw_collation):
4052
if IS_TDS72_PLUS(self):
4053
return VarCharMax(collation)
4054
elif IS_TDS71_PLUS(self):
4055
return Text71(-1, '', collation)
4057
return Text70(codec=self.server_codec)
4059
def long_string_type(self, collation=raw_collation):
4060
if IS_TDS72_PLUS(self):
4061
return NVarCharMax(0, collation)
4062
elif IS_TDS71_PLUS(self):
4063
return NText71(-1, '', collation)
4067
def type_by_declaration(self, declaration, nullable):
4068
declaration = declaration.strip().upper()
4069
for type_class in self._type_map.values():
4070
type_inst = type_class.from_declaration(declaration, nullable, self)
4073
raise ValueError('Unable to parse type declaration', declaration)
4076
class Column(object):
4083
def __init__(self, name='', type=None, flags=0, value=None):
4084
self.char_codec = None
4085
self.column_name = name
4086
self.column_usertype = 0
4092
return '<Column(name={0}, value={1}, type={2})>'.format(repr(self.column_name), repr(self.value), repr(self.type))
4095
class _Results(object):
4101
def _parse_instances(msg):
4103
if len(msg) > 3 and _ord(msg[0]) == 5:
4104
tokens = msg[3:].decode('ascii').split(';')
4108
for token in tokens:
4110
instdict[name] = token
4117
results[instdict['InstanceName'].upper()] = instdict
4125
# Get port of all instances
4126
# @return default port number or 0 if error
4127
# @remark experimental, cf. MC-SQLR.pdf.
4129
def tds7_get_instances(ip_addr, timeout=5):
4130
s = socket.socket(type=socket.SOCK_DGRAM)
4131
s.settimeout(timeout)
4134
s.sendto(b'\x03', (ip_addr, 1434))
4135
msg = s.recv(16 * 1024 - 1)
4136
# got data, read and parse
4137
return _parse_instances(msg)
4142
def _applytz(dt, tz):
4145
dt = dt.replace(tzinfo=tz)