~thumper/bzr/alias-command

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Tim Penhey
  • Date: 2008-05-30 10:57:24 UTC
  • mfrom: (2900.1.561 +trunk)
  • Revision ID: tim@penhey.net-20080530105724-8494sid7i6ajilg4
Merge bzr.dev and resolve conflicts.

Show diffs side-by-side

added added

removed removed

Lines of Context:
21
21
import os
22
22
import socket
23
23
import threading
24
 
import urllib2
25
24
 
 
25
import bzrlib
26
26
from bzrlib import (
27
27
        bzrdir,
28
28
        errors,
33
33
from bzrlib.smart import (
34
34
        client,
35
35
        medium,
 
36
        message,
36
37
        protocol,
37
38
        request as _mod_request,
38
39
        server,
39
40
        vfs,
40
41
)
41
 
from bzrlib.tests.http_utils import (
42
 
        HTTPServerWithSmarts,
43
 
        SmartRequestHandler,
44
 
        )
45
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
46
43
from bzrlib.transport import (
47
44
        get_transport,
118
115
        sock.bind(('127.0.0.1', 0))
119
116
        sock.listen(1)
120
117
        port = sock.getsockname()[1]
121
 
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port)
 
118
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
122
119
        return sock, client_medium
123
120
 
124
121
    def receive_bytes_on_server(self, sock, bytes):
136
133
        t.start()
137
134
        return t
138
135
    
139
 
    def test_construct_smart_stream_medium_client(self):
140
 
        # make a new instance of the common base for Stream-like Mediums.
141
 
        # this just ensures that the constructor stays parameter-free which
142
 
        # is important for reuse : some subclasses will dynamically connect,
143
 
        # others are always on, etc.
144
 
        client_medium = medium.SmartClientStreamMedium()
145
 
 
146
 
    def test_construct_smart_client_medium(self):
147
 
        # the base client medium takes no parameters
148
 
        client_medium = medium.SmartClientMedium()
149
 
    
150
136
    def test_construct_smart_simple_pipes_client_medium(self):
151
137
        # the SimplePipes client medium takes two pipes:
152
138
        # readable pipe, writeable pipe.
153
139
        # Constructing one should just save these and do nothing.
154
140
        # We test this by passing in None.
155
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
156
142
        
157
143
    def test_simple_pipes_client_request_type(self):
158
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
159
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
160
146
        request = client_medium.get_request()
161
147
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
162
148
 
168
154
        # classes - as the sibling classes share this logic, they do not have
169
155
        # explicit tests for this.
170
156
        output = StringIO()
171
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
157
        client_medium = medium.SmartSimplePipesClientMedium(
 
158
            None, output, 'base')
172
159
        request = client_medium.get_request()
173
160
        request.finished_writing()
174
161
        request.finished_reading()
179
166
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
180
167
        # accept_bytes writes to the writeable pipe.
181
168
        output = StringIO()
182
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
169
        client_medium = medium.SmartSimplePipesClientMedium(
 
170
            None, output, 'base')
183
171
        client_medium._accept_bytes('abc')
184
172
        self.assertEqual('abc', output.getvalue())
185
173
    
187
175
        # calling disconnect does nothing.
188
176
        input = StringIO()
189
177
        output = StringIO()
190
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
178
        client_medium = medium.SmartSimplePipesClientMedium(
 
179
            input, output, 'base')
191
180
        # send some bytes to ensure disconnecting after activity still does not
192
181
        # close.
193
182
        client_medium._accept_bytes('abc')
200
189
        # accept_bytes writes to.
201
190
        input = StringIO()
202
191
        output = StringIO()
203
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
192
        client_medium = medium.SmartSimplePipesClientMedium(
 
193
            input, output, 'base')
204
194
        client_medium._accept_bytes('abc')
205
195
        client_medium.disconnect()
206
196
        client_medium._accept_bytes('abc')
211
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
212
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
213
203
        # does nothing.
214
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
204
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
215
205
        client_medium.disconnect()
216
206
 
217
207
    def test_simple_pipes_client_can_always_read(self):
218
208
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
219
209
        # always tries to read from the underlying pipe.
220
210
        input = StringIO('abcdef')
221
 
        client_medium = medium.SmartSimplePipesClientMedium(input, None)
 
211
        client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
222
212
        self.assertEqual('abc', client_medium.read_bytes(3))
223
213
        client_medium.disconnect()
224
214
        self.assertEqual('def', client_medium.read_bytes(3))
233
223
        flush_calls = []
234
224
        def logging_flush(): flush_calls.append('flush')
235
225
        output.flush = logging_flush
236
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
226
        client_medium = medium.SmartSimplePipesClientMedium(
 
227
            input, output, 'base')
237
228
        # this call is here to ensure we only flush once, not on every
238
229
        # _accept_bytes call.
239
230
        client_medium._accept_bytes('abc')
253
244
        # having vendor be invalid means that if it tries to connect via the
254
245
        # vendor it will blow up.
255
246
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
256
 
            username=None, password=None, vendor="not a vendor",
 
247
            username=None, password=None, base='base', vendor="not a vendor",
257
248
            bzr_remote_path='bzr')
258
249
        sock.close()
259
250
 
263
254
        output = StringIO()
264
255
        vendor = StringIOSSHVendor(StringIO(), output)
265
256
        client_medium = medium.SmartSSHClientMedium(
266
 
            'a hostname', 'a port', 'a username', 'a password', vendor, 'bzr')
 
257
            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
 
258
            'bzr')
267
259
        client_medium._accept_bytes('abc')
268
260
        self.assertEqual('abc', output.getvalue())
269
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
284
276
        client_medium = self.callDeprecated(
285
277
            ['bzr_remote_path is required as of bzr 0.92'],
286
278
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
287
 
            'a password', vendor)
 
279
            'a password', 'base', vendor)
288
280
        client_medium._accept_bytes('abc')
289
281
        self.assertEqual('abc', output.getvalue())
290
282
        self.assertEqual([('connect_ssh', 'a username', 'a password',
298
290
        output = StringIO()
299
291
        vendor = StringIOSSHVendor(StringIO(), output)
300
292
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
301
 
            'a username', 'a password', vendor, bzr_remote_path='fugly')
 
293
            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
302
294
        client_medium._accept_bytes('abc')
303
295
        self.assertEqual('abc', output.getvalue())
304
296
        self.assertEqual([('connect_ssh', 'a username', 'a password',
312
304
        input = StringIO()
313
305
        output = StringIO()
314
306
        vendor = StringIOSSHVendor(input, output)
315
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
316
 
                                                    vendor=vendor,
317
 
                                                    bzr_remote_path='bzr')
 
307
        client_medium = medium.SmartSSHClientMedium(
 
308
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
318
309
        client_medium._accept_bytes('abc')
319
310
        client_medium.disconnect()
320
311
        self.assertTrue(input.closed)
334
325
        input = StringIO()
335
326
        output = StringIO()
336
327
        vendor = StringIOSSHVendor(input, output)
337
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
338
 
            vendor=vendor, bzr_remote_path='bzr')
 
328
        client_medium = medium.SmartSSHClientMedium(
 
329
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
339
330
        client_medium._accept_bytes('abc')
340
331
        client_medium.disconnect()
341
332
        # the disconnect has closed output, so we need a new output for the
363
354
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
364
355
        # Doing a disconnect on a new (and thus unconnected) SSH medium
365
356
        # does not fail.  It's ok to disconnect an unconnected medium.
366
 
        client_medium = medium.SmartSSHClientMedium(None,
367
 
                                                    bzr_remote_path='bzr')
 
357
        client_medium = medium.SmartSSHClientMedium(
 
358
            None, base='base', bzr_remote_path='bzr')
368
359
        client_medium.disconnect()
369
360
 
370
361
    def test_ssh_client_raises_on_read_when_not_connected(self):
371
362
        # Doing a read on a new (and thus unconnected) SSH medium raises
372
363
        # MediumNotConnected.
373
 
        client_medium = medium.SmartSSHClientMedium(None,
374
 
                                                    bzr_remote_path='bzr')
 
364
        client_medium = medium.SmartSSHClientMedium(
 
365
            None, base='base', bzr_remote_path='bzr')
375
366
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
376
367
                          0)
377
368
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
388
379
        def logging_flush(): flush_calls.append('flush')
389
380
        output.flush = logging_flush
390
381
        vendor = StringIOSSHVendor(input, output)
391
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
392
 
                                                    vendor=vendor,
393
 
                                                    bzr_remote_path='bzr')
 
382
        client_medium = medium.SmartSSHClientMedium(
 
383
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
394
384
        # this call is here to ensure we only flush once, not on every
395
385
        # _accept_bytes call.
396
386
        client_medium._accept_bytes('abc')
404
394
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
405
395
        sock.bind(('127.0.0.1', 0))
406
396
        unopened_port = sock.getsockname()[1]
407
 
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
397
        client_medium = medium.SmartTCPClientMedium(
 
398
            '127.0.0.1', unopened_port, 'base')
408
399
        sock.close()
409
400
 
410
401
    def test_tcp_client_connects_on_first_use(self):
438
429
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
439
430
        # Doing a disconnect on a new (and thus unconnected) TCP medium
440
431
        # does not fail.  It's ok to disconnect an unconnected medium.
441
 
        client_medium = medium.SmartTCPClientMedium(None, None)
 
432
        client_medium = medium.SmartTCPClientMedium(None, None, None)
442
433
        client_medium.disconnect()
443
434
 
444
435
    def test_tcp_client_raises_on_read_when_not_connected(self):
445
436
        # Doing a read on a new (and thus unconnected) TCP medium raises
446
437
        # MediumNotConnected.
447
 
        client_medium = medium.SmartTCPClientMedium(None, None)
 
438
        client_medium = medium.SmartTCPClientMedium(None, None, None)
448
439
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
449
440
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
450
441
 
470
461
    def test_tcp_client_host_unknown_connection_error(self):
471
462
        self.requireFeature(InvalidHostnameFeature)
472
463
        client_medium = medium.SmartTCPClientMedium(
473
 
            'non_existent.invalid', 4155)
 
464
            'non_existent.invalid', 4155, 'base')
474
465
        self.assertRaises(
475
466
            errors.ConnectionError, client_medium._ensure_connection)
476
467
 
488
479
        # WritingCompleted to prevent bad assumptions on stream environments
489
480
        # breaking the needs of message-based environments.
490
481
        output = StringIO()
491
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
482
        client_medium = medium.SmartSimplePipesClientMedium(
 
483
            None, output, 'base')
492
484
        request = medium.SmartClientStreamMediumRequest(client_medium)
493
485
        request.finished_writing()
494
486
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
499
491
        # and checking that the pipes get the data.
500
492
        input = StringIO()
501
493
        output = StringIO()
502
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
494
        client_medium = medium.SmartSimplePipesClientMedium(
 
495
            input, output, 'base')
503
496
        request = medium.SmartClientStreamMediumRequest(client_medium)
504
497
        request.accept_bytes('123')
505
498
        request.finished_writing()
511
504
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
512
505
        # the current request to the new SmartClientStreamMediumRequest
513
506
        output = StringIO()
514
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
507
        client_medium = medium.SmartSimplePipesClientMedium(
 
508
            None, output, 'base')
515
509
        request = medium.SmartClientStreamMediumRequest(client_medium)
516
510
        self.assertIs(client_medium._current_request, request)
517
511
 
519
513
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
520
514
        # a non-None _current_request raises TooManyConcurrentRequests.
521
515
        output = StringIO()
522
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
516
        client_medium = medium.SmartSimplePipesClientMedium(
 
517
            None, output, 'base')
523
518
        client_medium._current_request = "a"
524
519
        self.assertRaises(errors.TooManyConcurrentRequests,
525
520
            medium.SmartClientStreamMediumRequest, client_medium)
528
523
        # calling finished_reading clears the current request from the requests
529
524
        # medium
530
525
        output = StringIO()
531
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
526
        client_medium = medium.SmartSimplePipesClientMedium(
 
527
            None, output, 'base')
532
528
        request = medium.SmartClientStreamMediumRequest(client_medium)
533
529
        request.finished_writing()
534
530
        request.finished_reading()
537
533
    def test_finished_read_before_finished_write_errors(self):
538
534
        # calling finished_reading before calling finished_writing triggers a
539
535
        # WritingNotComplete error.
540
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
536
        client_medium = medium.SmartSimplePipesClientMedium(
 
537
            None, None, 'base')
541
538
        request = medium.SmartClientStreamMediumRequest(client_medium)
542
539
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
543
540
        
550
547
        # smoke tests.
551
548
        input = StringIO('321')
552
549
        output = StringIO()
553
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
550
        client_medium = medium.SmartSimplePipesClientMedium(
 
551
            input, output, 'base')
554
552
        request = medium.SmartClientStreamMediumRequest(client_medium)
555
553
        request.finished_writing()
556
554
        self.assertEqual('321', request.read_bytes(3))
563
561
        # WritingNotComplete error because the Smart protocol is designed to be
564
562
        # compatible with strict message based protocols like HTTP where the
565
563
        # request cannot be submitted until the writing has completed.
566
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
564
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
567
565
        request = medium.SmartClientStreamMediumRequest(client_medium)
568
566
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
569
567
 
572
570
        # ReadingCompleted to prevent bad assumptions on stream environments
573
571
        # breaking the needs of message-based environments.
574
572
        output = StringIO()
575
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
573
        client_medium = medium.SmartSimplePipesClientMedium(
 
574
            None, output, 'base')
576
575
        request = medium.SmartClientStreamMediumRequest(client_medium)
577
576
        request.finished_writing()
578
577
        request.finished_reading()
610
609
        self.accepted_bytes = ''
611
610
        self._finished_reading = False
612
611
        self.expected_bytes = expected_bytes
613
 
        self.excess_buffer = ''
 
612
        self.unused_data = ''
614
613
 
615
614
    def accept_bytes(self, bytes):
616
615
        self.accepted_bytes += bytes
617
616
        if self.accepted_bytes.startswith(self.expected_bytes):
618
617
            self._finished_reading = True
619
 
            self.excess_buffer = self.accepted_bytes[len(self.expected_bytes):]
 
618
            self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
620
619
 
621
620
    def next_read_size(self):
622
621
        if self._finished_reading:
944
943
        # A request that starts with "bzr request 2\n" is version two.
945
944
        server_protocol = self.build_protocol_socket('bzr request 2\n')
946
945
        self.assertProtocolTwo(server_protocol)
 
946
 
 
947
 
 
948
class TestGetProtocolFactoryForBytes(tests.TestCase):
 
949
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
 
950
    should use to decode a given request.  Any bytes not part of the version
 
951
    marker string (and thus part of the actual request) are returned alongside
 
952
    the protocol factory.
 
953
    """
 
954
 
 
955
    def test_version_three(self):
 
956
        result = medium._get_protocol_factory_for_bytes(
 
957
            'bzr message 3 (bzr 1.6)\nextra bytes')
 
958
        protocol_factory, remainder = result
 
959
        self.assertEqual(
 
960
            protocol.build_server_protocol_three, protocol_factory)
 
961
        self.assertEqual('extra bytes', remainder)
 
962
        
 
963
    def test_version_two(self):
 
964
        result = medium._get_protocol_factory_for_bytes(
 
965
            'bzr request 2\nextra bytes')
 
966
        protocol_factory, remainder = result
 
967
        self.assertEqual(
 
968
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
 
969
        self.assertEqual('extra bytes', remainder)
 
970
        
 
971
    def test_version_one(self):
 
972
        """Version one requests have no version markers."""
 
973
        result = medium._get_protocol_factory_for_bytes('anything\n')
 
974
        protocol_factory, remainder = result
 
975
        self.assertEqual(
 
976
            protocol.SmartServerRequestProtocolOne, protocol_factory)
 
977
        self.assertEqual('anything\n', remainder)
947
978
        
948
979
 
949
980
class TestSmartTCPServer(tests.TestCase):
958
989
            def get_bytes(self, path):
959
990
                raise Exception("some random exception from inside server")
960
991
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
961
 
        smart_server.start_background_thread()
 
992
        smart_server.start_background_thread('-' + self.id())
962
993
        try:
963
994
            transport = remote.RemoteTCPTransport(smart_server.get_url())
964
995
            try:
994
1025
            self.real_backing_transport = self.backing_transport
995
1026
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
996
1027
        self.server = server.SmartTCPServer(self.backing_transport)
997
 
        self.server.start_background_thread()
 
1028
        self.server.start_background_thread('-' + self.id())
998
1029
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
999
1030
        self.addCleanup(self.tearDownServer)
1000
1031
 
1130
1161
    def test_server_started_hook_memory(self):
1131
1162
        """The server_started hook fires when the server is started."""
1132
1163
        self.hook_calls = []
1133
 
        server.SmartTCPServer.hooks.install_hook('server_started',
1134
 
            self.capture_server_call)
 
1164
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1165
            self.capture_server_call, None)
1135
1166
        self.setUpServer()
1136
1167
        # at this point, the server will be starting a thread up.
1137
1168
        # there is no indicator at the moment, so bodge it by doing a request.
1144
1175
    def test_server_started_hook_file(self):
1145
1176
        """The server_started hook fires when the server is started."""
1146
1177
        self.hook_calls = []
1147
 
        server.SmartTCPServer.hooks.install_hook('server_started',
1148
 
            self.capture_server_call)
 
1178
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1179
            self.capture_server_call, None)
1149
1180
        self.setUpServer(backing_transport=get_transport("."))
1150
1181
        # at this point, the server will be starting a thread up.
1151
1182
        # there is no indicator at the moment, so bodge it by doing a request.
1160
1191
    def test_server_stopped_hook_simple_memory(self):
1161
1192
        """The server_stopped hook fires when the server is stopped."""
1162
1193
        self.hook_calls = []
1163
 
        server.SmartTCPServer.hooks.install_hook('server_stopped',
1164
 
            self.capture_server_call)
 
1194
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1195
            self.capture_server_call, None)
1165
1196
        self.setUpServer()
1166
1197
        result = [([self.backing_transport.base], self.transport.base)]
1167
1198
        # check the stopping message isn't emitted up front.
1177
1208
    def test_server_stopped_hook_simple_file(self):
1178
1209
        """The server_stopped hook fires when the server is stopped."""
1179
1210
        self.hook_calls = []
1180
 
        server.SmartTCPServer.hooks.install_hook('server_stopped',
1181
 
            self.capture_server_call)
 
1211
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1212
            self.capture_server_call, None)
1182
1213
        self.setUpServer(backing_transport=get_transport("."))
1183
1214
        result = [(
1184
1215
            [self.backing_transport.base, self.backing_transport.external_url()]
1202
1233
    and the request dispatching.
1203
1234
 
1204
1235
    Note: these tests are rudimentary versions of the command object tests in
1205
 
    test_remote.py.
 
1236
    test_smart.py.
1206
1237
    """
1207
1238
        
1208
1239
    def test_hello(self):
1341
1372
        
1342
1373
    def test_use_connection_factory(self):
1343
1374
        # We want to be able to pass a client as a parameter to RemoteTransport.
1344
 
        input = StringIO("ok\n3\nbardone\n")
 
1375
        input = StringIO('ok\n3\nbardone\n')
1345
1376
        output = StringIO()
1346
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1377
        client_medium = medium.SmartSimplePipesClientMedium(
 
1378
            input, output, 'base')
1347
1379
        transport = remote.RemoteTransport(
1348
1380
            'bzr://localhost/', medium=client_medium)
 
1381
        # Disable version detection.
 
1382
        client_medium._protocol_version = 1
1349
1383
 
1350
1384
        # We want to make sure the client is used when the first remote
1351
1385
        # method is called.  No data should have been sent, or read.
1352
1386
        self.assertEqual(0, input.tell())
1353
1387
        self.assertEqual('', output.getvalue())
1354
1388
 
1355
 
        # Now call a method that should result in a single request : as the
 
1389
        # Now call a method that should result in one request: as the
1356
1390
        # transport makes its own protocol instances, we check on the wire.
1357
1391
        # XXX: TODO: give the transport a protocol factory, which can make
1358
1392
        # an instrumented protocol for us.
1363
1397
 
1364
1398
    def test__translate_error_readonly(self):
1365
1399
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1366
 
        client_medium = medium.SmartClientMedium()
 
1400
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
1367
1401
        transport = remote.RemoteTransport(
1368
1402
            'bzr://localhost/', medium=client_medium)
1369
1403
        self.assertRaises(errors.TransportNotPossible,
1385
1419
    Subclasses can override client_protocol_class and server_protocol_class.
1386
1420
    """
1387
1421
 
 
1422
    request_encoder = None
 
1423
    response_decoder = None
 
1424
    server_protocol_class = None
1388
1425
    client_protocol_class = None
1389
 
    server_protocol_class = None
1390
1426
 
1391
 
    def make_client_protocol(self):
 
1427
    def make_client_protocol_and_output(self, input_bytes=None):
 
1428
        """
 
1429
        :returns: a Request
 
1430
        """
 
1431
        # This is very similar to
 
1432
        # bzrlib.smart.client._SmartClient._build_client_protocol
 
1433
        # XXX: make this use _SmartClient!
 
1434
        if input_bytes is None:
 
1435
            input = StringIO()
 
1436
        else:
 
1437
            input = StringIO(input_bytes)
 
1438
        output = StringIO()
1392
1439
        client_medium = medium.SmartSimplePipesClientMedium(
1393
 
            StringIO(), StringIO())
1394
 
        return self.client_protocol_class(client_medium.get_request())
 
1440
            input, output, 'base')
 
1441
        request = client_medium.get_request()
 
1442
        if self.client_protocol_class is not None:
 
1443
            client_protocol = self.client_protocol_class(request)
 
1444
            return client_protocol, client_protocol, output
 
1445
        else:
 
1446
            self.assertNotEqual(None, self.request_encoder)
 
1447
            self.assertNotEqual(None, self.response_decoder)
 
1448
            requester = self.request_encoder(request)
 
1449
            response_handler = message.ConventionalResponseHandler()
 
1450
            response_protocol = self.response_decoder(
 
1451
                response_handler, expect_version_marker=True)
 
1452
            response_handler.setProtoAndMediumRequest(
 
1453
                response_protocol, request)
 
1454
            return requester, response_handler, output
 
1455
 
 
1456
    def make_client_protocol(self, input_bytes=None):
 
1457
        result = self.make_client_protocol_and_output(input_bytes=input_bytes)
 
1458
        requester, response_handler, output = result
 
1459
        return requester, response_handler
1395
1460
 
1396
1461
    def make_server_protocol(self):
1397
1462
        out_stream = StringIO()
1398
1463
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1399
1464
        return smart_protocol, out_stream
1400
1465
 
 
1466
    def setUp(self):
 
1467
        super(TestSmartProtocol, self).setUp()
 
1468
        self.response_marker = getattr(
 
1469
            self.client_protocol_class, 'response_marker', None)
 
1470
        self.request_marker = getattr(
 
1471
            self.client_protocol_class, 'request_marker', None)
 
1472
 
1401
1473
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1402
 
        client):
 
1474
        requester):
1403
1475
        """Check that smart (de)serialises offsets as expected.
1404
1476
        
1405
1477
        We check both serialisation and deserialisation at the same time
1414
1486
        readv_cmd = vfs.ReadvRequest(None, '/')
1415
1487
        offsets = readv_cmd._deserialise_offsets(expected_serialised)
1416
1488
        self.assertEqual(expected_offsets, offsets)
1417
 
        serialised = client._serialise_offsets(offsets)
 
1489
        serialised = requester._serialise_offsets(offsets)
1418
1490
        self.assertEqual(expected_serialised, serialised)
1419
1491
 
1420
1492
    def build_protocol_waiting_for_body(self):
1421
1493
        smart_protocol, out_stream = self.make_server_protocol()
1422
 
        smart_protocol.has_dispatched = True
 
1494
        smart_protocol._has_dispatched = True
1423
1495
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
1424
1496
            None, _mod_request.request_handlers, '/')
1425
1497
        class FakeCommand(object):
1448
1520
                _mod_request.SuccessfulSmartServerResponse(input_tuple))
1449
1521
            self.assertEqual(expected_bytes, server_output.getvalue())
1450
1522
        # check the decoding of the client smart_protocol from expected_bytes:
1451
 
        input = StringIO(expected_bytes)
1452
 
        output = StringIO()
1453
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1454
 
        request = client_medium.get_request()
1455
 
        smart_protocol = self.client_protocol_class(request)
1456
 
        smart_protocol.call('foo')
1457
 
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1523
        requester, response_handler = self.make_client_protocol(expected_bytes)
 
1524
        requester.call('foo')
 
1525
        self.assertEqual(expected_tuple, response_handler.read_response_tuple())
1458
1526
 
1459
1527
 
1460
1528
class CommonSmartProtocolTestMixin(object):
1461
1529
 
1462
 
    def test_errors_are_logged(self):
1463
 
        """If an error occurs during testing, it is logged to the test log."""
1464
 
        smart_protocol, out_stream = self.make_server_protocol()
1465
 
        # This triggers a "bad request" error.
1466
 
        smart_protocol.accept_bytes('abc\n')
1467
 
        test_log = self._get_log(keep_log_file=True)
1468
 
        self.assertContainsRe(test_log, 'Traceback')
1469
 
        self.assertContainsRe(test_log, 'SmartProtocolError')
1470
 
 
1471
1530
    def test_connection_closed_reporting(self):
1472
 
        smart_protocol = self.make_client_protocol()
1473
 
        smart_protocol.call('hello')
 
1531
        requester, response_handler = self.make_client_protocol()
 
1532
        requester.call('hello')
1474
1533
        ex = self.assertRaises(errors.ConnectionReset,
1475
 
            smart_protocol.read_response_tuple)
 
1534
            response_handler.read_response_tuple)
1476
1535
        self.assertEqual("Connection closed: "
1477
1536
            "please check connectivity and permissions "
1478
1537
            "(and try -Dhpss if further diagnosis is required)", str(ex))
1484
1543
        one with the order of reads not increasing (an out of order read), and
1485
1544
        one that should coalesce.
1486
1545
        """
1487
 
        client_protocol = self.make_client_protocol()
1488
 
        self.assertOffsetSerialisation([], '', client_protocol)
1489
 
        self.assertOffsetSerialisation([(1,2)], '1,2', client_protocol)
 
1546
        requester, response_handler = self.make_client_protocol()
 
1547
        self.assertOffsetSerialisation([], '', requester)
 
1548
        self.assertOffsetSerialisation([(1,2)], '1,2', requester)
1490
1549
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1491
 
            client_protocol)
 
1550
            requester)
1492
1551
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1493
 
            '1,2\n3,4\n100,200', client_protocol)
1494
 
 
1495
 
 
1496
 
class TestSmartProtocolOne(TestSmartProtocol, CommonSmartProtocolTestMixin):
1497
 
    """Tests for the smart protocol version one."""
 
1552
            '1,2\n3,4\n100,200', requester)
 
1553
 
 
1554
 
 
1555
class TestVersionOneFeaturesInProtocolOne(
 
1556
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1557
    """Tests for version one smart protocol features as implemeted by version
 
1558
    one."""
1498
1559
 
1499
1560
    client_protocol_class = protocol.SmartClientRequestProtocolOne
1500
1561
    server_protocol_class = protocol.SmartServerRequestProtocolOne
1501
1562
 
1502
1563
    def test_construct_version_one_server_protocol(self):
1503
1564
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1504
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1565
        self.assertEqual('', smart_protocol.unused_data)
1505
1566
        self.assertEqual('', smart_protocol.in_buffer)
1506
 
        self.assertFalse(smart_protocol.has_dispatched)
 
1567
        self.assertFalse(smart_protocol._has_dispatched)
1507
1568
        self.assertEqual(1, smart_protocol.next_read_size())
1508
1569
 
1509
1570
    def test_construct_version_one_client_protocol(self):
1510
1571
        # we can construct a client protocol from a client medium request
1511
1572
        output = StringIO()
1512
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1573
        client_medium = medium.SmartSimplePipesClientMedium(
 
1574
            None, output, 'base')
1513
1575
        request = client_medium.get_request()
1514
1576
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1515
1577
 
1523
1585
        self.assertEqual(
1524
1586
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
1525
1587
            out_stream.getvalue())
1526
 
        self.assertTrue(smart_protocol.has_dispatched)
 
1588
        self.assertTrue(smart_protocol._has_dispatched)
1527
1589
        self.assertEqual(0, smart_protocol.next_read_size())
1528
1590
 
1529
1591
    def test_accept_body_bytes_to_protocol(self):
1546
1608
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1547
1609
        self.assertEqual(0, smart_protocol.next_read_size())
1548
1610
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1549
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1611
        self.assertEqual('', smart_protocol.unused_data)
1550
1612
        self.assertEqual('', smart_protocol.in_buffer)
1551
1613
 
1552
1614
    def test_accept_excess_bytes_are_preserved(self):
1555
1617
            None, out_stream.write)
1556
1618
        smart_protocol.accept_bytes('hello\nhello\n')
1557
1619
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1558
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1620
        self.assertEqual("hello\n", smart_protocol.unused_data)
1559
1621
        self.assertEqual("", smart_protocol.in_buffer)
1560
1622
 
1561
1623
    def test_accept_excess_bytes_after_body(self):
1562
1624
        protocol = self.build_protocol_waiting_for_body()
1563
1625
        protocol.accept_bytes('7\nabcdefgdone\nX')
1564
1626
        self.assertTrue(self.end_received)
1565
 
        self.assertEqual("X", protocol.excess_buffer)
 
1627
        self.assertEqual("X", protocol.unused_data)
1566
1628
        self.assertEqual("", protocol.in_buffer)
1567
1629
        protocol.accept_bytes('Y')
1568
 
        self.assertEqual("XY", protocol.excess_buffer)
 
1630
        self.assertEqual("XY", protocol.unused_data)
1569
1631
        self.assertEqual("", protocol.in_buffer)
1570
1632
 
1571
1633
    def test_accept_excess_bytes_after_dispatch(self):
1575
1637
        smart_protocol.accept_bytes('hello\n')
1576
1638
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1577
1639
        smart_protocol.accept_bytes('hel')
1578
 
        self.assertEqual("hel", smart_protocol.excess_buffer)
 
1640
        self.assertEqual("hel", smart_protocol.unused_data)
1579
1641
        smart_protocol.accept_bytes('lo\n')
1580
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1642
        self.assertEqual("hello\n", smart_protocol.unused_data)
1581
1643
        self.assertEqual("", smart_protocol.in_buffer)
1582
1644
 
1583
1645
    def test__send_response_sets_finished_reading(self):
1607
1669
        # the error if the response is a non-understood version.
1608
1670
        input = StringIO('ok\x012\n')
1609
1671
        output = StringIO()
1610
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1672
        client_medium = medium.SmartSimplePipesClientMedium(
 
1673
            input, output, 'base')
1611
1674
        request = client_medium.get_request()
1612
1675
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1613
1676
        self.assertEqual(2, smart_protocol.query_version())
1630
1693
        expected_bytes = "foo\n7\nabcdefgdone\n"
1631
1694
        input = StringIO("\n")
1632
1695
        output = StringIO()
1633
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1696
        client_medium = medium.SmartSimplePipesClientMedium(
 
1697
            input, output, 'base')
1634
1698
        request = client_medium.get_request()
1635
1699
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1636
1700
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1642
1706
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
1643
1707
        input = StringIO("\n")
1644
1708
        output = StringIO()
1645
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1709
        client_medium = medium.SmartSimplePipesClientMedium(
 
1710
            input, output, 'base')
1646
1711
        request = client_medium.get_request()
1647
1712
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1648
1713
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1652
1717
            server_bytes):
1653
1718
        input = StringIO(server_bytes)
1654
1719
        output = StringIO()
1655
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1720
        client_medium = medium.SmartSimplePipesClientMedium(
 
1721
            input, output, 'base')
1656
1722
        request = client_medium.get_request()
1657
1723
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1658
1724
        smart_protocol.call('foo')
1690
1756
        server_bytes = "ok\n7\n1234567done\n"
1691
1757
        input = StringIO(server_bytes)
1692
1758
        output = StringIO()
1693
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1759
        client_medium = medium.SmartSimplePipesClientMedium(
 
1760
            input, output, 'base')
1694
1761
        request = client_medium.get_request()
1695
1762
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1696
1763
        smart_protocol.call('foo')
1707
1774
        server_bytes = "ok\n7\n1234567done\n"
1708
1775
        input = StringIO(server_bytes)
1709
1776
        output = StringIO()
1710
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1777
        client_medium = medium.SmartSimplePipesClientMedium(
 
1778
            input, output, 'base')
1711
1779
        request = client_medium.get_request()
1712
1780
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1713
1781
        smart_protocol.call('foo')
1724
1792
        server_bytes = "ok\n7\n1234567done\n"
1725
1793
        input = StringIO(server_bytes)
1726
1794
        output = StringIO()
1727
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1795
        client_medium = medium.SmartSimplePipesClientMedium(
 
1796
            input, output, 'base')
1728
1797
        request = client_medium.get_request()
1729
1798
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1730
1799
        smart_protocol.call('foo')
1735
1804
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1736
1805
 
1737
1806
 
1738
 
class TestSmartProtocolTwo(TestSmartProtocol, CommonSmartProtocolTestMixin):
1739
 
    """Tests for the smart protocol version two.
1740
 
 
1741
 
    This test case is mostly the same as TestSmartProtocolOne.
 
1807
class TestVersionOneFeaturesInProtocolTwo(
 
1808
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1809
    """Tests for version one smart protocol features as implemeted by version
 
1810
    two.
1742
1811
    """
1743
1812
 
1744
1813
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
1746
1815
 
1747
1816
    def test_construct_version_two_server_protocol(self):
1748
1817
        smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
1749
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1818
        self.assertEqual('', smart_protocol.unused_data)
1750
1819
        self.assertEqual('', smart_protocol.in_buffer)
1751
 
        self.assertFalse(smart_protocol.has_dispatched)
 
1820
        self.assertFalse(smart_protocol._has_dispatched)
1752
1821
        self.assertEqual(1, smart_protocol.next_read_size())
1753
1822
 
1754
1823
    def test_construct_version_two_client_protocol(self):
1755
1824
        # we can construct a client protocol from a client medium request
1756
1825
        output = StringIO()
1757
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1826
        client_medium = medium.SmartSimplePipesClientMedium(
 
1827
            None, output, 'base')
1758
1828
        request = client_medium.get_request()
1759
1829
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1760
1830
 
 
1831
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1832
        out_stream = StringIO()
 
1833
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1834
        smart_protocol.accept_bytes('abc')
 
1835
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1836
        smart_protocol.accept_bytes('\n')
 
1837
        self.assertEqual(
 
1838
            self.response_marker +
 
1839
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1840
            out_stream.getvalue())
 
1841
        self.assertTrue(smart_protocol._has_dispatched)
 
1842
        self.assertEqual(0, smart_protocol.next_read_size())
 
1843
 
 
1844
    def test_accept_body_bytes_to_protocol(self):
 
1845
        protocol = self.build_protocol_waiting_for_body()
 
1846
        self.assertEqual(6, protocol.next_read_size())
 
1847
        protocol.accept_bytes('7\nabc')
 
1848
        self.assertEqual(9, protocol.next_read_size())
 
1849
        protocol.accept_bytes('defgd')
 
1850
        protocol.accept_bytes('one\n')
 
1851
        self.assertEqual(0, protocol.next_read_size())
 
1852
        self.assertTrue(self.end_received)
 
1853
 
 
1854
    def test_accept_request_and_body_all_at_once(self):
 
1855
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1856
        mem_transport = memory.MemoryTransport()
 
1857
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1858
        out_stream = StringIO()
 
1859
        smart_protocol = self.server_protocol_class(
 
1860
            mem_transport, out_stream.write)
 
1861
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1862
        self.assertEqual(0, smart_protocol.next_read_size())
 
1863
        self.assertEqual(self.response_marker +
 
1864
                         'success\nreadv\n3\ndefdone\n',
 
1865
                         out_stream.getvalue())
 
1866
        self.assertEqual('', smart_protocol.unused_data)
 
1867
        self.assertEqual('', smart_protocol.in_buffer)
 
1868
 
 
1869
    def test_accept_excess_bytes_are_preserved(self):
 
1870
        out_stream = StringIO()
 
1871
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1872
        smart_protocol.accept_bytes('hello\nhello\n')
 
1873
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1874
                         out_stream.getvalue())
 
1875
        self.assertEqual("hello\n", smart_protocol.unused_data)
 
1876
        self.assertEqual("", smart_protocol.in_buffer)
 
1877
 
 
1878
    def test_accept_excess_bytes_after_body(self):
 
1879
        # The excess bytes look like the start of another request.
 
1880
        server_protocol = self.build_protocol_waiting_for_body()
 
1881
        server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
 
1882
        self.assertTrue(self.end_received)
 
1883
        self.assertEqual(self.response_marker,
 
1884
                         server_protocol.unused_data)
 
1885
        self.assertEqual("", server_protocol.in_buffer)
 
1886
        server_protocol.accept_bytes('Y')
 
1887
        self.assertEqual(self.response_marker + "Y",
 
1888
                         server_protocol.unused_data)
 
1889
        self.assertEqual("", server_protocol.in_buffer)
 
1890
 
 
1891
    def test_accept_excess_bytes_after_dispatch(self):
 
1892
        out_stream = StringIO()
 
1893
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1894
        smart_protocol.accept_bytes('hello\n')
 
1895
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1896
                         out_stream.getvalue())
 
1897
        smart_protocol.accept_bytes(self.request_marker + 'hel')
 
1898
        self.assertEqual(self.request_marker + "hel",
 
1899
                         smart_protocol.unused_data)
 
1900
        smart_protocol.accept_bytes('lo\n')
 
1901
        self.assertEqual(self.request_marker + "hello\n",
 
1902
                         smart_protocol.unused_data)
 
1903
        self.assertEqual("", smart_protocol.in_buffer)
 
1904
 
 
1905
    def test__send_response_sets_finished_reading(self):
 
1906
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1907
        self.assertEqual(1, smart_protocol.next_read_size())
 
1908
        smart_protocol._send_response(
 
1909
            _mod_request.SuccessfulSmartServerResponse(('x',)))
 
1910
        self.assertEqual(0, smart_protocol.next_read_size())
 
1911
 
 
1912
    def test__send_response_errors_with_base_response(self):
 
1913
        """Ensure that only the Successful/Failed subclasses are used."""
 
1914
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1915
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1916
            _mod_request.SmartServerResponse(('x',)))
 
1917
 
 
1918
    def test_query_version(self):
 
1919
        """query_version on a SmartClientProtocolTwo should return a number.
 
1920
        
 
1921
        The protocol provides the query_version because the domain level clients
 
1922
        may all need to be able to probe for capabilities.
 
1923
        """
 
1924
        # What we really want to test here is that SmartClientProtocolTwo calls
 
1925
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1926
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1927
        # the error if the response is a non-understood version.
 
1928
        input = StringIO(self.response_marker + 'success\nok\x012\n')
 
1929
        output = StringIO()
 
1930
        client_medium = medium.SmartSimplePipesClientMedium(
 
1931
            input, output, 'base')
 
1932
        request = client_medium.get_request()
 
1933
        smart_protocol = self.client_protocol_class(request)
 
1934
        self.assertEqual(2, smart_protocol.query_version())
 
1935
 
 
1936
    def test_client_call_empty_response(self):
 
1937
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1938
        # when the parsed line is an empty line, and results in a tuple with
 
1939
        # one element - an empty string.
 
1940
        self.assertServerToClientEncoding(
 
1941
            self.response_marker + 'success\n\n', ('', ), [(), ('', )])
 
1942
 
 
1943
    def test_client_call_three_element_response(self):
 
1944
        # protocol.call() can get back tuples of other lengths. A three element
 
1945
        # tuple should be unpacked as three strings.
 
1946
        self.assertServerToClientEncoding(
 
1947
            self.response_marker + 'success\na\x01b\x0134\n',
 
1948
            ('a', 'b', '34'),
 
1949
            [('a', 'b', '34')])
 
1950
 
 
1951
    def test_client_call_with_body_bytes_uploads(self):
 
1952
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1953
        # wire.
 
1954
        expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
 
1955
        input = StringIO("\n")
 
1956
        output = StringIO()
 
1957
        client_medium = medium.SmartSimplePipesClientMedium(
 
1958
            input, output, 'base')
 
1959
        request = client_medium.get_request()
 
1960
        smart_protocol = self.client_protocol_class(request)
 
1961
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1962
        self.assertEqual(expected_bytes, output.getvalue())
 
1963
 
 
1964
    def test_client_call_with_body_readv_array(self):
 
1965
        # protocol.call_with_upload should encode the readv array and then
 
1966
        # length-prefix the bytes onto the wire.
 
1967
        expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
 
1968
        input = StringIO("\n")
 
1969
        output = StringIO()
 
1970
        client_medium = medium.SmartSimplePipesClientMedium(
 
1971
            input, output, 'base')
 
1972
        request = client_medium.get_request()
 
1973
        smart_protocol = self.client_protocol_class(request)
 
1974
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1975
        self.assertEqual(expected_bytes, output.getvalue())
 
1976
 
 
1977
    def test_client_read_body_bytes_all(self):
 
1978
        # read_body_bytes should decode the body bytes from the wire into
 
1979
        # a response.
 
1980
        expected_bytes = "1234567"
 
1981
        server_bytes = (self.response_marker +
 
1982
                        "success\nok\n7\n1234567done\n")
 
1983
        input = StringIO(server_bytes)
 
1984
        output = StringIO()
 
1985
        client_medium = medium.SmartSimplePipesClientMedium(
 
1986
            input, output, 'base')
 
1987
        request = client_medium.get_request()
 
1988
        smart_protocol = self.client_protocol_class(request)
 
1989
        smart_protocol.call('foo')
 
1990
        smart_protocol.read_response_tuple(True)
 
1991
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
1992
 
 
1993
    def test_client_read_body_bytes_incremental(self):
 
1994
        # test reading a few bytes at a time from the body
 
1995
        # XXX: possibly we should test dribbling the bytes into the stringio
 
1996
        # to make the state machine work harder: however, as we use the
 
1997
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
1998
        # that.
 
1999
        expected_bytes = "1234567"
 
2000
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2001
        input = StringIO(server_bytes)
 
2002
        output = StringIO()
 
2003
        client_medium = medium.SmartSimplePipesClientMedium(
 
2004
            input, output, 'base')
 
2005
        request = client_medium.get_request()
 
2006
        smart_protocol = self.client_protocol_class(request)
 
2007
        smart_protocol.call('foo')
 
2008
        smart_protocol.read_response_tuple(True)
 
2009
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
2010
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
2011
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
2012
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
2013
 
 
2014
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
2015
        # cancelling the expected body needs to finish the request, but not
 
2016
        # read any more bytes.
 
2017
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2018
        input = StringIO(server_bytes)
 
2019
        output = StringIO()
 
2020
        client_medium = medium.SmartSimplePipesClientMedium(
 
2021
            input, output, 'base')
 
2022
        request = client_medium.get_request()
 
2023
        smart_protocol = self.client_protocol_class(request)
 
2024
        smart_protocol.call('foo')
 
2025
        smart_protocol.read_response_tuple(True)
 
2026
        smart_protocol.cancel_read_body()
 
2027
        self.assertEqual(len(self.response_marker + 'success\nok\n'),
 
2028
                         input.tell())
 
2029
        self.assertRaises(
 
2030
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2031
 
 
2032
 
 
2033
class TestSmartProtocolTwoSpecificsMixin(object):
 
2034
 
1761
2035
    def assertBodyStreamSerialisation(self, expected_serialisation,
1762
2036
                                      body_stream):
1763
2037
        """Assert that body_stream is serialised as expected_serialisation."""
1810
2084
        self.assertBodyStreamSerialisation(expected_bytes, stream)
1811
2085
        self.assertBodyStreamRoundTrips(stream)
1812
2086
 
1813
 
    def test_accept_bytes_of_bad_request_to_protocol(self):
1814
 
        out_stream = StringIO()
1815
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1816
 
            None, out_stream.write)
1817
 
        smart_protocol.accept_bytes('abc')
1818
 
        self.assertEqual('abc', smart_protocol.in_buffer)
1819
 
        smart_protocol.accept_bytes('\n')
1820
 
        self.assertEqual(
1821
 
            protocol.RESPONSE_VERSION_TWO +
1822
 
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1823
 
            out_stream.getvalue())
1824
 
        self.assertTrue(smart_protocol.has_dispatched)
1825
 
        self.assertEqual(0, smart_protocol.next_read_size())
1826
 
 
1827
 
    def test_accept_body_bytes_to_protocol(self):
1828
 
        protocol = self.build_protocol_waiting_for_body()
1829
 
        self.assertEqual(6, protocol.next_read_size())
1830
 
        protocol.accept_bytes('7\nabc')
1831
 
        self.assertEqual(9, protocol.next_read_size())
1832
 
        protocol.accept_bytes('defgd')
1833
 
        protocol.accept_bytes('one\n')
1834
 
        self.assertEqual(0, protocol.next_read_size())
1835
 
        self.assertTrue(self.end_received)
1836
 
 
1837
 
    def test_accept_request_and_body_all_at_once(self):
1838
 
        self._captureVar('BZR_NO_SMART_VFS', None)
1839
 
        mem_transport = memory.MemoryTransport()
1840
 
        mem_transport.put_bytes('foo', 'abcdefghij')
1841
 
        out_stream = StringIO()
1842
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
1843
 
                out_stream.write)
1844
 
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1845
 
        self.assertEqual(0, smart_protocol.next_read_size())
1846
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO +
1847
 
                         'success\nreadv\n3\ndefdone\n',
1848
 
                         out_stream.getvalue())
1849
 
        self.assertEqual('', smart_protocol.excess_buffer)
1850
 
        self.assertEqual('', smart_protocol.in_buffer)
1851
 
 
1852
 
    def test_accept_excess_bytes_are_preserved(self):
1853
 
        out_stream = StringIO()
1854
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1855
 
            None, out_stream.write)
1856
 
        smart_protocol.accept_bytes('hello\nhello\n')
1857
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
1858
 
                         out_stream.getvalue())
1859
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
1860
 
        self.assertEqual("", smart_protocol.in_buffer)
1861
 
 
1862
 
    def test_accept_excess_bytes_after_body(self):
1863
 
        # The excess bytes look like the start of another request.
1864
 
        server_protocol = self.build_protocol_waiting_for_body()
1865
 
        server_protocol.accept_bytes(
1866
 
            '7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
1867
 
        self.assertTrue(self.end_received)
1868
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO,
1869
 
                         server_protocol.excess_buffer)
1870
 
        self.assertEqual("", server_protocol.in_buffer)
1871
 
        server_protocol.accept_bytes('Y')
1872
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
1873
 
                         server_protocol.excess_buffer)
1874
 
        self.assertEqual("", server_protocol.in_buffer)
1875
 
 
1876
 
    def test_accept_excess_bytes_after_dispatch(self):
1877
 
        out_stream = StringIO()
1878
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1879
 
            None, out_stream.write)
1880
 
        smart_protocol.accept_bytes('hello\n')
1881
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
1882
 
                         out_stream.getvalue())
1883
 
        smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
1884
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
1885
 
                         smart_protocol.excess_buffer)
1886
 
        smart_protocol.accept_bytes('lo\n')
1887
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
1888
 
                         smart_protocol.excess_buffer)
1889
 
        self.assertEqual("", smart_protocol.in_buffer)
1890
 
 
1891
 
    def test__send_response_sets_finished_reading(self):
1892
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1893
 
            None, lambda x: None)
1894
 
        self.assertEqual(1, smart_protocol.next_read_size())
1895
 
        smart_protocol._send_response(
1896
 
            _mod_request.SuccessfulSmartServerResponse(('x',)))
1897
 
        self.assertEqual(0, smart_protocol.next_read_size())
1898
 
 
1899
 
    def test__send_response_with_body_stream_sets_finished_reading(self):
1900
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1901
 
            None, lambda x: None)
1902
 
        self.assertEqual(1, smart_protocol.next_read_size())
1903
 
        smart_protocol._send_response(
1904
 
            _mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
1905
 
        self.assertEqual(0, smart_protocol.next_read_size())
1906
 
 
1907
 
    def test__send_response_errors_with_base_response(self):
1908
 
        """Ensure that only the Successful/Failed subclasses are used."""
1909
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1910
 
            None, lambda x: None)
1911
 
        self.assertRaises(AttributeError, smart_protocol._send_response,
1912
 
            _mod_request.SmartServerResponse(('x',)))
1913
 
 
1914
2087
    def test__send_response_includes_failure_marker(self):
1915
2088
        """FailedSmartServerResponse have 'failed\n' after the version."""
1916
2089
        out_stream = StringIO()
1931
2104
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
1932
2105
                         out_stream.getvalue())
1933
2106
 
1934
 
    def test_query_version(self):
1935
 
        """query_version on a SmartClientProtocolTwo should return a number.
1936
 
        
1937
 
        The protocol provides the query_version because the domain level clients
1938
 
        may all need to be able to probe for capabilities.
1939
 
        """
1940
 
        # What we really want to test here is that SmartClientProtocolTwo calls
1941
 
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1942
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
1943
 
        # the error if the response is a non-understood version.
1944
 
        input = StringIO(protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
1945
 
        output = StringIO()
1946
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1947
 
        request = client_medium.get_request()
1948
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1949
 
        self.assertEqual(2, smart_protocol.query_version())
1950
 
 
1951
 
    def test_client_call_empty_response(self):
1952
 
        # protocol.call() can get back an empty tuple as a response. This occurs
1953
 
        # when the parsed line is an empty line, and results in a tuple with
1954
 
        # one element - an empty string.
1955
 
        self.assertServerToClientEncoding(
1956
 
            protocol.RESPONSE_VERSION_TWO + 'success\n\n', ('', ), [(), ('', )])
1957
 
 
1958
 
    def test_client_call_three_element_response(self):
1959
 
        # protocol.call() can get back tuples of other lengths. A three element
1960
 
        # tuple should be unpacked as three strings.
1961
 
        self.assertServerToClientEncoding(
1962
 
            protocol.RESPONSE_VERSION_TWO + 'success\na\x01b\x0134\n',
1963
 
            ('a', 'b', '34'),
1964
 
            [('a', 'b', '34')])
1965
 
 
1966
 
    def test_client_call_with_body_bytes_uploads(self):
1967
 
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1968
 
        # wire.
1969
 
        expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
1970
 
        input = StringIO("\n")
1971
 
        output = StringIO()
1972
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1973
 
        request = client_medium.get_request()
1974
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1975
 
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1976
 
        self.assertEqual(expected_bytes, output.getvalue())
1977
 
 
1978
 
    def test_client_call_with_body_readv_array(self):
1979
 
        # protocol.call_with_upload should encode the readv array and then
1980
 
        # length-prefix the bytes onto the wire.
1981
 
        expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
1982
 
        input = StringIO("\n")
1983
 
        output = StringIO()
1984
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1985
 
        request = client_medium.get_request()
1986
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1987
 
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1988
 
        self.assertEqual(expected_bytes, output.getvalue())
1989
 
 
1990
 
    def test_client_read_response_tuple_sets_response_status(self):
1991
 
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
1992
 
        input = StringIO(server_bytes)
1993
 
        output = StringIO()
1994
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1995
 
        request = client_medium.get_request()
1996
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1997
 
        smart_protocol.call('foo')
1998
 
        smart_protocol.read_response_tuple(False)
1999
 
        self.assertEqual(True, smart_protocol.response_status)
2000
 
 
2001
 
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
2002
 
        """read_response_tuple raises UnknownSmartMethod if the response is
2003
 
        says the server did not recognise the request.
2004
 
        """
2005
 
        server_bytes = (
2006
 
            protocol.RESPONSE_VERSION_TWO +
2007
 
            "failed\n" +
2008
 
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
2009
 
        input = StringIO(server_bytes)
2010
 
        output = StringIO()
2011
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
2012
 
        request = client_medium.get_request()
2013
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2014
 
        smart_protocol.call('foo')
2015
 
        self.assertRaises(
2016
 
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
2017
 
        self.assertEqual(False, smart_protocol.response_status)
2018
 
        # The request has been finished.  There is no body to read, and
2019
 
        # attempts to read one will fail.
2020
 
        self.assertRaises(
2021
 
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
2022
 
 
2023
 
    def test_client_read_body_bytes_all(self):
2024
 
        # read_body_bytes should decode the body bytes from the wire into
2025
 
        # a response.
2026
 
        expected_bytes = "1234567"
2027
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2028
 
                        "success\nok\n7\n1234567done\n")
2029
 
        input = StringIO(server_bytes)
2030
 
        output = StringIO()
2031
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
2032
 
        request = client_medium.get_request()
2033
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2034
 
        smart_protocol.call('foo')
2035
 
        smart_protocol.read_response_tuple(True)
2036
 
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
2037
 
 
2038
 
    def test_client_read_body_bytes_incremental(self):
2039
 
        # test reading a few bytes at a time from the body
2040
 
        # XXX: possibly we should test dribbling the bytes into the stringio
2041
 
        # to make the state machine work harder: however, as we use the
2042
 
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
2043
 
        # that.
2044
 
        expected_bytes = "1234567"
2045
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2046
 
                        "success\nok\n7\n1234567done\n")
2047
 
        input = StringIO(server_bytes)
2048
 
        output = StringIO()
2049
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
2050
 
        request = client_medium.get_request()
2051
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2052
 
        smart_protocol.call('foo')
2053
 
        smart_protocol.read_response_tuple(True)
2054
 
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
2055
 
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
2056
 
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
2057
 
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
2058
 
 
2059
 
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
2060
 
        # cancelling the expected body needs to finish the request, but not
2061
 
        # read any more bytes.
2062
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2063
 
                        "success\nok\n7\n1234567done\n")
2064
 
        input = StringIO(server_bytes)
2065
 
        output = StringIO()
2066
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
2067
 
        request = client_medium.get_request()
2068
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2069
 
        smart_protocol.call('foo')
2070
 
        smart_protocol.read_response_tuple(True)
2071
 
        smart_protocol.cancel_read_body()
2072
 
        self.assertEqual(len(protocol.RESPONSE_VERSION_TWO + 'success\nok\n'),
2073
 
                         input.tell())
2074
 
        self.assertRaises(
2075
 
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2107
    def test__send_response_with_body_stream_sets_finished_reading(self):
 
2108
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
2109
            None, lambda x: None)
 
2110
        self.assertEqual(1, smart_protocol.next_read_size())
 
2111
        smart_protocol._send_response(
 
2112
            _mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
 
2113
        self.assertEqual(0, smart_protocol.next_read_size())
2076
2114
 
2077
2115
    def test_streamed_body_bytes(self):
2078
2116
        body_header = 'chunked\n'
2083
2121
                        body_terminator)
2084
2122
        input = StringIO(server_bytes)
2085
2123
        output = StringIO()
2086
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
2124
        client_medium = medium.SmartSimplePipesClientMedium(
 
2125
            input, output, 'base')
2087
2126
        request = client_medium.get_request()
2088
2127
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2089
2128
        smart_protocol.call('foo')
2103
2142
                        "success\nok\n" + body)
2104
2143
        input = StringIO(server_bytes)
2105
2144
        output = StringIO()
2106
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
2145
        client_medium = medium.SmartSimplePipesClientMedium(
 
2146
            input, output, 'base')
2107
2147
        smart_request = client_medium.get_request()
2108
2148
        smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
2109
2149
        smart_protocol.call('foo')
2114
2154
        stream = smart_protocol.read_streamed_body()
2115
2155
        self.assertEqual(expected_chunks, list(stream))
2116
2156
 
 
2157
    def test_client_read_response_tuple_sets_response_status(self):
 
2158
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
 
2159
        input = StringIO(server_bytes)
 
2160
        output = StringIO()
 
2161
        client_medium = medium.SmartSimplePipesClientMedium(
 
2162
            input, output, 'base')
 
2163
        request = client_medium.get_request()
 
2164
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2165
        smart_protocol.call('foo')
 
2166
        smart_protocol.read_response_tuple(False)
 
2167
        self.assertEqual(True, smart_protocol.response_status)
 
2168
 
 
2169
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
 
2170
        """read_response_tuple raises UnknownSmartMethod if the response says
 
2171
        the server did not recognise the request.
 
2172
        """
 
2173
        server_bytes = (
 
2174
            protocol.RESPONSE_VERSION_TWO +
 
2175
            "failed\n" +
 
2176
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
2177
        input = StringIO(server_bytes)
 
2178
        output = StringIO()
 
2179
        client_medium = medium.SmartSimplePipesClientMedium(
 
2180
            input, output, 'base')
 
2181
        request = client_medium.get_request()
 
2182
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2183
        smart_protocol.call('foo')
 
2184
        self.assertRaises(
 
2185
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
 
2186
        # The request has been finished.  There is no body to read, and
 
2187
        # attempts to read one will fail.
 
2188
        self.assertRaises(
 
2189
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2190
 
 
2191
 
 
2192
class TestSmartProtocolTwoSpecifics(
 
2193
        TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
 
2194
    """Tests for aspects of smart protocol version two that are unique to
 
2195
    version two.
 
2196
 
 
2197
    Thus tests involving body streams and success/failure markers belong here.
 
2198
    """
 
2199
 
 
2200
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
2201
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
2202
 
 
2203
 
 
2204
class TestVersionOneFeaturesInProtocolThree(
 
2205
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2206
    """Tests for version one smart protocol features as implemented by version
 
2207
    three.
 
2208
    """
 
2209
 
 
2210
    request_encoder = protocol.ProtocolThreeRequester
 
2211
    response_decoder = protocol.ProtocolThreeDecoder
 
2212
    # build_server_protocol_three is a function, so we can't set it as a class
 
2213
    # attribute directly, because then Python will assume it is actually a
 
2214
    # method.  So we make server_protocol_class be a static method, rather than
 
2215
    # simply doing:
 
2216
    # "server_protocol_class = protocol.build_server_protocol_three".
 
2217
    server_protocol_class = staticmethod(protocol.build_server_protocol_three)
 
2218
 
 
2219
    def setUp(self):
 
2220
        super(TestVersionOneFeaturesInProtocolThree, self).setUp()
 
2221
        self.response_marker = protocol.MESSAGE_VERSION_THREE
 
2222
        self.request_marker = protocol.MESSAGE_VERSION_THREE
 
2223
 
 
2224
    def test_construct_version_three_server_protocol(self):
 
2225
        smart_protocol = protocol.ProtocolThreeDecoder(None)
 
2226
        self.assertEqual('', smart_protocol.unused_data)
 
2227
        self.assertEqual('', smart_protocol._in_buffer)
 
2228
        self.assertFalse(smart_protocol._has_dispatched)
 
2229
        # The protocol starts by expecting four bytes, a length prefix for the
 
2230
        # headers.
 
2231
        self.assertEqual(4, smart_protocol.next_read_size())
 
2232
 
 
2233
 
 
2234
class NoOpRequest(_mod_request.SmartServerRequest):
 
2235
 
 
2236
    def do(self):
 
2237
        return _mod_request.SuccessfulSmartServerResponse(())
 
2238
 
 
2239
dummy_registry = {'ARG': NoOpRequest}
 
2240
 
 
2241
 
 
2242
class LoggingMessageHandler(object):
 
2243
 
 
2244
    def __init__(self):
 
2245
        self.event_log = []
 
2246
 
 
2247
    def _log(self, *args):
 
2248
        self.event_log.append(args)
 
2249
 
 
2250
    def headers_received(self, headers):
 
2251
        self._log('headers', headers)
 
2252
 
 
2253
    def protocol_error(self, exception):
 
2254
        self._log('protocol_error', exception)
 
2255
 
 
2256
    def byte_part_received(self, byte):
 
2257
        self._log('byte', byte)
 
2258
 
 
2259
    def bytes_part_received(self, bytes):
 
2260
        self._log('bytes', bytes)
 
2261
 
 
2262
    def structure_part_received(self, structure):
 
2263
        self._log('structure', structure)
 
2264
 
 
2265
    def end_received(self):
 
2266
        self._log('end')
 
2267
 
 
2268
 
 
2269
class TestProtocolThree(TestSmartProtocol):
 
2270
    """Tests for v3 of the server-side protocol."""
 
2271
 
 
2272
    request_encoder = protocol.ProtocolThreeRequester
 
2273
    response_decoder = protocol.ProtocolThreeDecoder
 
2274
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2275
 
 
2276
    def test_trivial_request(self):
 
2277
        """Smoke test for the simplest possible v3 request: empty headers, no
 
2278
        message parts.
 
2279
        """
 
2280
        output = StringIO()
 
2281
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2282
        end = 'e'
 
2283
        request_bytes = headers + end
 
2284
        smart_protocol = self.server_protocol_class(LoggingMessageHandler())
 
2285
        smart_protocol.accept_bytes(request_bytes)
 
2286
        self.assertEqual(0, smart_protocol.next_read_size())
 
2287
        self.assertEqual('', smart_protocol.unused_data)
 
2288
 
 
2289
    def make_protocol_expecting_message_part(self):
 
2290
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2291
        message_handler = LoggingMessageHandler()
 
2292
        smart_protocol = self.server_protocol_class(message_handler)
 
2293
        smart_protocol.accept_bytes(headers)
 
2294
        # Clear the event log
 
2295
        del message_handler.event_log[:]
 
2296
        return smart_protocol, message_handler.event_log
 
2297
 
 
2298
    def test_decode_one_byte(self):
 
2299
        """The protocol can decode a 'one byte' message part."""
 
2300
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2301
        smart_protocol.accept_bytes('ox')
 
2302
        self.assertEqual([('byte', 'x')], event_log)
 
2303
 
 
2304
    def test_decode_bytes(self):
 
2305
        """The protocol can decode a 'bytes' message part."""
 
2306
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2307
        smart_protocol.accept_bytes(
 
2308
            'b' # message part kind
 
2309
            '\0\0\0\x07' # length prefix
 
2310
            'payload' # payload
 
2311
            )
 
2312
        self.assertEqual([('bytes', 'payload')], event_log)
 
2313
 
 
2314
    def test_decode_structure(self):
 
2315
        """The protocol can decode a 'structure' message part."""
 
2316
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2317
        smart_protocol.accept_bytes(
 
2318
            's' # message part kind
 
2319
            '\0\0\0\x07' # length prefix
 
2320
            'l3:ARGe' # ['ARG']
 
2321
            )
 
2322
        self.assertEqual([('structure', ['ARG'])], event_log)
 
2323
 
 
2324
    def test_decode_multiple_bytes(self):
 
2325
        """The protocol can decode a multiple 'bytes' message parts."""
 
2326
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2327
        smart_protocol.accept_bytes(
 
2328
            'b' # message part kind
 
2329
            '\0\0\0\x05' # length prefix
 
2330
            'first' # payload
 
2331
            'b' # message part kind
 
2332
            '\0\0\0\x06'
 
2333
            'second'
 
2334
            )
 
2335
        self.assertEqual(
 
2336
            [('bytes', 'first'), ('bytes', 'second')], event_log)
 
2337
 
 
2338
 
 
2339
class TestConventionalResponseHandler(tests.TestCase):
 
2340
 
 
2341
    def test_interrupted_body_stream(self):
 
2342
        interrupted_body_stream = (
 
2343
            'oS' # successful response
 
2344
            's\0\0\0\x02le' # empty args
 
2345
            'b\0\0\0\x09chunk one' # first chunk
 
2346
            'b\0\0\0\x09chunk two' # second chunk
 
2347
            'oE' # error flag
 
2348
            's\0\0\0\x0el5:error3:abce' # bencoded error
 
2349
            'e' # message end
 
2350
            )
 
2351
        from bzrlib.smart.message import ConventionalResponseHandler
 
2352
        response_handler = ConventionalResponseHandler()
 
2353
        protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2354
        # put decoder in desired state (waiting for message parts)
 
2355
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
 
2356
        output = StringIO()
 
2357
        client_medium = medium.SmartSimplePipesClientMedium(
 
2358
            StringIO(interrupted_body_stream), output, 'base')
 
2359
        medium_request = client_medium.get_request()
 
2360
        medium_request.finished_writing()
 
2361
        response_handler.setProtoAndMediumRequest(
 
2362
            protocol_decoder, medium_request)
 
2363
        stream = response_handler.read_streamed_body()
 
2364
        self.assertEqual('chunk one', stream.next())
 
2365
        self.assertEqual('chunk two', stream.next())
 
2366
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
 
2367
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2368
 
 
2369
 
 
2370
class TestMessageHandlerErrors(tests.TestCase):
 
2371
    """Tests for v3 that unrecognised (but well-formed) requests/responses are
 
2372
    still fully read off the wire, so that subsequent requests/responses on the
 
2373
    same medium can be decoded.
 
2374
    """
 
2375
 
 
2376
    def test_non_conventional_request(self):
 
2377
        """ConventionalRequestHandler (the default message handler on the
 
2378
        server side) will reject an unconventional message, but still consume
 
2379
        all the bytes of that message and signal when it has done so.
 
2380
 
 
2381
        This is what allows a server to continue to accept requests after the
 
2382
        client sends a completely unrecognised request.
 
2383
        """
 
2384
        # Define an invalid request (but one that is a well-formed message).
 
2385
        # This particular invalid request not only lacks the mandatory
 
2386
        # verb+args tuple, it has a single-byte part, which is forbidden.  In
 
2387
        # fact it has that part twice, to trigger multiple errors.
 
2388
        invalid_request = (
 
2389
            protocol.MESSAGE_VERSION_THREE +  # protocol version marker
 
2390
            '\0\0\0\x02de' + # empty headers
 
2391
            'oX' + # a single byte part: 'X'.  ConventionalRequestHandler will
 
2392
                   # error at this part.
 
2393
            'oX' + # and again.
 
2394
            'e' # end of message
 
2395
            )
 
2396
 
 
2397
        to_server = StringIO(invalid_request)
 
2398
        from_server = StringIO()
 
2399
        transport = memory.MemoryTransport('memory:///')
 
2400
        server = medium.SmartServerPipeStreamMedium(
 
2401
            to_server, from_server, transport)
 
2402
        proto = server._build_protocol()
 
2403
        message_handler = proto.message_handler
 
2404
        server._serve_one_request(proto)
 
2405
        # All the bytes have been read from the medium...
 
2406
        self.assertEqual('', to_server.read())
 
2407
        # ...and the protocol decoder has consumed all the bytes, and has
 
2408
        # finished reading.
 
2409
        self.assertEqual('', proto.unused_data)
 
2410
        self.assertEqual(0, proto.next_read_size())
 
2411
 
 
2412
 
 
2413
class InstrumentedRequestHandler(object):
 
2414
    """Test Double of SmartServerRequestHandler."""
 
2415
 
 
2416
    def __init__(self):
 
2417
        self.calls = []
 
2418
 
 
2419
    def body_chunk_received(self, chunk_bytes):
 
2420
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2421
 
 
2422
    def no_body_received(self):
 
2423
        self.calls.append(('no_body_received',))
 
2424
 
 
2425
    def prefixed_body_received(self, body_bytes):
 
2426
        self.calls.append(('prefixed_body_received', body_bytes))
 
2427
 
 
2428
    def end_received(self):
 
2429
        self.calls.append(('end_received',))
 
2430
 
 
2431
 
 
2432
class StubRequest(object):
 
2433
 
 
2434
    def finished_reading(self):
 
2435
        pass
 
2436
 
 
2437
 
 
2438
class TestClientDecodingProtocolThree(TestSmartProtocol):
 
2439
    """Tests for v3 of the client-side protocol decoding."""
 
2440
 
 
2441
    def make_logging_response_decoder(self):
 
2442
        """Make v3 response decoder using a test response handler."""
 
2443
        response_handler = LoggingMessageHandler()
 
2444
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2445
        return decoder, response_handler
 
2446
 
 
2447
    def make_conventional_response_decoder(self):
 
2448
        """Make v3 response decoder using a conventional response handler."""
 
2449
        response_handler = message.ConventionalResponseHandler()
 
2450
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2451
        response_handler.setProtoAndMediumRequest(decoder, StubRequest())
 
2452
        return decoder, response_handler
 
2453
 
 
2454
    def test_trivial_response_decoding(self):
 
2455
        """Smoke test for the simplest possible v3 response: empty headers,
 
2456
        status byte, empty args, no body.
 
2457
        """
 
2458
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2459
        response_status = 'oS' # success
 
2460
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2461
        end = 'e' # end marker
 
2462
        message_bytes = headers + response_status + args + end
 
2463
        decoder, response_handler = self.make_logging_response_decoder()
 
2464
        decoder.accept_bytes(message_bytes)
 
2465
        # The protocol decoder has finished, and consumed all bytes
 
2466
        self.assertEqual(0, decoder.next_read_size())
 
2467
        self.assertEqual('', decoder.unused_data)
 
2468
        # The message handler has been invoked with all the parts of the
 
2469
        # trivial response: empty headers, status byte, no args, end.
 
2470
        self.assertEqual(
 
2471
            [('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
 
2472
            response_handler.event_log)
 
2473
 
 
2474
    def test_incomplete_message(self):
 
2475
        """A decoder will keep signalling that it needs more bytes via
 
2476
        next_read_size() != 0 until it has seen a complete message, regardless
 
2477
        which state it is in.
 
2478
        """
 
2479
        # Define a simple response that uses all possible message parts.
 
2480
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2481
        response_status = 'oS' # success
 
2482
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2483
        body = 'b\0\0\0\x04BODY' # a body: 'BODY'
 
2484
        end = 'e' # end marker
 
2485
        simple_response = headers + response_status + args + body + end
 
2486
        # Feed the request to the decoder one byte at a time.
 
2487
        decoder, response_handler = self.make_logging_response_decoder()
 
2488
        for byte in simple_response:
 
2489
            self.assertNotEqual(0, decoder.next_read_size())
 
2490
            decoder.accept_bytes(byte)
 
2491
        # Now the response is complete
 
2492
        self.assertEqual(0, decoder.next_read_size())
 
2493
 
 
2494
    def test_read_response_tuple_raises_UnknownSmartMethod(self):
 
2495
        """read_response_tuple raises UnknownSmartMethod if the server replied
 
2496
        with 'UnknownMethod'.
 
2497
        """
 
2498
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2499
        response_status = 'oE' # error flag
 
2500
        # args: ('UnknownMethod', 'method-name')
 
2501
        args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
 
2502
        end = 'e' # end marker
 
2503
        message_bytes = headers + response_status + args + end
 
2504
        decoder, response_handler = self.make_conventional_response_decoder()
 
2505
        decoder.accept_bytes(message_bytes)
 
2506
        error = self.assertRaises(
 
2507
            errors.UnknownSmartMethod, response_handler.read_response_tuple)
 
2508
        self.assertEqual('method-name', error.verb)
 
2509
 
 
2510
    def test_read_response_tuple_error(self):
 
2511
        """If the response has an error, it is raised as an exception."""
 
2512
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2513
        response_status = 'oE' # error
 
2514
        args = 's\0\0\0\x1al9:first arg10:second arge' # two args
 
2515
        end = 'e' # end marker
 
2516
        message_bytes = headers + response_status + args + end
 
2517
        decoder, response_handler = self.make_conventional_response_decoder()
 
2518
        decoder.accept_bytes(message_bytes)
 
2519
        error = self.assertRaises(
 
2520
            errors.ErrorFromSmartServer, response_handler.read_response_tuple)
 
2521
        self.assertEqual(('first arg', 'second arg'), error.error_tuple)
 
2522
 
 
2523
 
 
2524
class TestClientEncodingProtocolThree(TestSmartProtocol):
 
2525
 
 
2526
    request_encoder = protocol.ProtocolThreeRequester
 
2527
    response_decoder = protocol.ProtocolThreeDecoder
 
2528
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2529
 
 
2530
    def make_client_encoder_and_output(self):
 
2531
        result = self.make_client_protocol_and_output()
 
2532
        requester, response_handler, output = result
 
2533
        return requester, output
 
2534
 
 
2535
    def test_call_smoke_test(self):
 
2536
        """A smoke test for ProtocolThreeRequester.call.
 
2537
 
 
2538
        This test checks that a particular simple invocation of call emits the
 
2539
        correct bytes for that invocation.
 
2540
        """
 
2541
        requester, output = self.make_client_encoder_and_output()
 
2542
        requester.set_headers({'header name': 'header value'})
 
2543
        requester.call('one arg')
 
2544
        self.assertEquals(
 
2545
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2546
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2547
            's\x00\x00\x00\x0bl7:one arge' # args
 
2548
            'e', # end
 
2549
            output.getvalue())
 
2550
 
 
2551
    def test_call_with_body_bytes_smoke_test(self):
 
2552
        """A smoke test for ProtocolThreeRequester.call_with_body_bytes.
 
2553
 
 
2554
        This test checks that a particular simple invocation of
 
2555
        call_with_body_bytes emits the correct bytes for that invocation.
 
2556
        """
 
2557
        requester, output = self.make_client_encoder_and_output()
 
2558
        requester.set_headers({'header name': 'header value'})
 
2559
        requester.call_with_body_bytes(('one arg',), 'body bytes')
 
2560
        self.assertEquals(
 
2561
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2562
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2563
            's\x00\x00\x00\x0bl7:one arge' # args
 
2564
            'b' # there is a prefixed body
 
2565
            '\x00\x00\x00\nbody bytes' # the prefixed body
 
2566
            'e', # end
 
2567
            output.getvalue())
 
2568
 
 
2569
    def test_call_writes_just_once(self):
 
2570
        """A bodyless request is written to the medium all at once."""
 
2571
        medium_request = StubMediumRequest()
 
2572
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2573
        encoder.call('arg1', 'arg2', 'arg3')
 
2574
        self.assertEqual(
 
2575
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2576
 
 
2577
    def test_call_with_body_bytes_writes_just_once(self):
 
2578
        """A request with body bytes is written to the medium all at once."""
 
2579
        medium_request = StubMediumRequest()
 
2580
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2581
        encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
 
2582
        self.assertEqual(
 
2583
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2584
 
 
2585
 
 
2586
class StubMediumRequest(object):
 
2587
    """A stub medium request that tracks the number of times accept_bytes is
 
2588
    called.
 
2589
    """
 
2590
 
 
2591
    def __init__(self):
 
2592
        self.calls = []
 
2593
 
 
2594
    def accept_bytes(self, bytes):
 
2595
        self.calls.append('accept_bytes')
 
2596
 
 
2597
    def finished_writing(self):
 
2598
        self.calls.append('finished_writing')
 
2599
 
 
2600
 
 
2601
class TestResponseEncodingProtocolThree(tests.TestCase):
 
2602
 
 
2603
    def make_response_encoder(self):
 
2604
        out_stream = StringIO()
 
2605
        response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
 
2606
        return response_encoder, out_stream
 
2607
 
 
2608
    def test_send_error_unknown_method(self):
 
2609
        encoder, out_stream = self.make_response_encoder()
 
2610
        encoder.send_error(errors.UnknownSmartMethod('method name'))
 
2611
        # Use assertEndsWith so that we don't compare the header, which varies
 
2612
        # by bzrlib.__version__.
 
2613
        self.assertEndsWith(
 
2614
            out_stream.getvalue(),
 
2615
            # error status
 
2616
            'oE' +
 
2617
            # tuple: 'UnknownMethod', 'method name'
 
2618
            's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
 
2619
            # end of message
 
2620
            'e')
 
2621
 
 
2622
 
 
2623
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
 
2624
    """Tests for buffering of responses.
 
2625
 
 
2626
    We want to avoid doing many small writes when one would do, to avoid
 
2627
    unnecessary network overhead.
 
2628
    """
 
2629
 
 
2630
    def setUp(self):
 
2631
        self.writes = []
 
2632
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
 
2633
 
 
2634
    def assertWriteCount(self, expected_count):
 
2635
        self.assertEqual(
 
2636
            expected_count, len(self.writes),
 
2637
            "Too many writes: %r" % (self.writes,))
 
2638
        
 
2639
    def test_send_error_writes_just_once(self):
 
2640
        """An error response is written to the medium all at once."""
 
2641
        self.responder.send_error(Exception('An exception string.'))
 
2642
        self.assertWriteCount(1)
 
2643
 
 
2644
    def test_send_response_writes_just_once(self):
 
2645
        """A normal response with no body is written to the medium all at once.
 
2646
        """
 
2647
        response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
 
2648
        self.responder.send_response(response)
 
2649
        self.assertWriteCount(1)
 
2650
 
 
2651
    def test_send_response_with_body_writes_just_once(self):
 
2652
        """A normal response with a monolithic body is written to the medium
 
2653
        all at once.
 
2654
        """
 
2655
        response = _mod_request.SuccessfulSmartServerResponse(
 
2656
            ('arg', 'arg'), body='body bytes')
 
2657
        self.responder.send_response(response)
 
2658
        self.assertWriteCount(1)
 
2659
 
 
2660
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
 
2661
        """A normal response with a stream body is written to the medium
 
2662
        writes to the medium once per chunk.
 
2663
        """
 
2664
        # Construct a response with stream with 2 chunks in it.
 
2665
        response = _mod_request.SuccessfulSmartServerResponse(
 
2666
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
 
2667
        self.responder.send_response(response)
 
2668
        # We will write 3 times: exactly once for each chunk, plus a final
 
2669
        # write to end the response.
 
2670
        self.assertWriteCount(3)
 
2671
 
2117
2672
 
2118
2673
class TestSmartClientUnicode(tests.TestCase):
2119
2674
    """_SmartClient tests for unicode arguments.
2134
2689
        """
2135
2690
        input = StringIO("\n")
2136
2691
        output = StringIO()
2137
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
2138
 
        smart_client = client._SmartClient(client_medium, 'ignored base')
 
2692
        client_medium = medium.SmartSimplePipesClientMedium(
 
2693
            input, output, 'ignored base')
 
2694
        smart_client = client._SmartClient(client_medium)
2139
2695
        self.assertRaises(TypeError,
2140
2696
            smart_client.call_with_body_bytes, method, args, body)
2141
2697
        self.assertEqual("", output.getvalue())
2152
2708
        self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
2153
2709
 
2154
2710
 
 
2711
class MockMedium(object):
 
2712
    """A mock medium that can be used to test _SmartClient.
 
2713
    
 
2714
    It can be given a series of requests to expect (and responses it should
 
2715
    return for them).  It can also be told when the client is expected to
 
2716
    disconnect a medium.  Expectations must be satisfied in the order they are
 
2717
    given, or else an AssertionError will be raised.
 
2718
 
 
2719
    Typical use looks like::
 
2720
 
 
2721
        medium = MockMedium()
 
2722
        medium.expect_request(...)
 
2723
        medium.expect_request(...)
 
2724
        medium.expect_request(...)
 
2725
    """
 
2726
 
 
2727
    def __init__(self):
 
2728
        self.base = 'dummy base'
 
2729
        self._mock_request = _MockMediumRequest(self)
 
2730
        self._expected_events = []
 
2731
        self._protocol_version = None
 
2732
        
 
2733
    def expect_request(self, request_bytes, response_bytes,
 
2734
                       allow_partial_read=False):
 
2735
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
 
2736
 
 
2737
        No assumption is made about how many times accept_bytes should be
 
2738
        called to send the request.  Similarly, no assumption is made about how
 
2739
        many times read_bytes/read_line are called by protocol code to read a
 
2740
        response.  e.g.::
 
2741
        
 
2742
            request.accept_bytes('ab')
 
2743
            request.accept_bytes('cd')
 
2744
            request.finished_writing()
 
2745
 
 
2746
        and::
 
2747
        
 
2748
            request.accept_bytes('abcd')
 
2749
            request.finished_writing()
 
2750
 
 
2751
        Will both satisfy ``medium.expect_request('abcd', ...)``.  Thus tests
 
2752
        using this should not break due to irrelevant changes in protocol
 
2753
        implementations.
 
2754
 
 
2755
        :param allow_partial_read: if True, no assertion is raised if a
 
2756
            response is not fully read.  Setting this is useful when the client
 
2757
            is expected to disconnect without needing to read the complete
 
2758
            response.  Default is False.
 
2759
        """
 
2760
        self._expected_events.append(('send request', request_bytes))
 
2761
        if allow_partial_read:
 
2762
            self._expected_events.append(
 
2763
                ('read response (partial)', response_bytes))
 
2764
        else:
 
2765
            self._expected_events.append(('read response', response_bytes))
 
2766
 
 
2767
    def expect_disconnect(self):
 
2768
        """Expect the client to call ``medium.disconnect()``."""
 
2769
        self._expected_events.append('disconnect')
 
2770
 
 
2771
    def _assertEvent(self, observed_event):
 
2772
        """Raise AssertionError unless observed_event matches the next expected
 
2773
        event.
 
2774
 
 
2775
        :seealso: expect_request
 
2776
        :seealso: expect_disconnect
 
2777
        """
 
2778
        try:
 
2779
            expected_event = self._expected_events.pop(0)
 
2780
        except IndexError:
 
2781
            raise AssertionError(
 
2782
                'Mock medium observed event %r, but no more events expected'
 
2783
                % (observed_event,))
 
2784
        if expected_event[0] == 'read response (partial)':
 
2785
            if observed_event[0] != 'read response':
 
2786
                raise AssertionError(
 
2787
                    'Mock medium observed event %r, but expected event %r'
 
2788
                    % (observed_event, expected_event))
 
2789
        elif observed_event != expected_event:
 
2790
            raise AssertionError(
 
2791
                'Mock medium observed event %r, but expected event %r'
 
2792
                % (observed_event, expected_event))
 
2793
        if self._expected_events:
 
2794
            next_event = self._expected_events[0]
 
2795
            if next_event[0].startswith('read response'):
 
2796
                self._mock_request._response = next_event[1]
 
2797
 
 
2798
    def get_request(self):
 
2799
        return self._mock_request
 
2800
 
 
2801
    def disconnect(self):
 
2802
        if self._mock_request._read_bytes:
 
2803
            self._assertEvent(('read response', self._mock_request._read_bytes))
 
2804
            self._mock_request._read_bytes = ''
 
2805
        self._assertEvent('disconnect')
 
2806
 
 
2807
 
 
2808
class _MockMediumRequest(object):
 
2809
    """A mock ClientMediumRequest used by MockMedium."""
 
2810
 
 
2811
    def __init__(self, mock_medium):
 
2812
        self._medium = mock_medium
 
2813
        self._written_bytes = ''
 
2814
        self._read_bytes = ''
 
2815
        self._response = None
 
2816
 
 
2817
    def accept_bytes(self, bytes):
 
2818
        self._written_bytes += bytes
 
2819
 
 
2820
    def finished_writing(self):
 
2821
        self._medium._assertEvent(('send request', self._written_bytes))
 
2822
        self._written_bytes = ''
 
2823
 
 
2824
    def finished_reading(self):
 
2825
        self._medium._assertEvent(('read response', self._read_bytes))
 
2826
        self._read_bytes = ''
 
2827
 
 
2828
    def read_bytes(self, size):
 
2829
        resp = self._response
 
2830
        bytes, resp = resp[:size], resp[size:]
 
2831
        self._response = resp
 
2832
        self._read_bytes += bytes
 
2833
        return bytes
 
2834
 
 
2835
    def read_line(self):
 
2836
        resp = self._response
 
2837
        try:
 
2838
            line, resp = resp.split('\n', 1)
 
2839
            line += '\n'
 
2840
        except ValueError:
 
2841
            line, resp = resp, ''
 
2842
        self._response = resp
 
2843
        self._read_bytes += line
 
2844
        return line
 
2845
 
 
2846
 
 
2847
class Test_SmartClientVersionDetection(tests.TestCase):
 
2848
    """Tests for _SmartClient's automatic protocol version detection.
 
2849
 
 
2850
    On the first remote call, _SmartClient will keep retrying the request with
 
2851
    different protocol versions until it finds one that works.
 
2852
    """
 
2853
 
 
2854
    def test_version_three_server(self):
 
2855
        """With a protocol 3 server, only one request is needed."""
 
2856
        medium = MockMedium()
 
2857
        smart_client = client._SmartClient(medium, headers={})
 
2858
        message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
 
2859
        medium.expect_request(
 
2860
            message_start +
 
2861
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2862
            message_start + 's\0\0\0\x13l14:response valueee')
 
2863
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2864
        # The call succeeded without raising any exceptions from the mock
 
2865
        # medium, and the smart_client returns the response from the server.
 
2866
        self.assertEqual(('response value',), result)
 
2867
        self.assertEqual([], medium._expected_events)
 
2868
 
 
2869
    def test_version_two_server(self):
 
2870
        """If the server only speaks protocol 2, the client will first try
 
2871
        version 3, then fallback to protocol 2.
 
2872
 
 
2873
        Further, _SmartClient caches the detection, so future requests will all
 
2874
        use protocol 2 immediately.
 
2875
        """
 
2876
        medium = MockMedium()
 
2877
        smart_client = client._SmartClient(medium, headers={})
 
2878
        # First the client should send a v3 request, but the server will reply
 
2879
        # with a v2 error.
 
2880
        medium.expect_request(
 
2881
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2882
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2883
            'bzr response 2\nfailed\n\n')
 
2884
        # So then the client should disconnect to reset the connection, because
 
2885
        # the client needs to assume the server cannot read any further
 
2886
        # requests off the original connection.
 
2887
        medium.expect_disconnect()
 
2888
        # The client should then retry the original request in v2
 
2889
        medium.expect_request(
 
2890
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2891
            'bzr response 2\nsuccess\nresponse value\n')
 
2892
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2893
        # The smart_client object will return the result of the successful
 
2894
        # query.
 
2895
        self.assertEqual(('response value',), result)
 
2896
 
 
2897
        # Now try another request, and this time the client will just use
 
2898
        # protocol 2.  (i.e. the autodetection won't be repeated)
 
2899
        medium.expect_request(
 
2900
            'bzr request 2\nanother-method\n',
 
2901
            'bzr response 2\nsuccess\nanother response\n')
 
2902
        result = smart_client.call('another-method')
 
2903
        self.assertEqual(('another response',), result)
 
2904
        self.assertEqual([], medium._expected_events)
 
2905
 
 
2906
    def test_unknown_version(self):
 
2907
        """If the server does not use any known (or at least supported)
 
2908
        protocol version, a SmartProtocolError is raised.
 
2909
        """
 
2910
        medium = MockMedium()
 
2911
        smart_client = client._SmartClient(medium, headers={})
 
2912
        unknown_protocol_bytes = 'Unknown protocol!'
 
2913
        # The client will try v3 and v2 before eventually giving up.
 
2914
        medium.expect_request(
 
2915
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2916
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2917
            unknown_protocol_bytes)
 
2918
        medium.expect_disconnect()
 
2919
        medium.expect_request(
 
2920
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2921
            unknown_protocol_bytes)
 
2922
        medium.expect_disconnect()
 
2923
        self.assertRaises(
 
2924
            errors.SmartProtocolError,
 
2925
            smart_client.call, 'method-name', 'arg 1', 'arg 2')
 
2926
        self.assertEqual([], medium._expected_events)
 
2927
        
 
2928
 
 
2929
class Test_SmartClient(tests.TestCase):
 
2930
 
 
2931
    def test_call_default_headers(self):
 
2932
        """ProtocolThreeRequester.call by default sends a 'Software
 
2933
        version' header.
 
2934
        """
 
2935
        smart_client = client._SmartClient('dummy medium')
 
2936
        self.assertEqual(
 
2937
            bzrlib.__version__, smart_client._headers['Software version'])
 
2938
        # XXX: need a test that smart_client._headers is passed to the request
 
2939
        # encoder.
 
2940
 
 
2941
 
2155
2942
class LengthPrefixedBodyDecoder(tests.TestCase):
2156
2943
 
2157
2944
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
2476
3263
        self.assertEqual(base_transport._http_transport,
2477
3264
                         new_transport._http_transport)
2478
3265
        self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
 
3266
        self.assertEqual(
 
3267
            'child_dir/',
 
3268
            new_transport._client.remote_path_from_transport(new_transport))
2479
3269
 
2480
3270
    def test_remote_path_unnormal_base(self):
2481
3271
        # If the transport's base isn't normalised, the _remote_path should
2489
3279
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
2490
3280
        new_transport = base_transport.clone('c')
2491
3281
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
3282
        self.assertEqual(
 
3283
            'c/',
 
3284
            new_transport._client.remote_path_from_transport(new_transport))
2492
3285
 
2493
3286
        
2494
3287
# TODO: Client feature that does get_bundle and then installs that into a