~lifeless/bzr/index.range_map

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Robert Collins
  • Date: 2008-06-19 01:17:19 UTC
  • mfrom: (3218.1.277 +trunk)
  • Revision ID: robertc@robertcollins.net-20080619011719-1c4g4uxzzhdls2wf
Merge bzr.dev.

Show diffs side-by-side

added added

removed removed

Lines of Context:
21
21
import os
22
22
import socket
23
23
import threading
24
 
import urllib2
25
24
 
 
25
import bzrlib
26
26
from bzrlib import (
27
27
        bzrdir,
28
28
        errors,
33
33
from bzrlib.smart import (
34
34
        client,
35
35
        medium,
 
36
        message,
36
37
        protocol,
37
 
        request,
38
38
        request as _mod_request,
39
39
        server,
40
40
        vfs,
41
41
)
42
 
from bzrlib.tests.http_utils import (
43
 
        HTTPServerWithSmarts,
44
 
        SmartRequestHandler,
45
 
        )
46
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
47
43
from bzrlib.transport import (
48
44
        get_transport,
119
115
        sock.bind(('127.0.0.1', 0))
120
116
        sock.listen(1)
121
117
        port = sock.getsockname()[1]
122
 
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port)
 
118
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
123
119
        return sock, client_medium
124
120
 
125
121
    def receive_bytes_on_server(self, sock, bytes):
137
133
        t.start()
138
134
        return t
139
135
    
140
 
    def test_construct_smart_stream_medium_client(self):
141
 
        # make a new instance of the common base for Stream-like Mediums.
142
 
        # this just ensures that the constructor stays parameter-free which
143
 
        # is important for reuse : some subclasses will dynamically connect,
144
 
        # others are always on, etc.
145
 
        client_medium = medium.SmartClientStreamMedium()
146
 
 
147
 
    def test_construct_smart_client_medium(self):
148
 
        # the base client medium takes no parameters
149
 
        client_medium = medium.SmartClientMedium()
150
 
    
151
136
    def test_construct_smart_simple_pipes_client_medium(self):
152
137
        # the SimplePipes client medium takes two pipes:
153
138
        # readable pipe, writeable pipe.
154
139
        # Constructing one should just save these and do nothing.
155
140
        # We test this by passing in None.
156
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
157
142
        
158
143
    def test_simple_pipes_client_request_type(self):
159
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
160
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
161
146
        request = client_medium.get_request()
162
147
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
163
148
 
169
154
        # classes - as the sibling classes share this logic, they do not have
170
155
        # explicit tests for this.
171
156
        output = StringIO()
172
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
157
        client_medium = medium.SmartSimplePipesClientMedium(
 
158
            None, output, 'base')
173
159
        request = client_medium.get_request()
174
160
        request.finished_writing()
175
161
        request.finished_reading()
180
166
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
181
167
        # accept_bytes writes to the writeable pipe.
182
168
        output = StringIO()
183
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
169
        client_medium = medium.SmartSimplePipesClientMedium(
 
170
            None, output, 'base')
184
171
        client_medium._accept_bytes('abc')
185
172
        self.assertEqual('abc', output.getvalue())
186
173
    
188
175
        # calling disconnect does nothing.
189
176
        input = StringIO()
190
177
        output = StringIO()
191
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
178
        client_medium = medium.SmartSimplePipesClientMedium(
 
179
            input, output, 'base')
192
180
        # send some bytes to ensure disconnecting after activity still does not
193
181
        # close.
194
182
        client_medium._accept_bytes('abc')
201
189
        # accept_bytes writes to.
202
190
        input = StringIO()
203
191
        output = StringIO()
204
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
192
        client_medium = medium.SmartSimplePipesClientMedium(
 
193
            input, output, 'base')
205
194
        client_medium._accept_bytes('abc')
206
195
        client_medium.disconnect()
207
196
        client_medium._accept_bytes('abc')
212
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
213
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
214
203
        # does nothing.
215
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
204
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
216
205
        client_medium.disconnect()
217
206
 
218
207
    def test_simple_pipes_client_can_always_read(self):
219
208
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
220
209
        # always tries to read from the underlying pipe.
221
210
        input = StringIO('abcdef')
222
 
        client_medium = medium.SmartSimplePipesClientMedium(input, None)
 
211
        client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
223
212
        self.assertEqual('abc', client_medium.read_bytes(3))
224
213
        client_medium.disconnect()
225
214
        self.assertEqual('def', client_medium.read_bytes(3))
234
223
        flush_calls = []
235
224
        def logging_flush(): flush_calls.append('flush')
236
225
        output.flush = logging_flush
237
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
226
        client_medium = medium.SmartSimplePipesClientMedium(
 
227
            input, output, 'base')
238
228
        # this call is here to ensure we only flush once, not on every
239
229
        # _accept_bytes call.
240
230
        client_medium._accept_bytes('abc')
254
244
        # having vendor be invalid means that if it tries to connect via the
255
245
        # vendor it will blow up.
256
246
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
257
 
            username=None, password=None, vendor="not a vendor",
 
247
            username=None, password=None, base='base', vendor="not a vendor",
258
248
            bzr_remote_path='bzr')
259
249
        sock.close()
260
250
 
264
254
        output = StringIO()
265
255
        vendor = StringIOSSHVendor(StringIO(), output)
266
256
        client_medium = medium.SmartSSHClientMedium(
267
 
            'a hostname', 'a port', 'a username', 'a password', vendor, 'bzr')
 
257
            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
 
258
            'bzr')
268
259
        client_medium._accept_bytes('abc')
269
260
        self.assertEqual('abc', output.getvalue())
270
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
285
276
        client_medium = self.callDeprecated(
286
277
            ['bzr_remote_path is required as of bzr 0.92'],
287
278
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
288
 
            'a password', vendor)
 
279
            'a password', 'base', vendor)
289
280
        client_medium._accept_bytes('abc')
290
281
        self.assertEqual('abc', output.getvalue())
291
282
        self.assertEqual([('connect_ssh', 'a username', 'a password',
299
290
        output = StringIO()
300
291
        vendor = StringIOSSHVendor(StringIO(), output)
301
292
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
302
 
            'a username', 'a password', vendor, bzr_remote_path='fugly')
 
293
            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
303
294
        client_medium._accept_bytes('abc')
304
295
        self.assertEqual('abc', output.getvalue())
305
296
        self.assertEqual([('connect_ssh', 'a username', 'a password',
313
304
        input = StringIO()
314
305
        output = StringIO()
315
306
        vendor = StringIOSSHVendor(input, output)
316
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
317
 
                                                    vendor=vendor,
318
 
                                                    bzr_remote_path='bzr')
 
307
        client_medium = medium.SmartSSHClientMedium(
 
308
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
319
309
        client_medium._accept_bytes('abc')
320
310
        client_medium.disconnect()
321
311
        self.assertTrue(input.closed)
335
325
        input = StringIO()
336
326
        output = StringIO()
337
327
        vendor = StringIOSSHVendor(input, output)
338
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
339
 
            vendor=vendor, bzr_remote_path='bzr')
 
328
        client_medium = medium.SmartSSHClientMedium(
 
329
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
340
330
        client_medium._accept_bytes('abc')
341
331
        client_medium.disconnect()
342
332
        # the disconnect has closed output, so we need a new output for the
364
354
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
365
355
        # Doing a disconnect on a new (and thus unconnected) SSH medium
366
356
        # does not fail.  It's ok to disconnect an unconnected medium.
367
 
        client_medium = medium.SmartSSHClientMedium(None,
368
 
                                                    bzr_remote_path='bzr')
 
357
        client_medium = medium.SmartSSHClientMedium(
 
358
            None, base='base', bzr_remote_path='bzr')
369
359
        client_medium.disconnect()
370
360
 
371
361
    def test_ssh_client_raises_on_read_when_not_connected(self):
372
362
        # Doing a read on a new (and thus unconnected) SSH medium raises
373
363
        # MediumNotConnected.
374
 
        client_medium = medium.SmartSSHClientMedium(None,
375
 
                                                    bzr_remote_path='bzr')
 
364
        client_medium = medium.SmartSSHClientMedium(
 
365
            None, base='base', bzr_remote_path='bzr')
376
366
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
377
367
                          0)
378
368
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
389
379
        def logging_flush(): flush_calls.append('flush')
390
380
        output.flush = logging_flush
391
381
        vendor = StringIOSSHVendor(input, output)
392
 
        client_medium = medium.SmartSSHClientMedium('a hostname',
393
 
                                                    vendor=vendor,
394
 
                                                    bzr_remote_path='bzr')
 
382
        client_medium = medium.SmartSSHClientMedium(
 
383
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
395
384
        # this call is here to ensure we only flush once, not on every
396
385
        # _accept_bytes call.
397
386
        client_medium._accept_bytes('abc')
405
394
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
406
395
        sock.bind(('127.0.0.1', 0))
407
396
        unopened_port = sock.getsockname()[1]
408
 
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
397
        client_medium = medium.SmartTCPClientMedium(
 
398
            '127.0.0.1', unopened_port, 'base')
409
399
        sock.close()
410
400
 
411
401
    def test_tcp_client_connects_on_first_use(self):
434
424
        # now disconnect again: this should not do anything, if disconnection
435
425
        # really did disconnect.
436
426
        medium.disconnect()
 
427
 
437
428
    
438
429
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
439
430
        # Doing a disconnect on a new (and thus unconnected) TCP medium
440
431
        # does not fail.  It's ok to disconnect an unconnected medium.
441
 
        client_medium = medium.SmartTCPClientMedium(None, None)
 
432
        client_medium = medium.SmartTCPClientMedium(None, None, None)
442
433
        client_medium.disconnect()
443
434
 
444
435
    def test_tcp_client_raises_on_read_when_not_connected(self):
445
436
        # Doing a read on a new (and thus unconnected) TCP medium raises
446
437
        # MediumNotConnected.
447
 
        client_medium = medium.SmartTCPClientMedium(None, None)
 
438
        client_medium = medium.SmartTCPClientMedium(None, None, None)
448
439
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
449
440
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
450
441
 
470
461
    def test_tcp_client_host_unknown_connection_error(self):
471
462
        self.requireFeature(InvalidHostnameFeature)
472
463
        client_medium = medium.SmartTCPClientMedium(
473
 
            'non_existent.invalid', 4155)
 
464
            'non_existent.invalid', 4155, 'base')
474
465
        self.assertRaises(
475
466
            errors.ConnectionError, client_medium._ensure_connection)
476
467
 
488
479
        # WritingCompleted to prevent bad assumptions on stream environments
489
480
        # breaking the needs of message-based environments.
490
481
        output = StringIO()
491
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
482
        client_medium = medium.SmartSimplePipesClientMedium(
 
483
            None, output, 'base')
492
484
        request = medium.SmartClientStreamMediumRequest(client_medium)
493
485
        request.finished_writing()
494
486
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
499
491
        # and checking that the pipes get the data.
500
492
        input = StringIO()
501
493
        output = StringIO()
502
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
494
        client_medium = medium.SmartSimplePipesClientMedium(
 
495
            input, output, 'base')
503
496
        request = medium.SmartClientStreamMediumRequest(client_medium)
504
497
        request.accept_bytes('123')
505
498
        request.finished_writing()
511
504
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
512
505
        # the current request to the new SmartClientStreamMediumRequest
513
506
        output = StringIO()
514
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
507
        client_medium = medium.SmartSimplePipesClientMedium(
 
508
            None, output, 'base')
515
509
        request = medium.SmartClientStreamMediumRequest(client_medium)
516
510
        self.assertIs(client_medium._current_request, request)
517
511
 
519
513
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
520
514
        # a non-None _current_request raises TooManyConcurrentRequests.
521
515
        output = StringIO()
522
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
516
        client_medium = medium.SmartSimplePipesClientMedium(
 
517
            None, output, 'base')
523
518
        client_medium._current_request = "a"
524
519
        self.assertRaises(errors.TooManyConcurrentRequests,
525
520
            medium.SmartClientStreamMediumRequest, client_medium)
528
523
        # calling finished_reading clears the current request from the requests
529
524
        # medium
530
525
        output = StringIO()
531
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
526
        client_medium = medium.SmartSimplePipesClientMedium(
 
527
            None, output, 'base')
532
528
        request = medium.SmartClientStreamMediumRequest(client_medium)
533
529
        request.finished_writing()
534
530
        request.finished_reading()
537
533
    def test_finished_read_before_finished_write_errors(self):
538
534
        # calling finished_reading before calling finished_writing triggers a
539
535
        # WritingNotComplete error.
540
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
536
        client_medium = medium.SmartSimplePipesClientMedium(
 
537
            None, None, 'base')
541
538
        request = medium.SmartClientStreamMediumRequest(client_medium)
542
539
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
543
540
        
550
547
        # smoke tests.
551
548
        input = StringIO('321')
552
549
        output = StringIO()
553
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
550
        client_medium = medium.SmartSimplePipesClientMedium(
 
551
            input, output, 'base')
554
552
        request = medium.SmartClientStreamMediumRequest(client_medium)
555
553
        request.finished_writing()
556
554
        self.assertEqual('321', request.read_bytes(3))
563
561
        # WritingNotComplete error because the Smart protocol is designed to be
564
562
        # compatible with strict message based protocols like HTTP where the
565
563
        # request cannot be submitted until the writing has completed.
566
 
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
564
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
567
565
        request = medium.SmartClientStreamMediumRequest(client_medium)
568
566
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
569
567
 
572
570
        # ReadingCompleted to prevent bad assumptions on stream environments
573
571
        # breaking the needs of message-based environments.
574
572
        output = StringIO()
575
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
573
        client_medium = medium.SmartSimplePipesClientMedium(
 
574
            None, output, 'base')
576
575
        request = medium.SmartClientStreamMediumRequest(client_medium)
577
576
        request.finished_writing()
578
577
        request.finished_reading()
610
609
        self.accepted_bytes = ''
611
610
        self._finished_reading = False
612
611
        self.expected_bytes = expected_bytes
613
 
        self.excess_buffer = ''
 
612
        self.unused_data = ''
614
613
 
615
614
    def accept_bytes(self, bytes):
616
615
        self.accepted_bytes += bytes
617
616
        if self.accepted_bytes.startswith(self.expected_bytes):
618
617
            self._finished_reading = True
619
 
            self.excess_buffer = self.accepted_bytes[len(self.expected_bytes):]
 
618
            self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
620
619
 
621
620
    def next_read_size(self):
622
621
        if self._finished_reading:
733
732
        server._serve_one_request(SampleRequest('x'))
734
733
        self.assertTrue(server.finished)
735
734
        
 
735
    def test_socket_stream_incomplete_request(self):
 
736
        """The medium should still construct the right protocol version even if
 
737
        the initial read only reads part of the request.
 
738
 
 
739
        Specifically, it should correctly read the protocol version line even
 
740
        if the partial read doesn't end in a newline.  An older, naive
 
741
        implementation of _get_line in the server used to have a bug in that
 
742
        case.
 
743
        """
 
744
        incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
 
745
        rest_of_request_bytes = 'lo\n'
 
746
        expected_response = (
 
747
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
748
        server_sock, client_sock = self.portable_socket_pair()
 
749
        server = medium.SmartServerSocketStreamMedium(
 
750
            server_sock, None)
 
751
        client_sock.sendall(incomplete_request_bytes)
 
752
        server_protocol = server._build_protocol()
 
753
        client_sock.sendall(rest_of_request_bytes)
 
754
        server._serve_one_request(server_protocol)
 
755
        server_sock.close()
 
756
        self.assertEqual(expected_response, client_sock.recv(50),
 
757
                         "Not a version 2 response to 'hello' request.")
 
758
        self.assertEqual('', client_sock.recv(1))
 
759
 
 
760
    def test_pipe_stream_incomplete_request(self):
 
761
        """The medium should still construct the right protocol version even if
 
762
        the initial read only reads part of the request.
 
763
 
 
764
        Specifically, it should correctly read the protocol version line even
 
765
        if the partial read doesn't end in a newline.  An older, naive
 
766
        implementation of _get_line in the server used to have a bug in that
 
767
        case.
 
768
        """
 
769
        incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
 
770
        rest_of_request_bytes = 'lo\n'
 
771
        expected_response = (
 
772
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
773
        # Make a pair of pipes, to and from the server
 
774
        to_server, to_server_w = os.pipe()
 
775
        from_server_r, from_server = os.pipe()
 
776
        to_server = os.fdopen(to_server, 'r', 0)
 
777
        to_server_w = os.fdopen(to_server_w, 'w', 0)
 
778
        from_server_r = os.fdopen(from_server_r, 'r', 0)
 
779
        from_server = os.fdopen(from_server, 'w', 0)
 
780
        server = medium.SmartServerPipeStreamMedium(
 
781
            to_server, from_server, None)
 
782
        # Like test_socket_stream_incomplete_request, write an incomplete
 
783
        # request (that does not end in '\n') and build a protocol from it.
 
784
        to_server_w.write(incomplete_request_bytes)
 
785
        server_protocol = server._build_protocol()
 
786
        # Send the rest of the request, and finish serving it.
 
787
        to_server_w.write(rest_of_request_bytes)
 
788
        server._serve_one_request(server_protocol)
 
789
        to_server_w.close()
 
790
        from_server.close()
 
791
        self.assertEqual(expected_response, from_server_r.read(),
 
792
                         "Not a version 2 response to 'hello' request.")
 
793
        self.assertEqual('', from_server_r.read(1))
 
794
        from_server_r.close()
 
795
        to_server.close()
 
796
 
736
797
    def test_pipe_like_stream_with_two_requests(self):
737
798
        # If two requests are read in one go, then two calls to
738
799
        # _serve_one_request should still process both of them as if they had
882
943
        # A request that starts with "bzr request 2\n" is version two.
883
944
        server_protocol = self.build_protocol_socket('bzr request 2\n')
884
945
        self.assertProtocolTwo(server_protocol)
 
946
 
 
947
 
 
948
class TestGetProtocolFactoryForBytes(tests.TestCase):
 
949
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
 
950
    should use to decode a given request.  Any bytes not part of the version
 
951
    marker string (and thus part of the actual request) are returned alongside
 
952
    the protocol factory.
 
953
    """
 
954
 
 
955
    def test_version_three(self):
 
956
        result = medium._get_protocol_factory_for_bytes(
 
957
            'bzr message 3 (bzr 1.6)\nextra bytes')
 
958
        protocol_factory, remainder = result
 
959
        self.assertEqual(
 
960
            protocol.build_server_protocol_three, protocol_factory)
 
961
        self.assertEqual('extra bytes', remainder)
 
962
        
 
963
    def test_version_two(self):
 
964
        result = medium._get_protocol_factory_for_bytes(
 
965
            'bzr request 2\nextra bytes')
 
966
        protocol_factory, remainder = result
 
967
        self.assertEqual(
 
968
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
 
969
        self.assertEqual('extra bytes', remainder)
 
970
        
 
971
    def test_version_one(self):
 
972
        """Version one requests have no version markers."""
 
973
        result = medium._get_protocol_factory_for_bytes('anything\n')
 
974
        protocol_factory, remainder = result
 
975
        self.assertEqual(
 
976
            protocol.SmartServerRequestProtocolOne, protocol_factory)
 
977
        self.assertEqual('anything\n', remainder)
885
978
        
886
979
 
887
980
class TestSmartTCPServer(tests.TestCase):
896
989
            def get_bytes(self, path):
897
990
                raise Exception("some random exception from inside server")
898
991
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
899
 
        smart_server.start_background_thread()
 
992
        smart_server.start_background_thread('-' + self.id())
900
993
        try:
901
994
            transport = remote.RemoteTCPTransport(smart_server.get_url())
902
995
            try:
932
1025
            self.real_backing_transport = self.backing_transport
933
1026
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
934
1027
        self.server = server.SmartTCPServer(self.backing_transport)
935
 
        self.server.start_background_thread()
 
1028
        self.server.start_background_thread('-' + self.id())
936
1029
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
937
1030
        self.addCleanup(self.tearDownServer)
938
1031
 
1068
1161
    def test_server_started_hook_memory(self):
1069
1162
        """The server_started hook fires when the server is started."""
1070
1163
        self.hook_calls = []
1071
 
        server.SmartTCPServer.hooks.install_hook('server_started',
1072
 
            self.capture_server_call)
 
1164
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1165
            self.capture_server_call, None)
1073
1166
        self.setUpServer()
1074
1167
        # at this point, the server will be starting a thread up.
1075
1168
        # there is no indicator at the moment, so bodge it by doing a request.
1082
1175
    def test_server_started_hook_file(self):
1083
1176
        """The server_started hook fires when the server is started."""
1084
1177
        self.hook_calls = []
1085
 
        server.SmartTCPServer.hooks.install_hook('server_started',
1086
 
            self.capture_server_call)
 
1178
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1179
            self.capture_server_call, None)
1087
1180
        self.setUpServer(backing_transport=get_transport("."))
1088
1181
        # at this point, the server will be starting a thread up.
1089
1182
        # there is no indicator at the moment, so bodge it by doing a request.
1098
1191
    def test_server_stopped_hook_simple_memory(self):
1099
1192
        """The server_stopped hook fires when the server is stopped."""
1100
1193
        self.hook_calls = []
1101
 
        server.SmartTCPServer.hooks.install_hook('server_stopped',
1102
 
            self.capture_server_call)
 
1194
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1195
            self.capture_server_call, None)
1103
1196
        self.setUpServer()
1104
1197
        result = [([self.backing_transport.base], self.transport.base)]
1105
1198
        # check the stopping message isn't emitted up front.
1115
1208
    def test_server_stopped_hook_simple_file(self):
1116
1209
        """The server_stopped hook fires when the server is stopped."""
1117
1210
        self.hook_calls = []
1118
 
        server.SmartTCPServer.hooks.install_hook('server_stopped',
1119
 
            self.capture_server_call)
 
1211
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1212
            self.capture_server_call, None)
1120
1213
        self.setUpServer(backing_transport=get_transport("."))
1121
1214
        result = [(
1122
1215
            [self.backing_transport.base, self.backing_transport.external_url()]
1140
1233
    and the request dispatching.
1141
1234
 
1142
1235
    Note: these tests are rudimentary versions of the command object tests in
1143
 
    test_remote.py.
 
1236
    test_smart.py.
1144
1237
    """
1145
1238
        
1146
1239
    def test_hello(self):
1147
 
        cmd = request.HelloRequest(None)
 
1240
        cmd = _mod_request.HelloRequest(None, '/')
1148
1241
        response = cmd.execute()
1149
1242
        self.assertEqual(('ok', '2'), response.args)
1150
1243
        self.assertEqual(None, response.body)
1156
1249
        wt.add('hello')
1157
1250
        rev_id = wt.commit('add hello')
1158
1251
        
1159
 
        cmd = request.GetBundleRequest(self.get_transport())
 
1252
        cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1160
1253
        response = cmd.execute('.', rev_id)
1161
1254
        bundle = serializer.read_bundle(StringIO(response.body))
1162
1255
        self.assertEqual((), response.args)
1171
1264
 
1172
1265
    def build_handler(self, transport):
1173
1266
        """Returns a handler for the commands in protocol version one."""
1174
 
        return request.SmartServerRequestHandler(transport,
1175
 
                                                 request.request_handlers)
 
1267
        return _mod_request.SmartServerRequestHandler(
 
1268
            transport, _mod_request.request_handlers, '/')
1176
1269
 
1177
1270
    def test_construct_request_handler(self):
1178
1271
        """Constructing a request handler should be easy and set defaults."""
1179
 
        handler = request.SmartServerRequestHandler(None, None)
 
1272
        handler = _mod_request.SmartServerRequestHandler(None, commands=None,
 
1273
                root_client_path='/')
1180
1274
        self.assertFalse(handler.finished_reading)
1181
1275
 
1182
1276
    def test_hello(self):
1188
1282
    def test_disable_vfs_handler_classes_via_environment(self):
1189
1283
        # VFS handler classes will raise an error from "execute" if
1190
1284
        # BZR_NO_SMART_VFS is set.
1191
 
        handler = vfs.HasRequest(None)
 
1285
        handler = vfs.HasRequest(None, '/')
1192
1286
        # set environment variable after construction to make sure it's
1193
1287
        # examined.
1194
1288
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
1253
1347
        handler.accept_body('100,1')
1254
1348
        handler.end_of_body()
1255
1349
        self.assertTrue(handler.finished_reading)
1256
 
        self.assertEqual(('ShortReadvError', 'a-file', '100', '1', '0'),
 
1350
        self.assertEqual(('ShortReadvError', './a-file', '100', '1', '0'),
1257
1351
            handler.response.args)
1258
1352
        self.assertEqual(None, handler.response.body)
1259
1353
 
1278
1372
        
1279
1373
    def test_use_connection_factory(self):
1280
1374
        # We want to be able to pass a client as a parameter to RemoteTransport.
1281
 
        input = StringIO("ok\n3\nbardone\n")
 
1375
        input = StringIO('ok\n3\nbardone\n')
1282
1376
        output = StringIO()
1283
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1377
        client_medium = medium.SmartSimplePipesClientMedium(
 
1378
            input, output, 'base')
1284
1379
        transport = remote.RemoteTransport(
1285
1380
            'bzr://localhost/', medium=client_medium)
 
1381
        # Disable version detection.
 
1382
        client_medium._protocol_version = 1
1286
1383
 
1287
1384
        # We want to make sure the client is used when the first remote
1288
1385
        # method is called.  No data should have been sent, or read.
1289
1386
        self.assertEqual(0, input.tell())
1290
1387
        self.assertEqual('', output.getvalue())
1291
1388
 
1292
 
        # Now call a method that should result in a single request : as the
 
1389
        # Now call a method that should result in one request: as the
1293
1390
        # transport makes its own protocol instances, we check on the wire.
1294
1391
        # XXX: TODO: give the transport a protocol factory, which can make
1295
1392
        # an instrumented protocol for us.
1300
1397
 
1301
1398
    def test__translate_error_readonly(self):
1302
1399
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1303
 
        client_medium = medium.SmartClientMedium()
 
1400
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
1304
1401
        transport = remote.RemoteTransport(
1305
1402
            'bzr://localhost/', medium=client_medium)
1306
1403
        self.assertRaises(errors.TransportNotPossible,
1307
1404
            transport._translate_error, ("ReadOnlyError", ))
1308
1405
 
1309
1406
 
1310
 
class InstrumentedServerProtocol(medium.SmartServerStreamMedium):
1311
 
    """A smart server which is backed by memory and saves its write requests."""
1312
 
 
1313
 
    def __init__(self, write_output_list):
1314
 
        medium.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
1315
 
        self._write_output_list = write_output_list
1316
 
 
1317
 
 
1318
1407
class TestSmartProtocol(tests.TestCase):
1319
1408
    """Base class for smart protocol tests.
1320
1409
 
1330
1419
    Subclasses can override client_protocol_class and server_protocol_class.
1331
1420
    """
1332
1421
 
 
1422
    request_encoder = None
 
1423
    response_decoder = None
 
1424
    server_protocol_class = None
1333
1425
    client_protocol_class = None
1334
 
    server_protocol_class = None
 
1426
 
 
1427
    def make_client_protocol_and_output(self, input_bytes=None):
 
1428
        """
 
1429
        :returns: a Request
 
1430
        """
 
1431
        # This is very similar to
 
1432
        # bzrlib.smart.client._SmartClient._build_client_protocol
 
1433
        # XXX: make this use _SmartClient!
 
1434
        if input_bytes is None:
 
1435
            input = StringIO()
 
1436
        else:
 
1437
            input = StringIO(input_bytes)
 
1438
        output = StringIO()
 
1439
        client_medium = medium.SmartSimplePipesClientMedium(
 
1440
            input, output, 'base')
 
1441
        request = client_medium.get_request()
 
1442
        if self.client_protocol_class is not None:
 
1443
            client_protocol = self.client_protocol_class(request)
 
1444
            return client_protocol, client_protocol, output
 
1445
        else:
 
1446
            self.assertNotEqual(None, self.request_encoder)
 
1447
            self.assertNotEqual(None, self.response_decoder)
 
1448
            requester = self.request_encoder(request)
 
1449
            response_handler = message.ConventionalResponseHandler()
 
1450
            response_protocol = self.response_decoder(
 
1451
                response_handler, expect_version_marker=True)
 
1452
            response_handler.setProtoAndMediumRequest(
 
1453
                response_protocol, request)
 
1454
            return requester, response_handler, output
 
1455
 
 
1456
    def make_client_protocol(self, input_bytes=None):
 
1457
        result = self.make_client_protocol_and_output(input_bytes=input_bytes)
 
1458
        requester, response_handler, output = result
 
1459
        return requester, response_handler
 
1460
 
 
1461
    def make_server_protocol(self):
 
1462
        out_stream = StringIO()
 
1463
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1464
        return smart_protocol, out_stream
1335
1465
 
1336
1466
    def setUp(self):
1337
1467
        super(TestSmartProtocol, self).setUp()
1338
 
        # XXX: self.server_to_client doesn't seem to be used.  If so,
1339
 
        # InstrumentedServerProtocol is redundant too.
1340
 
        self.server_to_client = []
1341
 
        self.to_server = StringIO()
1342
 
        self.to_client = StringIO()
1343
 
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1344
 
            self.to_server)
1345
 
        self.client_protocol = self.client_protocol_class(self.client_medium)
1346
 
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1347
 
        self.smart_server_request = request.SmartServerRequestHandler(
1348
 
            None, request.request_handlers)
 
1468
        self.response_marker = getattr(
 
1469
            self.client_protocol_class, 'response_marker', None)
 
1470
        self.request_marker = getattr(
 
1471
            self.client_protocol_class, 'request_marker', None)
1349
1472
 
1350
1473
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1351
 
        client):
 
1474
        requester):
1352
1475
        """Check that smart (de)serialises offsets as expected.
1353
1476
        
1354
1477
        We check both serialisation and deserialisation at the same time
1360
1483
        """
1361
1484
        # XXX: '_deserialise_offsets' should be a method of the
1362
1485
        # SmartServerRequestProtocol in future.
1363
 
        readv_cmd = vfs.ReadvRequest(None)
 
1486
        readv_cmd = vfs.ReadvRequest(None, '/')
1364
1487
        offsets = readv_cmd._deserialise_offsets(expected_serialised)
1365
1488
        self.assertEqual(expected_offsets, offsets)
1366
 
        serialised = client._serialise_offsets(offsets)
 
1489
        serialised = requester._serialise_offsets(offsets)
1367
1490
        self.assertEqual(expected_serialised, serialised)
1368
1491
 
1369
1492
    def build_protocol_waiting_for_body(self):
1370
 
        out_stream = StringIO()
1371
 
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1372
 
        smart_protocol.has_dispatched = True
1373
 
        smart_protocol.request = self.smart_server_request
 
1493
        smart_protocol, out_stream = self.make_server_protocol()
 
1494
        smart_protocol._has_dispatched = True
 
1495
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
 
1496
            None, _mod_request.request_handlers, '/')
1374
1497
        class FakeCommand(object):
1375
1498
            def do_body(cmd, body_bytes):
1376
1499
                self.end_received = True
1377
1500
                self.assertEqual('abcdefg', body_bytes)
1378
 
                return request.SuccessfulSmartServerResponse(('ok', ))
 
1501
                return _mod_request.SuccessfulSmartServerResponse(('ok', ))
1379
1502
        smart_protocol.request._command = FakeCommand()
1380
1503
        # Call accept_bytes to make sure that internal state like _body_decoder
1381
1504
        # is initialised.  This test should probably be given a clearer
1392
1515
        # check the encoding of the server for all input_tuples matches
1393
1516
        # expected bytes
1394
1517
        for input_tuple in input_tuples:
1395
 
            server_output = StringIO()
1396
 
            server_protocol = self.server_protocol_class(
1397
 
                None, server_output.write)
 
1518
            server_protocol, server_output = self.make_server_protocol()
1398
1519
            server_protocol._send_response(
1399
1520
                _mod_request.SuccessfulSmartServerResponse(input_tuple))
1400
1521
            self.assertEqual(expected_bytes, server_output.getvalue())
1401
1522
        # check the decoding of the client smart_protocol from expected_bytes:
1402
 
        input = StringIO(expected_bytes)
1403
 
        output = StringIO()
1404
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1405
 
        request = client_medium.get_request()
1406
 
        smart_protocol = self.client_protocol_class(request)
1407
 
        smart_protocol.call('foo')
1408
 
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1523
        requester, response_handler = self.make_client_protocol(expected_bytes)
 
1524
        requester.call('foo')
 
1525
        self.assertEqual(expected_tuple, response_handler.read_response_tuple())
1409
1526
 
1410
1527
 
1411
1528
class CommonSmartProtocolTestMixin(object):
1412
1529
 
1413
 
    def test_errors_are_logged(self):
1414
 
        """If an error occurs during testing, it is logged to the test log."""
1415
 
        out_stream = StringIO()
1416
 
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1417
 
        # This triggers a "bad request" error.
1418
 
        smart_protocol.accept_bytes('abc\n')
1419
 
        test_log = self._get_log(keep_log_file=True)
1420
 
        self.assertContainsRe(test_log, 'Traceback')
1421
 
        self.assertContainsRe(test_log, 'SmartProtocolError')
1422
 
 
1423
1530
    def test_connection_closed_reporting(self):
1424
 
        input = StringIO()
1425
 
        output = StringIO()
1426
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1427
 
        request = client_medium.get_request()
1428
 
        smart_protocol = self.client_protocol_class(request)
1429
 
        smart_protocol.call('hello')
1430
 
        ex = self.assertRaises(errors.ConnectionReset, 
1431
 
            smart_protocol.read_response_tuple)
 
1531
        requester, response_handler = self.make_client_protocol()
 
1532
        requester.call('hello')
 
1533
        ex = self.assertRaises(errors.ConnectionReset,
 
1534
            response_handler.read_response_tuple)
1432
1535
        self.assertEqual("Connection closed: "
1433
1536
            "please check connectivity and permissions "
1434
1537
            "(and try -Dhpss if further diagnosis is required)", str(ex))
1435
1538
 
1436
 
 
1437
 
class TestSmartProtocolOne(TestSmartProtocol, CommonSmartProtocolTestMixin):
1438
 
    """Tests for the smart protocol version one."""
 
1539
    def test_server_offset_serialisation(self):
 
1540
        """The Smart protocol serialises offsets as a comma and \n string.
 
1541
 
 
1542
        We check a number of boundary cases are as expected: empty, one offset,
 
1543
        one with the order of reads not increasing (an out of order read), and
 
1544
        one that should coalesce.
 
1545
        """
 
1546
        requester, response_handler = self.make_client_protocol()
 
1547
        self.assertOffsetSerialisation([], '', requester)
 
1548
        self.assertOffsetSerialisation([(1,2)], '1,2', requester)
 
1549
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
 
1550
            requester)
 
1551
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
 
1552
            '1,2\n3,4\n100,200', requester)
 
1553
 
 
1554
 
 
1555
class TestVersionOneFeaturesInProtocolOne(
 
1556
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1557
    """Tests for version one smart protocol features as implemeted by version
 
1558
    one."""
1439
1559
 
1440
1560
    client_protocol_class = protocol.SmartClientRequestProtocolOne
1441
1561
    server_protocol_class = protocol.SmartServerRequestProtocolOne
1442
1562
 
1443
1563
    def test_construct_version_one_server_protocol(self):
1444
1564
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1445
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1565
        self.assertEqual('', smart_protocol.unused_data)
1446
1566
        self.assertEqual('', smart_protocol.in_buffer)
1447
 
        self.assertFalse(smart_protocol.has_dispatched)
 
1567
        self.assertFalse(smart_protocol._has_dispatched)
1448
1568
        self.assertEqual(1, smart_protocol.next_read_size())
1449
1569
 
1450
1570
    def test_construct_version_one_client_protocol(self):
1451
1571
        # we can construct a client protocol from a client medium request
1452
1572
        output = StringIO()
1453
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1573
        client_medium = medium.SmartSimplePipesClientMedium(
 
1574
            None, output, 'base')
1454
1575
        request = client_medium.get_request()
1455
1576
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1456
1577
 
1457
 
    def test_server_offset_serialisation(self):
1458
 
        """The Smart protocol serialises offsets as a comma and \n string.
1459
 
 
1460
 
        We check a number of boundary cases are as expected: empty, one offset,
1461
 
        one with the order of reads not increasing (an out of order read), and
1462
 
        one that should coalesce.
1463
 
        """
1464
 
        self.assertOffsetSerialisation([], '', self.client_protocol)
1465
 
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1466
 
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1467
 
            self.client_protocol)
1468
 
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1469
 
            '1,2\n3,4\n100,200', self.client_protocol)
1470
 
 
1471
1578
    def test_accept_bytes_of_bad_request_to_protocol(self):
1472
1579
        out_stream = StringIO()
1473
1580
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1478
1585
        self.assertEqual(
1479
1586
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
1480
1587
            out_stream.getvalue())
1481
 
        self.assertTrue(smart_protocol.has_dispatched)
 
1588
        self.assertTrue(smart_protocol._has_dispatched)
1482
1589
        self.assertEqual(0, smart_protocol.next_read_size())
1483
1590
 
1484
1591
    def test_accept_body_bytes_to_protocol(self):
1501
1608
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1502
1609
        self.assertEqual(0, smart_protocol.next_read_size())
1503
1610
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1504
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1611
        self.assertEqual('', smart_protocol.unused_data)
1505
1612
        self.assertEqual('', smart_protocol.in_buffer)
1506
1613
 
1507
1614
    def test_accept_excess_bytes_are_preserved(self):
1510
1617
            None, out_stream.write)
1511
1618
        smart_protocol.accept_bytes('hello\nhello\n')
1512
1619
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1513
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1620
        self.assertEqual("hello\n", smart_protocol.unused_data)
1514
1621
        self.assertEqual("", smart_protocol.in_buffer)
1515
1622
 
1516
1623
    def test_accept_excess_bytes_after_body(self):
1517
1624
        protocol = self.build_protocol_waiting_for_body()
1518
1625
        protocol.accept_bytes('7\nabcdefgdone\nX')
1519
1626
        self.assertTrue(self.end_received)
1520
 
        self.assertEqual("X", protocol.excess_buffer)
 
1627
        self.assertEqual("X", protocol.unused_data)
1521
1628
        self.assertEqual("", protocol.in_buffer)
1522
1629
        protocol.accept_bytes('Y')
1523
 
        self.assertEqual("XY", protocol.excess_buffer)
 
1630
        self.assertEqual("XY", protocol.unused_data)
1524
1631
        self.assertEqual("", protocol.in_buffer)
1525
1632
 
1526
1633
    def test_accept_excess_bytes_after_dispatch(self):
1530
1637
        smart_protocol.accept_bytes('hello\n')
1531
1638
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1532
1639
        smart_protocol.accept_bytes('hel')
1533
 
        self.assertEqual("hel", smart_protocol.excess_buffer)
 
1640
        self.assertEqual("hel", smart_protocol.unused_data)
1534
1641
        smart_protocol.accept_bytes('lo\n')
1535
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1642
        self.assertEqual("hello\n", smart_protocol.unused_data)
1536
1643
        self.assertEqual("", smart_protocol.in_buffer)
1537
1644
 
1538
1645
    def test__send_response_sets_finished_reading(self):
1540
1647
            None, lambda x: None)
1541
1648
        self.assertEqual(1, smart_protocol.next_read_size())
1542
1649
        smart_protocol._send_response(
1543
 
            request.SuccessfulSmartServerResponse(('x',)))
 
1650
            _mod_request.SuccessfulSmartServerResponse(('x',)))
1544
1651
        self.assertEqual(0, smart_protocol.next_read_size())
1545
1652
 
1546
1653
    def test__send_response_errors_with_base_response(self):
1548
1655
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1549
1656
            None, lambda x: None)
1550
1657
        self.assertRaises(AttributeError, smart_protocol._send_response,
1551
 
            request.SmartServerResponse(('x',)))
 
1658
            _mod_request.SmartServerResponse(('x',)))
1552
1659
 
1553
1660
    def test_query_version(self):
1554
1661
        """query_version on a SmartClientProtocolOne should return a number.
1562
1669
        # the error if the response is a non-understood version.
1563
1670
        input = StringIO('ok\x012\n')
1564
1671
        output = StringIO()
1565
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1672
        client_medium = medium.SmartSimplePipesClientMedium(
 
1673
            input, output, 'base')
1566
1674
        request = client_medium.get_request()
1567
1675
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1568
1676
        self.assertEqual(2, smart_protocol.query_version())
1585
1693
        expected_bytes = "foo\n7\nabcdefgdone\n"
1586
1694
        input = StringIO("\n")
1587
1695
        output = StringIO()
1588
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1696
        client_medium = medium.SmartSimplePipesClientMedium(
 
1697
            input, output, 'base')
1589
1698
        request = client_medium.get_request()
1590
1699
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1591
1700
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1597
1706
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
1598
1707
        input = StringIO("\n")
1599
1708
        output = StringIO()
1600
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1709
        client_medium = medium.SmartSimplePipesClientMedium(
 
1710
            input, output, 'base')
1601
1711
        request = client_medium.get_request()
1602
1712
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1603
1713
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1604
1714
        self.assertEqual(expected_bytes, output.getvalue())
1605
1715
 
 
1716
    def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
 
1717
            server_bytes):
 
1718
        input = StringIO(server_bytes)
 
1719
        output = StringIO()
 
1720
        client_medium = medium.SmartSimplePipesClientMedium(
 
1721
            input, output, 'base')
 
1722
        request = client_medium.get_request()
 
1723
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1724
        smart_protocol.call('foo')
 
1725
        self.assertRaises(
 
1726
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
 
1727
        # The request has been finished.  There is no body to read, and
 
1728
        # attempts to read one will fail.
 
1729
        self.assertRaises(
 
1730
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
1731
 
 
1732
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
 
1733
        """read_response_tuple raises UnknownSmartMethod if the response says
 
1734
        the server did not recognise the request.
 
1735
        """
 
1736
        server_bytes = (
 
1737
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
1738
        self._test_client_read_response_tuple_raises_UnknownSmartMethod(
 
1739
            server_bytes)
 
1740
 
 
1741
    def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self):
 
1742
        """read_response_tuple also raises UnknownSmartMethod if the response
 
1743
        from a bzr 0.11 says the server did not recognise the request.
 
1744
 
 
1745
        (bzr 0.11 sends a slightly different error message to later versions.)
 
1746
        """
 
1747
        server_bytes = (
 
1748
            "error\x01Generic bzr smart protocol error: bad request u'foo'\n")
 
1749
        self._test_client_read_response_tuple_raises_UnknownSmartMethod(
 
1750
            server_bytes)
 
1751
 
1606
1752
    def test_client_read_body_bytes_all(self):
1607
1753
        # read_body_bytes should decode the body bytes from the wire into
1608
1754
        # a response.
1610
1756
        server_bytes = "ok\n7\n1234567done\n"
1611
1757
        input = StringIO(server_bytes)
1612
1758
        output = StringIO()
1613
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1759
        client_medium = medium.SmartSimplePipesClientMedium(
 
1760
            input, output, 'base')
1614
1761
        request = client_medium.get_request()
1615
1762
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1616
1763
        smart_protocol.call('foo')
1627
1774
        server_bytes = "ok\n7\n1234567done\n"
1628
1775
        input = StringIO(server_bytes)
1629
1776
        output = StringIO()
1630
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1777
        client_medium = medium.SmartSimplePipesClientMedium(
 
1778
            input, output, 'base')
1631
1779
        request = client_medium.get_request()
1632
1780
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1633
1781
        smart_protocol.call('foo')
1644
1792
        server_bytes = "ok\n7\n1234567done\n"
1645
1793
        input = StringIO(server_bytes)
1646
1794
        output = StringIO()
1647
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1795
        client_medium = medium.SmartSimplePipesClientMedium(
 
1796
            input, output, 'base')
1648
1797
        request = client_medium.get_request()
1649
1798
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1650
1799
        smart_protocol.call('foo')
1654
1803
        self.assertRaises(
1655
1804
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1656
1805
 
1657
 
 
1658
 
class TestSmartProtocolTwo(TestSmartProtocol, CommonSmartProtocolTestMixin):
1659
 
    """Tests for the smart protocol version two.
1660
 
 
1661
 
    This test case is mostly the same as TestSmartProtocolOne.
 
1806
    def test_client_read_body_bytes_interrupted_connection(self):
 
1807
        server_bytes = "ok\n999\nincomplete body"
 
1808
        input = StringIO(server_bytes)
 
1809
        output = StringIO()
 
1810
        client_medium = medium.SmartSimplePipesClientMedium(
 
1811
            input, output, 'base')
 
1812
        request = client_medium.get_request()
 
1813
        smart_protocol = self.client_protocol_class(request)
 
1814
        smart_protocol.call('foo')
 
1815
        smart_protocol.read_response_tuple(True)
 
1816
        self.assertRaises(
 
1817
            errors.ConnectionReset, smart_protocol.read_body_bytes)
 
1818
 
 
1819
 
 
1820
class TestVersionOneFeaturesInProtocolTwo(
 
1821
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1822
    """Tests for version one smart protocol features as implemeted by version
 
1823
    two.
1662
1824
    """
1663
1825
 
1664
1826
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
1666
1828
 
1667
1829
    def test_construct_version_two_server_protocol(self):
1668
1830
        smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
1669
 
        self.assertEqual('', smart_protocol.excess_buffer)
 
1831
        self.assertEqual('', smart_protocol.unused_data)
1670
1832
        self.assertEqual('', smart_protocol.in_buffer)
1671
 
        self.assertFalse(smart_protocol.has_dispatched)
 
1833
        self.assertFalse(smart_protocol._has_dispatched)
1672
1834
        self.assertEqual(1, smart_protocol.next_read_size())
1673
1835
 
1674
1836
    def test_construct_version_two_client_protocol(self):
1675
1837
        # we can construct a client protocol from a client medium request
1676
1838
        output = StringIO()
1677
 
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1839
        client_medium = medium.SmartSimplePipesClientMedium(
 
1840
            None, output, 'base')
1678
1841
        request = client_medium.get_request()
1679
1842
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1680
1843
 
1681
 
    def test_server_offset_serialisation(self):
1682
 
        """The Smart protocol serialises offsets as a comma and \n string.
1683
 
 
1684
 
        We check a number of boundary cases are as expected: empty, one offset,
1685
 
        one with the order of reads not increasing (an out of order read), and
1686
 
        one that should coalesce.
 
1844
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1845
        out_stream = StringIO()
 
1846
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1847
        smart_protocol.accept_bytes('abc')
 
1848
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1849
        smart_protocol.accept_bytes('\n')
 
1850
        self.assertEqual(
 
1851
            self.response_marker +
 
1852
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1853
            out_stream.getvalue())
 
1854
        self.assertTrue(smart_protocol._has_dispatched)
 
1855
        self.assertEqual(0, smart_protocol.next_read_size())
 
1856
 
 
1857
    def test_accept_body_bytes_to_protocol(self):
 
1858
        protocol = self.build_protocol_waiting_for_body()
 
1859
        self.assertEqual(6, protocol.next_read_size())
 
1860
        protocol.accept_bytes('7\nabc')
 
1861
        self.assertEqual(9, protocol.next_read_size())
 
1862
        protocol.accept_bytes('defgd')
 
1863
        protocol.accept_bytes('one\n')
 
1864
        self.assertEqual(0, protocol.next_read_size())
 
1865
        self.assertTrue(self.end_received)
 
1866
 
 
1867
    def test_accept_request_and_body_all_at_once(self):
 
1868
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1869
        mem_transport = memory.MemoryTransport()
 
1870
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1871
        out_stream = StringIO()
 
1872
        smart_protocol = self.server_protocol_class(
 
1873
            mem_transport, out_stream.write)
 
1874
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1875
        self.assertEqual(0, smart_protocol.next_read_size())
 
1876
        self.assertEqual(self.response_marker +
 
1877
                         'success\nreadv\n3\ndefdone\n',
 
1878
                         out_stream.getvalue())
 
1879
        self.assertEqual('', smart_protocol.unused_data)
 
1880
        self.assertEqual('', smart_protocol.in_buffer)
 
1881
 
 
1882
    def test_accept_excess_bytes_are_preserved(self):
 
1883
        out_stream = StringIO()
 
1884
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1885
        smart_protocol.accept_bytes('hello\nhello\n')
 
1886
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1887
                         out_stream.getvalue())
 
1888
        self.assertEqual("hello\n", smart_protocol.unused_data)
 
1889
        self.assertEqual("", smart_protocol.in_buffer)
 
1890
 
 
1891
    def test_accept_excess_bytes_after_body(self):
 
1892
        # The excess bytes look like the start of another request.
 
1893
        server_protocol = self.build_protocol_waiting_for_body()
 
1894
        server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
 
1895
        self.assertTrue(self.end_received)
 
1896
        self.assertEqual(self.response_marker,
 
1897
                         server_protocol.unused_data)
 
1898
        self.assertEqual("", server_protocol.in_buffer)
 
1899
        server_protocol.accept_bytes('Y')
 
1900
        self.assertEqual(self.response_marker + "Y",
 
1901
                         server_protocol.unused_data)
 
1902
        self.assertEqual("", server_protocol.in_buffer)
 
1903
 
 
1904
    def test_accept_excess_bytes_after_dispatch(self):
 
1905
        out_stream = StringIO()
 
1906
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1907
        smart_protocol.accept_bytes('hello\n')
 
1908
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1909
                         out_stream.getvalue())
 
1910
        smart_protocol.accept_bytes(self.request_marker + 'hel')
 
1911
        self.assertEqual(self.request_marker + "hel",
 
1912
                         smart_protocol.unused_data)
 
1913
        smart_protocol.accept_bytes('lo\n')
 
1914
        self.assertEqual(self.request_marker + "hello\n",
 
1915
                         smart_protocol.unused_data)
 
1916
        self.assertEqual("", smart_protocol.in_buffer)
 
1917
 
 
1918
    def test__send_response_sets_finished_reading(self):
 
1919
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1920
        self.assertEqual(1, smart_protocol.next_read_size())
 
1921
        smart_protocol._send_response(
 
1922
            _mod_request.SuccessfulSmartServerResponse(('x',)))
 
1923
        self.assertEqual(0, smart_protocol.next_read_size())
 
1924
 
 
1925
    def test__send_response_errors_with_base_response(self):
 
1926
        """Ensure that only the Successful/Failed subclasses are used."""
 
1927
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1928
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1929
            _mod_request.SmartServerResponse(('x',)))
 
1930
 
 
1931
    def test_query_version(self):
 
1932
        """query_version on a SmartClientProtocolTwo should return a number.
 
1933
        
 
1934
        The protocol provides the query_version because the domain level clients
 
1935
        may all need to be able to probe for capabilities.
1687
1936
        """
1688
 
        self.assertOffsetSerialisation([], '', self.client_protocol)
1689
 
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1690
 
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1691
 
            self.client_protocol)
1692
 
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1693
 
            '1,2\n3,4\n100,200', self.client_protocol)
 
1937
        # What we really want to test here is that SmartClientProtocolTwo calls
 
1938
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1939
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1940
        # the error if the response is a non-understood version.
 
1941
        input = StringIO(self.response_marker + 'success\nok\x012\n')
 
1942
        output = StringIO()
 
1943
        client_medium = medium.SmartSimplePipesClientMedium(
 
1944
            input, output, 'base')
 
1945
        request = client_medium.get_request()
 
1946
        smart_protocol = self.client_protocol_class(request)
 
1947
        self.assertEqual(2, smart_protocol.query_version())
 
1948
 
 
1949
    def test_client_call_empty_response(self):
 
1950
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1951
        # when the parsed line is an empty line, and results in a tuple with
 
1952
        # one element - an empty string.
 
1953
        self.assertServerToClientEncoding(
 
1954
            self.response_marker + 'success\n\n', ('', ), [(), ('', )])
 
1955
 
 
1956
    def test_client_call_three_element_response(self):
 
1957
        # protocol.call() can get back tuples of other lengths. A three element
 
1958
        # tuple should be unpacked as three strings.
 
1959
        self.assertServerToClientEncoding(
 
1960
            self.response_marker + 'success\na\x01b\x0134\n',
 
1961
            ('a', 'b', '34'),
 
1962
            [('a', 'b', '34')])
 
1963
 
 
1964
    def test_client_call_with_body_bytes_uploads(self):
 
1965
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1966
        # wire.
 
1967
        expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
 
1968
        input = StringIO("\n")
 
1969
        output = StringIO()
 
1970
        client_medium = medium.SmartSimplePipesClientMedium(
 
1971
            input, output, 'base')
 
1972
        request = client_medium.get_request()
 
1973
        smart_protocol = self.client_protocol_class(request)
 
1974
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1975
        self.assertEqual(expected_bytes, output.getvalue())
 
1976
 
 
1977
    def test_client_call_with_body_readv_array(self):
 
1978
        # protocol.call_with_upload should encode the readv array and then
 
1979
        # length-prefix the bytes onto the wire.
 
1980
        expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
 
1981
        input = StringIO("\n")
 
1982
        output = StringIO()
 
1983
        client_medium = medium.SmartSimplePipesClientMedium(
 
1984
            input, output, 'base')
 
1985
        request = client_medium.get_request()
 
1986
        smart_protocol = self.client_protocol_class(request)
 
1987
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1988
        self.assertEqual(expected_bytes, output.getvalue())
 
1989
 
 
1990
    def test_client_read_body_bytes_all(self):
 
1991
        # read_body_bytes should decode the body bytes from the wire into
 
1992
        # a response.
 
1993
        expected_bytes = "1234567"
 
1994
        server_bytes = (self.response_marker +
 
1995
                        "success\nok\n7\n1234567done\n")
 
1996
        input = StringIO(server_bytes)
 
1997
        output = StringIO()
 
1998
        client_medium = medium.SmartSimplePipesClientMedium(
 
1999
            input, output, 'base')
 
2000
        request = client_medium.get_request()
 
2001
        smart_protocol = self.client_protocol_class(request)
 
2002
        smart_protocol.call('foo')
 
2003
        smart_protocol.read_response_tuple(True)
 
2004
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
2005
 
 
2006
    def test_client_read_body_bytes_incremental(self):
 
2007
        # test reading a few bytes at a time from the body
 
2008
        # XXX: possibly we should test dribbling the bytes into the stringio
 
2009
        # to make the state machine work harder: however, as we use the
 
2010
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
2011
        # that.
 
2012
        expected_bytes = "1234567"
 
2013
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2014
        input = StringIO(server_bytes)
 
2015
        output = StringIO()
 
2016
        client_medium = medium.SmartSimplePipesClientMedium(
 
2017
            input, output, 'base')
 
2018
        request = client_medium.get_request()
 
2019
        smart_protocol = self.client_protocol_class(request)
 
2020
        smart_protocol.call('foo')
 
2021
        smart_protocol.read_response_tuple(True)
 
2022
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
2023
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
2024
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
2025
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
2026
 
 
2027
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
2028
        # cancelling the expected body needs to finish the request, but not
 
2029
        # read any more bytes.
 
2030
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2031
        input = StringIO(server_bytes)
 
2032
        output = StringIO()
 
2033
        client_medium = medium.SmartSimplePipesClientMedium(
 
2034
            input, output, 'base')
 
2035
        request = client_medium.get_request()
 
2036
        smart_protocol = self.client_protocol_class(request)
 
2037
        smart_protocol.call('foo')
 
2038
        smart_protocol.read_response_tuple(True)
 
2039
        smart_protocol.cancel_read_body()
 
2040
        self.assertEqual(len(self.response_marker + 'success\nok\n'),
 
2041
                         input.tell())
 
2042
        self.assertRaises(
 
2043
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2044
 
 
2045
    def test_client_read_body_bytes_interrupted_connection(self):
 
2046
        server_bytes = (self.response_marker +
 
2047
                        "success\nok\n999\nincomplete body")
 
2048
        input = StringIO(server_bytes)
 
2049
        output = StringIO()
 
2050
        client_medium = medium.SmartSimplePipesClientMedium(
 
2051
            input, output, 'base')
 
2052
        request = client_medium.get_request()
 
2053
        smart_protocol = self.client_protocol_class(request)
 
2054
        smart_protocol.call('foo')
 
2055
        smart_protocol.read_response_tuple(True)
 
2056
        self.assertRaises(
 
2057
            errors.ConnectionReset, smart_protocol.read_body_bytes)
 
2058
 
 
2059
 
 
2060
class TestSmartProtocolTwoSpecificsMixin(object):
1694
2061
 
1695
2062
    def assertBodyStreamSerialisation(self, expected_serialisation,
1696
2063
                                      body_stream):
1735
2102
 
1736
2103
    def test_body_stream_error_serialistion(self):
1737
2104
        stream = ['first chunk',
1738
 
                  request.FailedSmartServerResponse(
 
2105
                  _mod_request.FailedSmartServerResponse(
1739
2106
                      ('FailureName', 'failure arg'))]
1740
2107
        expected_bytes = (
1741
2108
            'chunked\n' + 'b\nfirst chunk' +
1744
2111
        self.assertBodyStreamSerialisation(expected_bytes, stream)
1745
2112
        self.assertBodyStreamRoundTrips(stream)
1746
2113
 
1747
 
    def test_accept_bytes_of_bad_request_to_protocol(self):
1748
 
        out_stream = StringIO()
1749
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1750
 
            None, out_stream.write)
1751
 
        smart_protocol.accept_bytes('abc')
1752
 
        self.assertEqual('abc', smart_protocol.in_buffer)
1753
 
        smart_protocol.accept_bytes('\n')
1754
 
        self.assertEqual(
1755
 
            protocol.RESPONSE_VERSION_TWO +
1756
 
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1757
 
            out_stream.getvalue())
1758
 
        self.assertTrue(smart_protocol.has_dispatched)
1759
 
        self.assertEqual(0, smart_protocol.next_read_size())
1760
 
 
1761
 
    def test_accept_body_bytes_to_protocol(self):
1762
 
        protocol = self.build_protocol_waiting_for_body()
1763
 
        self.assertEqual(6, protocol.next_read_size())
1764
 
        protocol.accept_bytes('7\nabc')
1765
 
        self.assertEqual(9, protocol.next_read_size())
1766
 
        protocol.accept_bytes('defgd')
1767
 
        protocol.accept_bytes('one\n')
1768
 
        self.assertEqual(0, protocol.next_read_size())
1769
 
        self.assertTrue(self.end_received)
1770
 
 
1771
 
    def test_accept_request_and_body_all_at_once(self):
1772
 
        self._captureVar('BZR_NO_SMART_VFS', None)
1773
 
        mem_transport = memory.MemoryTransport()
1774
 
        mem_transport.put_bytes('foo', 'abcdefghij')
1775
 
        out_stream = StringIO()
1776
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
1777
 
                out_stream.write)
1778
 
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1779
 
        self.assertEqual(0, smart_protocol.next_read_size())
1780
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO +
1781
 
                         'success\nreadv\n3\ndefdone\n',
1782
 
                         out_stream.getvalue())
1783
 
        self.assertEqual('', smart_protocol.excess_buffer)
1784
 
        self.assertEqual('', smart_protocol.in_buffer)
1785
 
 
1786
 
    def test_accept_excess_bytes_are_preserved(self):
1787
 
        out_stream = StringIO()
1788
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1789
 
            None, out_stream.write)
1790
 
        smart_protocol.accept_bytes('hello\nhello\n')
1791
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
1792
 
                         out_stream.getvalue())
1793
 
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
1794
 
        self.assertEqual("", smart_protocol.in_buffer)
1795
 
 
1796
 
    def test_accept_excess_bytes_after_body(self):
1797
 
        # The excess bytes look like the start of another request.
1798
 
        server_protocol = self.build_protocol_waiting_for_body()
1799
 
        server_protocol.accept_bytes(
1800
 
            '7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
1801
 
        self.assertTrue(self.end_received)
1802
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO,
1803
 
                         server_protocol.excess_buffer)
1804
 
        self.assertEqual("", server_protocol.in_buffer)
1805
 
        server_protocol.accept_bytes('Y')
1806
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
1807
 
                         server_protocol.excess_buffer)
1808
 
        self.assertEqual("", server_protocol.in_buffer)
1809
 
 
1810
 
    def test_accept_excess_bytes_after_dispatch(self):
1811
 
        out_stream = StringIO()
1812
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1813
 
            None, out_stream.write)
1814
 
        smart_protocol.accept_bytes('hello\n')
1815
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
1816
 
                         out_stream.getvalue())
1817
 
        smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
1818
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
1819
 
                         smart_protocol.excess_buffer)
1820
 
        smart_protocol.accept_bytes('lo\n')
1821
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
1822
 
                         smart_protocol.excess_buffer)
1823
 
        self.assertEqual("", smart_protocol.in_buffer)
1824
 
 
1825
 
    def test__send_response_sets_finished_reading(self):
1826
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1827
 
            None, lambda x: None)
1828
 
        self.assertEqual(1, smart_protocol.next_read_size())
1829
 
        smart_protocol._send_response(
1830
 
            request.SuccessfulSmartServerResponse(('x',)))
1831
 
        self.assertEqual(0, smart_protocol.next_read_size())
1832
 
 
1833
 
    def test__send_response_with_body_stream_sets_finished_reading(self):
1834
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1835
 
            None, lambda x: None)
1836
 
        self.assertEqual(1, smart_protocol.next_read_size())
1837
 
        smart_protocol._send_response(
1838
 
            request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
1839
 
        self.assertEqual(0, smart_protocol.next_read_size())
1840
 
 
1841
 
    def test__send_response_errors_with_base_response(self):
1842
 
        """Ensure that only the Successful/Failed subclasses are used."""
1843
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1844
 
            None, lambda x: None)
1845
 
        self.assertRaises(AttributeError, smart_protocol._send_response,
1846
 
            request.SmartServerResponse(('x',)))
1847
 
 
1848
2114
    def test__send_response_includes_failure_marker(self):
1849
2115
        """FailedSmartServerResponse have 'failed\n' after the version."""
1850
2116
        out_stream = StringIO()
1851
2117
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1852
2118
            None, out_stream.write)
1853
2119
        smart_protocol._send_response(
1854
 
            request.FailedSmartServerResponse(('x',)))
 
2120
            _mod_request.FailedSmartServerResponse(('x',)))
1855
2121
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
1856
2122
                         out_stream.getvalue())
1857
2123
 
1861
2127
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1862
2128
            None, out_stream.write)
1863
2129
        smart_protocol._send_response(
1864
 
            request.SuccessfulSmartServerResponse(('x',)))
 
2130
            _mod_request.SuccessfulSmartServerResponse(('x',)))
1865
2131
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
1866
2132
                         out_stream.getvalue())
1867
2133
 
1868
 
    def test_query_version(self):
1869
 
        """query_version on a SmartClientProtocolTwo should return a number.
1870
 
        
1871
 
        The protocol provides the query_version because the domain level clients
1872
 
        may all need to be able to probe for capabilities.
1873
 
        """
1874
 
        # What we really want to test here is that SmartClientProtocolTwo calls
1875
 
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1876
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
1877
 
        # the error if the response is a non-understood version.
1878
 
        input = StringIO(protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
1879
 
        output = StringIO()
1880
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1881
 
        request = client_medium.get_request()
1882
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1883
 
        self.assertEqual(2, smart_protocol.query_version())
1884
 
 
1885
 
    def test_client_call_empty_response(self):
1886
 
        # protocol.call() can get back an empty tuple as a response. This occurs
1887
 
        # when the parsed line is an empty line, and results in a tuple with
1888
 
        # one element - an empty string.
1889
 
        self.assertServerToClientEncoding(
1890
 
            protocol.RESPONSE_VERSION_TWO + 'success\n\n', ('', ), [(), ('', )])
1891
 
 
1892
 
    def test_client_call_three_element_response(self):
1893
 
        # protocol.call() can get back tuples of other lengths. A three element
1894
 
        # tuple should be unpacked as three strings.
1895
 
        self.assertServerToClientEncoding(
1896
 
            protocol.RESPONSE_VERSION_TWO + 'success\na\x01b\x0134\n',
1897
 
            ('a', 'b', '34'),
1898
 
            [('a', 'b', '34')])
1899
 
 
1900
 
    def test_client_call_with_body_bytes_uploads(self):
1901
 
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1902
 
        # wire.
1903
 
        expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
1904
 
        input = StringIO("\n")
1905
 
        output = StringIO()
1906
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1907
 
        request = client_medium.get_request()
1908
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1909
 
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1910
 
        self.assertEqual(expected_bytes, output.getvalue())
1911
 
 
1912
 
    def test_client_call_with_body_readv_array(self):
1913
 
        # protocol.call_with_upload should encode the readv array and then
1914
 
        # length-prefix the bytes onto the wire.
1915
 
        expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
1916
 
        input = StringIO("\n")
1917
 
        output = StringIO()
1918
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1919
 
        request = client_medium.get_request()
1920
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1921
 
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1922
 
        self.assertEqual(expected_bytes, output.getvalue())
1923
 
 
1924
 
    def test_client_read_response_tuple_sets_response_status(self):
1925
 
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
1926
 
        input = StringIO(server_bytes)
1927
 
        output = StringIO()
1928
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1929
 
        request = client_medium.get_request()
1930
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1931
 
        smart_protocol.call('foo')
1932
 
        smart_protocol.read_response_tuple(False)
1933
 
        self.assertEqual(True, smart_protocol.response_status)
1934
 
 
1935
 
    def test_client_read_body_bytes_all(self):
1936
 
        # read_body_bytes should decode the body bytes from the wire into
1937
 
        # a response.
1938
 
        expected_bytes = "1234567"
1939
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
1940
 
                        "success\nok\n7\n1234567done\n")
1941
 
        input = StringIO(server_bytes)
1942
 
        output = StringIO()
1943
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1944
 
        request = client_medium.get_request()
1945
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1946
 
        smart_protocol.call('foo')
1947
 
        smart_protocol.read_response_tuple(True)
1948
 
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1949
 
 
1950
 
    def test_client_read_body_bytes_incremental(self):
1951
 
        # test reading a few bytes at a time from the body
1952
 
        # XXX: possibly we should test dribbling the bytes into the stringio
1953
 
        # to make the state machine work harder: however, as we use the
1954
 
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
1955
 
        # that.
1956
 
        expected_bytes = "1234567"
1957
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
1958
 
                        "success\nok\n7\n1234567done\n")
1959
 
        input = StringIO(server_bytes)
1960
 
        output = StringIO()
1961
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1962
 
        request = client_medium.get_request()
1963
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1964
 
        smart_protocol.call('foo')
1965
 
        smart_protocol.read_response_tuple(True)
1966
 
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
1967
 
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
1968
 
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
1969
 
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1970
 
 
1971
 
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1972
 
        # cancelling the expected body needs to finish the request, but not
1973
 
        # read any more bytes.
1974
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
1975
 
                        "success\nok\n7\n1234567done\n")
1976
 
        input = StringIO(server_bytes)
1977
 
        output = StringIO()
1978
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1979
 
        request = client_medium.get_request()
1980
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1981
 
        smart_protocol.call('foo')
1982
 
        smart_protocol.read_response_tuple(True)
1983
 
        smart_protocol.cancel_read_body()
1984
 
        self.assertEqual(len(protocol.RESPONSE_VERSION_TWO + 'success\nok\n'),
1985
 
                         input.tell())
1986
 
        self.assertRaises(
1987
 
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2134
    def test__send_response_with_body_stream_sets_finished_reading(self):
 
2135
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
2136
            None, lambda x: None)
 
2137
        self.assertEqual(1, smart_protocol.next_read_size())
 
2138
        smart_protocol._send_response(
 
2139
            _mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
 
2140
        self.assertEqual(0, smart_protocol.next_read_size())
1988
2141
 
1989
2142
    def test_streamed_body_bytes(self):
1990
2143
        body_header = 'chunked\n'
1995
2148
                        body_terminator)
1996
2149
        input = StringIO(server_bytes)
1997
2150
        output = StringIO()
1998
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
2151
        client_medium = medium.SmartSimplePipesClientMedium(
 
2152
            input, output, 'base')
1999
2153
        request = client_medium.get_request()
2000
2154
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2001
2155
        smart_protocol.call('foo')
2015
2169
                        "success\nok\n" + body)
2016
2170
        input = StringIO(server_bytes)
2017
2171
        output = StringIO()
2018
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
2172
        client_medium = medium.SmartSimplePipesClientMedium(
 
2173
            input, output, 'base')
2019
2174
        smart_request = client_medium.get_request()
2020
2175
        smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
2021
2176
        smart_protocol.call('foo')
2022
2177
        smart_protocol.read_response_tuple(True)
2023
2178
        expected_chunks = [
2024
2179
            'aaaa',
2025
 
            request.FailedSmartServerResponse(('error arg1', 'arg2'))]
 
2180
            _mod_request.FailedSmartServerResponse(('error arg1', 'arg2'))]
2026
2181
        stream = smart_protocol.read_streamed_body()
2027
2182
        self.assertEqual(expected_chunks, list(stream))
2028
2183
 
 
2184
    def test_streamed_body_bytes_interrupted_connection(self):
 
2185
        body_header = 'chunked\n'
 
2186
        incomplete_body_chunk = "9999\nincomplete chunk"
 
2187
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
2188
                        "success\nok\n" + body_header + incomplete_body_chunk)
 
2189
        input = StringIO(server_bytes)
 
2190
        output = StringIO()
 
2191
        client_medium = medium.SmartSimplePipesClientMedium(
 
2192
            input, output, 'base')
 
2193
        request = client_medium.get_request()
 
2194
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2195
        smart_protocol.call('foo')
 
2196
        smart_protocol.read_response_tuple(True)
 
2197
        stream = smart_protocol.read_streamed_body()
 
2198
        self.assertRaises(errors.ConnectionReset, stream.next)
 
2199
 
 
2200
    def test_client_read_response_tuple_sets_response_status(self):
 
2201
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
 
2202
        input = StringIO(server_bytes)
 
2203
        output = StringIO()
 
2204
        client_medium = medium.SmartSimplePipesClientMedium(
 
2205
            input, output, 'base')
 
2206
        request = client_medium.get_request()
 
2207
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2208
        smart_protocol.call('foo')
 
2209
        smart_protocol.read_response_tuple(False)
 
2210
        self.assertEqual(True, smart_protocol.response_status)
 
2211
 
 
2212
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
 
2213
        """read_response_tuple raises UnknownSmartMethod if the response says
 
2214
        the server did not recognise the request.
 
2215
        """
 
2216
        server_bytes = (
 
2217
            protocol.RESPONSE_VERSION_TWO +
 
2218
            "failed\n" +
 
2219
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
2220
        input = StringIO(server_bytes)
 
2221
        output = StringIO()
 
2222
        client_medium = medium.SmartSimplePipesClientMedium(
 
2223
            input, output, 'base')
 
2224
        request = client_medium.get_request()
 
2225
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2226
        smart_protocol.call('foo')
 
2227
        self.assertRaises(
 
2228
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
 
2229
        # The request has been finished.  There is no body to read, and
 
2230
        # attempts to read one will fail.
 
2231
        self.assertRaises(
 
2232
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2233
 
 
2234
 
 
2235
class TestSmartProtocolTwoSpecifics(
 
2236
        TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
 
2237
    """Tests for aspects of smart protocol version two that are unique to
 
2238
    version two.
 
2239
 
 
2240
    Thus tests involving body streams and success/failure markers belong here.
 
2241
    """
 
2242
 
 
2243
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
2244
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
2245
 
 
2246
 
 
2247
class TestVersionOneFeaturesInProtocolThree(
 
2248
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2249
    """Tests for version one smart protocol features as implemented by version
 
2250
    three.
 
2251
    """
 
2252
 
 
2253
    request_encoder = protocol.ProtocolThreeRequester
 
2254
    response_decoder = protocol.ProtocolThreeDecoder
 
2255
    # build_server_protocol_three is a function, so we can't set it as a class
 
2256
    # attribute directly, because then Python will assume it is actually a
 
2257
    # method.  So we make server_protocol_class be a static method, rather than
 
2258
    # simply doing:
 
2259
    # "server_protocol_class = protocol.build_server_protocol_three".
 
2260
    server_protocol_class = staticmethod(protocol.build_server_protocol_three)
 
2261
 
 
2262
    def setUp(self):
 
2263
        super(TestVersionOneFeaturesInProtocolThree, self).setUp()
 
2264
        self.response_marker = protocol.MESSAGE_VERSION_THREE
 
2265
        self.request_marker = protocol.MESSAGE_VERSION_THREE
 
2266
 
 
2267
    def test_construct_version_three_server_protocol(self):
 
2268
        smart_protocol = protocol.ProtocolThreeDecoder(None)
 
2269
        self.assertEqual('', smart_protocol.unused_data)
 
2270
        self.assertEqual('', smart_protocol._in_buffer)
 
2271
        self.assertFalse(smart_protocol._has_dispatched)
 
2272
        # The protocol starts by expecting four bytes, a length prefix for the
 
2273
        # headers.
 
2274
        self.assertEqual(4, smart_protocol.next_read_size())
 
2275
 
 
2276
 
 
2277
class NoOpRequest(_mod_request.SmartServerRequest):
 
2278
 
 
2279
    def do(self):
 
2280
        return _mod_request.SuccessfulSmartServerResponse(())
 
2281
 
 
2282
dummy_registry = {'ARG': NoOpRequest}
 
2283
 
 
2284
 
 
2285
class LoggingMessageHandler(object):
 
2286
 
 
2287
    def __init__(self):
 
2288
        self.event_log = []
 
2289
 
 
2290
    def _log(self, *args):
 
2291
        self.event_log.append(args)
 
2292
 
 
2293
    def headers_received(self, headers):
 
2294
        self._log('headers', headers)
 
2295
 
 
2296
    def protocol_error(self, exception):
 
2297
        self._log('protocol_error', exception)
 
2298
 
 
2299
    def byte_part_received(self, byte):
 
2300
        self._log('byte', byte)
 
2301
 
 
2302
    def bytes_part_received(self, bytes):
 
2303
        self._log('bytes', bytes)
 
2304
 
 
2305
    def structure_part_received(self, structure):
 
2306
        self._log('structure', structure)
 
2307
 
 
2308
    def end_received(self):
 
2309
        self._log('end')
 
2310
 
 
2311
 
 
2312
class TestProtocolThree(TestSmartProtocol):
 
2313
    """Tests for v3 of the server-side protocol."""
 
2314
 
 
2315
    request_encoder = protocol.ProtocolThreeRequester
 
2316
    response_decoder = protocol.ProtocolThreeDecoder
 
2317
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2318
 
 
2319
    def test_trivial_request(self):
 
2320
        """Smoke test for the simplest possible v3 request: empty headers, no
 
2321
        message parts.
 
2322
        """
 
2323
        output = StringIO()
 
2324
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2325
        end = 'e'
 
2326
        request_bytes = headers + end
 
2327
        smart_protocol = self.server_protocol_class(LoggingMessageHandler())
 
2328
        smart_protocol.accept_bytes(request_bytes)
 
2329
        self.assertEqual(0, smart_protocol.next_read_size())
 
2330
        self.assertEqual('', smart_protocol.unused_data)
 
2331
 
 
2332
    def make_protocol_expecting_message_part(self):
 
2333
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2334
        message_handler = LoggingMessageHandler()
 
2335
        smart_protocol = self.server_protocol_class(message_handler)
 
2336
        smart_protocol.accept_bytes(headers)
 
2337
        # Clear the event log
 
2338
        del message_handler.event_log[:]
 
2339
        return smart_protocol, message_handler.event_log
 
2340
 
 
2341
    def test_decode_one_byte(self):
 
2342
        """The protocol can decode a 'one byte' message part."""
 
2343
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2344
        smart_protocol.accept_bytes('ox')
 
2345
        self.assertEqual([('byte', 'x')], event_log)
 
2346
 
 
2347
    def test_decode_bytes(self):
 
2348
        """The protocol can decode a 'bytes' message part."""
 
2349
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2350
        smart_protocol.accept_bytes(
 
2351
            'b' # message part kind
 
2352
            '\0\0\0\x07' # length prefix
 
2353
            'payload' # payload
 
2354
            )
 
2355
        self.assertEqual([('bytes', 'payload')], event_log)
 
2356
 
 
2357
    def test_decode_structure(self):
 
2358
        """The protocol can decode a 'structure' message part."""
 
2359
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2360
        smart_protocol.accept_bytes(
 
2361
            's' # message part kind
 
2362
            '\0\0\0\x07' # length prefix
 
2363
            'l3:ARGe' # ['ARG']
 
2364
            )
 
2365
        self.assertEqual([('structure', ['ARG'])], event_log)
 
2366
 
 
2367
    def test_decode_multiple_bytes(self):
 
2368
        """The protocol can decode a multiple 'bytes' message parts."""
 
2369
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2370
        smart_protocol.accept_bytes(
 
2371
            'b' # message part kind
 
2372
            '\0\0\0\x05' # length prefix
 
2373
            'first' # payload
 
2374
            'b' # message part kind
 
2375
            '\0\0\0\x06'
 
2376
            'second'
 
2377
            )
 
2378
        self.assertEqual(
 
2379
            [('bytes', 'first'), ('bytes', 'second')], event_log)
 
2380
 
 
2381
 
 
2382
class TestConventionalResponseHandler(tests.TestCase):
 
2383
 
 
2384
    def make_response_handler(self, response_bytes):
 
2385
        from bzrlib.smart.message import ConventionalResponseHandler
 
2386
        response_handler = ConventionalResponseHandler()
 
2387
        protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2388
        # put decoder in desired state (waiting for message parts)
 
2389
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
 
2390
        output = StringIO()
 
2391
        client_medium = medium.SmartSimplePipesClientMedium(
 
2392
            StringIO(response_bytes), output, 'base')
 
2393
        medium_request = client_medium.get_request()
 
2394
        medium_request.finished_writing()
 
2395
        response_handler.setProtoAndMediumRequest(
 
2396
            protocol_decoder, medium_request)
 
2397
        return response_handler
 
2398
 
 
2399
    def test_body_stream_interrupted_by_error(self):
 
2400
        interrupted_body_stream = (
 
2401
            'oS' # successful response
 
2402
            's\0\0\0\x02le' # empty args
 
2403
            'b\0\0\0\x09chunk one' # first chunk
 
2404
            'b\0\0\0\x09chunk two' # second chunk
 
2405
            'oE' # error flag
 
2406
            's\0\0\0\x0el5:error3:abce' # bencoded error
 
2407
            'e' # message end
 
2408
            )
 
2409
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2410
        stream = response_handler.read_streamed_body()
 
2411
        self.assertEqual('chunk one', stream.next())
 
2412
        self.assertEqual('chunk two', stream.next())
 
2413
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
 
2414
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2415
 
 
2416
    def test_body_stream_interrupted_by_connection_lost(self):
 
2417
        interrupted_body_stream = (
 
2418
            'oS' # successful response
 
2419
            's\0\0\0\x02le' # empty args
 
2420
            'b\0\0\xff\xffincomplete chunk')
 
2421
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2422
        stream = response_handler.read_streamed_body()
 
2423
        self.assertRaises(errors.ConnectionReset, stream.next)
 
2424
 
 
2425
    def test_read_body_bytes_interrupted_by_connection_lost(self):
 
2426
        interrupted_body_stream = (
 
2427
            'oS' # successful response
 
2428
            's\0\0\0\x02le' # empty args
 
2429
            'b\0\0\xff\xffincomplete chunk')
 
2430
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2431
        self.assertRaises(
 
2432
            errors.ConnectionReset, response_handler.read_body_bytes)
 
2433
 
 
2434
 
 
2435
class TestMessageHandlerErrors(tests.TestCase):
 
2436
    """Tests for v3 that unrecognised (but well-formed) requests/responses are
 
2437
    still fully read off the wire, so that subsequent requests/responses on the
 
2438
    same medium can be decoded.
 
2439
    """
 
2440
 
 
2441
    def test_non_conventional_request(self):
 
2442
        """ConventionalRequestHandler (the default message handler on the
 
2443
        server side) will reject an unconventional message, but still consume
 
2444
        all the bytes of that message and signal when it has done so.
 
2445
 
 
2446
        This is what allows a server to continue to accept requests after the
 
2447
        client sends a completely unrecognised request.
 
2448
        """
 
2449
        # Define an invalid request (but one that is a well-formed message).
 
2450
        # This particular invalid request not only lacks the mandatory
 
2451
        # verb+args tuple, it has a single-byte part, which is forbidden.  In
 
2452
        # fact it has that part twice, to trigger multiple errors.
 
2453
        invalid_request = (
 
2454
            protocol.MESSAGE_VERSION_THREE +  # protocol version marker
 
2455
            '\0\0\0\x02de' + # empty headers
 
2456
            'oX' + # a single byte part: 'X'.  ConventionalRequestHandler will
 
2457
                   # error at this part.
 
2458
            'oX' + # and again.
 
2459
            'e' # end of message
 
2460
            )
 
2461
 
 
2462
        to_server = StringIO(invalid_request)
 
2463
        from_server = StringIO()
 
2464
        transport = memory.MemoryTransport('memory:///')
 
2465
        server = medium.SmartServerPipeStreamMedium(
 
2466
            to_server, from_server, transport)
 
2467
        proto = server._build_protocol()
 
2468
        message_handler = proto.message_handler
 
2469
        server._serve_one_request(proto)
 
2470
        # All the bytes have been read from the medium...
 
2471
        self.assertEqual('', to_server.read())
 
2472
        # ...and the protocol decoder has consumed all the bytes, and has
 
2473
        # finished reading.
 
2474
        self.assertEqual('', proto.unused_data)
 
2475
        self.assertEqual(0, proto.next_read_size())
 
2476
 
 
2477
 
 
2478
class InstrumentedRequestHandler(object):
 
2479
    """Test Double of SmartServerRequestHandler."""
 
2480
 
 
2481
    def __init__(self):
 
2482
        self.calls = []
 
2483
 
 
2484
    def body_chunk_received(self, chunk_bytes):
 
2485
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2486
 
 
2487
    def no_body_received(self):
 
2488
        self.calls.append(('no_body_received',))
 
2489
 
 
2490
    def prefixed_body_received(self, body_bytes):
 
2491
        self.calls.append(('prefixed_body_received', body_bytes))
 
2492
 
 
2493
    def end_received(self):
 
2494
        self.calls.append(('end_received',))
 
2495
 
 
2496
 
 
2497
class StubRequest(object):
 
2498
 
 
2499
    def finished_reading(self):
 
2500
        pass
 
2501
 
 
2502
 
 
2503
class TestClientDecodingProtocolThree(TestSmartProtocol):
 
2504
    """Tests for v3 of the client-side protocol decoding."""
 
2505
 
 
2506
    def make_logging_response_decoder(self):
 
2507
        """Make v3 response decoder using a test response handler."""
 
2508
        response_handler = LoggingMessageHandler()
 
2509
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2510
        return decoder, response_handler
 
2511
 
 
2512
    def make_conventional_response_decoder(self):
 
2513
        """Make v3 response decoder using a conventional response handler."""
 
2514
        response_handler = message.ConventionalResponseHandler()
 
2515
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2516
        response_handler.setProtoAndMediumRequest(decoder, StubRequest())
 
2517
        return decoder, response_handler
 
2518
 
 
2519
    def test_trivial_response_decoding(self):
 
2520
        """Smoke test for the simplest possible v3 response: empty headers,
 
2521
        status byte, empty args, no body.
 
2522
        """
 
2523
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2524
        response_status = 'oS' # success
 
2525
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2526
        end = 'e' # end marker
 
2527
        message_bytes = headers + response_status + args + end
 
2528
        decoder, response_handler = self.make_logging_response_decoder()
 
2529
        decoder.accept_bytes(message_bytes)
 
2530
        # The protocol decoder has finished, and consumed all bytes
 
2531
        self.assertEqual(0, decoder.next_read_size())
 
2532
        self.assertEqual('', decoder.unused_data)
 
2533
        # The message handler has been invoked with all the parts of the
 
2534
        # trivial response: empty headers, status byte, no args, end.
 
2535
        self.assertEqual(
 
2536
            [('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
 
2537
            response_handler.event_log)
 
2538
 
 
2539
    def test_incomplete_message(self):
 
2540
        """A decoder will keep signalling that it needs more bytes via
 
2541
        next_read_size() != 0 until it has seen a complete message, regardless
 
2542
        which state it is in.
 
2543
        """
 
2544
        # Define a simple response that uses all possible message parts.
 
2545
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2546
        response_status = 'oS' # success
 
2547
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2548
        body = 'b\0\0\0\x04BODY' # a body: 'BODY'
 
2549
        end = 'e' # end marker
 
2550
        simple_response = headers + response_status + args + body + end
 
2551
        # Feed the request to the decoder one byte at a time.
 
2552
        decoder, response_handler = self.make_logging_response_decoder()
 
2553
        for byte in simple_response:
 
2554
            self.assertNotEqual(0, decoder.next_read_size())
 
2555
            decoder.accept_bytes(byte)
 
2556
        # Now the response is complete
 
2557
        self.assertEqual(0, decoder.next_read_size())
 
2558
 
 
2559
    def test_read_response_tuple_raises_UnknownSmartMethod(self):
 
2560
        """read_response_tuple raises UnknownSmartMethod if the server replied
 
2561
        with 'UnknownMethod'.
 
2562
        """
 
2563
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2564
        response_status = 'oE' # error flag
 
2565
        # args: ('UnknownMethod', 'method-name')
 
2566
        args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
 
2567
        end = 'e' # end marker
 
2568
        message_bytes = headers + response_status + args + end
 
2569
        decoder, response_handler = self.make_conventional_response_decoder()
 
2570
        decoder.accept_bytes(message_bytes)
 
2571
        error = self.assertRaises(
 
2572
            errors.UnknownSmartMethod, response_handler.read_response_tuple)
 
2573
        self.assertEqual('method-name', error.verb)
 
2574
 
 
2575
    def test_read_response_tuple_error(self):
 
2576
        """If the response has an error, it is raised as an exception."""
 
2577
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2578
        response_status = 'oE' # error
 
2579
        args = 's\0\0\0\x1al9:first arg10:second arge' # two args
 
2580
        end = 'e' # end marker
 
2581
        message_bytes = headers + response_status + args + end
 
2582
        decoder, response_handler = self.make_conventional_response_decoder()
 
2583
        decoder.accept_bytes(message_bytes)
 
2584
        error = self.assertRaises(
 
2585
            errors.ErrorFromSmartServer, response_handler.read_response_tuple)
 
2586
        self.assertEqual(('first arg', 'second arg'), error.error_tuple)
 
2587
 
 
2588
 
 
2589
class TestClientEncodingProtocolThree(TestSmartProtocol):
 
2590
 
 
2591
    request_encoder = protocol.ProtocolThreeRequester
 
2592
    response_decoder = protocol.ProtocolThreeDecoder
 
2593
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2594
 
 
2595
    def make_client_encoder_and_output(self):
 
2596
        result = self.make_client_protocol_and_output()
 
2597
        requester, response_handler, output = result
 
2598
        return requester, output
 
2599
 
 
2600
    def test_call_smoke_test(self):
 
2601
        """A smoke test for ProtocolThreeRequester.call.
 
2602
 
 
2603
        This test checks that a particular simple invocation of call emits the
 
2604
        correct bytes for that invocation.
 
2605
        """
 
2606
        requester, output = self.make_client_encoder_and_output()
 
2607
        requester.set_headers({'header name': 'header value'})
 
2608
        requester.call('one arg')
 
2609
        self.assertEquals(
 
2610
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2611
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2612
            's\x00\x00\x00\x0bl7:one arge' # args
 
2613
            'e', # end
 
2614
            output.getvalue())
 
2615
 
 
2616
    def test_call_with_body_bytes_smoke_test(self):
 
2617
        """A smoke test for ProtocolThreeRequester.call_with_body_bytes.
 
2618
 
 
2619
        This test checks that a particular simple invocation of
 
2620
        call_with_body_bytes emits the correct bytes for that invocation.
 
2621
        """
 
2622
        requester, output = self.make_client_encoder_and_output()
 
2623
        requester.set_headers({'header name': 'header value'})
 
2624
        requester.call_with_body_bytes(('one arg',), 'body bytes')
 
2625
        self.assertEquals(
 
2626
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2627
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2628
            's\x00\x00\x00\x0bl7:one arge' # args
 
2629
            'b' # there is a prefixed body
 
2630
            '\x00\x00\x00\nbody bytes' # the prefixed body
 
2631
            'e', # end
 
2632
            output.getvalue())
 
2633
 
 
2634
    def test_call_writes_just_once(self):
 
2635
        """A bodyless request is written to the medium all at once."""
 
2636
        medium_request = StubMediumRequest()
 
2637
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2638
        encoder.call('arg1', 'arg2', 'arg3')
 
2639
        self.assertEqual(
 
2640
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2641
 
 
2642
    def test_call_with_body_bytes_writes_just_once(self):
 
2643
        """A request with body bytes is written to the medium all at once."""
 
2644
        medium_request = StubMediumRequest()
 
2645
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2646
        encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
 
2647
        self.assertEqual(
 
2648
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2649
 
 
2650
 
 
2651
class StubMediumRequest(object):
 
2652
    """A stub medium request that tracks the number of times accept_bytes is
 
2653
    called.
 
2654
    """
 
2655
 
 
2656
    def __init__(self):
 
2657
        self.calls = []
 
2658
        self._medium = 'dummy medium'
 
2659
 
 
2660
    def accept_bytes(self, bytes):
 
2661
        self.calls.append('accept_bytes')
 
2662
 
 
2663
    def finished_writing(self):
 
2664
        self.calls.append('finished_writing')
 
2665
 
 
2666
 
 
2667
class TestResponseEncodingProtocolThree(tests.TestCase):
 
2668
 
 
2669
    def make_response_encoder(self):
 
2670
        out_stream = StringIO()
 
2671
        response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
 
2672
        return response_encoder, out_stream
 
2673
 
 
2674
    def test_send_error_unknown_method(self):
 
2675
        encoder, out_stream = self.make_response_encoder()
 
2676
        encoder.send_error(errors.UnknownSmartMethod('method name'))
 
2677
        # Use assertEndsWith so that we don't compare the header, which varies
 
2678
        # by bzrlib.__version__.
 
2679
        self.assertEndsWith(
 
2680
            out_stream.getvalue(),
 
2681
            # error status
 
2682
            'oE' +
 
2683
            # tuple: 'UnknownMethod', 'method name'
 
2684
            's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
 
2685
            # end of message
 
2686
            'e')
 
2687
 
 
2688
 
 
2689
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
 
2690
    """Tests for buffering of responses.
 
2691
 
 
2692
    We want to avoid doing many small writes when one would do, to avoid
 
2693
    unnecessary network overhead.
 
2694
    """
 
2695
 
 
2696
    def setUp(self):
 
2697
        self.writes = []
 
2698
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
 
2699
 
 
2700
    def assertWriteCount(self, expected_count):
 
2701
        self.assertEqual(
 
2702
            expected_count, len(self.writes),
 
2703
            "Too many writes: %r" % (self.writes,))
 
2704
        
 
2705
    def test_send_error_writes_just_once(self):
 
2706
        """An error response is written to the medium all at once."""
 
2707
        self.responder.send_error(Exception('An exception string.'))
 
2708
        self.assertWriteCount(1)
 
2709
 
 
2710
    def test_send_response_writes_just_once(self):
 
2711
        """A normal response with no body is written to the medium all at once.
 
2712
        """
 
2713
        response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
 
2714
        self.responder.send_response(response)
 
2715
        self.assertWriteCount(1)
 
2716
 
 
2717
    def test_send_response_with_body_writes_just_once(self):
 
2718
        """A normal response with a monolithic body is written to the medium
 
2719
        all at once.
 
2720
        """
 
2721
        response = _mod_request.SuccessfulSmartServerResponse(
 
2722
            ('arg', 'arg'), body='body bytes')
 
2723
        self.responder.send_response(response)
 
2724
        self.assertWriteCount(1)
 
2725
 
 
2726
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
 
2727
        """A normal response with a stream body is written to the medium
 
2728
        writes to the medium once per chunk.
 
2729
        """
 
2730
        # Construct a response with stream with 2 chunks in it.
 
2731
        response = _mod_request.SuccessfulSmartServerResponse(
 
2732
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
 
2733
        self.responder.send_response(response)
 
2734
        # We will write 3 times: exactly once for each chunk, plus a final
 
2735
        # write to end the response.
 
2736
        self.assertWriteCount(3)
 
2737
 
2029
2738
 
2030
2739
class TestSmartClientUnicode(tests.TestCase):
2031
2740
    """_SmartClient tests for unicode arguments.
2046
2755
        """
2047
2756
        input = StringIO("\n")
2048
2757
        output = StringIO()
2049
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
2758
        client_medium = medium.SmartSimplePipesClientMedium(
 
2759
            input, output, 'ignored base')
2050
2760
        smart_client = client._SmartClient(client_medium)
2051
2761
        self.assertRaises(TypeError,
2052
2762
            smart_client.call_with_body_bytes, method, args, body)
2064
2774
        self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
2065
2775
 
2066
2776
 
 
2777
class MockMedium(object):
 
2778
    """A mock medium that can be used to test _SmartClient.
 
2779
    
 
2780
    It can be given a series of requests to expect (and responses it should
 
2781
    return for them).  It can also be told when the client is expected to
 
2782
    disconnect a medium.  Expectations must be satisfied in the order they are
 
2783
    given, or else an AssertionError will be raised.
 
2784
 
 
2785
    Typical use looks like::
 
2786
 
 
2787
        medium = MockMedium()
 
2788
        medium.expect_request(...)
 
2789
        medium.expect_request(...)
 
2790
        medium.expect_request(...)
 
2791
    """
 
2792
 
 
2793
    def __init__(self):
 
2794
        self.base = 'dummy base'
 
2795
        self._mock_request = _MockMediumRequest(self)
 
2796
        self._expected_events = []
 
2797
        self._protocol_version = None
 
2798
        
 
2799
    def expect_request(self, request_bytes, response_bytes,
 
2800
                       allow_partial_read=False):
 
2801
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
 
2802
 
 
2803
        No assumption is made about how many times accept_bytes should be
 
2804
        called to send the request.  Similarly, no assumption is made about how
 
2805
        many times read_bytes/read_line are called by protocol code to read a
 
2806
        response.  e.g.::
 
2807
        
 
2808
            request.accept_bytes('ab')
 
2809
            request.accept_bytes('cd')
 
2810
            request.finished_writing()
 
2811
 
 
2812
        and::
 
2813
        
 
2814
            request.accept_bytes('abcd')
 
2815
            request.finished_writing()
 
2816
 
 
2817
        Will both satisfy ``medium.expect_request('abcd', ...)``.  Thus tests
 
2818
        using this should not break due to irrelevant changes in protocol
 
2819
        implementations.
 
2820
 
 
2821
        :param allow_partial_read: if True, no assertion is raised if a
 
2822
            response is not fully read.  Setting this is useful when the client
 
2823
            is expected to disconnect without needing to read the complete
 
2824
            response.  Default is False.
 
2825
        """
 
2826
        self._expected_events.append(('send request', request_bytes))
 
2827
        if allow_partial_read:
 
2828
            self._expected_events.append(
 
2829
                ('read response (partial)', response_bytes))
 
2830
        else:
 
2831
            self._expected_events.append(('read response', response_bytes))
 
2832
 
 
2833
    def expect_disconnect(self):
 
2834
        """Expect the client to call ``medium.disconnect()``."""
 
2835
        self._expected_events.append('disconnect')
 
2836
 
 
2837
    def _assertEvent(self, observed_event):
 
2838
        """Raise AssertionError unless observed_event matches the next expected
 
2839
        event.
 
2840
 
 
2841
        :seealso: expect_request
 
2842
        :seealso: expect_disconnect
 
2843
        """
 
2844
        try:
 
2845
            expected_event = self._expected_events.pop(0)
 
2846
        except IndexError:
 
2847
            raise AssertionError(
 
2848
                'Mock medium observed event %r, but no more events expected'
 
2849
                % (observed_event,))
 
2850
        if expected_event[0] == 'read response (partial)':
 
2851
            if observed_event[0] != 'read response':
 
2852
                raise AssertionError(
 
2853
                    'Mock medium observed event %r, but expected event %r'
 
2854
                    % (observed_event, expected_event))
 
2855
        elif observed_event != expected_event:
 
2856
            raise AssertionError(
 
2857
                'Mock medium observed event %r, but expected event %r'
 
2858
                % (observed_event, expected_event))
 
2859
        if self._expected_events:
 
2860
            next_event = self._expected_events[0]
 
2861
            if next_event[0].startswith('read response'):
 
2862
                self._mock_request._response = next_event[1]
 
2863
 
 
2864
    def get_request(self):
 
2865
        return self._mock_request
 
2866
 
 
2867
    def disconnect(self):
 
2868
        if self._mock_request._read_bytes:
 
2869
            self._assertEvent(('read response', self._mock_request._read_bytes))
 
2870
            self._mock_request._read_bytes = ''
 
2871
        self._assertEvent('disconnect')
 
2872
 
 
2873
 
 
2874
class _MockMediumRequest(object):
 
2875
    """A mock ClientMediumRequest used by MockMedium."""
 
2876
 
 
2877
    def __init__(self, mock_medium):
 
2878
        self._medium = mock_medium
 
2879
        self._written_bytes = ''
 
2880
        self._read_bytes = ''
 
2881
        self._response = None
 
2882
 
 
2883
    def accept_bytes(self, bytes):
 
2884
        self._written_bytes += bytes
 
2885
 
 
2886
    def finished_writing(self):
 
2887
        self._medium._assertEvent(('send request', self._written_bytes))
 
2888
        self._written_bytes = ''
 
2889
 
 
2890
    def finished_reading(self):
 
2891
        self._medium._assertEvent(('read response', self._read_bytes))
 
2892
        self._read_bytes = ''
 
2893
 
 
2894
    def read_bytes(self, size):
 
2895
        resp = self._response
 
2896
        bytes, resp = resp[:size], resp[size:]
 
2897
        self._response = resp
 
2898
        self._read_bytes += bytes
 
2899
        return bytes
 
2900
 
 
2901
    def read_line(self):
 
2902
        resp = self._response
 
2903
        try:
 
2904
            line, resp = resp.split('\n', 1)
 
2905
            line += '\n'
 
2906
        except ValueError:
 
2907
            line, resp = resp, ''
 
2908
        self._response = resp
 
2909
        self._read_bytes += line
 
2910
        return line
 
2911
 
 
2912
 
 
2913
class Test_SmartClientVersionDetection(tests.TestCase):
 
2914
    """Tests for _SmartClient's automatic protocol version detection.
 
2915
 
 
2916
    On the first remote call, _SmartClient will keep retrying the request with
 
2917
    different protocol versions until it finds one that works.
 
2918
    """
 
2919
 
 
2920
    def test_version_three_server(self):
 
2921
        """With a protocol 3 server, only one request is needed."""
 
2922
        medium = MockMedium()
 
2923
        smart_client = client._SmartClient(medium, headers={})
 
2924
        message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
 
2925
        medium.expect_request(
 
2926
            message_start +
 
2927
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2928
            message_start + 's\0\0\0\x13l14:response valueee')
 
2929
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2930
        # The call succeeded without raising any exceptions from the mock
 
2931
        # medium, and the smart_client returns the response from the server.
 
2932
        self.assertEqual(('response value',), result)
 
2933
        self.assertEqual([], medium._expected_events)
 
2934
 
 
2935
    def test_version_two_server(self):
 
2936
        """If the server only speaks protocol 2, the client will first try
 
2937
        version 3, then fallback to protocol 2.
 
2938
 
 
2939
        Further, _SmartClient caches the detection, so future requests will all
 
2940
        use protocol 2 immediately.
 
2941
        """
 
2942
        medium = MockMedium()
 
2943
        smart_client = client._SmartClient(medium, headers={})
 
2944
        # First the client should send a v3 request, but the server will reply
 
2945
        # with a v2 error.
 
2946
        medium.expect_request(
 
2947
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2948
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2949
            'bzr response 2\nfailed\n\n')
 
2950
        # So then the client should disconnect to reset the connection, because
 
2951
        # the client needs to assume the server cannot read any further
 
2952
        # requests off the original connection.
 
2953
        medium.expect_disconnect()
 
2954
        # The client should then retry the original request in v2
 
2955
        medium.expect_request(
 
2956
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2957
            'bzr response 2\nsuccess\nresponse value\n')
 
2958
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2959
        # The smart_client object will return the result of the successful
 
2960
        # query.
 
2961
        self.assertEqual(('response value',), result)
 
2962
 
 
2963
        # Now try another request, and this time the client will just use
 
2964
        # protocol 2.  (i.e. the autodetection won't be repeated)
 
2965
        medium.expect_request(
 
2966
            'bzr request 2\nanother-method\n',
 
2967
            'bzr response 2\nsuccess\nanother response\n')
 
2968
        result = smart_client.call('another-method')
 
2969
        self.assertEqual(('another response',), result)
 
2970
        self.assertEqual([], medium._expected_events)
 
2971
 
 
2972
    def test_unknown_version(self):
 
2973
        """If the server does not use any known (or at least supported)
 
2974
        protocol version, a SmartProtocolError is raised.
 
2975
        """
 
2976
        medium = MockMedium()
 
2977
        smart_client = client._SmartClient(medium, headers={})
 
2978
        unknown_protocol_bytes = 'Unknown protocol!'
 
2979
        # The client will try v3 and v2 before eventually giving up.
 
2980
        medium.expect_request(
 
2981
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2982
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2983
            unknown_protocol_bytes)
 
2984
        medium.expect_disconnect()
 
2985
        medium.expect_request(
 
2986
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2987
            unknown_protocol_bytes)
 
2988
        medium.expect_disconnect()
 
2989
        self.assertRaises(
 
2990
            errors.SmartProtocolError,
 
2991
            smart_client.call, 'method-name', 'arg 1', 'arg 2')
 
2992
        self.assertEqual([], medium._expected_events)
 
2993
 
 
2994
    def test_first_response_is_error(self):
 
2995
        """If the server replies with an error, then the version detection
 
2996
        should be complete.
 
2997
        
 
2998
        This test is very similar to test_version_two_server, but catches a bug
 
2999
        we had in the case where the first reply was an error response.
 
3000
        """
 
3001
        medium = MockMedium()
 
3002
        smart_client = client._SmartClient(medium, headers={})
 
3003
        message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
 
3004
        # Issue a request that gets an error reply in a non-default protocol
 
3005
        # version.
 
3006
        medium.expect_request(
 
3007
            message_start +
 
3008
            's\x00\x00\x00\x10l11:method-nameee',
 
3009
            'bzr response 2\nfailed\n\n')
 
3010
        medium.expect_disconnect()
 
3011
        medium.expect_request(
 
3012
            'bzr request 2\nmethod-name\n',
 
3013
            'bzr response 2\nfailed\nFooBarError\n')
 
3014
        err = self.assertRaises(
 
3015
            errors.ErrorFromSmartServer,
 
3016
            smart_client.call, 'method-name')
 
3017
        self.assertEqual(('FooBarError',), err.error_tuple)
 
3018
        # Now the medium should have remembered the protocol version, so
 
3019
        # subsequent requests will use the remembered version immediately.
 
3020
        medium.expect_request(
 
3021
            'bzr request 2\nmethod-name\n',
 
3022
            'bzr response 2\nsuccess\nresponse value\n')
 
3023
        result = smart_client.call('method-name')
 
3024
        self.assertEqual(('response value',), result)
 
3025
        self.assertEqual([], medium._expected_events)
 
3026
 
 
3027
 
 
3028
class Test_SmartClient(tests.TestCase):
 
3029
 
 
3030
    def test_call_default_headers(self):
 
3031
        """ProtocolThreeRequester.call by default sends a 'Software
 
3032
        version' header.
 
3033
        """
 
3034
        smart_client = client._SmartClient('dummy medium')
 
3035
        self.assertEqual(
 
3036
            bzrlib.__version__, smart_client._headers['Software version'])
 
3037
        # XXX: need a test that smart_client._headers is passed to the request
 
3038
        # encoder.
 
3039
 
 
3040
 
2067
3041
class LengthPrefixedBodyDecoder(tests.TestCase):
2068
3042
 
2069
3043
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
2283
3257
        decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish)
2284
3258
        self.assertTrue(decoder.finished_reading)
2285
3259
        self.assertEqual('first chunk', decoder.read_next_chunk())
2286
 
        expected_failure = request.FailedSmartServerResponse(
 
3260
        expected_failure = _mod_request.FailedSmartServerResponse(
2287
3261
            ('part1', 'part2'))
2288
3262
        self.assertEqual(expected_failure, decoder.read_next_chunk())
2289
3263
 
2299
3273
class TestSuccessfulSmartServerResponse(tests.TestCase):
2300
3274
 
2301
3275
    def test_construct_no_body(self):
2302
 
        response = request.SuccessfulSmartServerResponse(('foo', 'bar'))
 
3276
        response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'))
2303
3277
        self.assertEqual(('foo', 'bar'), response.args)
2304
3278
        self.assertEqual(None, response.body)
2305
3279
 
2306
3280
    def test_construct_with_body(self):
2307
 
        response = request.SuccessfulSmartServerResponse(
2308
 
            ('foo', 'bar'), 'bytes')
 
3281
        response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'),
 
3282
                                                              'bytes')
2309
3283
        self.assertEqual(('foo', 'bar'), response.args)
2310
3284
        self.assertEqual('bytes', response.body)
2311
3285
        # repr(response) doesn't trigger exceptions.
2313
3287
 
2314
3288
    def test_construct_with_body_stream(self):
2315
3289
        bytes_iterable = ['abc']
2316
 
        response = request.SuccessfulSmartServerResponse(
 
3290
        response = _mod_request.SuccessfulSmartServerResponse(
2317
3291
            ('foo', 'bar'), body_stream=bytes_iterable)
2318
3292
        self.assertEqual(('foo', 'bar'), response.args)
2319
3293
        self.assertEqual(bytes_iterable, response.body_stream)
2322
3296
        """'body' and 'body_stream' are mutually exclusive."""
2323
3297
        self.assertRaises(
2324
3298
            errors.BzrError,
2325
 
            request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
 
3299
            _mod_request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
2326
3300
 
2327
3301
    def test_is_successful(self):
2328
3302
        """is_successful should return True for SuccessfulSmartServerResponse."""
2329
 
        response = request.SuccessfulSmartServerResponse(('error',))
 
3303
        response = _mod_request.SuccessfulSmartServerResponse(('error',))
2330
3304
        self.assertEqual(True, response.is_successful())
2331
3305
 
2332
3306
 
2333
3307
class TestFailedSmartServerResponse(tests.TestCase):
2334
3308
 
2335
3309
    def test_construct(self):
2336
 
        response = request.FailedSmartServerResponse(('foo', 'bar'))
 
3310
        response = _mod_request.FailedSmartServerResponse(('foo', 'bar'))
2337
3311
        self.assertEqual(('foo', 'bar'), response.args)
2338
3312
        self.assertEqual(None, response.body)
2339
 
        response = request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
 
3313
        response = _mod_request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
2340
3314
        self.assertEqual(('foo', 'bar'), response.args)
2341
3315
        self.assertEqual('bytes', response.body)
2342
3316
        # repr(response) doesn't trigger exceptions.
2344
3318
 
2345
3319
    def test_is_successful(self):
2346
3320
        """is_successful should return False for FailedSmartServerResponse."""
2347
 
        response = request.FailedSmartServerResponse(('error',))
 
3321
        response = _mod_request.FailedSmartServerResponse(('error',))
2348
3322
        self.assertEqual(False, response.is_successful())
2349
3323
 
2350
3324
 
2388
3362
        self.assertEqual(base_transport._http_transport,
2389
3363
                         new_transport._http_transport)
2390
3364
        self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
 
3365
        self.assertEqual(
 
3366
            'child_dir/',
 
3367
            new_transport._client.remote_path_from_transport(new_transport))
2391
3368
 
2392
3369
    def test_remote_path_unnormal_base(self):
2393
3370
        # If the transport's base isn't normalised, the _remote_path should
2401
3378
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
2402
3379
        new_transport = base_transport.clone('c')
2403
3380
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
3381
        self.assertEqual(
 
3382
            'c/',
 
3383
            new_transport._client.remote_path_from_transport(new_transport))
2404
3384
 
2405
3385
        
2406
3386
# TODO: Client feature that does get_bundle and then installs that into a