~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/Twisted-10.0.0/twisted/names/authority.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twisted.names.test.test_names -*-
 
2
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
 
3
# See LICENSE for details.
 
4
 
 
5
 
 
6
from __future__ import nested_scopes
 
7
 
 
8
import os
 
9
import time
 
10
 
 
11
from twisted.names import dns
 
12
from twisted.internet import defer
 
13
from twisted.python import failure
 
14
 
 
15
import common
 
16
 
 
17
def getSerial(filename = '/tmp/twisted-names.serial'):
 
18
    """Return a monotonically increasing (across program runs) integer.
 
19
 
 
20
    State is stored in the given file.  If it does not exist, it is
 
21
    created with rw-/---/--- permissions.
 
22
    """
 
23
    serial = time.strftime('%Y%m%d')
 
24
 
 
25
    o = os.umask(0177)
 
26
    try:
 
27
        if not os.path.exists(filename):
 
28
            f = file(filename, 'w')
 
29
            f.write(serial + ' 0')
 
30
            f.close()
 
31
    finally:
 
32
        os.umask(o)
 
33
 
 
34
    serialFile = file(filename, 'r')
 
35
    lastSerial, ID = serialFile.readline().split()
 
36
    ID = (lastSerial == serial) and (int(ID) + 1) or 0
 
37
    serialFile.close()
 
38
    serialFile = file(filename, 'w')
 
39
    serialFile.write('%s %d' % (serial, ID))
 
40
    serialFile.close()
 
41
    serial = serial + ('%02d' % (ID,))
 
42
    return serial
 
43
 
 
44
 
 
45
#class LookupCacherMixin(object):
 
46
#    _cache = None
 
47
#
 
48
#    def _lookup(self, name, cls, type, timeout = 10):
 
49
#        if not self._cache:
 
50
#            self._cache = {}
 
51
#            self._meth = super(LookupCacherMixin, self)._lookup
 
52
#
 
53
#        if self._cache.has_key((name, cls, type)):
 
54
#            return self._cache[(name, cls, type)]
 
55
#        else:
 
56
#            r = self._meth(name, cls, type, timeout)
 
57
#            self._cache[(name, cls, type)] = r
 
58
#            return r
 
59
 
 
60
 
 
61
class FileAuthority(common.ResolverBase):
 
62
    """An Authority that is loaded from a file."""
 
63
 
 
64
    soa = None
 
65
    records = None
 
66
 
 
67
    def __init__(self, filename):
 
68
        common.ResolverBase.__init__(self)
 
69
        self.loadFile(filename)
 
70
        self._cache = {}
 
71
 
 
72
 
 
73
    def __setstate__(self, state):
 
74
        self.__dict__ = state
 
75
#        print 'setstate ', self.soa
 
76
 
 
77
    def _lookup(self, name, cls, type, timeout = None):
 
78
        cnames = []
 
79
        results = []
 
80
        authority = []
 
81
        additional = []
 
82
        default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
 
83
 
 
84
        domain_records = self.records.get(name.lower())
 
85
 
 
86
        if domain_records:
 
87
            for record in domain_records:
 
88
                if record.ttl is not None:
 
89
                    ttl = record.ttl
 
90
                else:
 
91
                    ttl = default_ttl
 
92
 
 
93
                if record.TYPE == type or type == dns.ALL_RECORDS:
 
94
                    results.append(
 
95
                        dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
 
96
                    )
 
97
                elif record.TYPE == dns.NS and type != dns.ALL_RECORDS:
 
98
                    authority.append(
 
99
                        dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
 
100
                    )
 
101
                if record.TYPE == dns.CNAME:
 
102
                    cnames.append(
 
103
                        dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
 
104
                    )
 
105
            if not results:
 
106
                results = cnames
 
107
 
 
108
            for record in results + authority:
 
109
                section = {dns.NS: additional, dns.CNAME: results, dns.MX: additional}.get(record.type)
 
110
                if section is not None:
 
111
                    n = str(record.payload.name)
 
112
                    for rec in self.records.get(n.lower(), ()):
 
113
                        if rec.TYPE == dns.A:
 
114
                            section.append(
 
115
                                dns.RRHeader(n, dns.A, dns.IN, rec.ttl or default_ttl, rec, auth=True)
 
116
                            )
 
117
 
 
118
            return defer.succeed((results, authority, additional))
 
119
        else:
 
120
            if name.lower().endswith(self.soa[0].lower()):
 
121
                # We are the authority and we didn't find it.  Goodbye.
 
122
                return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
123
            return defer.fail(failure.Failure(dns.DomainError(name)))
 
124
 
 
125
 
 
126
    def lookupZone(self, name, timeout = 10):
 
127
        if self.soa[0].lower() == name.lower():
 
128
            # Wee hee hee hooo yea
 
129
            default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
 
130
            if self.soa[1].ttl is not None:
 
131
                soa_ttl = self.soa[1].ttl
 
132
            else:
 
133
                soa_ttl = default_ttl
 
134
            results = [dns.RRHeader(self.soa[0], dns.SOA, dns.IN, soa_ttl, self.soa[1], auth=True)]
 
135
            for (k, r) in self.records.items():
 
136
                for rec in r:
 
137
                    if rec.ttl is not None:
 
138
                        ttl = rec.ttl
 
139
                    else:
 
140
                        ttl = default_ttl
 
141
                    if rec.TYPE != dns.SOA:
 
142
                        results.append(dns.RRHeader(k, rec.TYPE, dns.IN, ttl, rec, auth=True))
 
143
            results.append(results[0])
 
144
            return defer.succeed((results, (), ()))
 
145
        return defer.fail(failure.Failure(dns.DomainError(name)))
 
146
 
 
147
    def _cbAllRecords(self, results):
 
148
        ans, auth, add = [], [], []
 
149
        for res in results:
 
150
            if res[0]:
 
151
                ans.extend(res[1][0])
 
152
                auth.extend(res[1][1])
 
153
                add.extend(res[1][2])
 
154
        return ans, auth, add
 
155
 
 
156
 
 
157
class PySourceAuthority(FileAuthority):
 
158
    """A FileAuthority that is built up from Python source code."""
 
159
 
 
160
    def loadFile(self, filename):
 
161
        g, l = self.setupConfigNamespace(), {}
 
162
        execfile(filename, g, l)
 
163
        if not l.has_key('zone'):
 
164
            raise ValueError, "No zone defined in " + filename
 
165
 
 
166
        self.records = {}
 
167
        for rr in l['zone']:
 
168
            if isinstance(rr[1], dns.Record_SOA):
 
169
                self.soa = rr
 
170
            self.records.setdefault(rr[0].lower(), []).append(rr[1])
 
171
 
 
172
 
 
173
    def wrapRecord(self, type):
 
174
        return lambda name, *arg, **kw: (name, type(*arg, **kw))
 
175
 
 
176
 
 
177
    def setupConfigNamespace(self):
 
178
        r = {}
 
179
        items = dns.__dict__.iterkeys()
 
180
        for record in [x for x in items if x.startswith('Record_')]:
 
181
            type = getattr(dns, record)
 
182
            f = self.wrapRecord(type)
 
183
            r[record[len('Record_'):]] = f
 
184
        return r
 
185
 
 
186
 
 
187
class BindAuthority(FileAuthority):
 
188
    """An Authority that loads BIND configuration files"""
 
189
 
 
190
    def loadFile(self, filename):
 
191
        self.origin = os.path.basename(filename) + '.' # XXX - this might suck
 
192
        lines = open(filename).readlines()
 
193
        lines = self.stripComments(lines)
 
194
        lines = self.collapseContinuations(lines)
 
195
        self.parseLines(lines)
 
196
 
 
197
 
 
198
    def stripComments(self, lines):
 
199
        return [
 
200
            a.find(';') == -1 and a or a[:a.find(';')] for a in [
 
201
                b.strip() for b in lines
 
202
            ]
 
203
        ]
 
204
 
 
205
 
 
206
    def collapseContinuations(self, lines):
 
207
        L = []
 
208
        state = 0
 
209
        for line in lines:
 
210
            if state == 0:
 
211
                if line.find('(') == -1:
 
212
                    L.append(line)
 
213
                else:
 
214
                    L.append(line[:line.find('(')])
 
215
                    state = 1
 
216
            else:
 
217
                if line.find(')') != -1:
 
218
                    L[-1] += ' ' + line[:line.find(')')]
 
219
                    state = 0
 
220
                else:
 
221
                    L[-1] += ' ' + line
 
222
        lines = L
 
223
        L = []
 
224
        for line in lines:
 
225
            L.append(line.split())
 
226
        return filter(None, L)
 
227
 
 
228
 
 
229
    def parseLines(self, lines):
 
230
        TTL = 60 * 60 * 3
 
231
        ORIGIN = self.origin
 
232
 
 
233
        self.records = {}
 
234
 
 
235
        for (line, index) in zip(lines, range(len(lines))):
 
236
            if line[0] == '$TTL':
 
237
                TTL = dns.str2time(line[1])
 
238
            elif line[0] == '$ORIGIN':
 
239
                ORIGIN = line[1]
 
240
            elif line[0] == '$INCLUDE': # XXX - oh, fuck me
 
241
                raise NotImplementedError('$INCLUDE directive not implemented')
 
242
            elif line[0] == '$GENERATE':
 
243
                raise NotImplementedError('$GENERATE directive not implemented')
 
244
            else:
 
245
                self.parseRecordLine(ORIGIN, TTL, line)
 
246
 
 
247
 
 
248
    def addRecord(self, owner, ttl, type, domain, cls, rdata):
 
249
        if not domain.endswith('.'):
 
250
            domain = domain + '.' + owner
 
251
        else:
 
252
            domain = domain[:-1]
 
253
        f = getattr(self, 'class_%s' % cls, None)
 
254
        if f:
 
255
            f(ttl, type, domain, rdata)
 
256
        else:
 
257
            raise NotImplementedError, "Record class %r not supported" % cls
 
258
 
 
259
 
 
260
    def class_IN(self, ttl, type, domain, rdata):
 
261
        record = getattr(dns, 'Record_%s' % type, None)
 
262
        if record:
 
263
            r = record(*rdata)
 
264
            r.ttl = ttl
 
265
            self.records.setdefault(domain.lower(), []).append(r)
 
266
 
 
267
            print 'Adding IN Record', domain, ttl, r
 
268
            if type == 'SOA':
 
269
                self.soa = (domain, r)
 
270
        else:
 
271
            raise NotImplementedError, "Record type %r not supported" % type
 
272
 
 
273
 
 
274
    #
 
275
    # This file ends here.  Read no further.
 
276
    #
 
277
    def parseRecordLine(self, origin, ttl, line):
 
278
        MARKERS = dns.QUERY_CLASSES.values() + dns.QUERY_TYPES.values()
 
279
        cls = 'IN'
 
280
        owner = origin
 
281
 
 
282
        if line[0] == '@':
 
283
            line = line[1:]
 
284
            owner = origin
 
285
#            print 'default owner'
 
286
        elif not line[0].isdigit() and line[0] not in MARKERS:
 
287
            owner = line[0]
 
288
            line = line[1:]
 
289
#            print 'owner is ', owner
 
290
 
 
291
        if line[0].isdigit() or line[0] in MARKERS:
 
292
            domain = owner
 
293
            owner = origin
 
294
#            print 'woops, owner is ', owner, ' domain is ', domain
 
295
        else:
 
296
            domain = line[0]
 
297
            line = line[1:]
 
298
#            print 'domain is ', domain
 
299
 
 
300
        if line[0] in dns.QUERY_CLASSES.values():
 
301
            cls = line[0]
 
302
            line = line[1:]
 
303
#            print 'cls is ', cls
 
304
            if line[0].isdigit():
 
305
                ttl = int(line[0])
 
306
                line = line[1:]
 
307
#                print 'ttl is ', ttl
 
308
        elif line[0].isdigit():
 
309
            ttl = int(line[0])
 
310
            line = line[1:]
 
311
#            print 'ttl is ', ttl
 
312
            if line[0] in dns.QUERY_CLASSES.values():
 
313
                cls = line[0]
 
314
                line = line[1:]
 
315
#                print 'cls is ', cls
 
316
 
 
317
        type = line[0]
 
318
#        print 'type is ', type
 
319
        rdata = line[1:]
 
320
#        print 'rdata is ', rdata
 
321
 
 
322
        self.addRecord(owner, ttl, type, domain, cls, rdata)