~certify-web-dev/twisted/certify-trunk

« back to all changes in this revision

Viewing changes to twisted/names/dns.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-17 14:52:35 UTC
  • mfrom: (1.1.5 upstream) (2.1.2 etch)
  • Revision ID: james.westby@ubuntu.com-20070117145235-btmig6qfmqfen0om
Tags: 2.5.0-0ubuntu1
New upstream version, compatible with python2.5.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.names.test.test_dns -*-
 
2
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
 
 
6
"""
 
7
DNS protocol implementation.
 
8
 
 
9
API Stability: Unstable
 
10
 
 
11
Future Plans:
 
12
    - Get rid of some toplevels, maybe.
 
13
    - Put in a better lookupRecordType implementation.
 
14
 
 
15
@author: U{Moshe Zadka<mailto:moshez@twistedmatrix.com>},
 
16
         U{Jp Calderone<mailto:exarkun@twistedmatrix.com>}
 
17
"""
 
18
 
 
19
# System imports
 
20
import warnings
 
21
 
 
22
import struct, random, types, socket
 
23
 
 
24
try:
 
25
    import cStringIO as StringIO
 
26
except ImportError:
 
27
    import StringIO
 
28
 
 
29
AF_INET6 = socket.AF_INET6
 
30
 
 
31
try:
 
32
    from Crypto.Util import randpool
 
33
except ImportError:
 
34
    for randSource in ('urandom',):
 
35
        try:
 
36
            f = file('/dev/' + randSource)
 
37
            f.read(2)
 
38
            f.close()
 
39
        except:
 
40
            pass
 
41
        else:
 
42
            def randomSource(r = file('/dev/' + randSource, 'rb').read):
 
43
                return struct.unpack('H', r(2))[0]
 
44
            break
 
45
    else:
 
46
        warnings.warn(
 
47
            "PyCrypto not available - proceeding with non-cryptographically "
 
48
            "secure random source",
 
49
            RuntimeWarning,
 
50
            1
 
51
        )
 
52
 
 
53
        def randomSource():
 
54
            return random.randint(0, 65535)
 
55
else:
 
56
    def randomSource(r = randpool.RandomPool().get_bytes):
 
57
        return struct.unpack('H', r(2))[0]
 
58
from zope.interface import implements, Interface
 
59
 
 
60
 
 
61
# Twisted imports
 
62
from twisted.internet import protocol, defer
 
63
from twisted.python import log, failure
 
64
from twisted.python import util as tputil
 
65
 
 
66
PORT = 53
 
67
 
 
68
(A, NS, MD, MF, CNAME, SOA, MB, MG, MR, NULL, WKS, PTR, HINFO, MINFO, MX, TXT,
 
69
 RP, AFSDB) = range(1, 19)
 
70
AAAA = 28
 
71
SRV = 33
 
72
A6 = 38
 
73
DNAME = 39
 
74
 
 
75
QUERY_TYPES = {
 
76
    A: 'A',
 
77
    NS: 'NS',
 
78
    MD: 'MD',
 
79
    MF: 'MF',
 
80
    CNAME: 'CNAME',
 
81
    SOA: 'SOA',
 
82
    MB: 'MB',
 
83
    MG: 'MG',
 
84
    MR: 'MR',
 
85
    NULL: 'NULL',
 
86
    WKS: 'WKS',
 
87
    PTR: 'PTR',
 
88
    HINFO: 'HINFO',
 
89
    MINFO: 'MINFO',
 
90
    MX: 'MX',
 
91
    TXT: 'TXT',
 
92
    RP: 'RP',
 
93
    AFSDB: 'AFSDB',
 
94
 
 
95
    # 19 through 27?  Eh, I'll get to 'em.
 
96
 
 
97
    AAAA: 'AAAA',
 
98
    SRV: 'SRV',
 
99
 
 
100
    A6: 'A6',
 
101
    DNAME: 'DNAME'
 
102
}
 
103
 
 
104
IXFR, AXFR, MAILB, MAILA, ALL_RECORDS = range(251, 256)
 
105
 
 
106
# "Extended" queries (Hey, half of these are deprecated, good job)
 
107
EXT_QUERIES = {
 
108
    IXFR: 'IXFR',
 
109
    AXFR: 'AXFR',
 
110
    MAILB: 'MAILB',
 
111
    MAILA: 'MAILA',
 
112
    ALL_RECORDS: 'ALL_RECORDS'
 
113
}
 
114
 
 
115
REV_TYPES = dict([
 
116
    (v, k) for (k, v) in QUERY_TYPES.items() + EXT_QUERIES.items()
 
117
])
 
118
 
 
119
IN, CS, CH, HS = range(1, 5)
 
120
ANY = 255
 
121
 
 
122
QUERY_CLASSES = {
 
123
    IN: 'IN',
 
124
    CS: 'CS',
 
125
    CH: 'CH',
 
126
    HS: 'HS',
 
127
    ANY: 'ANY'
 
128
}
 
129
REV_CLASSES = dict([
 
130
    (v, k) for (k, v) in QUERY_CLASSES.items()
 
131
])
 
132
 
 
133
 
 
134
# Opcodes
 
135
OP_QUERY, OP_INVERSE, OP_STATUS, OP_NOTIFY = range(4)
 
136
 
 
137
# Response Codes
 
138
OK, EFORMAT, ESERVER, ENAME, ENOTIMP, EREFUSED = range(6)
 
139
 
 
140
class IRecord(Interface):
 
141
    """An single entry in a zone of authority.
 
142
 
 
143
    @cvar TYPE: An indicator of what kind of record this is.
 
144
    """
 
145
 
 
146
 
 
147
# Backwards compatibility aliases - these should be deprecated or something I
 
148
# suppose. -exarkun
 
149
from twisted.names.error import DomainError, AuthoritativeDomainError
 
150
from twisted.names.error import DNSQueryTimeoutError
 
151
 
 
152
 
 
153
def str2time(s):
 
154
    suffixes = (
 
155
        ('S', 1), ('M', 60), ('H', 60 * 60), ('D', 60 * 60 * 24),
 
156
        ('W', 60 * 60 * 24 * 7), ('Y', 60 * 60 * 24 * 365)
 
157
    )
 
158
    if isinstance(s, types.StringType):
 
159
        s = s.upper().strip()
 
160
        for (suff, mult) in suffixes:
 
161
            if s.endswith(suff):
 
162
                return int(float(s[:-1]) * mult)
 
163
        try:
 
164
            s = int(s)
 
165
        except ValueError:
 
166
            raise ValueError, "Invalid time interval specifier: " + s
 
167
    return s
 
168
 
 
169
 
 
170
def readPrecisely(file, l):
 
171
    buff = file.read(l)
 
172
    if len(buff) < l:
 
173
        raise EOFError
 
174
    return buff
 
175
 
 
176
 
 
177
class IEncodable(Interface):
 
178
    """
 
179
    Interface for something which can be encoded to and decoded
 
180
    from a file object.
 
181
    """
 
182
    def encode(strio, compDict = None):
 
183
        """
 
184
        Write a representation of this object to the given
 
185
        file object.
 
186
 
 
187
        @type strio: File-like object
 
188
        @param strio: The stream to which to write bytes
 
189
 
 
190
        @type compDict: C{dict} or C{None}
 
191
        @param compDict: A dictionary of backreference addresses that have
 
192
        have already been written to this stream and that may be used for
 
193
        compression.
 
194
        """
 
195
 
 
196
    def decode(strio, length = None):
 
197
        """
 
198
        Reconstruct an object from data read from the given
 
199
        file object.
 
200
 
 
201
        @type strio: File-like object
 
202
        @param strio: The stream from which bytes may be read
 
203
 
 
204
        @type length: C{int} or C{None}
 
205
        @param length: The number of bytes in this RDATA field.  Most
 
206
        implementations can ignore this value.  Only in the case of
 
207
        records similar to TXT where the total length is in no way
 
208
        encoded in the data is it necessary.
 
209
        """
 
210
 
 
211
 
 
212
class Name:
 
213
    implements(IEncodable)
 
214
 
 
215
    def __init__(self, name=''):
 
216
        assert isinstance(name, types.StringTypes), "%r is not a string" % (name,)
 
217
        self.name = name
 
218
 
 
219
    def encode(self, strio, compDict=None):
 
220
        """
 
221
        Encode this Name into the appropriate byte format.
 
222
 
 
223
        @type strio: file
 
224
        @param strio: The byte representation of this Name will be written to
 
225
        this file.
 
226
 
 
227
        @type compDict: dict
 
228
        @param compDict: dictionary of Names that have already been encoded
 
229
        and whose addresses may be backreferenced by this Name (for the purpose
 
230
        of reducing the message size).
 
231
        """
 
232
        name = self.name
 
233
        while name:
 
234
            if compDict is not None:
 
235
                if compDict.has_key(name):
 
236
                    strio.write(
 
237
                        struct.pack("!H", 0xc000 | compDict[name]))
 
238
                    return
 
239
                else:
 
240
                    compDict[name] = strio.tell() + Message.headerSize
 
241
            ind = name.find('.')
 
242
            if ind > 0:
 
243
                label, name = name[:ind], name[ind + 1:]
 
244
            else:
 
245
                label, name = name, ''
 
246
                ind = len(label)
 
247
            strio.write(chr(ind))
 
248
            strio.write(label)
 
249
        strio.write(chr(0))
 
250
 
 
251
 
 
252
    def decode(self, strio, length = None):
 
253
        """
 
254
        Decode a byte string into this Name.
 
255
 
 
256
        @type strio: file
 
257
        @param strio: Bytes will be read from this file until the full Name
 
258
        is decoded.
 
259
 
 
260
        @raise EOFError: Raised when there are not enough bytes available
 
261
        from C{strio}.
 
262
        """
 
263
        self.name = ''
 
264
        off = 0
 
265
        while 1:
 
266
            l = ord(readPrecisely(strio, 1))
 
267
            if l == 0:
 
268
                if off > 0:
 
269
                    strio.seek(off)
 
270
                return
 
271
            if (l >> 6) == 3:
 
272
                new_off = ((l&63) << 8
 
273
                            | ord(readPrecisely(strio, 1)))
 
274
                if off == 0:
 
275
                    off = strio.tell()
 
276
                strio.seek(new_off)
 
277
                continue
 
278
            label = readPrecisely(strio, l)
 
279
            if self.name == '':
 
280
                self.name = label
 
281
            else:
 
282
                self.name = self.name + '.' + label
 
283
 
 
284
    def __eq__(self, other):
 
285
        if isinstance(other, Name):
 
286
            return str(self) == str(other)
 
287
        return 0
 
288
 
 
289
 
 
290
    def __hash__(self):
 
291
        return hash(str(self))
 
292
 
 
293
 
 
294
    def __str__(self):
 
295
        return self.name
 
296
 
 
297
class Query:
 
298
    """
 
299
    Represent a single DNS query.
 
300
 
 
301
    @ivar name: The name about which this query is requesting information.
 
302
    @ivar type: The query type.
 
303
    @ivar cls: The query class.
 
304
    """
 
305
 
 
306
    implements(IEncodable)
 
307
 
 
308
    name = None
 
309
    type = None
 
310
    cls = None
 
311
 
 
312
    def __init__(self, name='', type=A, cls=IN):
 
313
        """
 
314
        @type name: C{str}
 
315
        @param name: The name about which to request information.
 
316
 
 
317
        @type type: C{int}
 
318
        @param type: The query type.
 
319
 
 
320
        @type cls: C{int}
 
321
        @param cls: The query class.
 
322
        """
 
323
        self.name = Name(name)
 
324
        self.type = type
 
325
        self.cls = cls
 
326
 
 
327
 
 
328
    def encode(self, strio, compDict=None):
 
329
        self.name.encode(strio, compDict)
 
330
        strio.write(struct.pack("!HH", self.type, self.cls))
 
331
 
 
332
 
 
333
    def decode(self, strio, length = None):
 
334
        self.name.decode(strio)
 
335
        buff = readPrecisely(strio, 4)
 
336
        self.type, self.cls = struct.unpack("!HH", buff)
 
337
 
 
338
 
 
339
    def __hash__(self):
 
340
        return hash((str(self.name).lower(), self.type, self.cls))
 
341
 
 
342
 
 
343
    def __cmp__(self, other):
 
344
        return isinstance(other, Query) and cmp(
 
345
            (str(self.name).lower(), self.type, self.cls),
 
346
            (str(other.name).lower(), other.type, other.cls)
 
347
        ) or cmp(self.__class__, other.__class__)
 
348
 
 
349
 
 
350
    def __str__(self):
 
351
        t = QUERY_TYPES.get(self.type, EXT_QUERIES.get(self.type, 'UNKNOWN (%d)' % self.type))
 
352
        c = QUERY_CLASSES.get(self.cls, 'UNKNOWN (%d)' % self.cls)
 
353
        return '<Query %s %s %s>' % (self.name, t, c)
 
354
 
 
355
 
 
356
    def __repr__(self):
 
357
        return 'Query(%r, %r, %r)' % (str(self.name), self.type, self.cls)
 
358
 
 
359
 
 
360
class RRHeader:
 
361
    """
 
362
    A resource record header.
 
363
 
 
364
    @cvar fmt: C{str} specifying the byte format of an RR.
 
365
 
 
366
    @ivar name: The name about which this reply contains information.
 
367
    @ivar type: The query type of the original request.
 
368
    @ivar cls: The query class of the original request.
 
369
    @ivar ttl: The time-to-live for this record.
 
370
    @ivar payload: An object that implements the IEncodable interface
 
371
    @ivar auth: Whether this header is authoritative or not.
 
372
    """
 
373
 
 
374
    implements(IEncodable)
 
375
 
 
376
    fmt = "!HHIH"
 
377
 
 
378
    name = None
 
379
    type = None
 
380
    cls = None
 
381
    ttl = None
 
382
    payload = None
 
383
    rdlength = None
 
384
 
 
385
    cachedResponse = None
 
386
 
 
387
    def __init__(self, name='', type=A, cls=IN, ttl=0, payload=None, auth=False):
 
388
        """
 
389
        @type name: C{str}
 
390
        @param name: The name about which this reply contains information.
 
391
 
 
392
        @type type: C{int}
 
393
        @param type: The query type.
 
394
 
 
395
        @type cls: C{int}
 
396
        @param cls: The query class.
 
397
 
 
398
        @type ttl: C{int}
 
399
        @param ttl: Time to live for this record.
 
400
 
 
401
        @type payload: An object implementing C{IEncodable}
 
402
        @param payload: A Query Type specific data object.
 
403
        """
 
404
        assert (payload is None) or (payload.TYPE == type)
 
405
 
 
406
        self.name = Name(name)
 
407
        self.type = type
 
408
        self.cls = cls
 
409
        self.ttl = ttl
 
410
        self.payload = payload
 
411
        self.auth = auth
 
412
 
 
413
 
 
414
    def encode(self, strio, compDict=None):
 
415
        self.name.encode(strio, compDict)
 
416
        strio.write(struct.pack(self.fmt, self.type, self.cls, self.ttl, 0))
 
417
        if self.payload:
 
418
            prefix = strio.tell()
 
419
            self.payload.encode(strio, compDict)
 
420
            aft = strio.tell()
 
421
            strio.seek(prefix - 2, 0)
 
422
            strio.write(struct.pack('!H', aft - prefix))
 
423
            strio.seek(aft, 0)
 
424
 
 
425
 
 
426
    def decode(self, strio, length = None):
 
427
        self.name.decode(strio)
 
428
        l = struct.calcsize(self.fmt)
 
429
        buff = readPrecisely(strio, l)
 
430
        r = struct.unpack(self.fmt, buff)
 
431
        self.type, self.cls, self.ttl, self.rdlength = r
 
432
 
 
433
 
 
434
    def isAuthoritative(self):
 
435
        return self.auth
 
436
 
 
437
 
 
438
    def __str__(self):
 
439
        t = QUERY_TYPES.get(self.type, EXT_QUERIES.get(self.type, 'UNKNOWN (%d)' % self.type))
 
440
        c = QUERY_CLASSES.get(self.cls, 'UNKNOWN (%d)' % self.cls)
 
441
        return '<RR name=%s type=%s class=%s ttl=%ds auth=%s>' % (self.name, t, c, self.ttl, self.auth and 'True' or 'False')
 
442
 
 
443
 
 
444
    __repr__ = __str__
 
445
 
 
446
class SimpleRecord(tputil.FancyStrMixin, tputil.FancyEqMixin):
 
447
    """
 
448
    A Resource Record which consists of a single RFC 1035 domain-name.
 
449
    """
 
450
    TYPE = None
 
451
 
 
452
    implements(IEncodable, IRecord)
 
453
    name = None
 
454
 
 
455
    showAttributes = (('name', 'name', '%s'), 'ttl')
 
456
    compareAttributes = ('name', 'ttl')
 
457
 
 
458
    def __init__(self, name='', ttl=None):
 
459
        self.name = Name(name)
 
460
        self.ttl = str2time(ttl)
 
461
 
 
462
 
 
463
    def encode(self, strio, compDict = None):
 
464
        self.name.encode(strio, compDict)
 
465
 
 
466
 
 
467
    def decode(self, strio, length = None):
 
468
        self.name = Name()
 
469
        self.name.decode(strio)
 
470
 
 
471
 
 
472
    def __hash__(self):
 
473
        return hash(self.name)
 
474
 
 
475
 
 
476
# Kinds of RRs - oh my!
 
477
class Record_NS(SimpleRecord):
 
478
    TYPE = NS
 
479
 
 
480
class Record_MD(SimpleRecord):       # OBSOLETE
 
481
    TYPE = MD
 
482
 
 
483
class Record_MF(SimpleRecord):       # OBSOLETE
 
484
    TYPE = MF
 
485
 
 
486
class Record_CNAME(SimpleRecord):
 
487
    TYPE = CNAME
 
488
 
 
489
class Record_MB(SimpleRecord):       # EXPERIMENTAL
 
490
    TYPE = MB
 
491
 
 
492
class Record_MG(SimpleRecord):       # EXPERIMENTAL
 
493
    TYPE = MG
 
494
 
 
495
class Record_MR(SimpleRecord):       # EXPERIMENTAL
 
496
    TYPE = MR
 
497
 
 
498
class Record_PTR(SimpleRecord):
 
499
    TYPE = PTR
 
500
 
 
501
class Record_DNAME(SimpleRecord):
 
502
    TYPE = DNAME
 
503
 
 
504
class Record_A(tputil.FancyEqMixin):
 
505
    implements(IEncodable, IRecord)
 
506
 
 
507
    TYPE = A
 
508
    address = None
 
509
 
 
510
    compareAttributes = ('address', 'ttl')
 
511
 
 
512
    def __init__(self, address='0.0.0.0', ttl=None):
 
513
        address = socket.inet_aton(address)
 
514
        self.address = address
 
515
        self.ttl = str2time(ttl)
 
516
 
 
517
 
 
518
    def encode(self, strio, compDict = None):
 
519
        strio.write(self.address)
 
520
 
 
521
 
 
522
    def decode(self, strio, length = None):
 
523
        self.address = readPrecisely(strio, 4)
 
524
 
 
525
 
 
526
    def __hash__(self):
 
527
        return hash(self.address)
 
528
 
 
529
 
 
530
    def __str__(self):
 
531
        return '<A %s ttl=%s>' % (self.dottedQuad(), self.ttl)
 
532
 
 
533
 
 
534
    def dottedQuad(self):
 
535
        return socket.inet_ntoa(self.address)
 
536
 
 
537
 
 
538
class Record_SOA(tputil.FancyEqMixin, tputil.FancyStrMixin):
 
539
    implements(IEncodable, IRecord)
 
540
 
 
541
    compareAttributes = ('serial', 'mname', 'rname', 'refresh', 'expire', 'retry', 'ttl')
 
542
    showAttributes = (('mname', 'mname', '%s'), ('rname', 'rname', '%s'), 'serial', 'refresh', 'retry', 'expire', 'minimum', 'ttl')
 
543
 
 
544
    TYPE = SOA
 
545
 
 
546
    def __init__(self, mname='', rname='', serial=0, refresh=0, retry=0, expire=0, minimum=0, ttl=None):
 
547
        self.mname, self.rname = Name(mname), Name(rname)
 
548
        self.serial, self.refresh = str2time(serial), str2time(refresh)
 
549
        self.minimum, self.expire = str2time(minimum), str2time(expire)
 
550
        self.retry = str2time(retry)
 
551
        self.ttl = str2time(ttl)
 
552
 
 
553
 
 
554
    def encode(self, strio, compDict = None):
 
555
        self.mname.encode(strio, compDict)
 
556
        self.rname.encode(strio, compDict)
 
557
        strio.write(
 
558
            struct.pack(
 
559
                '!LlllL',
 
560
                self.serial, self.refresh, self.retry, self.expire,
 
561
                self.minimum
 
562
            )
 
563
        )
 
564
 
 
565
 
 
566
    def decode(self, strio, length = None):
 
567
        self.mname, self.rname = Name(), Name()
 
568
        self.mname.decode(strio)
 
569
        self.rname.decode(strio)
 
570
        r = struct.unpack('!LlllL', readPrecisely(strio, 20))
 
571
        self.serial, self.refresh, self.retry, self.expire, self.minimum = r
 
572
 
 
573
 
 
574
    def __hash__(self):
 
575
        return hash((
 
576
            self.serial, self.mname, self.rname,
 
577
            self.refresh, self.expire, self.retry
 
578
        ))
 
579
 
 
580
 
 
581
class Record_NULL:                   # EXPERIMENTAL
 
582
    implements(IEncodable, IRecord)
 
583
    TYPE = NULL
 
584
 
 
585
    def __init__(self, payload=None, ttl=None):
 
586
        self.payload = payload
 
587
        self.ttl = str2time(ttl)
 
588
 
 
589
 
 
590
    def encode(self, strio, compDict = None):
 
591
        strio.write(self.payload)
 
592
 
 
593
 
 
594
    def decode(self, strio, length = None):
 
595
        self.payload = readPrecisely(strio, length)
 
596
 
 
597
 
 
598
    def __hash__(self):
 
599
        return hash(self.payload)
 
600
 
 
601
 
 
602
class Record_WKS(tputil.FancyEqMixin, tputil.FancyStrMixin):                    # OBSOLETE
 
603
    implements(IEncodable, IRecord)
 
604
    TYPE = WKS
 
605
 
 
606
    compareAttributes = ('address', 'protocol', 'map', 'ttl')
 
607
    showAttributes = ('address', 'protocol', 'ttl')
 
608
 
 
609
    def __init__(self, address='0.0.0.0', protocol=0, map='', ttl=None):
 
610
        self.address = socket.inet_aton(address)
 
611
        self.protocol, self.map = protocol, map
 
612
        self.ttl = str2time(ttl)
 
613
 
 
614
 
 
615
    def encode(self, strio, compDict = None):
 
616
        strio.write(self.address)
 
617
        strio.write(struct.pack('!B', self.protocol))
 
618
        strio.write(self.map)
 
619
 
 
620
 
 
621
    def decode(self, strio, length = None):
 
622
        self.address = readPrecisely(strio, 4)
 
623
        self.protocol = struct.unpack('!B', readPrecisely(strio, 1))[0]
 
624
        self.map = readPrecisely(strio, length - 5)
 
625
 
 
626
 
 
627
    def __hash__(self):
 
628
        return hash((self.address, self.protocol, self.map))
 
629
 
 
630
 
 
631
class Record_AAAA(tputil.FancyEqMixin):               # OBSOLETE (or headed there)
 
632
    implements(IEncodable, IRecord)
 
633
    TYPE = AAAA
 
634
 
 
635
    compareAttributes = ('address', 'ttl')
 
636
 
 
637
    def __init__(self, address = '::', ttl=None):
 
638
        self.address = socket.inet_pton(AF_INET6, address)
 
639
        self.ttl = str2time(ttl)
 
640
 
 
641
 
 
642
    def encode(self, strio, compDict = None):
 
643
        strio.write(self.address)
 
644
 
 
645
 
 
646
    def decode(self, strio, length = None):
 
647
        self.address = readPrecisely(strio, 16)
 
648
 
 
649
 
 
650
    def __hash__(self):
 
651
        return hash(self.address)
 
652
 
 
653
 
 
654
    def __str__(self):
 
655
        return '<AAAA %s ttl=%s>' % (socket.inet_ntop(AF_INET6, self.address), self.ttl)
 
656
 
 
657
 
 
658
class Record_A6:
 
659
    implements(IEncodable, IRecord)
 
660
    TYPE = A6
 
661
 
 
662
    def __init__(self, prefixLen=0, suffix='::', prefix='', ttl=None):
 
663
        self.prefixLen = prefixLen
 
664
        self.suffix = socket.inet_pton(AF_INET6, suffix)
 
665
        self.prefix = Name(prefix)
 
666
        self.bytes = int((128 - self.prefixLen) / 8.0)
 
667
        self.ttl = str2time(ttl)
 
668
 
 
669
 
 
670
    def encode(self, strio, compDict = None):
 
671
        strio.write(struct.pack('!B', self.prefixLen))
 
672
        if self.bytes:
 
673
            strio.write(self.suffix[-self.bytes:])
 
674
        if self.prefixLen:
 
675
            # This may not be compressed
 
676
            self.prefix.encode(strio, None)
 
677
 
 
678
 
 
679
    def decode(self, strio, length = None):
 
680
        self.prefixLen = struct.unpack('!B', readPrecisely(strio, 1))[0]
 
681
        self.bytes = int((128 - self.prefixLen) / 8.0)
 
682
        if self.bytes:
 
683
            self.suffix = '\x00' * (16 - self.bytes) + readPrecisely(strio, self.bytes)
 
684
        if self.prefixLen:
 
685
            self.prefix.decode(strio)
 
686
 
 
687
 
 
688
    def __eq__(self, other):
 
689
        if isinstance(other, Record_A6):
 
690
            return (self.prefixLen == other.prefixLen and
 
691
                    self.suffix[-self.bytes:] == other.suffix[-self.bytes:] and
 
692
                    self.prefix == other.prefix and
 
693
                    self.ttl == other.ttl)
 
694
        return 0
 
695
 
 
696
 
 
697
    def __hash__(self):
 
698
        return hash((self.prefixLen, self.suffix[-self.bytes:], self.prefix))
 
699
 
 
700
 
 
701
    def __str__(self):
 
702
        return '<A6 %s %s (%d) ttl=%s>' % (
 
703
            self.prefix,
 
704
            socket.inet_ntop(AF_INET6, self.suffix),
 
705
            self.prefixLen, self.ttl
 
706
        )
 
707
 
 
708
 
 
709
class Record_SRV(tputil.FancyEqMixin, tputil.FancyStrMixin):                # EXPERIMENTAL
 
710
    implements(IEncodable, IRecord)
 
711
    TYPE = SRV
 
712
 
 
713
    compareAttributes = ('priority', 'weight', 'target', 'port', 'ttl')
 
714
    showAttributes = ('priority', 'weight', ('target', 'target', '%s'), 'port', 'ttl')
 
715
 
 
716
    def __init__(self, priority=0, weight=0, port=0, target='', ttl=None):
 
717
        self.priority = int(priority)
 
718
        self.weight = int(weight)
 
719
        self.port = int(port)
 
720
        self.target = Name(target)
 
721
        self.ttl = str2time(ttl)
 
722
 
 
723
 
 
724
    def encode(self, strio, compDict = None):
 
725
        strio.write(struct.pack('!HHH', self.priority, self.weight, self.port))
 
726
        # This can't be compressed
 
727
        self.target.encode(strio, None)
 
728
 
 
729
 
 
730
    def decode(self, strio, length = None):
 
731
        r = struct.unpack('!HHH', readPrecisely(strio, struct.calcsize('!HHH')))
 
732
        self.priority, self.weight, self.port = r
 
733
        self.target = Name()
 
734
        self.target.decode(strio)
 
735
 
 
736
 
 
737
    def __hash__(self):
 
738
        return hash((self.priority, self.weight, self.port, self.target))
 
739
 
 
740
 
 
741
 
 
742
class Record_AFSDB(tputil.FancyStrMixin, tputil.FancyEqMixin):
 
743
    implements(IEncodable, IRecord)
 
744
    TYPE = AFSDB
 
745
 
 
746
    compareAttributes = ('subtype', 'hostname', 'ttl')
 
747
    showAttributes = ('subtype', ('hostname', 'hostname', '%s'), 'ttl')
 
748
 
 
749
    def __init__(self, subtype=0, hostname='', ttl=None):
 
750
        self.subtype = int(subtype)
 
751
        self.hostname = Name(hostname)
 
752
        self.ttl = str2time(ttl)
 
753
 
 
754
 
 
755
    def encode(self, strio, compDict = None):
 
756
        strio.write(struct.pack('!H', self.subtype))
 
757
        self.hostname.encode(strio, compDict)
 
758
 
 
759
 
 
760
    def decode(self, strio, length = None):
 
761
        r = struct.unpack('!H', readPrecisely(strio, struct.calcsize('!H')))
 
762
        self.subtype, = r
 
763
        self.hostname.decode(strio)
 
764
 
 
765
 
 
766
    def __hash__(self):
 
767
        return hash((self.subtype, self.hostname))
 
768
 
 
769
 
 
770
 
 
771
class Record_RP(tputil.FancyEqMixin, tputil.FancyStrMixin):
 
772
    implements(IEncodable, IRecord)
 
773
    TYPE = RP
 
774
 
 
775
    compareAttributes = ('mbox', 'txt', 'ttl')
 
776
    showAttributes = (('mbox', 'mbox', '%s'), ('txt', 'txt', '%s'), 'ttl')
 
777
 
 
778
    def __init__(self, mbox='', txt='', ttl=None):
 
779
        self.mbox = Name(mbox)
 
780
        self.txt = Name(txt)
 
781
        self.ttl = str2time(ttl)
 
782
 
 
783
 
 
784
    def encode(self, strio, compDict = None):
 
785
        self.mbox.encode(strio, compDict)
 
786
        self.txt.encode(strio, compDict)
 
787
 
 
788
 
 
789
    def decode(self, strio, length = None):
 
790
        self.mbox = Name()
 
791
        self.txt = Name()
 
792
        self.mbox.decode(strio)
 
793
        self.txt.decode(strio)
 
794
 
 
795
 
 
796
    def __hash__(self):
 
797
        return hash((self.mbox, self.txt))
 
798
 
 
799
 
 
800
 
 
801
class Record_HINFO(tputil.FancyStrMixin):
 
802
    implements(IEncodable, IRecord)
 
803
    TYPE = HINFO
 
804
 
 
805
    showAttributes = ('cpu', 'os', 'ttl')
 
806
 
 
807
    def __init__(self, cpu='', os='', ttl=None):
 
808
        self.cpu, self.os = cpu, os
 
809
        self.ttl = str2time(ttl)
 
810
 
 
811
 
 
812
    def encode(self, strio, compDict = None):
 
813
        strio.write(struct.pack('!B', len(self.cpu)) + self.cpu)
 
814
        strio.write(struct.pack('!B', len(self.os)) + self.os)
 
815
 
 
816
 
 
817
    def decode(self, strio, length = None):
 
818
        cpu = struct.unpack('!B', readPrecisely(strio, 1))[0]
 
819
        self.cpu = readPrecisely(strio, cpu)
 
820
        os = struct.unpack('!B', readPrecisely(strio, 1))[0]
 
821
        self.os = readPrecisely(strio, os)
 
822
 
 
823
 
 
824
    def __eq__(self, other):
 
825
        if isinstance(other, Record_HINFO):
 
826
            return (self.os.lower() == other.os.lower() and
 
827
                    self.cpu.lower() == other.cpu.lower() and
 
828
                    self.ttl == other.ttl)
 
829
        return 0
 
830
 
 
831
 
 
832
    def __hash__(self):
 
833
        return hash((self.os.lower(), self.cpu.lower()))
 
834
 
 
835
 
 
836
 
 
837
class Record_MINFO(tputil.FancyEqMixin, tputil.FancyStrMixin):                 # EXPERIMENTAL
 
838
    implements(IEncodable, IRecord)
 
839
    TYPE = MINFO
 
840
 
 
841
    rmailbx = None
 
842
    emailbx = None
 
843
 
 
844
    compareAttributes = ('rmailbx', 'emailbx', 'ttl')
 
845
    showAttributes = (('rmailbx', 'responsibility', '%s'),
 
846
                      ('emailbx', 'errors', '%s'),
 
847
                      'ttl')
 
848
 
 
849
    def __init__(self, rmailbx='', emailbx='', ttl=None):
 
850
        self.rmailbx, self.emailbx = Name(rmailbx), Name(emailbx)
 
851
        self.ttl = str2time(ttl)
 
852
 
 
853
 
 
854
    def encode(self, strio, compDict = None):
 
855
        self.rmailbx.encode(strio, compDict)
 
856
        self.emailbx.encode(strio, compDict)
 
857
 
 
858
 
 
859
    def decode(self, strio, length = None):
 
860
        self.rmailbx, self.emailbx = Name(), Name()
 
861
        self.rmailbx.decode(strio)
 
862
        self.emailbx.decode(strio)
 
863
 
 
864
 
 
865
    def __hash__(self):
 
866
        return hash((self.rmailbx, self.emailbx))
 
867
 
 
868
 
 
869
class Record_MX(tputil.FancyStrMixin, tputil.FancyEqMixin):
 
870
    implements(IEncodable, IRecord)
 
871
    TYPE = MX
 
872
 
 
873
    compareAttributes = ('preference', 'name', 'ttl')
 
874
    showAttributes = ('preference', ('name', 'name', '%s'), 'ttl')
 
875
 
 
876
    def __init__(self, preference=0, name='', ttl=None, **kwargs):
 
877
        self.preference, self.name = int(preference), Name(kwargs.get('exchange', name))
 
878
        self.ttl = str2time(ttl)
 
879
 
 
880
    def encode(self, strio, compDict = None):
 
881
        strio.write(struct.pack('!H', self.preference))
 
882
        self.name.encode(strio, compDict)
 
883
 
 
884
 
 
885
    def decode(self, strio, length = None):
 
886
        self.preference = struct.unpack('!H', readPrecisely(strio, 2))[0]
 
887
        self.name = Name()
 
888
        self.name.decode(strio)
 
889
 
 
890
    def exchange(self):
 
891
        warnings.warn("use Record_MX.name instead", DeprecationWarning, stacklevel=2)
 
892
        return self.name
 
893
 
 
894
    exchange = property(exchange)
 
895
 
 
896
    def __hash__(self):
 
897
        return hash((self.preference, self.name))
 
898
 
 
899
 
 
900
 
 
901
# Oh god, Record_TXT how I hate thee.
 
902
class Record_TXT(tputil.FancyEqMixin, tputil.FancyStrMixin):
 
903
    implements(IEncodable, IRecord)
 
904
 
 
905
    TYPE = TXT
 
906
 
 
907
    showAttributes = compareAttributes = ('data', 'ttl')
 
908
 
 
909
    def __init__(self, *data, **kw):
 
910
        self.data = list(data)
 
911
        # arg man python sucks so bad
 
912
        self.ttl = str2time(kw.get('ttl', None))
 
913
 
 
914
 
 
915
    def encode(self, strio, compDict = None):
 
916
        for d in self.data:
 
917
            strio.write(struct.pack('!B', len(d)) + d)
 
918
 
 
919
 
 
920
    def decode(self, strio, length = None):
 
921
        soFar = 0
 
922
        self.data = []
 
923
        while soFar < length:
 
924
            L = struct.unpack('!B', readPrecisely(strio, 1))[0]
 
925
            self.data.append(readPrecisely(strio, L))
 
926
            soFar += L + 1
 
927
        if soFar != length:
 
928
            log.msg(
 
929
                "Decoded %d bytes in TXT record, but rdlength is %d" % (
 
930
                    soFar, length
 
931
                )
 
932
            )
 
933
 
 
934
 
 
935
    def __hash__(self):
 
936
        return hash(tuple(self.data))
 
937
 
 
938
 
 
939
 
 
940
class Message:
 
941
    headerFmt = "!H2B4H"
 
942
    headerSize = struct.calcsize( headerFmt )
 
943
 
 
944
    # Question, answer, additional, and nameserver lists
 
945
    queries = answers = add = ns = None
 
946
 
 
947
    def __init__(self, id=0, answer=0, opCode=0, recDes=0, recAv=0,
 
948
                       auth=0, rCode=OK, trunc=0, maxSize=512):
 
949
        self.maxSize = maxSize
 
950
        self.id = id
 
951
        self.answer = answer
 
952
        self.opCode = opCode
 
953
        self.auth = auth
 
954
        self.trunc = trunc
 
955
        self.recDes = recDes
 
956
        self.recAv = recAv
 
957
        self.rCode = rCode
 
958
        self.queries = []
 
959
        self.answers = []
 
960
        self.authority = []
 
961
        self.additional = []
 
962
 
 
963
 
 
964
    def addQuery(self, name, type=ALL_RECORDS, cls=IN):
 
965
        """
 
966
        Add another query to this Message.
 
967
 
 
968
        @type name: C{str}
 
969
        @param name: The name to query.
 
970
 
 
971
        @type type: C{int}
 
972
        @param type: Query type
 
973
 
 
974
        @type cls: C{int}
 
975
        @param cls: Query class
 
976
        """
 
977
        self.queries.append(Query(name, type, cls))
 
978
 
 
979
 
 
980
    def encode(self, strio):
 
981
        compDict = {}
 
982
        body_tmp = StringIO.StringIO()
 
983
        for q in self.queries:
 
984
            q.encode(body_tmp, compDict)
 
985
        for q in self.answers:
 
986
            q.encode(body_tmp, compDict)
 
987
        for q in self.authority:
 
988
            q.encode(body_tmp, compDict)
 
989
        for q in self.additional:
 
990
            q.encode(body_tmp, compDict)
 
991
        body = body_tmp.getvalue()
 
992
        size = len(body) + self.headerSize
 
993
        if self.maxSize and size > self.maxSize:
 
994
            self.trunc = 1
 
995
            body = body[:self.maxSize - self.headerSize]
 
996
        byte3 = (( ( self.answer & 1 ) << 7 )
 
997
                 | ((self.opCode & 0xf ) << 3 )
 
998
                 | ((self.auth & 1 ) << 2 )
 
999
                 | ((self.trunc & 1 ) << 1 )
 
1000
                 | ( self.recDes & 1 ) )
 
1001
        byte4 = ( ( (self.recAv & 1 ) << 7 )
 
1002
                  | (self.rCode & 0xf ) )
 
1003
 
 
1004
        strio.write(struct.pack(self.headerFmt, self.id, byte3, byte4,
 
1005
                                len(self.queries), len(self.answers),
 
1006
                                len(self.authority), len(self.additional)))
 
1007
        strio.write(body)
 
1008
 
 
1009
 
 
1010
    def decode(self, strio, length = None):
 
1011
        self.maxSize = 0
 
1012
        header = readPrecisely(strio, self.headerSize)
 
1013
        r = struct.unpack(self.headerFmt, header)
 
1014
        self.id, byte3, byte4, nqueries, nans, nns, nadd = r
 
1015
        self.answer = ( byte3 >> 7 ) & 1
 
1016
        self.opCode = ( byte3 >> 3 ) & 0xf
 
1017
        self.auth = ( byte3 >> 2 ) & 1
 
1018
        self.trunc = ( byte3 >> 1 ) & 1
 
1019
        self.recDes = byte3 & 1
 
1020
        self.recAv = ( byte4 >> 7 ) & 1
 
1021
        self.rCode = byte4 & 0xf
 
1022
 
 
1023
        self.queries = []
 
1024
        for i in range(nqueries):
 
1025
            q = Query()
 
1026
            try:
 
1027
                q.decode(strio)
 
1028
            except EOFError:
 
1029
                return
 
1030
            self.queries.append(q)
 
1031
 
 
1032
        items = ((self.answers, nans), (self.authority, nns), (self.additional, nadd))
 
1033
        for (l, n) in items:
 
1034
            self.parseRecords(l, n, strio)
 
1035
 
 
1036
 
 
1037
    def parseRecords(self, list, num, strio):
 
1038
        for i in range(num):
 
1039
            header = RRHeader()
 
1040
            try:
 
1041
                header.decode(strio)
 
1042
            except EOFError:
 
1043
                return
 
1044
            t = self.lookupRecordType(header.type)
 
1045
            if not t:
 
1046
                continue
 
1047
            header.payload = t(ttl=header.ttl)
 
1048
            try:
 
1049
                header.payload.decode(strio, header.rdlength)
 
1050
            except EOFError:
 
1051
                return
 
1052
            list.append(header)
 
1053
 
 
1054
 
 
1055
    def lookupRecordType(self, type):
 
1056
        return globals().get('Record_' + QUERY_TYPES.get(type, ''), None)
 
1057
 
 
1058
 
 
1059
    def toStr(self):
 
1060
        strio = StringIO.StringIO()
 
1061
        self.encode(strio)
 
1062
        return strio.getvalue()
 
1063
 
 
1064
 
 
1065
    def fromStr(self, str):
 
1066
        strio = StringIO.StringIO(str)
 
1067
        self.decode(strio)
 
1068
 
 
1069
 
 
1070
class DNSDatagramProtocol(protocol.DatagramProtocol):
 
1071
    id = None
 
1072
    liveMessages = None
 
1073
    resends = None
 
1074
 
 
1075
    timeout = 10
 
1076
    reissue = 2
 
1077
 
 
1078
    def __init__(self, controller):
 
1079
        self.controller = controller
 
1080
        self.id = random.randrange(2 ** 10, 2 ** 15)
 
1081
 
 
1082
    def pickID(self):
 
1083
        while 1:
 
1084
            self.id += randomSource() % (2 ** 10)
 
1085
            self.id %= 2 ** 16
 
1086
            if self.id not in self.liveMessages:
 
1087
                break
 
1088
        return self.id
 
1089
 
 
1090
    def stopProtocol(self):
 
1091
        self.liveMessages = {}
 
1092
        self.resends = {}
 
1093
        self.transport = None
 
1094
 
 
1095
    def startProtocol(self):
 
1096
        self.liveMessages = {}
 
1097
        self.resends = {}
 
1098
 
 
1099
    def writeMessage(self, message, address):
 
1100
        self.transport.write(message.toStr(), address)
 
1101
 
 
1102
    def startListening(self):
 
1103
        from twisted.internet import reactor
 
1104
        reactor.listenUDP(0, self, maxPacketSize=512)
 
1105
 
 
1106
    def datagramReceived(self, data, addr):
 
1107
        m = Message()
 
1108
        try:
 
1109
            m.fromStr(data)
 
1110
        except EOFError:
 
1111
            log.msg("Truncated packet (%d bytes) from %s" % (len(data), addr))
 
1112
            return
 
1113
        except:
 
1114
            # Nothing should trigger this, but since we're potentially
 
1115
            # invoking a lot of different decoding methods, we might as well
 
1116
            # be extra cautious.  Anything that triggers this is itself
 
1117
            # buggy.
 
1118
            log.err(failure.Failure(), "Unexpected decoding error")
 
1119
            return
 
1120
 
 
1121
        if m.id in self.liveMessages:
 
1122
            d, canceller = self.liveMessages[m.id]
 
1123
            del self.liveMessages[m.id]
 
1124
            canceller.cancel()
 
1125
            # XXX we shouldn't need this hack of catching exceptioon on callback()
 
1126
            try:
 
1127
                d.callback(m)
 
1128
            except:
 
1129
                log.err()
 
1130
        else:
 
1131
            if m.id not in self.resends:
 
1132
                self.controller.messageReceived(m, self, addr)
 
1133
 
 
1134
 
 
1135
    def removeResend(self, id):
 
1136
        """Mark message ID as no longer having duplication suppression."""
 
1137
        try:
 
1138
            del self.resends[id]
 
1139
        except:
 
1140
            pass
 
1141
 
 
1142
    def query(self, address, queries, timeout = 10, id = None):
 
1143
        """
 
1144
        Send out a message with the given queries.
 
1145
 
 
1146
        @type address: C{tuple} of C{str} and C{int}
 
1147
        @param address: The address to which to send the query
 
1148
 
 
1149
        @type queries: C{list} of C{Query} instances
 
1150
        @param queries: The queries to transmit
 
1151
 
 
1152
        @rtype: C{Deferred}
 
1153
        """
 
1154
        from twisted.internet import reactor
 
1155
 
 
1156
        if not self.transport:
 
1157
            # XXX transport might not get created automatically, use callLater?
 
1158
            self.startListening()
 
1159
 
 
1160
        if id is None:
 
1161
            id = self.pickID()
 
1162
        else:
 
1163
            self.resends[id] = 1
 
1164
        m = Message(id, recDes=1)
 
1165
        m.queries = queries
 
1166
 
 
1167
        resultDeferred = defer.Deferred()
 
1168
        cancelCall = reactor.callLater(timeout, self._clearFailed, resultDeferred, id)
 
1169
        self.liveMessages[id] = (resultDeferred, cancelCall)
 
1170
 
 
1171
        self.writeMessage(m, address)
 
1172
        return resultDeferred
 
1173
 
 
1174
 
 
1175
    def _clearFailed(self, deferred, id):
 
1176
        try:
 
1177
            del self.liveMessages[id]
 
1178
        except:
 
1179
            pass
 
1180
        deferred.errback(failure.Failure(DNSQueryTimeoutError(id)))
 
1181
 
 
1182
 
 
1183
class DNSProtocol(protocol.Protocol):
 
1184
    id = None
 
1185
    liveMessages = None
 
1186
 
 
1187
    length = None
 
1188
    buffer = ''
 
1189
    d = None
 
1190
 
 
1191
 
 
1192
    def __init__(self, controller):
 
1193
        self.controller = controller
 
1194
        self.liveMessages = {}
 
1195
        self.id = random.randrange(2 ** 10, 2 ** 15)
 
1196
 
 
1197
 
 
1198
    def pickID(self):
 
1199
        while 1:
 
1200
            self.id += randomSource() % (2 ** 10)
 
1201
            self.id %= 2 ** 16
 
1202
            if not self.liveMessages.has_key(self.id):
 
1203
                break
 
1204
        return self.id
 
1205
 
 
1206
 
 
1207
    def writeMessage(self, message):
 
1208
        s = message.toStr()
 
1209
        self.transport.write(struct.pack('!H', len(s)) + s)
 
1210
 
 
1211
 
 
1212
    def connectionMade(self):
 
1213
        self.controller.connectionMade(self)
 
1214
 
 
1215
 
 
1216
    def dataReceived(self, data):
 
1217
        self.buffer = self.buffer + data
 
1218
 
 
1219
        while self.buffer:
 
1220
            if self.length is None and len(self.buffer) >= 2:
 
1221
                self.length = struct.unpack('!H', self.buffer[:2])[0]
 
1222
                self.buffer = self.buffer[2:]
 
1223
 
 
1224
            if len(self.buffer) >= self.length:
 
1225
                myChunk = self.buffer[:self.length]
 
1226
                m = Message()
 
1227
                m.fromStr(myChunk)
 
1228
 
 
1229
                try:
 
1230
                    d = self.liveMessages[m.id]
 
1231
                except KeyError:
 
1232
                    self.controller.messageReceived(m, self)
 
1233
                else:
 
1234
                    del self.liveMessages[m.id]
 
1235
                    try:
 
1236
                        d.callback(m)
 
1237
                    except:
 
1238
                        log.err()
 
1239
 
 
1240
                self.buffer = self.buffer[self.length:]
 
1241
                self.length = None
 
1242
            else:
 
1243
                break
 
1244
 
 
1245
 
 
1246
 
 
1247
    def query(self, queries, timeout = None):
 
1248
        """
 
1249
        Send out a message with the given queries.
 
1250
 
 
1251
        @type queries: C{list} of C{Query} instances
 
1252
        @param queries: The queries to transmit
 
1253
 
 
1254
        @rtype: C{Deferred}
 
1255
        """
 
1256
        id = self.pickID()
 
1257
        d = self.liveMessages[id] = defer.Deferred()
 
1258
        if timeout is not None:
 
1259
            d.setTimeout(timeout)
 
1260
        m = Message(id, recDes=1)
 
1261
        m.queries = queries
 
1262
        self.writeMessage(m)
 
1263
        return d