~ubuntu-branches/ubuntu/precise/pyzmq/precise

« back to all changes in this revision

Viewing changes to zmq/tests/test_message.py

  • Committer: Bazaar Package Importer
  • Author(s): Piotr Ożarowski
  • Date: 2011-02-15 09:08:36 UTC
  • mfrom: (2.1.2 experimental)
  • Revision ID: james.westby@ubuntu.com-20110215090836-phh4slym1g6muucn
Tags: 2.0.10.1-2
* Team upload.
* Upload to unstable
* Add Breaks: ${python:Breaks}

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
# -*- coding: utf8 -*-
1
3
#
2
4
#    Copyright (c) 2010 Brian E. Granger
3
5
#
22
24
#-----------------------------------------------------------------------------
23
25
 
24
26
import copy
 
27
import sys
25
28
from sys import getrefcount as grc
26
29
import time
 
30
from pprint import pprint
27
31
from unittest import TestCase
28
32
 
29
33
import zmq
30
 
from zmq.tests import PollZMQTestCase
 
34
from zmq.tests import BaseZMQTestCase, SkipTest
 
35
from zmq.utils.strtypes import unicode,bytes
31
36
 
32
37
#-----------------------------------------------------------------------------
33
38
# Tests
34
39
#-----------------------------------------------------------------------------
35
40
 
36
 
class TestMessage(TestCase):
 
41
x = 'x'.encode()
 
42
 
 
43
class TestMessage(BaseZMQTestCase):
37
44
 
38
45
    def test_above_30(self):
39
46
        """Message above 30 bytes are never copied by 0MQ."""
40
47
        for i in range(5, 16):  # 32, 64,..., 65536
41
 
            s = (2**i)*'x'
 
48
            s = (2**i)*x
42
49
            self.assertEquals(grc(s), 2)
43
50
            m = zmq.Message(s)
44
51
            self.assertEquals(grc(s), 4)
49
56
    def test_str(self):
50
57
        """Test the str representations of the Messages."""
51
58
        for i in range(16):
52
 
            s = (2**i)*'x'
53
 
            m = zmq.Message(s)
54
 
            self.assertEquals(s, str(s))
55
 
            self.assert_(s is str(s))
 
59
            s = (2**i)*x
 
60
            m = zmq.Message(s)
 
61
            self.assertEquals(s, str(m).encode())
 
62
 
 
63
    def test_bytes(self):
 
64
        """Test the Message.bytes property."""
 
65
        for i in range(1,16):
 
66
            s = (2**i)*x
 
67
            m = zmq.Message(s)
 
68
            b = m.bytes
 
69
            self.assertEquals(s, m.bytes)
 
70
            # check that it copies
 
71
            self.assert_(b is not s)
 
72
            # check that it copies only once
 
73
            self.assert_(b is m.bytes)
 
74
 
 
75
    def test_unicode(self):
 
76
        """Test the unicode representations of the Messages."""
 
77
        s = unicode('asdf')
 
78
        self.assertRaises(TypeError, zmq.Message, s)
 
79
        u = '§'
 
80
        if str is not unicode:
 
81
            u = u.decode('utf8')
 
82
        for i in range(16):
 
83
            s = (2**i)*u
 
84
            m = zmq.Message(s.encode('utf8'))
 
85
            self.assertEquals(s, unicode(m.bytes,'utf8'))
56
86
 
57
87
    def test_len(self):
58
88
        """Test the len of the Messages."""
59
89
        for i in range(16):
60
 
            s = (2**i)*'x'
 
90
            s = (2**i)*x
61
91
            m = zmq.Message(s)
62
 
            self.assertEquals(len(s), len(s))
 
92
            self.assertEquals(len(s), len(m))
63
93
 
64
94
    def test_lifecycle1(self):
65
95
        """Run through a ref counting cycle with a copy."""
 
96
        try:
 
97
            view = memoryview
 
98
        except NameError:
 
99
            view = type(None)
66
100
        for i in range(5, 16):  # 32, 64,..., 65536
67
 
            s = (2**i)*'x'
68
 
            self.assertEquals(grc(s), 2)
 
101
            s = (2**i)*x
 
102
            rc = 2
 
103
            self.assertEquals(grc(s), rc)
69
104
            m = zmq.Message(s)
70
 
            self.assertEquals(grc(s), 4)
 
105
            rc += 2
 
106
            self.assertEquals(grc(s), rc)
71
107
            m2 = copy.copy(m)
72
 
            self.assertEquals(grc(s), 5)
73
 
            self.assertEquals(s, str(m))
74
 
            self.assertEquals(s, str(m2))
75
 
            self.assert_(s is str(m))
76
 
            self.assert_(s is str(m2))
 
108
            rc += 1
 
109
            self.assertEquals(grc(s), rc)
 
110
            b = m2.buffer
 
111
            extra = int(isinstance(b,view))
 
112
            # memoryview incs by 2
 
113
            # buffer by 1
 
114
            rc += 1+extra
 
115
            self.assertEquals(grc(s), rc)
 
116
 
 
117
            self.assertEquals(s, str(m).encode())
 
118
            self.assertEquals(s, str(m2).encode())
 
119
            self.assertEquals(s, m.bytes)
 
120
            # self.assert_(s is str(m))
 
121
            # self.assert_(s is str(m2))
77
122
            del m2
78
 
            self.assertEquals(grc(s), 4)
 
123
            rc -= 1
 
124
            self.assertEquals(grc(s), rc)
 
125
            rc -= 1+extra
 
126
            del b
 
127
            self.assertEquals(grc(s), rc)
79
128
            del m
80
 
            self.assertEquals(grc(s), 2)
 
129
            rc -= 2
 
130
            self.assertEquals(grc(s), rc)
 
131
            self.assertEquals(rc, 2)
81
132
            del s
82
133
 
83
134
    def test_lifecycle2(self):
84
135
        """Run through a different ref counting cycle with a copy."""
 
136
        try:
 
137
            view = memoryview
 
138
        except NameError:
 
139
            view = type(None)
85
140
        for i in range(5, 16):  # 32, 64,..., 65536
86
 
            s = (2**i)*'x'
87
 
            self.assertEquals(grc(s), 2)
 
141
            s = (2**i)*x
 
142
            rc = 2
 
143
            self.assertEquals(grc(s), rc)
88
144
            m = zmq.Message(s)
89
 
            self.assertEquals(grc(s), 4)
 
145
            rc += 2
 
146
            self.assertEquals(grc(s), rc)
90
147
            m2 = copy.copy(m)
91
 
            self.assertEquals(grc(s), 5)
92
 
            self.assertEquals(s, str(m))
93
 
            self.assertEquals(s, str(m2))
94
 
            self.assert_(s is str(m))
95
 
            self.assert_(s is str(m2))
 
148
            rc += 1
 
149
            self.assertEquals(grc(s), rc)
 
150
            b = m.buffer
 
151
            extra = int(isinstance(b,view))
 
152
            rc += 1+extra
 
153
            self.assertEquals(grc(s), rc)
 
154
            self.assertEquals(s, str(m).encode())
 
155
            self.assertEquals(s, str(m2).encode())
 
156
            self.assertEquals(s, m2.bytes)
 
157
            self.assertEquals(s, m.bytes)
 
158
            # self.assert_(s is str(m))
 
159
            # self.assert_(s is str(m2))
 
160
            del b
 
161
            self.assertEquals(grc(s), rc)
96
162
            del m
97
 
            self.assertEquals(grc(s), 4)
 
163
            # m.buffer is kept until m is del'd
 
164
            rc -= 1+extra
 
165
            rc -= 1
 
166
            self.assertEquals(grc(s), rc)
98
167
            del m2
99
 
            self.assertEquals(grc(s), 2)
 
168
            rc -= 2
 
169
            self.assertEquals(grc(s), rc)
 
170
            self.assertEquals(rc, 2)
100
171
            del s
 
172
    
 
173
    def test_tracker(self):
 
174
        m = zmq.Message('asdf'.encode(), track=True)
 
175
        self.assertFalse(m.done)
 
176
        pm = zmq.MessageTracker(m)
 
177
        self.assertFalse(pm.done)
 
178
        del m
 
179
        self.assertTrue(pm.done)
 
180
    
 
181
    def test_no_tracker(self):
 
182
        m = zmq.Message('asdf'.encode(), track=False)
 
183
        self.assertRaises(ValueError, getattr, m, 'done')
 
184
        m2 = copy.copy(m)
 
185
        self.assertRaises(ValueError, getattr, m2, 'done')
 
186
        self.assertRaises(ValueError, zmq.MessageTracker, m)
 
187
    
 
188
    def test_multi_tracker(self):
 
189
        m = zmq.Message('asdf'.encode(), track=True)
 
190
        m2 = zmq.Message('whoda'.encode(), track=True)
 
191
        mt = zmq.MessageTracker(m,m2)
 
192
        self.assertFalse(m.done)
 
193
        self.assertFalse(mt.done)
 
194
        self.assertRaises(zmq.NotDone, mt.wait, 0.1)
 
195
        del m
 
196
        time.sleep(0.1)
 
197
        self.assertRaises(zmq.NotDone, mt.wait, 0.1)
 
198
        self.assertFalse(mt.done)
 
199
        del m2
 
200
        self.assertTrue(mt.wait() is None)
 
201
        self.assertTrue(mt.done)
 
202
        
 
203
    
 
204
    def test_buffer_in(self):
 
205
        """test using a buffer as input"""
 
206
        try:
 
207
            view = memoryview
 
208
        except NameError:
 
209
            view = buffer
 
210
        if unicode is str:
 
211
            ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode('utf8')
 
212
        else:
 
213
            ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√"
 
214
        m = zmq.Message(view(ins))
 
215
    
 
216
    def test_bad_buffer_in(self):
 
217
        """test using a bad object"""
 
218
        self.assertRaises(TypeError, zmq.Message, 5)
 
219
        self.assertRaises(TypeError, zmq.Message, object())
 
220
        
 
221
    def test_buffer_out(self):
 
222
        """receiving buffered output"""
 
223
        try:
 
224
            view = memoryview
 
225
        except NameError:
 
226
            view = buffer
 
227
        if unicode is str:
 
228
            ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode('utf8')
 
229
        else:
 
230
            ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√"
 
231
        m = zmq.Message(ins)
 
232
        outb = m.buffer
 
233
        self.assertTrue(isinstance(outb, view))
 
234
        self.assert_(outb is m.buffer)
 
235
        self.assert_(m.buffer is m.buffer)
 
236
    
 
237
    def test_multisend(self):
 
238
        """ensure that a message remains intact after multiple sends"""
 
239
        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
 
240
        s = "message".encode()
 
241
        m = zmq.Message(s)
 
242
        self.assertEquals(s, m.bytes)
 
243
        
 
244
        a.send(m, copy=False)
 
245
        time.sleep(0.1)
 
246
        self.assertEquals(s, m.bytes)
 
247
        a.send(m, copy=False)
 
248
        time.sleep(0.1)
 
249
        self.assertEquals(s, m.bytes)
 
250
        a.send(m, copy=True)
 
251
        time.sleep(0.1)
 
252
        self.assertEquals(s, m.bytes)
 
253
        a.send(m, copy=True)
 
254
        time.sleep(0.1)
 
255
        self.assertEquals(s, m.bytes)
 
256
        for i in range(4):
 
257
            r = b.recv()
 
258
            self.assertEquals(s,r)
 
259
        self.assertEquals(s, m.bytes)
 
260
    
 
261
    def test_buffer_numpy(self):
 
262
        """test non-copying numpy array messages"""
 
263
        try:
 
264
            import numpy
 
265
        except ImportError:
 
266
            raise SkipTest("NumPy unavailable")
 
267
        shapes = map(numpy.random.randint, [2]*5,[16]*5)
 
268
        for i in range(1,len(shapes)+1):
 
269
            shape = shapes[:i]
 
270
            A = numpy.random.random(shape)
 
271
            m = zmq.Message(A)
 
272
            self.assertEquals(A.data, m.buffer)
 
273
            B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape)
 
274
            self.assertEquals((A==B).all(), True)
 
275
    
 
276
    def test_memoryview(self):
 
277
        """test messages from memoryview (only valid for python >= 2.7)"""
 
278
        major,minor = sys.version_info[:2]
 
279
        if not (major >= 3 or (major == 2 and minor >= 7)):
 
280
            raise SkipTest
 
281
 
 
282
        s = 'carrotjuice'.encode()
 
283
        v = memoryview(s)
 
284
        m = zmq.Message(s)
 
285
        buf = m.buffer
 
286
        s2 = buf.tobytes()
 
287
        self.assertEquals(s2,s)
 
288
        self.assertEquals(m.bytes,s)
 
289
        
101
290