~ubuntu-branches/ubuntu/intrepid/miro/intrepid

« back to all changes in this revision

Viewing changes to portable/BitTorrent/RawServer.py

  • Committer: Bazaar Package Importer
  • Author(s): Christopher James Halse Rogers
  • Date: 2008-02-09 13:37:10 UTC
  • mfrom: (1.1.2 upstream)
  • Revision ID: james.westby@ubuntu.com-20080209133710-9rs90q6gckvp1b6i
Tags: 1.1.2-0ubuntu1
New upstream release

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Written by Bram Cohen
2
 
# see LICENSE.txt for license information
3
 
 
4
 
from bisect import insort
5
 
import socket
6
 
from cStringIO import StringIO
7
 
from traceback import print_exc
8
 
from errno import EWOULDBLOCK, ENOBUFS
9
 
try:
10
 
    from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
11
 
    timemult = 1000
12
 
except ImportError:
13
 
    from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
14
 
    timemult = 1
15
 
from threading import Thread, Event
16
 
from time import time, sleep
17
 
import sys
18
 
from random import randrange
19
 
 
20
 
all = POLLIN | POLLOUT
21
 
 
22
 
try:
23
 
    socketpair = socket.socketpair
24
 
except AttributeError:
25
 
    def socketpair():
26
 
        dummy_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27
 
        dummy_server.bind( ('127.0.0.1', 0) )
28
 
        dummy_server.listen(1)
29
 
        server_address = dummy_server.getsockname()
30
 
        first = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
31
 
        first.connect(server_address)
32
 
        second, address = dummy_server.accept()
33
 
        dummy_server.close()
34
 
        return first, second
35
 
 
36
 
class SingleSocket:
37
 
    def __init__(self, raw_server, sock, handler):
38
 
        self.raw_server = raw_server
39
 
        self.socket = sock
40
 
        self.handler = handler
41
 
        self.buffer = []
42
 
        self.last_hit = time()
43
 
        self.fileno = sock.fileno()
44
 
        self.connected = False
45
 
        
46
 
    def get_ip(self):
47
 
        try:
48
 
            return self.socket.getpeername()[0]
49
 
        except socket.error:
50
 
            return 'no connection'
51
 
        
52
 
    def close(self):
53
 
        sock = self.socket
54
 
        self.socket = None
55
 
        self.buffer = []
56
 
        del self.raw_server.single_sockets[self.fileno]
57
 
        self.raw_server.poll.unregister(sock)
58
 
        sock.close()
59
 
 
60
 
    def shutdown(self, val):
61
 
        self.socket.shutdown(val)
62
 
 
63
 
    def is_flushed(self):
64
 
        return len(self.buffer) == 0
65
 
 
66
 
    def write(self, s):
67
 
        assert self.socket is not None
68
 
        self.buffer.append(s)
69
 
        if len(self.buffer) == 1:
70
 
            self.try_write()
71
 
 
72
 
    def try_write(self):
73
 
        if self.connected:
74
 
            try:
75
 
                while self.buffer != []:
76
 
                    amount = self.socket.send(self.buffer[0])
77
 
                    if amount != len(self.buffer[0]):
78
 
                        if amount != 0:
79
 
                            self.buffer[0] = self.buffer[0][amount:]
80
 
                        break
81
 
                    del self.buffer[0]
82
 
            except socket.error, e:
83
 
                code, msg = e
84
 
                if code != EWOULDBLOCK:
85
 
                    self.raw_server.dead_from_write.append(self)
86
 
                    return
87
 
        if self.buffer == []:
88
 
            self.raw_server.poll.register(self.socket, POLLIN)
89
 
        else:
90
 
            self.raw_server.poll.register(self.socket, all)
91
 
 
92
 
def default_error_handler(x):
93
 
    print x
94
 
 
95
 
class RawServer:
96
 
    def __init__(self, doneflag, timeout_check_interval, timeout, noisy = True,
97
 
            errorfunc = default_error_handler, maxconnects = 55):
98
 
        self.timeout_check_interval = timeout_check_interval
99
 
        self.timeout = timeout
100
 
        self.poll = poll()
101
 
        # {socket: SingleSocket}
102
 
        self.single_sockets = {}
103
 
        self.dead_from_write = []
104
 
        self.doneflag = doneflag
105
 
        self.noisy = noisy
106
 
        self.errorfunc = errorfunc
107
 
        self.maxconnects = maxconnects
108
 
        self.funcs = []
109
 
        self.unscheduled_tasks = []
110
 
        self.add_task(self.scan_for_timeouts, timeout_check_interval)
111
 
        self.wakeup_receiver, self.wakeup_sender = socketpair()
112
 
        self.poll.register(self.wakeup_receiver, POLLIN)
113
 
 
114
 
    def wakeup(self):
115
 
        self.wakeup_sender.send("a")
116
 
 
117
 
    def add_task(self, func, delay):
118
 
        self.unscheduled_tasks.append((func, delay))
119
 
 
120
 
    def scan_for_timeouts(self):
121
 
        self.add_task(self.scan_for_timeouts, self.timeout_check_interval)
122
 
        t = time() - self.timeout
123
 
        tokill = []
124
 
        for s in self.single_sockets.values():
125
 
            if s.last_hit < t:
126
 
                tokill.append(s)
127
 
        for k in tokill:
128
 
            if k.socket is not None:
129
 
                self._close_socket(k)
130
 
 
131
 
    def bind(self, port, bind = '', reuse = False):
132
 
        self.bindaddr = bind
133
 
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
134
 
        if reuse:
135
 
            server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
136
 
        server.setblocking(0)
137
 
        try:
138
 
            server.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 32)
139
 
        except:
140
 
            pass
141
 
        server.bind((bind, port))
142
 
        server.listen(5)
143
 
        self.poll.register(server, POLLIN)
144
 
        self.server = server
145
 
 
146
 
    def start_connection(self, dns, handler = None):
147
 
        if handler is None:
148
 
            handler = self.handler
149
 
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150
 
        sock.setblocking(0)
151
 
        try:
152
 
            sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 32)
153
 
        except:
154
 
            pass
155
 
        sock.bind((self.bindaddr, 0))
156
 
        try:
157
 
            sock.connect_ex(dns)
158
 
        except socket.error:
159
 
            raise
160
 
        except Exception, e:
161
 
            raise socket.error(str(e))
162
 
        self.poll.register(sock, POLLIN)
163
 
        s = SingleSocket(self, sock, handler)
164
 
        self.single_sockets[sock.fileno()] = s
165
 
        return s
166
 
        
167
 
    def handle_events(self, events):
168
 
        for sock, event in events:
169
 
            if sock == self.wakeup_receiver.fileno():
170
 
                # nothing to do here, simply accepting the event woke us from
171
 
                # do_poll().  reading 1024 bytes should be enough to clear-
172
 
                # wakeup_receiver's buffer.
173
 
                self.wakeup_receiver.recv(1024)
174
 
                continue
175
 
            if sock == self.server.fileno():
176
 
                if event & (POLLHUP | POLLERR) != 0:
177
 
                    self.poll.unregister(self.server)
178
 
                    self.server.close()
179
 
                    self.errorfunc('lost server socket')
180
 
                else:
181
 
                    try:
182
 
                        newsock, addr = self.server.accept()
183
 
                        newsock.setblocking(0)
184
 
                        if len(self.single_sockets) >= self.maxconnects:
185
 
                            newsock.close()
186
 
                            continue
187
 
                        nss = SingleSocket(self, newsock, self.handler)
188
 
                        self.single_sockets[newsock.fileno()] = nss
189
 
                        self.poll.register(newsock, POLLIN)
190
 
                        self.handler.external_connection_made(nss)
191
 
                    except socket.error:
192
 
                        sleep(1)
193
 
            else:
194
 
                s = self.single_sockets.get(sock)
195
 
                if s is None:
196
 
                    continue
197
 
                s.connected = True
198
 
                if (event & (POLLHUP | POLLERR)) != 0:
199
 
                    self._close_socket(s)
200
 
                    continue
201
 
                if (event & POLLIN) != 0:
202
 
                    try:
203
 
                        s.last_hit = time()
204
 
                        data = s.socket.recv(100000)
205
 
                        if data == '':
206
 
                            self._close_socket(s)
207
 
                        else:
208
 
                            s.handler.data_came_in(s, data)
209
 
                    except socket.error, e:
210
 
                        code, msg = e
211
 
                        if code != EWOULDBLOCK:
212
 
                            self._close_socket(s)
213
 
                            continue
214
 
                if (event & POLLOUT) != 0 and s.socket is not None and not s.is_flushed():
215
 
                    s.try_write()
216
 
                    if s.is_flushed():
217
 
                        s.handler.connection_flushed(s)
218
 
 
219
 
    def pop_unscheduled(self):
220
 
        try:
221
 
            while True:
222
 
                (func, delay) = self.unscheduled_tasks.pop()
223
 
                insort(self.funcs, (time() + delay, func))
224
 
        except IndexError:
225
 
            pass
226
 
 
227
 
    def listen_forever(self, handler):
228
 
        self.handler = handler
229
 
        try:
230
 
            while not self.doneflag.isSet():
231
 
                try:
232
 
                    self.pop_unscheduled()
233
 
                    if len(self.funcs) == 0:
234
 
                        period = 2 ** 30
235
 
                    else:
236
 
                        period = self.funcs[0][0] - time()
237
 
                    if period < 0:
238
 
                        period = 0
239
 
                    events = self.poll.poll(period * timemult)
240
 
                    if self.doneflag.isSet():
241
 
                        return
242
 
                    while len(self.funcs) > 0 and self.funcs[0][0] <= time():
243
 
                        garbage, func = self.funcs[0]
244
 
                        del self.funcs[0]
245
 
                        try:
246
 
                            func()
247
 
                        except KeyboardInterrupt:
248
 
                            print_exc()
249
 
                            return
250
 
                        except:
251
 
                            if self.noisy:
252
 
                                data = StringIO()
253
 
                                print_exc(file = data)
254
 
                                self.errorfunc(data.getvalue())
255
 
                    self._close_dead()
256
 
                    self.handle_events(events)
257
 
                    if self.doneflag.isSet():
258
 
                        return
259
 
                    self._close_dead()
260
 
                except error, e:
261
 
                    if self.doneflag.isSet():
262
 
                        return
263
 
                    # I can't find a coherent explanation for what the behavior should be here,
264
 
                    # and people report conflicting behavior, so I'll just try all the possibilities
265
 
                    try:
266
 
                        code, msg, desc = e
267
 
                    except:
268
 
                        try:
269
 
                            code, msg = e
270
 
                        except:
271
 
                            code = ENOBUFS
272
 
                    if code == ENOBUFS:
273
 
                        self.errorfunc("Have to exit due to the TCP stack flaking out")
274
 
                        return
275
 
                except KeyboardInterrupt:
276
 
                    print_exc()
277
 
                    return
278
 
                except:
279
 
                    data = StringIO()
280
 
                    print_exc(file = data)
281
 
                    self.errorfunc(data.getvalue())
282
 
        finally:
283
 
            for ss in self.single_sockets.values():
284
 
                ss.close()
285
 
            self.server.close()
286
 
 
287
 
    def _close_dead(self):
288
 
        while len(self.dead_from_write) > 0:
289
 
            old = self.dead_from_write
290
 
            self.dead_from_write = []
291
 
            for s in old:
292
 
                if s.socket is not None:
293
 
                    self._close_socket(s)
294
 
 
295
 
    def _close_socket(self, s):
296
 
        sock = s.socket.fileno()
297
 
        s.socket.close()
298
 
        self.poll.unregister(sock)
299
 
        del self.single_sockets[sock]
300
 
        s.socket = None
301
 
        s.handler.connection_lost(s)
302
 
 
303
 
# everything below is for testing
304
 
 
305
 
class DummyHandler:
306
 
    def __init__(self):
307
 
        self.external_made = []
308
 
        self.data_in = []
309
 
        self.lost = []
310
 
 
311
 
    def external_connection_made(self, s):
312
 
        self.external_made.append(s)
313
 
    
314
 
    def data_came_in(self, s, data):
315
 
        self.data_in.append((s, data))
316
 
    
317
 
    def connection_lost(self, s):
318
 
        self.lost.append(s)
319
 
 
320
 
    def connection_flushed(self, s):
321
 
        pass
322
 
 
323
 
def sl(rs, handler, port):
324
 
    rs.bind(port)
325
 
    Thread(target = rs.listen_forever, args = [handler]).start()
326
 
 
327
 
def loop(rs):
328
 
    x = []
329
 
    def r(rs = rs, x = x):
330
 
        rs.add_task(x[0], .1)
331
 
    x.append(r)
332
 
    rs.add_task(r, .1)
333
 
 
334
 
beginport = 5000 + randrange(10000)
335
 
 
336
 
def test_starting_side_close():
337
 
    try:
338
 
        fa = Event()
339
 
        fb = Event()
340
 
        da = DummyHandler()
341
 
        sa = RawServer(fa, 100, 100)
342
 
        loop(sa)
343
 
        sl(sa, da, beginport)
344
 
        db = DummyHandler()
345
 
        sb = RawServer(fb, 100, 100)
346
 
        loop(sb)
347
 
        sl(sb, db, beginport + 1)
348
 
 
349
 
        sleep(.5)
350
 
        ca = sa.start_connection(('127.0.0.1', beginport + 1))
351
 
        sleep(1)
352
 
        
353
 
        assert da.external_made == []
354
 
        assert da.data_in == []
355
 
        assert da.lost == []
356
 
        assert len(db.external_made) == 1
357
 
        cb = db.external_made[0]
358
 
        del db.external_made[:]
359
 
        assert db.data_in == []
360
 
        assert db.lost == []
361
 
 
362
 
        ca.write('aaa')
363
 
        cb.write('bbb')
364
 
        sleep(1)
365
 
        
366
 
        assert da.external_made == []
367
 
        assert da.data_in == [(ca, 'bbb')]
368
 
        del da.data_in[:]
369
 
        assert da.lost == []
370
 
        assert db.external_made == []
371
 
        assert db.data_in == [(cb, 'aaa')]
372
 
        del db.data_in[:]
373
 
        assert db.lost == []
374
 
 
375
 
        ca.write('ccc')
376
 
        cb.write('ddd')
377
 
        sleep(1)
378
 
        
379
 
        assert da.external_made == []
380
 
        assert da.data_in == [(ca, 'ddd')]
381
 
        del da.data_in[:]
382
 
        assert da.lost == []
383
 
        assert db.external_made == []
384
 
        assert db.data_in == [(cb, 'ccc')]
385
 
        del db.data_in[:]
386
 
        assert db.lost == []
387
 
 
388
 
        ca.close()
389
 
        sleep(1)
390
 
 
391
 
        assert da.external_made == []
392
 
        assert da.data_in == []
393
 
        assert da.lost == []
394
 
        assert db.external_made == []
395
 
        assert db.data_in == []
396
 
        assert db.lost == [cb]
397
 
        del db.lost[:]
398
 
    finally:
399
 
        fa.set()
400
 
        fb.set()
401
 
 
402
 
def test_receiving_side_close():
403
 
    try:
404
 
        da = DummyHandler()
405
 
        fa = Event()
406
 
        sa = RawServer(fa, 100, 100)
407
 
        loop(sa)
408
 
        sl(sa, da, beginport + 2)
409
 
        db = DummyHandler()
410
 
        fb = Event()
411
 
        sb = RawServer(fb, 100, 100)
412
 
        loop(sb)
413
 
        sl(sb, db, beginport + 3)
414
 
        
415
 
        sleep(.5)
416
 
        ca = sa.start_connection(('127.0.0.1', beginport + 3))
417
 
        sleep(1)
418
 
        
419
 
        assert da.external_made == []
420
 
        assert da.data_in == []
421
 
        assert da.lost == []
422
 
        assert len(db.external_made) == 1
423
 
        cb = db.external_made[0]
424
 
        del db.external_made[:]
425
 
        assert db.data_in == []
426
 
        assert db.lost == []
427
 
 
428
 
        ca.write('aaa')
429
 
        cb.write('bbb')
430
 
        sleep(1)
431
 
        
432
 
        assert da.external_made == []
433
 
        assert da.data_in == [(ca, 'bbb')]
434
 
        del da.data_in[:]
435
 
        assert da.lost == []
436
 
        assert db.external_made == []
437
 
        assert db.data_in == [(cb, 'aaa')]
438
 
        del db.data_in[:]
439
 
        assert db.lost == []
440
 
 
441
 
        ca.write('ccc')
442
 
        cb.write('ddd')
443
 
        sleep(1)
444
 
        
445
 
        assert da.external_made == []
446
 
        assert da.data_in == [(ca, 'ddd')]
447
 
        del da.data_in[:]
448
 
        assert da.lost == []
449
 
        assert db.external_made == []
450
 
        assert db.data_in == [(cb, 'ccc')]
451
 
        del db.data_in[:]
452
 
        assert db.lost == []
453
 
 
454
 
        cb.close()
455
 
        sleep(1)
456
 
 
457
 
        assert da.external_made == []
458
 
        assert da.data_in == []
459
 
        assert da.lost == [ca]
460
 
        del da.lost[:]
461
 
        assert db.external_made == []
462
 
        assert db.data_in == []
463
 
        assert db.lost == []
464
 
    finally:
465
 
        fa.set()
466
 
        fb.set()
467
 
 
468
 
def test_connection_refused():
469
 
    try:
470
 
        da = DummyHandler()
471
 
        fa = Event()
472
 
        sa = RawServer(fa, 100, 100)
473
 
        loop(sa)
474
 
        sl(sa, da, beginport + 6)
475
 
 
476
 
        sleep(.5)
477
 
        ca = sa.start_connection(('127.0.0.1', beginport + 15))
478
 
        sleep(1)
479
 
        
480
 
        assert da.external_made == []
481
 
        assert da.data_in == []
482
 
        assert da.lost == [ca]
483
 
        del da.lost[:]
484
 
    finally:
485
 
        fa.set()
486
 
 
487
 
def test_both_close():
488
 
    try:
489
 
        da = DummyHandler()
490
 
        fa = Event()
491
 
        sa = RawServer(fa, 100, 100)
492
 
        loop(sa)
493
 
        sl(sa, da, beginport + 4)
494
 
 
495
 
        sleep(1)
496
 
        db = DummyHandler()
497
 
        fb = Event()
498
 
        sb = RawServer(fb, 100, 100)
499
 
        loop(sb)
500
 
        sl(sb, db, beginport + 5)
501
 
 
502
 
        sleep(.5)
503
 
        ca = sa.start_connection(('127.0.0.1', beginport + 5))
504
 
        sleep(1)
505
 
        
506
 
        assert da.external_made == []
507
 
        assert da.data_in == []
508
 
        assert da.lost == []
509
 
        assert len(db.external_made) == 1
510
 
        cb = db.external_made[0]
511
 
        del db.external_made[:]
512
 
        assert db.data_in == []
513
 
        assert db.lost == []
514
 
 
515
 
        ca.write('aaa')
516
 
        cb.write('bbb')
517
 
        sleep(1)
518
 
        
519
 
        assert da.external_made == []
520
 
        assert da.data_in == [(ca, 'bbb')]
521
 
        del da.data_in[:]
522
 
        assert da.lost == []
523
 
        assert db.external_made == []
524
 
        assert db.data_in == [(cb, 'aaa')]
525
 
        del db.data_in[:]
526
 
        assert db.lost == []
527
 
 
528
 
        ca.write('ccc')
529
 
        cb.write('ddd')
530
 
        sleep(1)
531
 
        
532
 
        assert da.external_made == []
533
 
        assert da.data_in == [(ca, 'ddd')]
534
 
        del da.data_in[:]
535
 
        assert da.lost == []
536
 
        assert db.external_made == []
537
 
        assert db.data_in == [(cb, 'ccc')]
538
 
        del db.data_in[:]
539
 
        assert db.lost == []
540
 
 
541
 
        ca.close()
542
 
        cb.close()
543
 
        sleep(1)
544
 
 
545
 
        assert da.external_made == []
546
 
        assert da.data_in == []
547
 
        assert da.lost == []
548
 
        assert db.external_made == []
549
 
        assert db.data_in == []
550
 
        assert db.lost == []
551
 
    finally:
552
 
        fa.set()
553
 
        fb.set()
554
 
 
555
 
def test_normal():
556
 
    l = []
557
 
    f = Event()
558
 
    s = RawServer(f, 100, 100)
559
 
    loop(s)
560
 
    sl(s, DummyHandler(), beginport + 7)
561
 
    s.add_task(lambda l = l: l.append('b'), 2)
562
 
    s.add_task(lambda l = l: l.append('a'), 1)
563
 
    s.add_task(lambda l = l: l.append('d'), 4)
564
 
    sleep(1.5)
565
 
    s.add_task(lambda l = l: l.append('c'), 1.5)
566
 
    sleep(3)
567
 
    assert l == ['a', 'b', 'c', 'd']
568
 
    f.set()
569
 
 
570
 
def test_catch_exception():
571
 
    l = []
572
 
    f = Event()
573
 
    s = RawServer(f, 100, 100, False)
574
 
    loop(s)
575
 
    sl(s, DummyHandler(), beginport + 9)
576
 
    s.add_task(lambda l = l: l.append('b'), 2)
577
 
    s.add_task(lambda: 4/0, 1)
578
 
    sleep(3)
579
 
    assert l == ['b']
580
 
    f.set()
581
 
 
582
 
def test_closes_if_not_hit():
583
 
    try:
584
 
        da = DummyHandler()
585
 
        fa = Event()
586
 
        sa = RawServer(fa, 2, 2)
587
 
        loop(sa)
588
 
        sl(sa, da, beginport + 14)
589
 
 
590
 
        sleep(1)
591
 
        db = DummyHandler()
592
 
        fb = Event()
593
 
        sb = RawServer(fb, 100, 100)
594
 
        loop(sb)
595
 
        sl(sb, db, beginport + 13)
596
 
        
597
 
        sleep(.5)
598
 
        sa.start_connection(('127.0.0.1', beginport + 13))
599
 
        sleep(1)
600
 
        
601
 
        assert da.external_made == []
602
 
        assert da.data_in == []
603
 
        assert da.lost == []
604
 
        assert len(db.external_made) == 1
605
 
        del db.external_made[:]
606
 
        assert db.data_in == []
607
 
        assert db.lost == []
608
 
 
609
 
        sleep(3.1)
610
 
        
611
 
        assert len(da.lost) == 1
612
 
        assert len(db.lost) == 1
613
 
    finally:
614
 
        fa.set()
615
 
        fb.set()
616
 
 
617
 
def test_does_not_close_if_hit():
618
 
    try:
619
 
        fa = Event()
620
 
        fb = Event()
621
 
        da = DummyHandler()
622
 
        sa = RawServer(fa, 2, 2)
623
 
        loop(sa)
624
 
        sl(sa, da, beginport + 12)
625
 
 
626
 
        sleep(1)
627
 
        db = DummyHandler()
628
 
        sb = RawServer(fb, 100, 100)
629
 
        loop(sb)
630
 
        sl(sb, db, beginport + 13)
631
 
        
632
 
        sleep(.5)
633
 
        sa.start_connection(('127.0.0.1', beginport + 13))
634
 
        sleep(1)
635
 
        
636
 
        assert da.external_made == []
637
 
        assert da.data_in == []
638
 
        assert da.lost == []
639
 
        assert len(db.external_made) == 1
640
 
        cb = db.external_made[0]
641
 
        del db.external_made[:]
642
 
        assert db.data_in == []
643
 
        assert db.lost == []
644
 
 
645
 
        cb.write('bbb')
646
 
        sleep(.5)
647
 
        
648
 
        assert da.lost == []
649
 
        assert db.lost == []
650
 
    finally:
651
 
        fa.set()
652
 
        fb.set()