2
1
# -*- coding: utf8 -*-
4
# Copyright (c) 2010 Brian E. Granger
6
# This file is part of pyzmq.
8
# pyzmq is free software; you can redistribute it and/or modify it under
9
# the terms of the Lesser GNU General Public License as published by
10
# the Free Software Foundation; either version 3 of the License, or
11
# (at your option) any later version.
13
# pyzmq is distributed in the hope that it will be useful,
14
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16
# Lesser GNU General Public License for more details.
18
# You should have received a copy of the Lesser GNU General Public License
19
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2
#-----------------------------------------------------------------------------
3
# Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley
5
# This file is part of pyzmq
7
# Distributed under the terms of the New BSD License. The full license is in
8
# the file COPYING.BSD, distributed as part of this software.
9
#-----------------------------------------------------------------------------
22
11
#-----------------------------------------------------------------------------
34
23
from zmq.tests import BaseZMQTestCase, SkipTest
35
from zmq.utils.strtypes import unicode,bytes
24
from zmq.utils.strtypes import unicode, bytes, asbytes, b
25
from zmq.utils.rebuffer import array_from_buffer
37
27
#-----------------------------------------------------------------------------
39
29
#-----------------------------------------------------------------------------
43
class TestMessage(BaseZMQTestCase):
31
# some useful constants:
42
view_rc = grc(x) - rc0
44
class TestFrame(BaseZMQTestCase):
45
46
def test_above_30(self):
46
47
"""Message above 30 bytes are never copied by 0MQ."""
47
48
for i in range(5, 16): # 32, 64,..., 65536
49
50
self.assertEquals(grc(s), 2)
51
52
self.assertEquals(grc(s), 4)
53
54
self.assertEquals(grc(s), 2)
56
57
def test_str(self):
57
"""Test the str representations of the Messages."""
58
"""Test the str representations of the Frames."""
58
59
for i in range(16):
61
self.assertEquals(s, str(m).encode())
62
self.assertEquals(s, asbytes(m))
63
64
def test_bytes(self):
64
"""Test the Message.bytes property."""
65
"""Test the Frame.bytes property."""
65
66
for i in range(1,16):
69
70
self.assertEquals(s, m.bytes)
70
71
# check that it copies
73
74
self.assert_(b is m.bytes)
75
76
def test_unicode(self):
76
"""Test the unicode representations of the Messages."""
77
"""Test the unicode representations of the Frames."""
77
78
s = unicode('asdf')
78
self.assertRaises(TypeError, zmq.Message, s)
79
self.assertRaises(TypeError, zmq.Frame, s)
80
81
if str is not unicode:
81
82
u = u.decode('utf8')
82
83
for i in range(16):
84
m = zmq.Message(s.encode('utf8'))
85
m = zmq.Frame(s.encode('utf8'))
85
86
self.assertEquals(s, unicode(m.bytes,'utf8'))
87
88
def test_len(self):
88
"""Test the len of the Messages."""
89
"""Test the len of the Frames."""
89
90
for i in range(16):
92
93
self.assertEquals(len(s), len(m))
94
95
def test_lifecycle1(self):
95
96
"""Run through a ref counting cycle with a copy."""
100
97
for i in range(5, 16): # 32, 64,..., 65536
103
100
self.assertEquals(grc(s), rc)
106
103
self.assertEquals(grc(s), rc)
107
104
m2 = copy.copy(m)
109
106
self.assertEquals(grc(s), rc)
111
extra = int(isinstance(b,view))
112
# memoryview incs by 2
115
110
self.assertEquals(grc(s), rc)
117
self.assertEquals(s, str(m).encode())
118
self.assertEquals(s, str(m2).encode())
112
self.assertEquals(s, asbytes(str(m)))
113
self.assertEquals(s, asbytes(m2))
119
114
self.assertEquals(s, m.bytes)
120
115
# self.assert_(s is str(m))
121
116
# self.assert_(s is str(m2))
124
119
self.assertEquals(grc(s), rc)
127
122
self.assertEquals(grc(s), rc)
134
129
def test_lifecycle2(self):
135
130
"""Run through a different ref counting cycle with a copy."""
140
131
for i in range(5, 16): # 32, 64,..., 65536
143
134
self.assertEquals(grc(s), rc)
146
137
self.assertEquals(grc(s), rc)
147
138
m2 = copy.copy(m)
149
140
self.assertEquals(grc(s), rc)
151
extra = int(isinstance(b,view))
153
143
self.assertEquals(grc(s), rc)
154
self.assertEquals(s, str(m).encode())
155
self.assertEquals(s, str(m2).encode())
144
self.assertEquals(s, asbytes(str(m)))
145
self.assertEquals(s, asbytes(m2))
156
146
self.assertEquals(s, m2.bytes)
157
147
self.assertEquals(s, m.bytes)
158
148
# self.assert_(s is str(m))
173
163
def test_tracker(self):
174
m = zmq.Message('asdf'.encode(), track=True)
175
self.assertFalse(m.done)
164
m = zmq.Frame(b'asdf', track=True)
165
self.assertFalse(m.tracker.done)
176
166
pm = zmq.MessageTracker(m)
177
167
self.assertFalse(pm.done)
179
169
self.assertTrue(pm.done)
181
171
def test_no_tracker(self):
182
m = zmq.Message('asdf'.encode(), track=False)
183
self.assertRaises(ValueError, getattr, m, 'done')
172
m = zmq.Frame(b'asdf', track=False)
173
self.assertEquals(m.tracker, None)
184
174
m2 = copy.copy(m)
185
self.assertRaises(ValueError, getattr, m2, 'done')
175
self.assertEquals(m2.tracker, None)
186
176
self.assertRaises(ValueError, zmq.MessageTracker, m)
188
178
def test_multi_tracker(self):
189
m = zmq.Message('asdf'.encode(), track=True)
190
m2 = zmq.Message('whoda'.encode(), track=True)
179
m = zmq.Frame(b'asdf', track=True)
180
m2 = zmq.Frame(b'whoda', track=True)
191
181
mt = zmq.MessageTracker(m,m2)
192
self.assertFalse(m.done)
182
self.assertFalse(m.tracker.done)
193
183
self.assertFalse(mt.done)
194
184
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
204
194
def test_buffer_in(self):
205
195
"""test using a buffer as input"""
210
196
if unicode is str:
211
197
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode('utf8')
213
199
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√"
214
m = zmq.Message(view(ins))
200
m = zmq.Frame(view(ins))
216
202
def test_bad_buffer_in(self):
217
203
"""test using a bad object"""
218
self.assertRaises(TypeError, zmq.Message, 5)
219
self.assertRaises(TypeError, zmq.Message, object())
204
self.assertRaises(TypeError, zmq.Frame, 5)
205
self.assertRaises(TypeError, zmq.Frame, object())
221
207
def test_buffer_out(self):
222
208
"""receiving buffered output"""
227
209
if unicode is str:
228
210
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode('utf8')
230
212
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√"
233
215
self.assertTrue(isinstance(outb, view))
234
216
self.assert_(outb is m.buffer)
265
247
except ImportError:
266
raise SkipTest("NumPy unavailable")
267
shapes = map(numpy.random.randint, [2]*5,[16]*5)
248
raise SkipTest("numpy required")
249
rand = numpy.random.randint
250
shapes = [ rand(2,16) for i in range(5) ]
268
251
for i in range(1,len(shapes)+1):
269
252
shape = shapes[:i]
270
253
A = numpy.random.random(shape)
272
self.assertEquals(A.data, m.buffer)
273
B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape)
255
if view.__name__ == 'buffer':
256
self.assertEquals(A.data, m.buffer)
257
B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape)
259
self.assertEquals(memoryview(A), m.buffer)
260
B = numpy.array(m.buffer,dtype=A.dtype).reshape(A.shape)
274
261
self.assertEquals((A==B).all(), True)
276
263
def test_memoryview(self):
277
"""test messages from memoryview (only valid for python >= 2.7)"""
264
"""test messages from memoryview"""
278
265
major,minor = sys.version_info[:2]
279
266
if not (major >= 3 or (major == 2 and minor >= 7)):
267
raise SkipTest("memoryviews only in python >= 2.7")
282
s = 'carrotjuice'.encode()
283
270
v = memoryview(s)
286
273
s2 = buf.tobytes()
287
274
self.assertEquals(s2,s)
288
275
self.assertEquals(m.bytes,s)
277
def test_noncopying_recv(self):
278
"""check for clobbering message buffers"""
280
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
283
sb.send(null, copy=False)
284
m = sa.recv(copy=False)
290
ff=b'\xff'*(40 + i*10)
291
sb.send(ff, copy=False)
292
m2 = sa.recv(copy=False)
293
if view.__name__ == 'buffer':
297
self.assertEquals(b, null)
298
self.assertEquals(mb, null)
299
self.assertEquals(m2.bytes, ff)
301
def test_buffer_numpy(self):
302
"""test non-copying numpy array messages"""
306
raise SkipTest("requires numpy")
307
if sys.version_info < (2,7):
308
raise SkipTest("requires new-style buffer interface (py >= 2.7)")
309
rand = numpy.random.randint
310
shapes = [ rand(2,5) for i in range(5) ]
311
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
312
dtypes = [int, float, '>i4', 'B']
313
for i in range(1,len(shapes)+1):
316
A = numpy.ndarray(shape, dtype=dt)
317
while not (A < 1e400).all():
318
# don't let nan sneak in
319
A = numpy.ndarray(shape, dtype=dt)
320
a.send(A, copy=False)
321
msg = b.recv(copy=False)
323
B = array_from_buffer(msg, A.dtype, A.shape)
324
self.assertEquals(A.shape, B.shape)
325
self.assertTrue((A==B).all())
326
A = numpy.ndarray(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
329
A['c'] = 'hello there'
330
a.send(A, copy=False)
331
msg = b.recv(copy=False)
333
B = array_from_buffer(msg, A.dtype, A.shape)
334
self.assertEquals(A.shape, B.shape)
335
self.assertTrue((A==B).all())
337
def test_frame_more(self):
338
"""test Frame.more attribute"""
339
frame = zmq.Frame(b"hello")
340
self.assertFalse(frame.more)
341
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
342
sa.send_multipart([b'hi', b'there'])
343
frame = self.recv(sb, copy=False)
344
self.assertTrue(frame.more)
345
frame = self.recv(sb, copy=False)
346
self.assertFalse(frame.more)