1
# Written by Bram Cohen
2
# see LICENSE.txt for license information
4
from bisect import insort
6
from cStringIO import StringIO
7
from traceback import print_exc
8
from errno import EWOULDBLOCK, ENOBUFS
10
from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
13
from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
15
from threading import Thread, Event
16
from time import time, sleep
18
from random import randrange
20
all = POLLIN | POLLOUT
23
socketpair = socket.socketpair
24
except AttributeError:
26
dummy_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27
dummy_server.bind( ('127.0.0.1', 0) )
28
dummy_server.listen(1)
29
server_address = dummy_server.getsockname()
30
first = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
31
first.connect(server_address)
32
second, address = dummy_server.accept()
37
def __init__(self, raw_server, sock, handler):
38
self.raw_server = raw_server
40
self.handler = handler
42
self.last_hit = time()
43
self.fileno = sock.fileno()
44
self.connected = False
48
return self.socket.getpeername()[0]
50
return 'no connection'
56
del self.raw_server.single_sockets[self.fileno]
57
self.raw_server.poll.unregister(sock)
60
def shutdown(self, val):
61
self.socket.shutdown(val)
64
return len(self.buffer) == 0
67
assert self.socket is not None
69
if len(self.buffer) == 1:
75
while self.buffer != []:
76
amount = self.socket.send(self.buffer[0])
77
if amount != len(self.buffer[0]):
79
self.buffer[0] = self.buffer[0][amount:]
82
except socket.error, e:
84
if code != EWOULDBLOCK:
85
self.raw_server.dead_from_write.append(self)
88
self.raw_server.poll.register(self.socket, POLLIN)
90
self.raw_server.poll.register(self.socket, all)
92
def default_error_handler(x):
96
def __init__(self, doneflag, timeout_check_interval, timeout, noisy = True,
97
errorfunc = default_error_handler, maxconnects = 55):
98
self.timeout_check_interval = timeout_check_interval
99
self.timeout = timeout
101
# {socket: SingleSocket}
102
self.single_sockets = {}
103
self.dead_from_write = []
104
self.doneflag = doneflag
106
self.errorfunc = errorfunc
107
self.maxconnects = maxconnects
109
self.unscheduled_tasks = []
110
self.add_task(self.scan_for_timeouts, timeout_check_interval)
111
self.wakeup_receiver, self.wakeup_sender = socketpair()
112
self.poll.register(self.wakeup_receiver, POLLIN)
115
self.wakeup_sender.send("a")
117
def add_task(self, func, delay):
118
self.unscheduled_tasks.append((func, delay))
120
def scan_for_timeouts(self):
121
self.add_task(self.scan_for_timeouts, self.timeout_check_interval)
122
t = time() - self.timeout
124
for s in self.single_sockets.values():
128
if k.socket is not None:
129
self._close_socket(k)
131
def bind(self, port, bind = '', reuse = False):
133
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
135
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
136
server.setblocking(0)
138
server.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 32)
141
server.bind((bind, port))
143
self.poll.register(server, POLLIN)
146
def start_connection(self, dns, handler = None):
148
handler = self.handler
149
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
152
sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 32)
155
sock.bind((self.bindaddr, 0))
161
raise socket.error(str(e))
162
self.poll.register(sock, POLLIN)
163
s = SingleSocket(self, sock, handler)
164
self.single_sockets[sock.fileno()] = s
167
def handle_events(self, events):
168
for sock, event in events:
169
if sock == self.wakeup_receiver.fileno():
170
# nothing to do here, simply accepting the event woke us from
171
# do_poll(). reading 1024 bytes should be enough to clear-
172
# wakeup_receiver's buffer.
173
self.wakeup_receiver.recv(1024)
175
if sock == self.server.fileno():
176
if event & (POLLHUP | POLLERR) != 0:
177
self.poll.unregister(self.server)
179
self.errorfunc('lost server socket')
182
newsock, addr = self.server.accept()
183
newsock.setblocking(0)
184
if len(self.single_sockets) >= self.maxconnects:
187
nss = SingleSocket(self, newsock, self.handler)
188
self.single_sockets[newsock.fileno()] = nss
189
self.poll.register(newsock, POLLIN)
190
self.handler.external_connection_made(nss)
194
s = self.single_sockets.get(sock)
198
if (event & (POLLHUP | POLLERR)) != 0:
199
self._close_socket(s)
201
if (event & POLLIN) != 0:
204
data = s.socket.recv(100000)
206
self._close_socket(s)
208
s.handler.data_came_in(s, data)
209
except socket.error, e:
211
if code != EWOULDBLOCK:
212
self._close_socket(s)
214
if (event & POLLOUT) != 0 and s.socket is not None and not s.is_flushed():
217
s.handler.connection_flushed(s)
219
def pop_unscheduled(self):
222
(func, delay) = self.unscheduled_tasks.pop()
223
insort(self.funcs, (time() + delay, func))
227
def listen_forever(self, handler):
228
self.handler = handler
230
while not self.doneflag.isSet():
232
self.pop_unscheduled()
233
if len(self.funcs) == 0:
236
period = self.funcs[0][0] - time()
239
events = self.poll.poll(period * timemult)
240
if self.doneflag.isSet():
242
while len(self.funcs) > 0 and self.funcs[0][0] <= time():
243
garbage, func = self.funcs[0]
247
except KeyboardInterrupt:
253
print_exc(file = data)
254
self.errorfunc(data.getvalue())
256
self.handle_events(events)
257
if self.doneflag.isSet():
261
if self.doneflag.isSet():
263
# I can't find a coherent explanation for what the behavior should be here,
264
# and people report conflicting behavior, so I'll just try all the possibilities
273
self.errorfunc("Have to exit due to the TCP stack flaking out")
275
except KeyboardInterrupt:
280
print_exc(file = data)
281
self.errorfunc(data.getvalue())
283
for ss in self.single_sockets.values():
287
def _close_dead(self):
288
while len(self.dead_from_write) > 0:
289
old = self.dead_from_write
290
self.dead_from_write = []
292
if s.socket is not None:
293
self._close_socket(s)
295
def _close_socket(self, s):
296
sock = s.socket.fileno()
298
self.poll.unregister(sock)
299
del self.single_sockets[sock]
301
s.handler.connection_lost(s)
303
# everything below is for testing
307
self.external_made = []
311
def external_connection_made(self, s):
312
self.external_made.append(s)
314
def data_came_in(self, s, data):
315
self.data_in.append((s, data))
317
def connection_lost(self, s):
320
def connection_flushed(self, s):
323
def sl(rs, handler, port):
325
Thread(target = rs.listen_forever, args = [handler]).start()
329
def r(rs = rs, x = x):
330
rs.add_task(x[0], .1)
334
beginport = 5000 + randrange(10000)
336
def test_starting_side_close():
341
sa = RawServer(fa, 100, 100)
343
sl(sa, da, beginport)
345
sb = RawServer(fb, 100, 100)
347
sl(sb, db, beginport + 1)
350
ca = sa.start_connection(('127.0.0.1', beginport + 1))
353
assert da.external_made == []
354
assert da.data_in == []
356
assert len(db.external_made) == 1
357
cb = db.external_made[0]
358
del db.external_made[:]
359
assert db.data_in == []
366
assert da.external_made == []
367
assert da.data_in == [(ca, 'bbb')]
370
assert db.external_made == []
371
assert db.data_in == [(cb, 'aaa')]
379
assert da.external_made == []
380
assert da.data_in == [(ca, 'ddd')]
383
assert db.external_made == []
384
assert db.data_in == [(cb, 'ccc')]
391
assert da.external_made == []
392
assert da.data_in == []
394
assert db.external_made == []
395
assert db.data_in == []
396
assert db.lost == [cb]
402
def test_receiving_side_close():
406
sa = RawServer(fa, 100, 100)
408
sl(sa, da, beginport + 2)
411
sb = RawServer(fb, 100, 100)
413
sl(sb, db, beginport + 3)
416
ca = sa.start_connection(('127.0.0.1', beginport + 3))
419
assert da.external_made == []
420
assert da.data_in == []
422
assert len(db.external_made) == 1
423
cb = db.external_made[0]
424
del db.external_made[:]
425
assert db.data_in == []
432
assert da.external_made == []
433
assert da.data_in == [(ca, 'bbb')]
436
assert db.external_made == []
437
assert db.data_in == [(cb, 'aaa')]
445
assert da.external_made == []
446
assert da.data_in == [(ca, 'ddd')]
449
assert db.external_made == []
450
assert db.data_in == [(cb, 'ccc')]
457
assert da.external_made == []
458
assert da.data_in == []
459
assert da.lost == [ca]
461
assert db.external_made == []
462
assert db.data_in == []
468
def test_connection_refused():
472
sa = RawServer(fa, 100, 100)
474
sl(sa, da, beginport + 6)
477
ca = sa.start_connection(('127.0.0.1', beginport + 15))
480
assert da.external_made == []
481
assert da.data_in == []
482
assert da.lost == [ca]
487
def test_both_close():
491
sa = RawServer(fa, 100, 100)
493
sl(sa, da, beginport + 4)
498
sb = RawServer(fb, 100, 100)
500
sl(sb, db, beginport + 5)
503
ca = sa.start_connection(('127.0.0.1', beginport + 5))
506
assert da.external_made == []
507
assert da.data_in == []
509
assert len(db.external_made) == 1
510
cb = db.external_made[0]
511
del db.external_made[:]
512
assert db.data_in == []
519
assert da.external_made == []
520
assert da.data_in == [(ca, 'bbb')]
523
assert db.external_made == []
524
assert db.data_in == [(cb, 'aaa')]
532
assert da.external_made == []
533
assert da.data_in == [(ca, 'ddd')]
536
assert db.external_made == []
537
assert db.data_in == [(cb, 'ccc')]
545
assert da.external_made == []
546
assert da.data_in == []
548
assert db.external_made == []
549
assert db.data_in == []
558
s = RawServer(f, 100, 100)
560
sl(s, DummyHandler(), beginport + 7)
561
s.add_task(lambda l = l: l.append('b'), 2)
562
s.add_task(lambda l = l: l.append('a'), 1)
563
s.add_task(lambda l = l: l.append('d'), 4)
565
s.add_task(lambda l = l: l.append('c'), 1.5)
567
assert l == ['a', 'b', 'c', 'd']
570
def test_catch_exception():
573
s = RawServer(f, 100, 100, False)
575
sl(s, DummyHandler(), beginport + 9)
576
s.add_task(lambda l = l: l.append('b'), 2)
577
s.add_task(lambda: 4/0, 1)
582
def test_closes_if_not_hit():
586
sa = RawServer(fa, 2, 2)
588
sl(sa, da, beginport + 14)
593
sb = RawServer(fb, 100, 100)
595
sl(sb, db, beginport + 13)
598
sa.start_connection(('127.0.0.1', beginport + 13))
601
assert da.external_made == []
602
assert da.data_in == []
604
assert len(db.external_made) == 1
605
del db.external_made[:]
606
assert db.data_in == []
611
assert len(da.lost) == 1
612
assert len(db.lost) == 1
617
def test_does_not_close_if_hit():
622
sa = RawServer(fa, 2, 2)
624
sl(sa, da, beginport + 12)
628
sb = RawServer(fb, 100, 100)
630
sl(sb, db, beginport + 13)
633
sa.start_connection(('127.0.0.1', beginport + 13))
636
assert da.external_made == []
637
assert da.data_in == []
639
assert len(db.external_made) == 1
640
cb = db.external_made[0]
641
del db.external_made[:]
642
assert db.data_in == []