1
# Written by Bram Cohen
2
# see LICENSE.txt for license information
4
from cStringIO import StringIO
5
from binascii import b2a_hex
6
from socket import error as socketerror
7
from urllib import quote
8
from traceback import print_exc
9
from BitTornado.BTcrypto import Crypto
16
bool = lambda x: not not x
22
protocol_name = 'BitTorrent protocol'
23
option_pattern = chr(0)*8
26
return long(b2a_hex(s), 16)
29
return chr((i >> 8) & 0xFF) + chr(i & 0xFF)
34
if quote(s).find('%') >= 0:
35
return b2a_hex(s).upper()
39
class IncompleteCounter:
47
return self.c >= MAX_INCOMPLETE
49
incompletecounter = IncompleteCounter()
52
# header, options, download id, my id, [length, message]
55
def __init__(self, Encoder, connection, id,
56
ext_handshake=False, encrypted = None, options = None):
57
self.Encoder = Encoder
58
self.connection = connection
59
self.connecter = Encoder.connecter
61
self.locally_initiated = (id != None)
62
self.readable_id = make_readable(id)
64
self.keepalive = lambda: None
69
self.read = self._read
70
self.write = self._write
73
if self.locally_initiated:
74
incompletecounter.increment()
77
self.encrypter = Crypto(True)
78
self.write(self.encrypter.pubkey+self.encrypter.padding())
80
self.encrypted = False
81
self.write(chr(len(protocol_name)) + protocol_name +
82
option_pattern + self.Encoder.download_id )
83
self.next_len, self.next_func = 1+len(protocol_name), self.read_header
85
self.Encoder.connecter.external_connection_made += 1
86
if encrypted: # passed an already running encrypter
87
self.encrypter = encrypted
90
self.next_len, self.next_func = 14, self.read_crypto_block3c
92
self.encrypted = False
93
self.options = options
94
self.write(self.Encoder.my_id)
95
self.next_len, self.next_func = 20, self.read_peer_id
97
self.encrypted = None # don't know yet
98
self.next_len, self.next_func = 1+len(protocol_name), self.read_header
99
self.Encoder.raw_server.add_task(self._auto_close, 30)
102
def _log_start(self): # only called with DEBUG = True
103
self.log = open('peerlog.'+self.get_ip()+'.txt','a')
104
self.log.write('connected - ')
105
if self.locally_initiated:
106
self.log.write('outgoing\n')
108
self.log.write('incoming\n')
109
self._logwritefunc = self.write
110
self.write = self._log_write
112
def _log_write(self, s):
113
self.log.write('w:'+b2a_hex(s)+'\n')
114
self._logwritefunc(s)
117
def get_ip(self, real=False):
118
return self.connection.get_ip(real)
123
def get_readable_id(self):
124
return self.readable_id
126
def is_locally_initiated(self):
127
return self.locally_initiated
129
def is_encrypted(self):
130
return bool(self.encrypted)
132
def is_flushed(self):
133
return self.connection.is_flushed()
135
def _read_header(self, s):
136
if s == chr(len(protocol_name))+protocol_name:
137
return 8, self.read_options
140
def read_header(self, s):
141
if self._read_header(s):
142
if self.encrypted or self.Encoder.config['crypto_stealth']:
144
return 8, self.read_options
145
if self.locally_initiated and not self.encrypted:
147
elif not self.Encoder.config['crypto_allowed']:
149
if not self.encrypted:
150
self.encrypted = True
151
self.encrypter = Crypto(self.locally_initiated)
152
self._write_buffer(s)
153
return self.encrypter.keylength, self.read_crypto_header
155
################## ENCRYPTION SUPPORT ######################
157
def _start_crypto(self):
158
self.encrypter.setrawaccess(self._read,self._write)
159
self.write = self.encrypter.write
160
self.read = self.encrypter.read
162
self.buffer = self.encrypter.decrypt(self.buffer)
164
def _end_crypto(self):
165
self.read = self._read
166
self.write = self._write
167
self.encrypter = None
169
def read_crypto_header(self, s):
170
self.encrypter.received_key(s)
171
self.encrypter.set_skey(self.Encoder.download_id)
172
if self.locally_initiated:
173
if self.Encoder.config['crypto_only']:
174
cryptmode = '\x00\x00\x00\x02' # full stream encryption
176
cryptmode = '\x00\x00\x00\x03' # header or full stream
177
padc = self.encrypter.padding()
178
self.write( self.encrypter.block3a
179
+ self.encrypter.block3b
180
+ self.encrypter.encrypt(
182
+ cryptmode # acceptable crypto modes
183
+ tobinary16(len(padc))
185
+ '\x00\x00' ) ) # no initial payload data
186
self._max_search = 520
187
return 1, self.read_crypto_block4a
188
self.write(self.encrypter.pubkey+self.encrypter.padding())
189
self._max_search = 520
190
return 0, self.read_crypto_block3a
192
def _search_for_pattern(self, s, pat):
195
if len(s) >= len(pat):
196
self._max_search -= len(s)+1-len(pat)
197
if self._max_search < 0:
200
self._write_buffer(s[1-len(pat):])
202
self._write_buffer(s[p+len(pat):])
205
### INCOMING CONNECTION ###
207
def read_crypto_block3a(self, s):
208
if not self._search_for_pattern(s,self.encrypter.block3a):
209
return -1, self.read_crypto_block3a # wait for more data
210
return len(self.encrypter.block3b), self.read_crypto_block3b
212
def read_crypto_block3b(self, s):
213
if s != self.encrypter.block3b:
215
self.Encoder.connecter.external_connection_made += 1
217
return 14, self.read_crypto_block3c
219
def read_crypto_block3c(self, s):
220
if s[:8] != ('\x00'*8): # check VC
222
self.cryptmode = toint(s[8:12]) % 4
223
if self.cryptmode == 0:
224
return None # no encryption selected
225
if ( self.cryptmode == 1 # only header encryption
226
and self.Encoder.config['crypto_only'] ):
228
padlen = (ord(s[12])<<8)+ord(s[13])
231
return padlen+2, self.read_crypto_pad3
233
def read_crypto_pad3(self, s):
235
ialen = (ord(s[0])<<8)+ord(s[1])
238
if self.cryptmode == 1:
239
cryptmode = '\x00\x00\x00\x01' # header only encryption
241
cryptmode = '\x00\x00\x00\x02' # full stream encryption
242
padd = self.encrypter.padding()
243
self.write( ('\x00'*8) # VC
244
+ cryptmode # encryption mode
245
+ tobinary16(len(padd))
248
return ialen, self.read_crypto_ia
249
return self.read_crypto_block3done()
251
def read_crypto_ia(self, s):
254
self.log.write('r:'+b2a_hex(s)+'(ia)\n')
256
self.log.write('r:'+b2a_hex(self.buffer)+'(buffer)\n')
257
return self.read_crypto_block3done(s)
259
def read_crypto_block3done(self, ia=''):
263
if self.cryptmode == 1: # only handshake encryption
264
assert not self.buffer # oops; check for exceptions to this
267
self._write_buffer(ia)
268
return 1+len(protocol_name), self.read_encrypted_header
270
### OUTGOING CONNECTION ###
272
def read_crypto_block4a(self, s):
273
if not self._search_for_pattern(s,self.encrypter.VC_pattern()):
274
return -1, self.read_crypto_block4a # wait for more data
276
return 6, self.read_crypto_block4b
278
def read_crypto_block4b(self, s):
279
self.cryptmode = toint(s[:4]) % 4
280
if self.cryptmode == 1: # only header encryption
281
if self.Encoder.config['crypto_only']:
283
elif self.cryptmode != 2:
284
return None # unknown encryption
285
padlen = (ord(s[4])<<8)+ord(s[5])
289
return padlen, self.read_crypto_pad4
290
return self.read_crypto_block4done()
292
def read_crypto_pad4(self, s):
294
return self.read_crypto_block4done()
296
def read_crypto_block4done(self):
299
if self.cryptmode == 1: # only handshake encryption
300
if not self.buffer: # oops; check for exceptions to this
303
self.write(chr(len(protocol_name)) + protocol_name +
304
option_pattern + self.Encoder.download_id)
305
return 1+len(protocol_name), self.read_encrypted_header
307
### START PROTOCOL OVER ENCRYPTED CONNECTION ###
309
def read_encrypted_header(self, s):
310
return self._read_header(s)
312
################################################
314
def read_options(self, s):
316
return 20, self.read_download_id
318
def read_download_id(self, s):
319
if ( s != self.Encoder.download_id
320
or not self.Encoder.check_ip(ip=self.get_ip()) ):
322
if not self.locally_initiated:
323
if not self.encrypted:
324
self.Encoder.connecter.external_connection_made += 1
325
self.write(chr(len(protocol_name)) + protocol_name +
326
option_pattern + self.Encoder.download_id + self.Encoder.my_id)
327
return 20, self.read_peer_id
329
def read_peer_id(self, s):
330
if not self.encrypted and self.Encoder.config['crypto_only']:
331
return None # allows older trackers to ping,
332
# but won't proceed w/ connections
335
self.readable_id = make_readable(s)
339
self.complete = self.Encoder.got_id(self)
340
if not self.complete:
342
if self.locally_initiated:
343
self.write(self.Encoder.my_id)
344
incompletecounter.decrement()
345
self._switch_to_read2()
346
c = self.Encoder.connecter.connection_made(self)
347
self.keepalive = c.send_keepalive
348
return 4, self.read_len
350
def read_len(self, s):
352
if l > self.Encoder.max_len:
354
return l, self.read_message
356
def read_message(self, s):
358
self.connecter.got_message(self, s)
359
return 4, self.read_len
361
def read_dead(self, s):
364
def _auto_close(self):
365
if not self.complete:
370
self.connection.close()
375
self.log.write('closed\n')
378
del self.Encoder.connections[self.connection]
380
self.connecter.connection_lost(self)
381
elif self.locally_initiated:
382
incompletecounter.decrement()
384
def send_message_raw(self, message):
387
def _write(self, message):
389
self.connection.write(message)
391
def data_came_in(self, connection, s):
394
def _write_buffer(self, s):
395
self.buffer = s+self.buffer
399
self.log.write('r:'+b2a_hex(s)+'\n')
400
self.Encoder.measurefunc(len(s))
405
# self.next_len = # of characters function expects
406
# or 0 = all characters in the buffer
407
# or -1 = wait for next read, then all characters in the buffer
408
# not compatible w/ keepalives, switch out after all negotiation complete
409
if self.next_len <= 0:
412
elif len(self.buffer) >= self.next_len:
413
m = self.buffer[:self.next_len]
414
self.buffer = self.buffer[self.next_len:]
418
x = self.next_func(m)
420
self.next_len, self.next_func = 1, self.read_dead
425
self.next_len, self.next_func = x
426
if self.next_len < 0: # already checked buffer
427
return # wait for additional data
428
if self.bufferlen is not None:
432
def _switch_to_read2(self):
433
self._write_buffer = None
435
self.encrypter.setrawaccess(self._read2,self._write)
437
self.read = self._read2
438
self.bufferlen = len(self.buffer)
439
self.buffer = [self.buffer]
441
def _read2(self, s): # more efficient, requires buffer['',''] & bufferlen
443
self.log.write('r:'+b2a_hex(s)+'\n')
444
self.Encoder.measurefunc(len(s))
448
p = self.next_len-self.bufferlen
449
if self.next_len == 0:
453
self.buffer.append(s)
454
self.bufferlen += len(s)
456
self.bufferlen = len(s)-p
457
self.buffer.append(s[:p])
458
m = ''.join(self.buffer)
465
# assert len(self.buffer) == 1
467
self.bufferlen = len(s)-self.next_len
468
m = s[:self.next_len]
472
self.buffer = [s[self.next_len:]]
477
x = self.next_func(m)
479
self.next_len, self.next_func = 1, self.read_dead
484
self.next_len, self.next_func = x
485
if self.next_len < 0: # already checked buffer
486
return # wait for additional data
489
def connection_flushed(self, connection):
491
self.connecter.connection_flushed(self)
493
def connection_lost(self, connection):
494
if self.Encoder.connections.has_key(connection):
498
class _dummy_banlist:
499
def includes(self, x):
503
def __init__(self, connecter, raw_server, my_id, max_len,
504
schedulefunc, keepalive_delay, download_id,
505
measurefunc, config, bans=_dummy_banlist() ):
506
self.raw_server = raw_server
507
self.connecter = connecter
509
self.max_len = max_len
510
self.schedulefunc = schedulefunc
511
self.keepalive_delay = keepalive_delay
512
self.download_id = download_id
513
self.measurefunc = measurefunc
515
self.connections = {}
517
self.external_bans = bans
520
if self.config['max_connections'] == 0:
521
self.max_connections = 2 ** 30
523
self.max_connections = self.config['max_connections']
524
schedulefunc(self.send_keepalives, keepalive_delay)
526
def send_keepalives(self):
527
self.schedulefunc(self.send_keepalives, self.keepalive_delay)
530
for c in self.connections.values():
533
def start_connections(self, list):
534
if not self.to_connect:
535
self.raw_server.add_task(self._start_connection_from_queue)
536
self.to_connect = list
538
def _start_connection_from_queue(self):
539
if self.connecter.external_connection_made:
540
max_initiate = self.config['max_initiate']
542
max_initiate = int(self.config['max_initiate']*1.5)
543
cons = len(self.connections)
544
if cons >= self.max_connections or cons >= max_initiate:
546
elif self.paused or incompletecounter.toomany():
550
dns, id, encrypted = self.to_connect.pop(0)
551
self.start_connection(dns, id, encrypted)
553
self.raw_server.add_task(self._start_connection_from_queue, delay)
555
def start_connection(self, dns, id, encrypted = None):
557
or len(self.connections) >= self.max_connections
559
or not self.check_ip(ip=dns[0]) ):
561
if self.config['crypto_only']:
562
if encrypted is None or encrypted: # fails on encrypted = 0
566
for v in self.connections.values():
569
if id and v.id == id:
572
if self.config['security'] and ip != 'unknown' and ip == dns[0]:
575
c = self.raw_server.start_connection(dns)
576
con = Connection(self, c, id, encrypted = encrypted)
577
self.connections[c] = con
583
def _start_connection(self, dns, id, encrypted = None):
584
def foo(self=self, dns=dns, id=id, encrypted=encrypted):
585
self.start_connection(dns, id, encrypted)
586
self.schedulefunc(foo, 0)
588
def check_ip(self, connection=None, ip=None):
590
ip = connection.get_ip(True)
591
if self.config['security'] and self.banned.has_key(ip):
593
if self.external_bans.includes(ip):
597
def got_id(self, connection):
598
if connection.id == self.my_id:
599
self.connecter.external_connection_made -= 1
601
ip = connection.get_ip(True)
602
for v in self.connections.values():
603
if connection is not v:
604
if connection.id == v.id:
605
if ip == v.get_ip(True):
609
if self.config['security'] and ip != 'unknown' and ip == v.get_ip(True):
613
def external_connection_made(self, connection):
614
if self.paused or len(self.connections) >= self.max_connections:
617
con = Connection(self, connection, None)
618
self.connections[connection] = con
619
connection.set_handler(con)
622
def externally_handshaked_connection_made(self, connection, options,
623
already_read, encrypted = None):
625
or len(self.connections) >= self.max_connections
626
or not self.check_ip(connection=connection) ):
629
con = Connection(self, connection, None,
630
ext_handshake = True, encrypted = encrypted, options = options)
631
self.connections[connection] = con
632
connection.set_handler(con)
634
con.data_came_in(con, already_read)
638
for c in self.connections.values():
640
self.connections = {}
645
def pause(self, flag):