~ubuntu-branches/ubuntu/raring/mayavi2/raring

« back to all changes in this revision

Viewing changes to mayavi/tools/old_mlab.py

  • Committer: Package Import Robot
  • Author(s): Varun Hiremath
  • Date: 2012-04-23 16:36:45 UTC
  • mfrom: (1.1.11) (2.2.6 sid)
  • Revision ID: package-import@ubuntu.com-20120423163645-wojak95rklqlbi1y
Tags: 4.1.0-1
* New upstream release
* Bump Standards-Version to 3.9.3

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
"""A simple wrapper around tvtk.tools.mlab suitable for MayaVi2!  This
2
 
is meant to be used from the embedded Python interpreter in MayaVi2 or
3
 
from IPython with the "-wthread" switch.
4
 
 
5
 
There are several test functions at the end of this file that are
6
 
illustrative to look at.
7
 
 
8
 
"""
9
 
 
10
 
# Author: Prabhu Ramachandran <prabhu_r@users.sf.net>
11
 
# Copyright (c) 2007, Enthought, Inc.
12
 
# License: BSD Style.
13
 
 
14
 
#TODO:  * Add optional scalars to plot3d
15
 
#       * Make streamline display colors by default
16
 
#       * Investigate why the old surf_regular seemed to give more beautiful
17
 
#         surfaces than the current surf. See for instance the difference
18
 
#         between test_surf_lattice and the old test_surf_regular
19
 
 
20
 
# Standard library imports.
21
 
import scipy
22
 
 
23
 
# Enthought library imports.
24
 
from envisage import get_application
25
 
from tvtk.api import tvtk
26
 
from tvtk.tools import mlab
27
 
from mayavi.modules.axes import Axes
28
 
from traits.api import HasTraits, Instance
29
 
from traitsui.api import View, Item, Group
30
 
 
31
 
# MayaVi related imports.
32
 
from mayavi.services import IMAYAVI
33
 
from mayavi.sources.vtk_data_source import VTKDataSource
34
 
from mayavi.filters.filter_base import FilterBase
35
 
from mayavi.modules.surface import Surface
36
 
from mayavi.modules.vectors import Vectors
37
 
from mayavi.modules.iso_surface import IsoSurface
38
 
from mayavi.modules.streamline import Streamline
39
 
from mayavi.modules.glyph import Glyph
40
 
from mayavi.modules.text import Text
41
 
from mayavi.app import Mayavi
42
 
from mayavi.core.source import Source
43
 
from mayavi.core.module import Module
44
 
from mayavi.core.module_manager import ModuleManager
45
 
from mayavi.sources.array_source import ArraySource
46
 
 
47
 
__all__ = ["scalarscatter", "vectorscatter", "scalarfield",
48
 
    "vectorfield", "isosurface", "vectors", "glyph", "streamline",
49
 
    "quiver3d", "points3d", "surf", "contour_surf", "imshow", "outline",
50
 
    "axes", "figure", "clf", "savefig", "xlabel", "ylabel", "zlabel",
51
 
    "title", "scalarbar", "vectorbar"]
52
 
 
53
 
 
54
 
######################################################################
55
 
# Application and mayavi instances.
56
 
 
57
 
application = get_application()
58
 
mayavi = None
59
 
if application is not None:
60
 
    mayavi = application.get_service(IMAYAVI)
61
 
 
62
 
 
63
 
######################################################################
64
 
# `ImageActor` class
65
 
 
66
 
# This should be added as a new MayaVi module.  It is here for testing
67
 
# and further improvements.
68
 
class ImageActor(Module):
69
 
 
70
 
    # An image actor.
71
 
    actor = Instance(tvtk.ImageActor, allow_none=False)
72
 
 
73
 
    view = View(Group(Item(name='actor', style='custom',
74
 
                           resizable=True),
75
 
                      show_labels=False),
76
 
                width=500,
77
 
                height=500,
78
 
                resizable=True)
79
 
 
80
 
    def setup_pipeline(self):
81
 
        self.actor = tvtk.ImageActor()
82
 
 
83
 
    def update_pipeline(self):
84
 
        """Override this method so that it *updates* the tvtk pipeline
85
 
        when data upstream is known to have changed.
86
 
        """
87
 
        mm = self.module_manager
88
 
        if mm is None:
89
 
            return
90
 
        src = mm.source
91
 
        self.actor.input = src.outputs[0]
92
 
        self.pipeline_changed = True
93
 
 
94
 
    def update_data(self):
95
 
        """Override this method so that it flushes the vtk pipeline if
96
 
        that is necessary.
97
 
        """
98
 
        # Just set data_changed, the component should do the rest.
99
 
        self.data_changed = True
100
 
 
101
 
    def _actor_changed(self, old, new):
102
 
        if old is not None:
103
 
            self.actors.remove(old)
104
 
        self.actors.append(new)
105
 
 
106
 
######################################################################
107
 
# Utility functions.
108
 
def _make_glyph_data(points, vectors=None, scalars=None):
109
 
    """Makes the data for glyphs using mlab.
110
 
    """
111
 
    g = mlab.Glyphs(points, vectors, scalars)
112
 
    return g.poly_data
113
 
 
114
 
def _make_default_figure():
115
 
    """Checks to see if a valid mayavi instance is running.  If not
116
 
    creates a new one.
117
 
    """
118
 
    global mayavi
119
 
    if mayavi is None or application.stopped is not None:
120
 
        fig = figure()
121
 
        mayavi = get_application().get_service(IMAYAVI)
122
 
    return mayavi
123
 
 
124
 
def _add_data(tvtk_data, name=''):
125
 
    """Add a TVTK data object `tvtk_data` to the mayavi pipleine.
126
 
    Give the object a name of `name`.
127
 
    """
128
 
    if isinstance(tvtk_data, tvtk.Object):
129
 
        d = VTKDataSource()
130
 
        d.data = tvtk_data
131
 
    elif isinstance(tvtk_data, Source):
132
 
        d = tvtk_data
133
 
    else:
134
 
        raise TypeError, \
135
 
              "first argument should be either a TVTK object"\
136
 
              " or a mayavi source"
137
 
 
138
 
    if len(name) > 0:
139
 
        d.name = name
140
 
    _make_default_figure()
141
 
    mayavi.add_source(d)
142
 
    return d
143
 
 
144
 
def _traverse(node):
145
 
    """Traverse a tree accessing the nodes children attribute.
146
 
    """
147
 
    try:
148
 
        for leaf in node.children:
149
 
            for leaflet in _traverse(leaf):
150
 
                yield leaflet
151
 
    except AttributeError:
152
 
        pass
153
 
    yield node
154
 
 
155
 
def _find_data(object):
156
 
    """Goes up the vtk pipeline to find the data sources of a given
157
 
    object.
158
 
    """
159
 
    if isinstance(object, ModuleManager):
160
 
        inputs = [object.source]
161
 
    elif hasattr(object, 'module_manager'):
162
 
        inputs = [object.module_manager.source]
163
 
    elif hasattr(object, 'data') or isinstance(object, ArraySource):
164
 
        inputs = [object]
165
 
    else:
166
 
        raise TypeError, 'Cannot find data source for given object'
167
 
    data_sources = []
168
 
    try:
169
 
        while True:
170
 
            input = inputs.pop()
171
 
            if hasattr(input, 'inputs'):
172
 
                inputs += input.inputs
173
 
            elif hasattr(input, 'image_data'):
174
 
                data_sources.append(input.image_data)
175
 
            else:
176
 
                data_sources.append(input.data)
177
 
    except IndexError:
178
 
        pass
179
 
    return data_sources
180
 
 
181
 
def _has_scalar_data(object):
182
 
    """Tests if an object has scalar data.
183
 
    """
184
 
    data_sources = _find_data(object)
185
 
    for source in data_sources:
186
 
        if source.point_data.scalars is not None:
187
 
            return True
188
 
        elif source.cell_data.scalars is not None:
189
 
            return True
190
 
    return False
191
 
 
192
 
def _has_vector_data(object):
193
 
    """Tests if an object has vector data.
194
 
    """
195
 
    data_sources = _find_data(object)
196
 
    for source in data_sources:
197
 
        if source.point_data.vectors is not None:
198
 
            return True
199
 
        elif source.cell_data.vectors is not None:
200
 
            return True
201
 
    return False
202
 
 
203
 
def _has_tensor_data(object):
204
 
    """Tests if an object has tensor data.
205
 
    """
206
 
    data_sources = _find_data(object)
207
 
    for source in data_sources:
208
 
        if source.point_data.tensors is not None:
209
 
            return True
210
 
        elif source.cell_data.tensors is not None:
211
 
            return True
212
 
    return False
213
 
 
214
 
def _find_module_manager(object=None, data_type=None):
215
 
    """If an object is specified, returns its module_manager, elsewhere finds
216
 
    the first module_manager in the scene.
217
 
    """
218
 
    if object is None:
219
 
        for object in _traverse(gcf()):
220
 
            if isinstance(object, ModuleManager):
221
 
                if ((data_type == 'scalar' and not _has_scalar_data(object))
222
 
                  or (data_type == 'vector' and not _has_vector_data(object))
223
 
                  or (data_type == 'tensor' and not _has_tensor_data(object))):
224
 
                    continue
225
 
                return object
226
 
        else:
227
 
            print("No object in the scene has a color map")
228
 
    else:
229
 
        if hasattr(object, 'module_manager'):
230
 
            if ((data_type == 'scalar' and _has_scalar_data(object))
231
 
               or (data_type == 'vector' and _has_vector_data(object))
232
 
               or (data_type == 'tensor' and _has_tensor_data(object))
233
 
                or data_type is None):
234
 
                return object.module_manager
235
 
            else:
236
 
                print("This object has no %s data" % data_type)
237
 
        else:
238
 
            print("This object has no color map")
239
 
    return None
240
 
 
241
 
def _orient_colorbar(colorbar, orientation):
242
 
    """Orients the given colorbar (make it horizontal or vertical).
243
 
    """
244
 
    if orientation == "vertical":
245
 
        colorbar.orientation = "vertical"
246
 
        colorbar.width = 0.1
247
 
        colorbar.height = 0.8
248
 
        colorbar.position = (0.01, 0.15)
249
 
    elif orientation == "horizontal":
250
 
        colorbar.orientation = "horizontal"
251
 
        colorbar.width = 0.8
252
 
        colorbar.height = 0.17
253
 
        colorbar.position = (0.1, 0.01)
254
 
    else:
255
 
        print "Unknown orientation"
256
 
    gcf().render()
257
 
 
258
 
def _typical_distance(data_obj):
259
 
    """ Returns a typical distance in a cloud of points.
260
 
        This is done by taking the size of the bounding box, and dividing it
261
 
        by the cubic root of the number of points.
262
 
    """
263
 
    x_min, x_max, y_min, y_max, z_min, z_max = data_obj.bounds
264
 
    distance = scipy.sqrt(((x_max-x_min)**2 + (y_max-y_min)**2 +
265
 
                           (z_max-z_min)**2)/(4*
266
 
                           data_obj.number_of_points**(0.33)))
267
 
    if distance == 0:
268
 
        return 1
269
 
    else:
270
 
        return 0.4*distance
271
 
 
272
 
######################################################################
273
 
# Data creation
274
 
 
275
 
def scalarscatter(*args, **kwargs):
276
 
    """
277
 
    Creates scattered scalar data.
278
 
 
279
 
    Function signatures
280
 
    -------------------
281
 
 
282
 
        scalarscatter(s, ...)
283
 
        scalarscatter(x, y, z, s, ...)
284
 
        scalarscatter(x, y, z, f, ...)
285
 
 
286
 
    If only 1 array s is passed the x, y and z arrays are assumed to be
287
 
    made from the indices of vectors.
288
 
 
289
 
    If 4 positional arguments are passed the last one must be an array s, or
290
 
    a callable, f, that returns an array.
291
 
 
292
 
    Arguments
293
 
    ---------
294
 
 
295
 
        x -- x coordinates of the points of the mesh (optional).
296
 
 
297
 
        y -- y coordinates of the points of the mesh (optional).
298
 
 
299
 
        z -- z coordinates of the points of the mesh (optional).
300
 
 
301
 
        s -- scalar value
302
 
 
303
 
        f -- callable that is used to build the scalar data (only if 4
304
 
             positional arguments are passed).
305
 
 
306
 
    Keyword arguments
307
 
    -----------------
308
 
 
309
 
        name -- The name of the vtk object created. Default: 'Scattered scalars'
310
 
 
311
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
312
 
                   Default is the x, y, z arrays extent.
313
 
 
314
 
    """
315
 
    if len(args)==1:
316
 
        s = args[0]
317
 
        x, y, z = scipy.indices(s.shape)
318
 
    elif len(args)==4:
319
 
        x, y, z, s = args
320
 
        if callable(s):
321
 
            s = f(x, y, z)
322
 
    else:
323
 
        raise ValueError, "wrong number of arguments"
324
 
 
325
 
    assert ( x.shape == y.shape and
326
 
             y.shape == z.shape and
327
 
             s.shape == z.shape ), "argument shape are not equal"
328
 
 
329
 
    if 'extent' in kwargs:
330
 
        xmin, xmax, ymin, ymax, zmin, zmax = kwargs.pop('extent')
331
 
        x = xmin + x*(xmax - xmin)/float(x.max() - x.min()) -x.min()
332
 
        y = ymin + y*(ymax - ymin)/float(y.max() - y.min()) -y.min()
333
 
        z = zmin + z*(zmax - zmin)/float(z.max() - z.min()) -z.min()
334
 
 
335
 
    points = scipy.c_[x.ravel(), y.ravel(), z.ravel()]
336
 
    scalars = s.ravel()
337
 
    name = kwargs.pop('name', 'Scattered scalars')
338
 
 
339
 
    data = _make_glyph_data(points, None, scalars)
340
 
    data_obj = _add_data(data, name)
341
 
    return data_obj
342
 
 
343
 
def vectorscatter(*args, **kwargs):
344
 
    """
345
 
    Creates scattered vector data.
346
 
 
347
 
    Function signatures
348
 
    -------------------
349
 
 
350
 
        vectorscatter(u, v, w, ...)
351
 
        vectorscatter(x, y, z, u, v, w, ...)
352
 
        vectorscatter(x, y, z, f, ...)
353
 
 
354
 
    If only 3 arrays u, v, w are passed the x, y and z arrays are assumed to be
355
 
    made from the indices of vectors.
356
 
 
357
 
    If 4 positional arguments are passed the last one must be a callable, f,
358
 
    that returns vectors.
359
 
 
360
 
    Arguments
361
 
    ---------
362
 
 
363
 
        x -- x coordinates of the points of the mesh (optional).
364
 
 
365
 
        y -- y coordinates of the points of the mesh (optional).
366
 
 
367
 
        z -- z coordinates of the points of the mesh (optional).
368
 
 
369
 
        u -- x coordinnate of the vector field
370
 
 
371
 
        v -- y coordinnate of the vector field
372
 
 
373
 
        w -- z coordinnate of the vector field
374
 
 
375
 
        f -- callable that is used to build the vector field (only if 4
376
 
             positional arguments are passed).
377
 
 
378
 
    Keyword arguments
379
 
    -----------------
380
 
 
381
 
        name -- The name of the vtk object created. Default: 'Scattered vector'
382
 
 
383
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
384
 
                   Default is the x, y, z arrays extent.
385
 
 
386
 
        scalars -- The scalars associated to the vectors. Defaults to none.
387
 
 
388
 
    """
389
 
    if len(args)==3:
390
 
        u, v, w = args
391
 
        x, y, z = scipy.indices(u.shape)
392
 
    elif len(args)==6:
393
 
        x, y, z, u, v, w = args
394
 
    elif len(args)==4:
395
 
        x, y, z, f = args
396
 
        assert callable(f), "when used with 4 arguments last argument must be callable"
397
 
        u, v, w = f(x, y, z)
398
 
    else:
399
 
        raise ValueError, "wrong number of arguments"
400
 
 
401
 
    assert ( x.shape == y.shape and
402
 
             y.shape == z.shape and
403
 
             z.shape == u.shape and
404
 
             u.shape == v.shape and
405
 
             v.shape == w.shape ), "argument shape are not equal"
406
 
 
407
 
    if 'extent' in kwargs:
408
 
        xmin, xmax, ymin, ymax, zmin, zmax = kwargs.pop('extent')
409
 
        x = xmin + x*(xmax - xmin)/float(x.max() - x.min()) -x.min()
410
 
        y = ymin + y*(ymax - ymin)/float(y.max() - y.min()) -y.min()
411
 
        z = zmin + z*(zmax - zmin)/float(z.max() - z.min()) -z.min()
412
 
 
413
 
    points = scipy.c_[x.ravel(), y.ravel(), z.ravel()]
414
 
    vectors = scipy.c_[u.ravel(), v.ravel(), w.ravel()]
415
 
    if 'scalars' in kwargs:
416
 
        scalars = kwargs['scalars'].ravel()
417
 
    else:
418
 
        scalars = None
419
 
    name = kwargs.pop('name', 'Scattered vectors')
420
 
 
421
 
    data = _make_glyph_data(points, vectors, scalars)
422
 
    data_obj = _add_data(data, name)
423
 
    return data_obj
424
 
 
425
 
def scalarfield(*args, **kwargs):
426
 
    """
427
 
    Creates a scalar field data.
428
 
 
429
 
    Function signatures
430
 
    -------------------
431
 
 
432
 
        scalarfield(s, ...)
433
 
        scalarfield(x, y, z, s, ...)
434
 
        scalarfield(x, y, z, f, ...)
435
 
 
436
 
    If only 1 array s is passed the x, y and z arrays are assumed to
437
 
    be made from the indices of the s array.
438
 
 
439
 
    If the x, y and z arrays are passed they are supposed to have been
440
 
    generated by  `numpy.mgrid`.  The function builds a scalar field assuming
441
 
    the points are regularly spaced.
442
 
 
443
 
    Arguments
444
 
    ---------
445
 
 
446
 
        x -- x coordinates of the points of the mesh (optional).
447
 
 
448
 
        y -- y coordinates of the points of the mesh (optional).
449
 
 
450
 
        z -- z coordinates of the points of the mesh (optional).
451
 
 
452
 
        s -- scalar values.
453
 
 
454
 
        f -- callable that is used to build the scalar field (only if 4
455
 
             positional arguments are passed).
456
 
 
457
 
    Keyword arguments
458
 
    -----------------
459
 
 
460
 
        name -- The name of the vtk object created. Default: 'Scalar field'
461
 
 
462
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
463
 
                   Default is the x, y, z arrays extent.
464
 
    """
465
 
 
466
 
    # Get the keyword args.
467
 
    name = kwargs.get('name', 'Scalar field')
468
 
 
469
 
    if len(args)==1:
470
 
        s = args[0]
471
 
        x, y, z = scipy.indices(s.shape)
472
 
    elif len(args)==4:
473
 
        x, y, z, s = args
474
 
        if callable(s):
475
 
            s = f(x, y, z)
476
 
    else:
477
 
        raise ValueError, "wrong number of arguments"
478
 
 
479
 
    assert ( x.shape == y.shape and
480
 
             y.shape == z.shape and
481
 
             s.shape == z.shape ), "argument shape are not equal"
482
 
 
483
 
    if 'extent' in kwargs:
484
 
        xmin, xmax, ymin, ymax, zmin, zmax = kwargs.pop('extent')
485
 
        x = xmin + x*(xmax - xmin)/float(x.max() - x.min()) -x.min()
486
 
        y = ymin + y*(ymax - ymin)/float(y.max() - y.min()) -y.min()
487
 
        z = zmin + z*(zmax - zmin)/float(z.max() - z.min()) -z.min()
488
 
 
489
 
    points = scipy.c_[x.ravel(), y.ravel(), z.ravel()]
490
 
    dx = x[1, 0, 0] - x[0, 0, 0]
491
 
    dy = y[0, 1, 0] - y[0, 0, 0]
492
 
    dz = z[0, 0, 1] - z[0, 0, 0]
493
 
 
494
 
    data = ArraySource(scalar_data=s,
495
 
                      origin=[points[0, 0], points[0, 1], points[0, 2]],
496
 
                      spacing=[dx, dy, dz])
497
 
    data_obj = _add_data(data, name)
498
 
    return data_obj
499
 
 
500
 
def vectorfield(*args, **kwargs):
501
 
    """
502
 
    Creates a vector field data.
503
 
 
504
 
    Function signatures
505
 
    -------------------
506
 
 
507
 
        vectorfield(u, v, w, ...)
508
 
        vectorfield(x, y, z, u, v, w, ...)
509
 
        vectorfield(x, y, z, f, ...)
510
 
 
511
 
    If only 3 arrays u, v, w are passed the x, y and z arrays are assumed to
512
 
    be made from the indices of the u, v, w arrays.
513
 
 
514
 
    If the x, y and z arrays are passed they are supposed to have been
515
 
    generated by  `numpy.mgrid`.  The function builds a vector field assuming
516
 
    the points are regularly spaced.
517
 
 
518
 
    Arguments
519
 
    ---------
520
 
 
521
 
        x -- x coordinates of the points of the mesh (optional).
522
 
 
523
 
        y -- y coordinates of the points of the mesh (optional).
524
 
 
525
 
        z -- z coordinates of the points of the mesh (optional).
526
 
 
527
 
        u -- x coordinnate of the vector field
528
 
 
529
 
        v -- y coordinnate of the vector field
530
 
 
531
 
        w -- z coordinnate of the vector field
532
 
 
533
 
        f -- callable that is used to build the vector field (only if 4
534
 
             positional arguments are passed).
535
 
 
536
 
    Keyword arguments
537
 
    -----------------
538
 
 
539
 
        name -- The name of the vtk object created. Default: 'Vector field'
540
 
 
541
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
542
 
                   Default is the x, y, z arrays extent.
543
 
 
544
 
        scalars -- The scalars associated to the vectors. Defaults to none.
545
 
 
546
 
        transpose_vectors -- If the additional argument
547
 
                             transpose_vectors is passed, then the
548
 
                             input vectors array is suitably
549
 
                             transposed.  By default transpose_vectors
550
 
                             is True so that the array is in the
551
 
                             correct format that VTK expects.
552
 
                             However, a transposed array is not
553
 
                             contiguous and thus a copy is made, this
554
 
                             also means that any changes to the users
555
 
                             input array will will not be reflected in
556
 
                             the renderered object (provided you know
557
 
                             how to do this).  Thus, sometimes users
558
 
                             might want to provide already transposed
559
 
                             data suitably formatted.  In these cases
560
 
                             one should set transpose_vectors to
561
 
                             False.
562
 
                             Default value: True
563
 
    """
564
 
    # Get the keyword args.
565
 
    transpose_vectors = kwargs.get('transpose_vectors', True)
566
 
 
567
 
    if len(args)==3:
568
 
        u, v, w = args
569
 
        x, y, z = scipy.indices(u.shape)
570
 
    elif len(args)==6:
571
 
        x, y, z, u, v, w = args
572
 
    elif len(args)==4:
573
 
        x, y, z, f = args
574
 
        assert callable(f), "when used with 4 arguments last argument must be callable"
575
 
        u, v, w = f(x, y, z)
576
 
    else:
577
 
        raise ValueError, "wrong number of arguments"
578
 
 
579
 
    assert ( x.shape == y.shape and
580
 
             y.shape == z.shape and
581
 
             z.shape == u.shape and
582
 
             u.shape == v.shape and
583
 
             v.shape == w.shape ), "argument shape are not equal"
584
 
 
585
 
    if 'extent' in kwargs:
586
 
        xmin, xmax, ymin, ymax, zmin, zmax = kwargs.pop('extent')
587
 
        x = xmin + x*(xmax - xmin)/float(x.max() - x.min()) -x.min()
588
 
        y = ymin + y*(ymax - ymin)/float(y.max() - y.min()) -y.min()
589
 
        z = zmin + z*(zmax - zmin)/float(z.max() - z.min()) -z.min()
590
 
 
591
 
    points = scipy.c_[x.ravel(), y.ravel(), z.ravel()]
592
 
    vectors = scipy.concatenate([u[..., scipy.newaxis],
593
 
                                 v[..., scipy.newaxis],
594
 
                                 w[..., scipy.newaxis] ],
595
 
                                 axis=3)
596
 
    if 'scalars' in kwargs:
597
 
        scalars = kwargs['scalars']
598
 
    else:
599
 
        scalars = None
600
 
    name = kwargs.pop('name', 'Vector field')
601
 
    dx = x[1, 0, 0] - x[0, 0, 0]
602
 
    dy = y[0, 1, 0] - y[0, 0, 0]
603
 
    dz = z[0, 0, 1] - z[0, 0, 0]
604
 
 
605
 
    if not transpose_vectors:
606
 
        vectors.shape = vectors.shape[::-1]
607
 
    data = ArraySource(transpose_input_array=transpose_vectors,
608
 
                      vector_data=vectors,
609
 
                      scalar_data=scalars,
610
 
                      origin=[points[0, 0], points[0, 1], points[0, 2]],
611
 
                      spacing=[dx, dy, dz])
612
 
    data_obj = _add_data(data, name)
613
 
    return data_obj
614
 
 
615
 
######################################################################
616
 
# Module creation
617
 
 
618
 
def isosurface(data_obj, name='IsoSurface', transparent=True,
619
 
                    contours=5):
620
 
    """ Applies the Iso-Surface mayavi module to the given VTK data object.
621
 
    """
622
 
    iso = IsoSurface()
623
 
 
624
 
    # Check what type the 'contours' are and do whatever is needed.
625
 
    contour_list = True
626
 
    try:
627
 
        len(contours)
628
 
    except TypeError:
629
 
        contour_list = False
630
 
 
631
 
    if contour_list:
632
 
        iso.contour.contours = contours
633
 
    else:
634
 
        assert type(contours) == int, "The contours argument must be an integer"
635
 
        assert contours > 1, "The contours argument must be positivee"
636
 
        iso.contour.set(auto_contours=True,
637
 
                            number_of_contours=contours)
638
 
 
639
 
    mayavi.add_module(iso, obj=data_obj)
640
 
 
641
 
    if transparent:
642
 
        data_range = iso.module_manager.scalar_lut_manager.data_range
643
 
        iso.module_manager.scalar_lut_manager.lut.alpha_range = \
644
 
                (0.2, 0.8)
645
 
        data_range = ( scipy.mean(data_range)
646
 
                        + 0.4 * ( data_range.max() - data_range.min())
647
 
                               * scipy.array([-1, 1]))
648
 
        iso.scene.render()
649
 
 
650
 
    return iso
651
 
 
652
 
def vectors(data_obj, color=None,  name='Vectors', mode='2d',
653
 
                scale_factor=1.):
654
 
    """ Applies the Vectors mayavi module to the given VTK data object.
655
 
    """
656
 
    v = Vectors(name=name)
657
 
    mayavi.add_module(v, obj=data_obj)
658
 
    mode_map = {'2d': 0, 'arrow': 1, 'cone': 2, 'cylinder': 3,
659
 
                'sphere': 4, 'cube': 5, 'point': 6}
660
 
    if mode == 'point':
661
 
        v.glyph.glyph_source = tvtk.PointSource(radius=0,
662
 
                                                number_of_points=1)
663
 
    else:
664
 
         v.glyph.glyph_source = v.glyph.glyph_list[mode_map[mode]]
665
 
    if color:
666
 
        v.glyph.color_mode = 'no_coloring'
667
 
        v.actor.property.color = color
668
 
    elif _has_scalar_data(data_obj) :
669
 
        v.glyph.color_mode = 'color_by_scalar'
670
 
    else:
671
 
        v.glyph.color_mode = 'color_by_vector'
672
 
    v.glyph.glyph.scale_factor = scale_factor
673
 
    return v
674
 
 
675
 
def glyph(data_obj, color=None, name='Glyph', mode='sphere',
676
 
            scale_factor=1.):
677
 
    """ Applies the Glyph mayavi module to the given VTK data object.
678
 
    """
679
 
    g = Glyph(name=name)
680
 
    mayavi.add_module(g, obj=data_obj)
681
 
    mode_map = {'2d': 0, 'arrow': 1, 'cone': 2, 'cylinder': 3,
682
 
                'sphere': 4, 'cube': 5, 'point': 6}
683
 
    if mode == 'point':
684
 
        g.glyph.glyph_source = tvtk.PointSource(radius=0,
685
 
                                                number_of_points=1)
686
 
    else:
687
 
         g.glyph.glyph_source = g.glyph.glyph_list[mode_map[mode]]
688
 
    if color:
689
 
        g.actor.property.color = color
690
 
    if _has_scalar_data(data_obj) :
691
 
        g.glyph.color_mode = 'color_by_scalar'
692
 
        g.glyph.scale_mode = 'scale_by_scalar'
693
 
    g.glyph.glyph.scale_factor = scale_factor
694
 
    return g
695
 
 
696
 
#FIXME : just started this procedure !! Need to modify the color so that if
697
 
# none it warps a scalar. Need to add a kwarg for the source.
698
 
def streamline(data_obj, color=None,  name='Streamline', ):
699
 
    """ Applies the Streamline mayavi module to the given VTK data object.
700
 
    """
701
 
    st = Streamline()
702
 
    mayavi.add_module(st, obj=data_obj)
703
 
    if color:
704
 
        st.actor.property.color = color
705
 
    elif _has_scalar_data(data_obj) :
706
 
        st.actor.mapper.scalar_visibility = True
707
 
    else:
708
 
        st.actor.mapper.interpolate_scalars_before_mapping = True
709
 
        st.actor.mapper.scalar_visibility = True
710
 
    return st
711
 
 
712
 
######################################################################
713
 
# Helper functions
714
 
 
715
 
def quiver3d(*args, **kwargs):
716
 
    """
717
 
    Plots glyphs (like arrows) indicating the direction of the vectors
718
 
    for a 3D volume of data supplied as arguments.
719
 
 
720
 
    Function signatures
721
 
    -------------------
722
 
 
723
 
        quiver3d(vectordata, ...)
724
 
        quiver3d(u, v, w, ...)
725
 
        quiver3d(x, y, z, u, v, w, ...)
726
 
        quiver3d(x, y, z, f, ...)
727
 
 
728
 
    If only one positional argument is passed, it should be VTK data
729
 
    object with vector data.
730
 
 
731
 
    If only 3 arrays u, v, w are passed the x, y and z arrays are assumed to be
732
 
    made from the indices of vectors.
733
 
 
734
 
    If 4 positional arguments are passed the last one must be a callable, f,
735
 
    that returns vectors.
736
 
 
737
 
    Arguments
738
 
    ---------
739
 
 
740
 
        vectordata -- VTK data object with vector data, such as created
741
 
                      by vectorscatter of vectorfield.
742
 
 
743
 
        x -- x coordinates of the points of the mesh (optional).
744
 
 
745
 
        y -- y coordinates of the points of the mesh (optional).
746
 
 
747
 
        z -- z coordinates of the points of the mesh (optional).
748
 
 
749
 
        u -- x coordinnate of the vector field
750
 
 
751
 
        v -- y coordinnate of the vector field
752
 
 
753
 
        w -- z coordinnate of the vector field
754
 
 
755
 
        f -- callable that is used to build the vector field (only if 4
756
 
             positional arguments are passed).
757
 
 
758
 
    Keyword arguments
759
 
    -----------------
760
 
 
761
 
        name -- The name of the vtk object created. Default: 'Quiver3D'
762
 
 
763
 
        mode -- This should be one of ['2d' (2d arrows),
764
 
                'arrow', 'cone', 'cylinder', 'sphere', 'cube',
765
 
                'point'] and depending on what is passed shows an
766
 
                appropriate glyph.  It defaults to a 3d arrow.
767
 
 
768
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
769
 
                   Default is the x, y, z arrays extent.
770
 
 
771
 
        scalars -- The scalars used to display the color of the glyphs.
772
 
                  Defaults to the norm of the vectors.
773
 
 
774
 
        color -- The color of the glyphs in the absence of scalars.
775
 
                 Default: (1., 0., 0.)
776
 
 
777
 
        autoscale -- Autoscale the size of the glyph.
778
 
                     Default: True
779
 
 
780
 
        scale_factor -- Default 1
781
 
    """
782
 
    if len(args)==1:
783
 
        data_obj = args[0]
784
 
    else:
785
 
        data_kwargs = kwargs.copy()
786
 
        data_kwargs.pop('name','')
787
 
        if len(args)==6:
788
 
            data_obj = vectorscatter(*args, **data_kwargs)
789
 
        else:
790
 
            data_obj = vectorfield(*args, **data_kwargs)
791
 
 
792
 
    if not 'name' in kwargs:
793
 
            kwargs['name'] = 'Quiver3D'
794
 
 
795
 
    if not 'mode' in kwargs:
796
 
            kwargs['mode'] = 'arrow'
797
 
 
798
 
    if not 'autoscale' in kwargs or kwargs['autoscale']:
799
 
        scale_factor = kwargs.get('scale_factor', 1.)
800
 
        kwargs['scale_factor'] = (scale_factor *
801
 
                        _typical_distance(_find_data(data_obj)[0]) )
802
 
    kwargs.pop('autoscale', '')
803
 
 
804
 
    return vectors(data_obj, **kwargs)
805
 
 
806
 
def points3d(*args, **kwargs):
807
 
    """
808
 
    Plots glyphs (like points) at the position of the supplied data.
809
 
 
810
 
    Function signatures
811
 
    -------------------
812
 
 
813
 
        points3d(scalardata, ...)
814
 
        points3d(x, y, z...)
815
 
        points3d(x, y, z, s, ...)
816
 
        points3d(x, y, z, f, ...)
817
 
 
818
 
    If only one positional argument is passed, it should be VTK data
819
 
    object with scalar data.
820
 
 
821
 
    If only 3 arrays x, y, z all the points are drawn with the same size
822
 
    and color
823
 
 
824
 
    If 4 positional arguments are passed the last one can be an array s
825
 
    or a callable f that gives the size and color of the glyph.
826
 
 
827
 
    Arguments
828
 
    ---------
829
 
 
830
 
        scalardata -- VTK data object with scalar data, such as created
831
 
                      by scalarscatter.
832
 
 
833
 
        x -- x coordinates of the points.
834
 
 
835
 
        y -- y coordinates of the points.
836
 
 
837
 
        z -- z coordinates of the points.
838
 
 
839
 
        s -- array giving the color and size of the glyphs (optional).
840
 
 
841
 
        f -- callable that returns the scalar associated with the points
842
 
             as a function of position.
843
 
 
844
 
    Keyword arguments
845
 
    -----------------
846
 
 
847
 
        name -- The name of the vtk object created. Default: 'Points3D'
848
 
 
849
 
        mode -- This should be one of ['2d' (2d arrows),
850
 
                'arrow', 'cone', 'cylinder', 'sphere', 'cube',
851
 
                'point'] and depending on what is passed shows an
852
 
                appropriate glyph.  It defaults to a sphere.
853
 
 
854
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
855
 
                   Default is the x, y, z arrays extent.
856
 
 
857
 
        scalars -- The scalars used to display the color of the glyphs.
858
 
 
859
 
        color -- The color of the glyphs. Overrides the scalar array.
860
 
                 Default: (1., 0., 0.).
861
 
 
862
 
        autoscale -- Autoscale the size of the glyph.
863
 
                     Default: True
864
 
 
865
 
        scale_factor -- Default 1
866
 
    """
867
 
    if len(args)==1:
868
 
        data_obj = args[0]
869
 
    else:
870
 
        data_kwargs = kwargs.copy()
871
 
        data_kwargs.pop('name','')
872
 
        if len(args)==4:
873
 
            x, y, z, s = args
874
 
            if callable(s):
875
 
                s = s(x, y, z)
876
 
        else:
877
 
            x, y, z = args
878
 
            s = scipy.ones(x.shape)
879
 
        data_obj = scalarscatter(x, y, z, s, **data_kwargs)
880
 
 
881
 
    if not 'name' in kwargs:
882
 
            kwargs['name'] = 'Points3D'
883
 
 
884
 
    if not 'mode' in kwargs:
885
 
            kwargs['mode'] = 'sphere'
886
 
 
887
 
    if not 'autoscale' in kwargs or kwargs['autoscale']:
888
 
        scale_factor = kwargs.get('scale_factor', 1.)
889
 
        kwargs['scale_factor'] = (0.6* scale_factor *
890
 
                        _typical_distance(_find_data(data_obj)[0]) )
891
 
    kwargs.pop('autoscale', '')
892
 
 
893
 
    g = glyph(data_obj, **kwargs)
894
 
    if len(args)==3:
895
 
        g.glyph.scale_mode = 'data_scaling_off'
896
 
    if 'color' in kwargs:
897
 
        g.glyph.color_mode = 'no_coloring'
898
 
    return g
899
 
 
900
 
def contour3d(*args, **kwargs):
901
 
    """
902
 
    Plots iso-surfaces for a 3D volume of data suplied as arguments.
903
 
 
904
 
    Function signatures
905
 
    -------------------
906
 
 
907
 
        contour3d(scalars, ...)
908
 
        contour3d(scalarfield, ...)
909
 
 
910
 
    Arguments
911
 
    ---------
912
 
 
913
 
        scalars -- A 3D array giving the value of the scalar on a grid.
914
 
 
915
 
        scalarfield -- VTK data object with scalar field data, such as
916
 
                       created by scalarfield.
917
 
 
918
 
 
919
 
    Keyword arguments
920
 
    -----------------
921
 
 
922
 
        name -- The name of the vtk object created. Default: 'Contour3D'
923
 
 
924
 
        transpose_scalars -- If the additional argument
925
 
                             transpose_scalars is passed, then the
926
 
                             input scalar array is suitably
927
 
                             transposed.  By default transpose_scalars
928
 
                             is True so that the array is in the
929
 
                             correct format that VTK expects.
930
 
                             However, a transposed array is not
931
 
                             contiguous and thus a copy is made, this
932
 
                             also means that any changes to the users
933
 
                             input array will will not be reflected in
934
 
                             the renderered object (provided you know
935
 
                             how to do this).  Thus, sometimes users
936
 
                             might want to provide already transposed
937
 
                             data suitably formatted.  In these cases
938
 
                             one should set transpose_scalars to
939
 
                             False.
940
 
                             Default value: True
941
 
 
942
 
        contours -- Integer/list specifying number/list of
943
 
                    iso-surfaces. Specifying 0 shows no contours.
944
 
                    Specifying a list of values will only give the
945
 
                    requested contours asked for.  Default: 3
946
 
 
947
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
948
 
                   Default is the shape of the scalars
949
 
 
950
 
        transparent -- Boolean to make the opacity of the isosurfaces depend
951
 
                       on the scalar.  Default: True
952
 
    """
953
 
    if len(args)==1:
954
 
        if hasattr(args[0], 'shape'):
955
 
            scalars = args[0]
956
 
            assert len(scalars.shape) == 3, 'Only 3D arrays are supported.'
957
 
            data_kwargs = kwargs.copy()
958
 
            data_kwargs.pop('contours', '')
959
 
            data_kwargs.pop('transparent', '')
960
 
            if not 'name' in kwargs:
961
 
                data_kwargs['name'] = 'Contour3D'
962
 
            data_obj = scalarfield(scalars, **data_kwargs)
963
 
        else:
964
 
            data_obj = args[0]
965
 
    else:
966
 
        raise TypeError, "contour3d takes only one argument"
967
 
 
968
 
    # Remove extra kwargs that are not needed for the iso-surface.
969
 
    kwargs.pop('extent', '')
970
 
    kwargs.pop('name', '')
971
 
 
972
 
    return isosurface(data_obj, **kwargs)
973
 
 
974
 
######################################################################
975
 
# The mlab functionality.
976
 
 
977
 
def plot3d(x, y, z, radius=0.01, use_tubes=True, color=(1., 0., 0.) ,
978
 
          name='Plot3D'):
979
 
    """Draws lines between points.
980
 
 
981
 
    Arguments
982
 
    ---------
983
 
 
984
 
    x -- x coordinates of the points of the line
985
 
 
986
 
    y -- y coordinates of the points of the line
987
 
 
988
 
    z -- z coordinates of the points of the line
989
 
 
990
 
    Keyword arguments
991
 
    -----------------
992
 
 
993
 
    color -- color of the line. Default: (1., 0., 0.)
994
 
 
995
 
    use_tubes -- Enables the drawing of the lines as tubes. Default: True
996
 
 
997
 
    radius -- radius of the tubes created. Default: 0.01
998
 
 
999
 
    name -- The name of the vtk object created. Default: 'Plot3D'
1000
 
 
1001
 
    """
1002
 
    assert ( x.shape == y.shape and
1003
 
             y.shape == z.shape ), "argument shape are not equal"
1004
 
 
1005
 
    points = scipy.c_[x, y, z]
1006
 
    np = len(points) - 1
1007
 
    lines = scipy.zeros((np, 2), 'l')
1008
 
    lines[:,0] = scipy.arange(0, np-0.5, 1, 'l')
1009
 
    lines[:,1] = scipy.arange(1, np+0.5, 1, 'l')
1010
 
    pd = tvtk.PolyData(points=points, lines=lines)
1011
 
    _add_data(pd, name)
1012
 
    if use_tubes:
1013
 
        filter = tvtk.TubeFilter(number_of_sides=6)
1014
 
        filter.radius = radius
1015
 
        f = FilterBase(filter=filter, name='TubeFilter')
1016
 
        mayavi.add_filter(f)
1017
 
    s = Surface()
1018
 
    s.actor.mapper.scalar_visibility = False
1019
 
    mayavi.add_module(s)
1020
 
    s.actor.property.color = color
1021
 
    return s
1022
 
 
1023
 
def surf(*args, **kwargs):
1024
 
    """
1025
 
    Plots a surface using grid spaced data supplied as 2D arrays.
1026
 
 
1027
 
    Function signatures
1028
 
    -------------------
1029
 
 
1030
 
        surf(z, scalars=z, ...)
1031
 
        surf(x, y, z, scalars=z, ...)
1032
 
 
1033
 
    If only one array z is passed the x and y arrays are assumed to be made
1034
 
    of the indices of z.
1035
 
    z is the elevation matrix. If no `scalars` argument is passed the color
1036
 
    of the surface also represents the z matrix. Setting the `scalars` argument
1037
 
    to None prevents this.
1038
 
 
1039
 
    Arguments
1040
 
    ---------
1041
 
 
1042
 
        x -- x coordinates of the points of the mesh (optional).
1043
 
 
1044
 
        y -- y coordinates of the points of the mesh (optional).
1045
 
 
1046
 
        z -- A 2D array giving the elevation of the mesh.
1047
 
             Also will work if z is a callable which supports x and y arrays
1048
 
             as the arguments, but x and y must then be supplied.
1049
 
 
1050
 
    Keyword arguments
1051
 
    -----------------
1052
 
 
1053
 
        extent --  [xmin, xmax, ymin, ymax, zmin, zmax]
1054
 
                   Default is the x, y, z arrays extent.
1055
 
 
1056
 
        scalars -- An array of the same shape as z that gives the color of the
1057
 
                   surface. This can also be a function that takes x and
1058
 
                   y as arguments.
1059
 
 
1060
 
        represention -- can be 'surface', 'wireframe', 'points', or 'mesh'
1061
 
                        Default is 'surface'
1062
 
 
1063
 
        color -- The color of the mesh in the absence of scalars.
1064
 
 
1065
 
        name -- The name of the vtk object created. Default is "Surf"
1066
 
    """
1067
 
    if len(args)==1:
1068
 
        z = args[0]
1069
 
        x, y = scipy.indices(z.shape)
1070
 
    else:
1071
 
        x, y, z = args
1072
 
    if callable(z):
1073
 
        z = z(x, y)
1074
 
    if not 'scalars' in kwargs:
1075
 
        kwargs['scalars'] = z
1076
 
    if callable(kwargs['scalars']):
1077
 
        kwargs['scalars'] = kwargs['scalars'](x, y)
1078
 
    if 'color' in kwargs and kwargs['color']:
1079
 
        kwargs['scalar_visibility'] = False
1080
 
    if 'extent' in kwargs:
1081
 
        xmin, xmax, ymin, ymax, zmin, zmax = kwargs.pop('extent')
1082
 
        x = xmin + x*(xmax - xmin)/float(x.max() - x.min()) -x.min()
1083
 
        y = ymin + y*(ymax - ymin)/float(y.max() - y.min()) -y.min()
1084
 
        z = zmin + z*(zmax - zmin)/float(z.max() - z.min()) -z.min()
1085
 
    return _surf(x, y, z, **kwargs)
1086
 
 
1087
 
def _surf(x, y, z, scalars=None, scalar_visibility=True,
1088
 
          color=(0.5, 1.0, 0.5), name='Surf', representation='surface',
1089
 
          tube_radius=0.05, sphere_radius=0.05, ):
1090
 
    """ Functions that does the work for "surf". It is called with the right
1091
 
        number of arguments after the "surf" fonction does the magic to
1092
 
        translate the user-supplied arguments into something this function
1093
 
        understands. """
1094
 
    triangles, points = mlab.make_triangles_points(x, y, z, scalars)
1095
 
    data = mlab.make_triangle_polydata(triangles, points, scalars)
1096
 
    _add_data(data, name)
1097
 
 
1098
 
    if representation == 'mesh':
1099
 
        # Extract the edges.
1100
 
        ef = tvtk.ExtractEdges()
1101
 
        extract_edges = FilterBase(filter=ef, name='ExtractEdges')
1102
 
        mayavi.add_filter(extract_edges)
1103
 
 
1104
 
        # Now show the lines with tubes.
1105
 
        tf = tvtk.TubeFilter(radius=tube_radius)
1106
 
        tube = FilterBase(filter=tf, name='TubeFilter')
1107
 
        mayavi.add_filter(tube)
1108
 
        s = Surface(name='Tubes')
1109
 
        s.actor.mapper.scalar_visibility = scalar_visibility
1110
 
        mayavi.add_module(s)
1111
 
        s.actor.property.color = color
1112
 
 
1113
 
        # Show the points with glyphs.
1114
 
        g = Glyph(name='Spheres')
1115
 
        g.glyph.glyph_source = g.glyph.glyph_list[4]
1116
 
        g.glyph.glyph_source.radius = sphere_radius
1117
 
        extract_edges.add_child(g)
1118
 
        g.glyph.scale_mode = 'data_scaling_off'
1119
 
        g.actor.mapper.scalar_visibility=scalar_visibility
1120
 
        g.actor.property.color = color
1121
 
        return s, g
1122
 
 
1123
 
    s = Surface()
1124
 
    s.actor.mapper.scalar_visibility = scalar_visibility
1125
 
    mayavi.add_module(s)
1126
 
    s.actor.property.color = color
1127
 
    s.actor.property.representation = representation
1128
 
    return s
1129
 
 
1130
 
def contour_surf(*args, **kwargs):
1131
 
    """ Plots the contours of a surface using grid spaced data supplied as
1132
 
    2D arrays.
1133
 
 
1134
 
    Function signatures::
1135
 
 
1136
 
        contour_surf(z, scalars=z, ...)
1137
 
        contour_surf(surf_object, ...)
1138
 
        contour_surf(x, y, z, scalars=z, ...)
1139
 
 
1140
 
    If only one array z is passed the x and y arrays are assumed to be made
1141
 
    of the indices of z.
1142
 
    z is the elevation matrix. If no `scalars` argument is passed the
1143
 
    contours are contour lines of the elevation, elsewhere they are
1144
 
    contour lines of the scalar array.
1145
 
    A surf object can also be passed in which case the function adds contours
1146
 
    to the existing surf.
1147
 
 
1148
 
    Arguments
1149
 
    ---------
1150
 
 
1151
 
    x
1152
 
        x coordinates of the points of the mesh (optional).
1153
 
    y
1154
 
        y coordinates of the points of the mesh (optional).
1155
 
    z
1156
 
        A 2D array giving the elevation of the mesh.
1157
 
        Also will work if z is a callable which supports x and y arrays
1158
 
        as the arguments, but x and y must then be supplied.
1159
 
 
1160
 
    Keyword arguments
1161
 
    -----------------
1162
 
 
1163
 
    extent :  [xmin, xmax, ymin, ymax, zmin, zmax]
1164
 
        Default is the x, y, z arrays extent.
1165
 
    contours
1166
 
        Integer/list specifying number/list of
1167
 
        iso-surfaces. Specifying 0 shows no contours.
1168
 
        Specifying a list of values will only give the
1169
 
        requested contours asked for.  Default: 10
1170
 
    scalars
1171
 
        An array of the same shape as z that gives the scalar
1172
 
        data to plot the contours of. This can also be a function
1173
 
        that takes x and y as arguments.
1174
 
    color
1175
 
        The color of the contour lines. If None, this is given by
1176
 
        the scalars.
1177
 
    contour_z
1178
 
        If this argument is given the contours are drawn on a
1179
 
        plane at the altitude specified by this argument instead
1180
 
        of on the surface. Currently this cannot be used when a
1181
 
        surf_object is passed as an argument.
1182
 
    name
1183
 
        The name of the vtk object created. Default is "Contour Surf"
1184
 
 
1185
 
    """
1186
 
 
1187
 
    contours = kwargs.pop('contours', 10)
1188
 
    if len(args) == 1 and isinstance(args[0], Surface):
1189
 
        if 'contour_z' in kwargs:
1190
 
            raise TypeError, 'contour_z cannot be used when passing a surf_object'
1191
 
        data_object = _find_module_manager(args[0]).source
1192
 
        mm = ModuleManager()
1193
 
        mayavi.add_module(mm, obj=data_object)
1194
 
        s = Surface(name=kwargs.get('name', 'Contours'))
1195
 
        mm.add_child(s)
1196
 
    else:
1197
 
        if 'contour_z' in kwargs:
1198
 
            if not 'scalars' in kwargs:
1199
 
                kwargs['scalars'] = args[-1]
1200
 
            contour_z = kwargs.pop('contour_z')
1201
 
            args = list(args)
1202
 
            args[-1] = scipy.zeros_like(args[0])
1203
 
        if not 'name' in kwargs:
1204
 
            kwargs['name'] = "Contour Surf"
1205
 
        s = surf(*args, **kwargs)
1206
 
    s.enable_contours = True
1207
 
 
1208
 
    # Check what type the 'contours' are and do whatever is needed.
1209
 
    contour_list = True
1210
 
    try:
1211
 
        len(contours)
1212
 
    except TypeError:
1213
 
        contour_list = False
1214
 
 
1215
 
    if contour_list:
1216
 
        s.contour.contours = contours
1217
 
        s.contour.set(auto_contours=False)
1218
 
    else:
1219
 
        assert type(contours) == int, "The contours argument must be an integer"
1220
 
        assert contours > 1, "The contours argument must be positive"
1221
 
        s.contour.set(auto_contours=True,
1222
 
                            number_of_contours=contours)
1223
 
 
1224
 
    return s
1225
 
 
1226
 
def imshow(arr, extent=None, interpolate=False,
1227
 
           lut_mode='blue-red', file_name='',
1228
 
           name='ImShow'):
1229
 
    """Allows one to view a 2D Numeric array as an image.  This works
1230
 
    best for very large arrays (like 1024x1024 arrays).
1231
 
 
1232
 
    Arguments
1233
 
    ---------
1234
 
 
1235
 
    arr -- Array to be viewed.
1236
 
 
1237
 
    Keyword arguments
1238
 
    -----------------
1239
 
 
1240
 
    scale -- Scale the x, y and z axis as per passed values.
1241
 
             Defaults to [1.0, 1.0, 1.0].
1242
 
 
1243
 
    extent --  [xmin, xmax, ymin, ymax]
1244
 
                   Default is the x, y arrays extent
1245
 
 
1246
 
    interpolate -- Boolean to interpolate the data in the image.
1247
 
    """
1248
 
    # FIXME
1249
 
    assert len(arr.shape) == 2, "Only 2D arrays can be viewed!"
1250
 
 
1251
 
    ny, nx = arr.shape
1252
 
    if extent:
1253
 
        xmin, xmax, ymin, ymax = extent
1254
 
    else:
1255
 
        xmin = 0.
1256
 
        xmax = nx
1257
 
        ymin = 0.
1258
 
        ymax = ny
1259
 
 
1260
 
    xa = scipy.linspace(xmin, xmax, nx, 'f')
1261
 
    ya = scipy.linspace(ymin, ymax, ny, 'f')
1262
 
 
1263
 
    arr_flat = scipy.ravel(arr)
1264
 
    min_val = min(arr_flat)
1265
 
    max_val = max(arr_flat)
1266
 
 
1267
 
    sp = mlab._create_structured_points_direct(xa, ya)
1268
 
 
1269
 
    from mayavi.core.lut_manager import LUTManager
1270
 
    lut = LUTManager(lut_mode=lut_mode, file_name=file_name)
1271
 
    lut.data_range = min_val, max_val
1272
 
    a = lut.lut.map_scalars(arr_flat, 0, 0)
1273
 
    a.name = 'scalars'
1274
 
    sp.point_data.scalars = a
1275
 
    sp.scalar_type = 'unsigned_char'
1276
 
    sp.number_of_scalar_components = 4
1277
 
    d = _add_data(sp, name)
1278
 
 
1279
 
    ia = ImageActor()
1280
 
    ia.actor.interpolate = interpolate
1281
 
    mayavi.add_module(ia)
1282
 
    return ia
1283
 
 
1284
 
######################################################################
1285
 
# Non data-related drawing elements
1286
 
def outline(object=None, color=None, name='Outline'):
1287
 
    """Creates an outline for the current data.
1288
 
 
1289
 
    Keyword arguments
1290
 
    -----------------
1291
 
 
1292
 
        object -- the object for which we create the outline
1293
 
                       (default None).
1294
 
 
1295
 
        color -- The color triplet, eg: ( 1., 0., 0.)
1296
 
    """
1297
 
    from mayavi.modules.outline import Outline
1298
 
    mayavi = _make_default_figure()
1299
 
    scene = gcf()
1300
 
    for obj in _traverse(scene):
1301
 
        if isinstance(obj, Outline) and obj.name == name:
1302
 
            o = obj
1303
 
            break
1304
 
    else:
1305
 
        o = Outline(name=name)
1306
 
        if object is not None:
1307
 
            object.add_child(o)
1308
 
        else:
1309
 
            mayavi.add_module(o)
1310
 
        if color is None:
1311
 
            color = scene.scene.foreground
1312
 
    if not color is None:
1313
 
        o.actor.property.color = color
1314
 
    return o
1315
 
 
1316
 
def axes(color=None, xlabel=None, ylabel=None, zlabel=None,
1317
 
         object=None, name='Axes'):
1318
 
    """Creates an axes for the current data.
1319
 
 
1320
 
    Keyword arguments
1321
 
    -----------------
1322
 
 
1323
 
        color -- The color triplet, eg: (1., 0., 0.)
1324
 
 
1325
 
        xlabel -- the label of the x axis, default: ''
1326
 
 
1327
 
        ylabel -- the label of the y axis, default: ''
1328
 
 
1329
 
        zlabel -- the label of the z axis, default: ''
1330
 
 
1331
 
        object -- the object for which we create the axes.
1332
 
    """
1333
 
    mayavi = _make_default_figure()
1334
 
    scene = gcf()
1335
 
    new = False
1336
 
    if object is not None:
1337
 
        a = Axes(name=name)
1338
 
        object.add_child(a)
1339
 
        new = True
1340
 
    else:
1341
 
        for obj in _traverse(scene):
1342
 
            if isinstance(obj, Axes) and obj.name == name:
1343
 
                a = obj
1344
 
                break
1345
 
        else:
1346
 
            a = Axes(name=name)
1347
 
            mayavi.add_module(a)
1348
 
            new = True
1349
 
    if new:
1350
 
        if color is None:
1351
 
            color = scene.scene.foreground
1352
 
        if xlabel is None:
1353
 
            xlabel = ''
1354
 
        if ylabel is None:
1355
 
            ylabel = ''
1356
 
        if zlabel is None:
1357
 
            zlabel = ''
1358
 
    if color is not None:
1359
 
        a.property.color = color
1360
 
    if xlabel is not None:
1361
 
        a.axes.x_label = xlabel
1362
 
    if ylabel is not None:
1363
 
        a.axes.y_label = ylabel
1364
 
    if zlabel is not None:
1365
 
        a.axes.z_label = zlabel
1366
 
    return a
1367
 
 
1368
 
def figure():
1369
 
    """If you are running from IPython this will start up mayavi for
1370
 
    you!  This returns the current running MayaVi script instance.
1371
 
    """
1372
 
    global mayavi, application
1373
 
    if mayavi is not None and application.stopped is None:
1374
 
        mayavi.new_scene()
1375
 
        return mayavi.engine.current_scene
1376
 
    m = Mayavi()
1377
 
    m.main()
1378
 
    m.script.new_scene()
1379
 
    engine = m.script.engine
1380
 
    mayavi = m.script
1381
 
    application = m.application
1382
 
    return mayavi.engine.current_scene
1383
 
 
1384
 
def gcf():
1385
 
    """Return a handle to the current figure.
1386
 
    """
1387
 
    return mayavi.engine.current_scene
1388
 
 
1389
 
def clf():
1390
 
    """Clear the current figure.
1391
 
    """
1392
 
    try:
1393
 
        scene = gcf()
1394
 
        scene.children[:] = []
1395
 
    except AttributeError:
1396
 
        pass
1397
 
 
1398
 
def savefig(filename, size=None, **kwargs):
1399
 
    """ Save the current scene.
1400
 
        The output format are deduced by the extension to filename.
1401
 
        Possibilities are png, jpg, bmp, tiff, ps, eps, pdf, rib (renderman),
1402
 
        oogl (geomview), iv (OpenInventor), vrml, obj (wavefront)
1403
 
 
1404
 
        If an additional size (2-tuple) argument is passed the window
1405
 
        is resized to the specified size in order to produce a
1406
 
        suitably sized output image.  Please note that when the window
1407
 
        is resized, the window may be obscured by other widgets and
1408
 
        the camera zoom is not reset which is likely to produce an
1409
 
        image that does not reflect what is seen on screen.
1410
 
 
1411
 
        Any extra keyword arguments are passed along to the respective
1412
 
        image format's save method.
1413
 
    """
1414
 
    gcf().scene.save(filename, size=size, **kwargs)
1415
 
 
1416
 
def xlabel(text):
1417
 
    """Creates a set of axes if there isn't already one, and sets the x label
1418
 
    """
1419
 
    return axes(xlabel=text)
1420
 
 
1421
 
def ylabel(text):
1422
 
    """Creates a set of axes if there isn't already one, and sets the y label
1423
 
    """
1424
 
    return axes(ylabel=text)
1425
 
 
1426
 
def zlabel(text):
1427
 
    """ Creates a set of axes if there isn't already one, and sets the z label
1428
 
    """
1429
 
    return axes(zlabel=text)
1430
 
 
1431
 
def title(text=None, color=None, size=None, name='Title'):
1432
 
    """Creates a title for the figure.
1433
 
 
1434
 
    Keyword arguments
1435
 
    -----------------
1436
 
 
1437
 
        text -- The text of the title, default: ''
1438
 
 
1439
 
        color -- The color triplet, eg: ( 1., 0., 0.)
1440
 
 
1441
 
        size -- The size, default: 1
1442
 
    """
1443
 
    scene = gcf()
1444
 
    for object in _traverse(scene):
1445
 
        if isinstance(object, Text) and object.name==name:
1446
 
            t = object
1447
 
            break
1448
 
    else:
1449
 
        t = Text(name=name)
1450
 
        mayavi.add_module(t)
1451
 
        if color is None:
1452
 
            color = scene.scene.foreground
1453
 
        if text is None:
1454
 
            text = 'title'
1455
 
    if color is not None:
1456
 
        t.property.color = color
1457
 
    if text is not None:
1458
 
        t.text = text
1459
 
    if text is not None or size is not None:
1460
 
        t.width = min(0.05*size*len(t.text), 1)
1461
 
        t.x_position = 0.5*(1 - t.width)
1462
 
        t.y_position = 0.8
1463
 
    return t
1464
 
 
1465
 
def text(x=0, y=0, text='', color=None, size=None, name='Text'):
1466
 
    """Adds a text on the figure.
1467
 
 
1468
 
    Keyword arguments
1469
 
    -----------------
1470
 
        x -- x position on the screen of the origin of the text, default: 0
1471
 
 
1472
 
        y -- y position on the screen of the origin of the text, default: 0
1473
 
 
1474
 
        text -- The text, default: ''
1475
 
 
1476
 
        color -- The color triplet, eg: ( 1., 0., 0.)
1477
 
 
1478
 
        size -- The size, default: 1
1479
 
    """
1480
 
    scene = gcf()
1481
 
    t = Text(name=name)
1482
 
    mayavi.add_module(t)
1483
 
    if color is None:
1484
 
        color = scene.scene.foreground
1485
 
    else:
1486
 
        t.property.color = color
1487
 
    t.text = text
1488
 
    t.x_position = x
1489
 
    t.y_position = y
1490
 
    return t
1491
 
 
1492
 
 
1493
 
def scalarbar(object=None, title=None, orientation=None):
1494
 
    """Adds a colorbar for the scalar color mapping of the given object.
1495
 
 
1496
 
    If no object is specified, the first object with scalar data in the scene
1497
 
    is used.
1498
 
 
1499
 
    Keyword arguments
1500
 
    -----------------
1501
 
 
1502
 
        title -- The title string
1503
 
 
1504
 
        orientation -- Can be 'horizontal' or 'vertical'
1505
 
    """
1506
 
    module_manager = _find_module_manager(object=object, data_type="scalar")
1507
 
    if module_manager is None:
1508
 
        return
1509
 
    if not module_manager.scalar_lut_manager.show_scalar_bar:
1510
 
        if title is None:
1511
 
            title = ''
1512
 
        if orientation is None:
1513
 
            orientation = 'horizontal'
1514
 
    colorbar = module_manager.scalar_lut_manager.scalar_bar
1515
 
    if title is not None:
1516
 
        colorbar.title = title
1517
 
    if orientation is not None:
1518
 
        _orient_colorbar(colorbar, orientation)
1519
 
    module_manager.scalar_lut_manager.show_scalar_bar = True
1520
 
    return colorbar
1521
 
 
1522
 
def vectorbar(object=None, title=None, orientation=None):
1523
 
    """Adds a colorbar for the vector color mapping of the given object.
1524
 
 
1525
 
    If no object is specified, the first object with vector data in the scene
1526
 
    is used.
1527
 
 
1528
 
    Keyword arguments
1529
 
    -----------------
1530
 
 
1531
 
        object -- Optional object to get the vector lut from
1532
 
 
1533
 
        title -- The title string
1534
 
 
1535
 
        orientation -- Can be 'horizontal' or 'vertical'
1536
 
    """
1537
 
    module_manager = _find_module_manager(object=object, data_type="vector")
1538
 
    if module_manager is None:
1539
 
        return
1540
 
    if not module_manager.vector_lut_manager.show_scalar_bar:
1541
 
        title = ''
1542
 
        orientation = 'horizontal'
1543
 
    colorbar = module_manager.vector_lut_manager.scalar_bar
1544
 
    if title is not None:
1545
 
        colorbar.title = title
1546
 
    if orientation is not None:
1547
 
        _orient_colorbar(colorbar, orientation)
1548
 
    module_manager.vector_lut_manager.show_scalar_bar = True
1549
 
    return colorbar
1550
 
 
1551
 
def colorbar(object=None, title=None, orientation=None):
1552
 
    """Adds a colorbar for the color mapping of the given object.
1553
 
 
1554
 
    If the object has scalar data, the scalar color mapping is
1555
 
    represented. Elsewhere the vector color mapping is represented, if
1556
 
    available.
1557
 
    If no object is specified, the first object with a color map in the scene
1558
 
    is used.
1559
 
 
1560
 
    Keyword arguments
1561
 
    -----------------
1562
 
 
1563
 
        object -- Optional object to get the vector lut from
1564
 
 
1565
 
        title       -- The title string
1566
 
 
1567
 
        orientation -- Can be 'horizontal' or 'vertical'
1568
 
    """
1569
 
    colorbar = scalarbar(object=object, title=title, orientation=orientation)
1570
 
    if colorbar is None:
1571
 
        colorbar = vectorbar(object=object, title=title, orientation=orientation)
1572
 
    return colorbar
1573
 
 
1574
 
 
1575
 
######################################################################
1576
 
# Test functions.
1577
 
######################################################################
1578
 
def test_plot3d():
1579
 
    """Generates a pretty set of lines."""
1580
 
    n_mer, n_long = 6, 11
1581
 
    pi = scipy.pi
1582
 
    dphi = pi/1000.0
1583
 
    phi = scipy.arange(0.0, 2*pi + 0.5*dphi, dphi, 'd')
1584
 
    mu = phi*n_mer
1585
 
    x = scipy.cos(mu)*(1+scipy.cos(n_long*mu/n_mer)*0.5)
1586
 
    y = scipy.sin(mu)*(1+scipy.cos(n_long*mu/n_mer)*0.5)
1587
 
    z = scipy.sin(n_long*mu/n_mer)*0.5
1588
 
 
1589
 
    l = plot3d(x, y, z, radius=0.05, color=(0.0, 0.0, 0.8))
1590
 
    return l
1591
 
 
1592
 
def test_molecule():
1593
 
    """Generates and shows a Caffeine molecule."""
1594
 
    o = [[30, 62, 19],[8, 21, 10]]
1595
 
    ox, oy, oz = map(scipy.array, zip(*o))
1596
 
    n = [[31, 21, 11], [18, 42, 14], [55, 46, 17], [56, 25, 13]]
1597
 
    nx, ny, nz = map(scipy.array, zip(*n))
1598
 
    c = [[5, 49, 15], [30, 50, 16], [42, 42, 15], [43, 29, 13], [18, 28, 12],
1599
 
         [32, 6, 8], [63, 36, 15], [59, 60, 20]]
1600
 
    cx, cy, cz = map(scipy.array, zip(*c))
1601
 
    h = [[23, 5, 7], [32, 0, 16], [37, 5, 0], [73, 36, 16], [69, 60, 20],
1602
 
         [54, 62, 28], [57, 66, 12], [6, 59, 16], [1, 44, 22], [0, 49, 6]]
1603
 
    hx, hy, hz = map(scipy.array, zip(*h))
1604
 
 
1605
 
    oxygen = points3d(ox, oy, oz, scale_factor=8, autoscale=False,
1606
 
                                        color=(1,0,0), name='Oxygen')
1607
 
    nitrogen = points3d(nx, ny, nz, scale_factor=10, autoscale=False,
1608
 
                                        color=(0,0,1), name='Nitrogen')
1609
 
    carbon = points3d(cx, cy, cz, scale_factor=10, autoscale=False,
1610
 
                                        color=(0,1,0), name='Carbon')
1611
 
    hydrogen = points3d(hx, hy, hz, scale_factor=5, autoscale=False,
1612
 
                                        color=(1,1,1), name='Hydrogen')
1613
 
 
1614
 
def test_surf_lattice():
1615
 
    """Test Surf on regularly spaced co-ordinates like MayaVi."""
1616
 
    def f(x, y):
1617
 
        sin, cos = scipy.sin, scipy.cos
1618
 
        return sin(x+y) + sin(2*x - y) + cos(3*x+4*y)
1619
 
        #return scipy.sin(x*y)/(x*y)
1620
 
 
1621
 
    x, y = scipy.mgrid[-7.:7.05:0.1, -5.:5.05:0.05]
1622
 
    s = surf(x, y, f)
1623
 
    cs = contour_surf(x, y, f, contour_z=0)
1624
 
    return s
1625
 
 
1626
 
def test_simple_surf():
1627
 
    """Test Surf with a simple collection of points."""
1628
 
    x, y = scipy.mgrid[0:3:1,0:3:1]
1629
 
    return surf(x, y, scipy.asarray(x, 'd'))
1630
 
 
1631
 
def test_surf():
1632
 
    """A very pretty picture of spherical harmonics translated from
1633
 
    the octaviz example."""
1634
 
    pi = scipy.pi
1635
 
    cos = scipy.cos
1636
 
    sin = scipy.sin
1637
 
    dphi, dtheta = pi/250.0, pi/250.0
1638
 
    [phi,theta] = scipy.mgrid[0:pi+dphi*1.5:dphi,0:2*pi+dtheta*1.5:dtheta]
1639
 
    m0 = 4; m1 = 3; m2 = 2; m3 = 3; m4 = 6; m5 = 2; m6 = 6; m7 = 4;
1640
 
    r = sin(m0*phi)**m1 + cos(m2*phi)**m3 + sin(m4*theta)**m5 + cos(m6*theta)**m7
1641
 
    x = r*sin(phi)*cos(theta)
1642
 
    y = r*cos(phi)
1643
 
    z = r*sin(phi)*sin(theta);
1644
 
 
1645
 
    return surf(x, y, z)
1646
 
 
1647
 
def test_mesh_sphere():
1648
 
    """Create a simple sphere and test the mesh."""
1649
 
    pi = scipy.pi
1650
 
    cos = scipy.cos
1651
 
    sin = scipy.sin
1652
 
    du, dv = pi/20.0, pi/20.0
1653
 
    phi, theta = scipy.mgrid[0.01:pi+du*1.5:du, 0:2*pi+dv*1.5:dv]
1654
 
    r = 1.0
1655
 
    x = r*sin(phi)*cos(theta)
1656
 
    y = r*sin(phi)*sin(theta)
1657
 
    z = r*cos(phi)
1658
 
    s = surf(x, y, z, representation='mesh',
1659
 
                   tube_radius=0.01, sphere_radius=0.025)
1660
 
 
1661
 
def test_mesh():
1662
 
    """Create a fancy looking mesh (example taken from octaviz)."""
1663
 
    pi = scipy.pi
1664
 
    cos = scipy.cos
1665
 
    sin = scipy.sin
1666
 
    du, dv = pi/20.0, pi/20.0
1667
 
    u, v = scipy.mgrid[0.01:pi+du*1.5:du, 0:2*pi+dv*1.5:dv]
1668
 
    x = (1- cos(u))*cos(u+2*pi/3) * cos(v + 2*pi/3.0)*0.5
1669
 
    y = (1- cos(u))*cos(u+2*pi/3) * cos(v - 2*pi/3.0)*0.5
1670
 
    z = cos(u-2*pi/3.)
1671
 
 
1672
 
    m = surf(x, y, z, scalar_visibility=True, representation='mesh',
1673
 
                   tube_radius=0.0075, sphere_radius=0.02)
1674
 
 
1675
 
def test_imshow():
1676
 
    """Show a large random array."""
1677
 
    z_large = scipy.random.random((1024, 512))
1678
 
    i = imshow(z_large, extent=[0., 1., 0., 1.])
1679
 
 
1680
 
def test_contour3d():
1681
 
    dims = [64, 64, 64]
1682
 
    xmin, xmax, ymin, ymax, zmin, zmax = [-5,5,-5,5,-5,5]
1683
 
    x, y, z = scipy.ogrid[xmin:xmax:dims[0]*1j,
1684
 
                          ymin:ymax:dims[1]*1j,
1685
 
                          zmin:zmax:dims[2]*1j]
1686
 
    x = x.astype('f')
1687
 
    y = y.astype('f')
1688
 
    z = z.astype('f')
1689
 
 
1690
 
    sin = scipy.sin
1691
 
    scalars = x*x*0.5 + y*y + z*z*2.0
1692
 
 
1693
 
    contour3d(scalars, contours=4)
1694
 
 
1695
 
    # Show an outline and zoom appropriately.
1696
 
    outline()
1697
 
    mayavi.engine.current_scene.scene.isometric_view()
1698
 
 
1699
 
def test_quiver3d():
1700
 
    dims = [16, 16, 16]
1701
 
    xmin, xmax, ymin, ymax, zmin, zmax = [-5,5,-5,5,-5,5]
1702
 
    x, y, z = scipy.mgrid[xmin:xmax:dims[0]*1j,
1703
 
                          ymin:ymax:dims[1]*1j,
1704
 
                          zmin:zmax:dims[2]*1j]
1705
 
    x = x.astype('f')
1706
 
    y = y.astype('f')
1707
 
    z = z.astype('f')
1708
 
 
1709
 
    sin = scipy.sin
1710
 
    cos = scipy.cos
1711
 
    u = cos(x)
1712
 
    v = sin(y)
1713
 
    w = sin(x*z)
1714
 
 
1715
 
    # All these work!
1716
 
    #quiver3d(u, v, w)
1717
 
    quiver3d(x, y, z, u, v, w)
1718
 
 
1719
 
    # Show an outline and zoom appropriately.
1720
 
    outline()
1721
 
    mayavi.engine.current_scene.scene.isometric_view()
1722
 
 
1723
 
def test_quiver3d_2d_data():
1724
 
    dims = [32, 32]
1725
 
    xmin, xmax, ymin, ymax = [-5,5,-5,5]
1726
 
    x, y = scipy.mgrid[xmin:xmax:dims[0]*1j,
1727
 
                       ymin:ymax:dims[1]*1j]
1728
 
    x = x.astype('f')
1729
 
    y = y.astype('f')
1730
 
 
1731
 
    sin = scipy.sin
1732
 
    cos = scipy.cos
1733
 
    u = cos(x)
1734
 
    v = sin(y)
1735
 
    w = scipy.zeros_like(x)
1736
 
 
1737
 
    quiver3d(x, y, w, u, v, w)
1738
 
 
1739
 
    # Show an outline and zoom appropriately.
1740
 
    outline()
1741
 
    mayavi.engine.current_scene.scene.isometric_view()