2
Colormap of a scalar value field with cross sections that can be animated
4
A complex example showing interaction between a Traits-based interactive model,
5
interactive plot parameters, and multiple Chaco components.
7
Renders a colormapped image of a scalar value field, and a cross section
8
chosen by a line interactor.
10
Animation must be disabled (unchecked) before the model can be edited.
13
# Standard library imports
14
from optparse import OptionParser
18
# Major library imports
19
from numpy import array, linspace, meshgrid, nanmin, nanmax, pi, errstate
21
# Enthought library imports
22
from chaco.api import ArrayPlotData, ColorBar, ContourLinePlot, \
23
ContourPolyPlot, DataRange1D, VPlotContainer, \
24
DataRange2D, GridMapper, GridDataSource, \
25
HPlotContainer, ImageData, LinearMapper, \
26
OverlayPlotContainer, Plot, PlotAxis
27
from chaco import default_colormaps
28
from enable.component_editor import ComponentEditor
29
from chaco.tools.api import LineInspector, PanTool, ZoomTool
30
from traits.api import Array, Callable, CFloat, CInt, Enum, Event, Float, \
31
HasTraits, Int, Instance, Str, Trait, on_trait_change, Button, Bool, \
33
from traitsui.api import Group, HGroup, Item, View, UItem, spring
35
from pyface.timer.api import Timer
37
# Remove the most boring colormaps from consideration:
38
colormaps = default_colormaps.color_map_name_dict.keys()
39
for boring in 'bone gray yarg gist_gray gist_yarg Greys'.split():
40
colormaps.remove(boring)
42
class Model(HasTraits):
44
#Traits view definitions:
46
Group(Item('function'),
47
HGroup(Item('npts_x', label="Number X Points"),
48
Item('npts_y', label="Number Y Points")),
49
HGroup(Item('min_x', label="Min X value"),
50
Item('max_x', label="Max X value")),
51
HGroup(Item('min_y', label="Min Y value"),
52
Item('max_y', label="Max Y value"))),
53
buttons=["OK", "Cancel"])
55
function = Str("tanh(x**2+y)*cos(y)*jn(0,x+y*2)")
62
min_y = CFloat(-1.5*pi)
63
max_y = CFloat(1.5*pi)
74
def __init__(self, *args, **kwargs):
75
super(Model, self).__init__(*args, **kwargs)
78
def compute_model(self):
79
# The xs and ys used for the image plot range need to be the
81
self.xs = linspace(self.min_x, self.max_x, self.npts_x+1)
82
self.ys = linspace(self.min_y, self.max_y, self.npts_y+1)
84
# The grid of points at which we will evaluate the 2D function
85
# is located at cell centers, so use halfsteps from the
86
# min/max values (which are edges)
87
xstep = (self.max_x - self.min_x) / self.npts_x
88
#ystep = (self.max_y - self.min_y) / self.npts_y
89
gridx = linspace(self.min_x+xstep/2, self.max_x-xstep/2, self.npts_x)
90
gridy = linspace(self.min_y+xstep/2, self.max_y-xstep/2, self.npts_y)
91
x, y = meshgrid(gridx, gridy)
94
exec "from scipy import *" in d
95
exec "from scipy.special import *" in d
96
self.zs = eval(self.function, d)
97
self.minz = nanmin(self.zs)
98
self.maxz = nanmax(self.zs)
99
self.model_changed = True
100
self._function = self.function
102
self.set(function = self._function, trait_change_notify=False)
104
def _anytrait_changed(self, name, value):
105
if name in ['function', 'npts_x', 'npts_y',
106
'min_x', 'max_x', 'min_y', 'max_y']:
110
class PlotUI(HasTraits):
112
# container for all plots
113
container = Instance(HPlotContainer)
115
# Plot components within this container:
116
polyplot = Instance(ContourPolyPlot)
117
lineplot = Instance(ContourLinePlot)
118
cross_plot = Instance(Plot)
119
cross_plot2 = Instance(Plot)
120
colorbar = Instance(ColorBar)
123
pd = Instance(ArrayPlotData)
127
colormap = Enum(colormaps)
129
#Traits view definitions:
131
Group(UItem('container', editor=ComponentEditor(size=(800,600)))),
134
plot_edit_view = View(
135
Group(Item('num_levels'),
137
buttons=["OK","Cancel"])
140
#---------------------------------------------------------------------------
142
#---------------------------------------------------------------------------
144
_image_index = Instance(GridDataSource)
145
_image_value = Instance(ImageData)
147
_cmap = Trait(default_colormaps.jet, Callable)
149
#---------------------------------------------------------------------------
150
# Public View interface
151
#---------------------------------------------------------------------------
153
def __init__(self, *args, **kwargs):
154
super(PlotUI, self).__init__(*args, **kwargs)
155
# FIXME: 'with' wrapping is temporary fix for infinite range in initial
156
# color map, which can cause a distracting warning print. This 'with'
157
# wrapping should be unnecessary after fix in color_mapper.py.
158
with errstate(invalid='ignore'):
161
def create_plot(self):
163
# Create the mapper, etc
164
self._image_index = GridDataSource(array([]),
166
sort_order=("ascending","ascending"))
167
image_index_range = DataRange2D(self._image_index)
168
self._image_index.on_trait_change(self._metadata_changed,
171
self._image_value = ImageData(data=array([]), value_depth=1)
172
image_value_range = DataRange1D(self._image_value)
176
# Create the contour plots
177
self.polyplot = ContourPolyPlot(index=self._image_index,
178
value=self._image_value,
179
index_mapper=GridMapper(range=
182
self._cmap(image_value_range),
183
levels=self.num_levels)
185
self.lineplot = ContourLinePlot(index=self._image_index,
186
value=self._image_value,
187
index_mapper=GridMapper(range=
188
self.polyplot.index_mapper.range),
189
levels=self.num_levels)
192
# Add a left axis to the plot
193
left = PlotAxis(orientation='left',
195
mapper=self.polyplot.index_mapper._ymapper,
196
component=self.polyplot)
197
self.polyplot.overlays.append(left)
199
# Add a bottom axis to the plot
200
bottom = PlotAxis(orientation='bottom',
202
mapper=self.polyplot.index_mapper._xmapper,
203
component=self.polyplot)
204
self.polyplot.overlays.append(bottom)
207
# Add some tools to the plot
208
self.polyplot.tools.append(PanTool(self.polyplot,
209
constrain_key="shift"))
210
self.polyplot.overlays.append(ZoomTool(component=self.polyplot,
211
tool_mode="box", always_on=False))
212
self.polyplot.overlays.append(LineInspector(component=self.polyplot,
214
inspect_mode="indexed",
218
self.polyplot.overlays.append(LineInspector(component=self.polyplot,
220
inspect_mode="indexed",
225
# Add these two plots to one container
226
contour_container = OverlayPlotContainer(padding=20,
229
contour_container.add(self.polyplot)
230
contour_container.add(self.lineplot)
234
cbar_index_mapper = LinearMapper(range=image_value_range)
235
self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
237
padding_top=self.polyplot.padding_top,
238
padding_bottom=self.polyplot.padding_bottom,
243
self.pd = ArrayPlotData(line_index = array([]),
244
line_value = array([]),
245
scatter_index = array([]),
246
scatter_value = array([]),
247
scatter_color = array([]))
249
self.cross_plot = Plot(self.pd, resizable="h")
250
self.cross_plot.height = 100
251
self.cross_plot.padding = 20
252
self.cross_plot.plot(("line_index", "line_value"),
254
self.cross_plot.plot(("scatter_index","scatter_value","scatter_color"),
257
color_mapper=self._cmap(image_value_range),
261
self.cross_plot.index_range = self.polyplot.index_range.x_range
263
self.pd.set_data("line_index2", array([]))
264
self.pd.set_data("line_value2", array([]))
265
self.pd.set_data("scatter_index2", array([]))
266
self.pd.set_data("scatter_value2", array([]))
267
self.pd.set_data("scatter_color2", array([]))
269
self.cross_plot2 = Plot(self.pd, width = 140, orientation="v", resizable="v", padding=20, padding_bottom=160)
270
self.cross_plot2.plot(("line_index2", "line_value2"),
272
self.cross_plot2.plot(("scatter_index2","scatter_value2","scatter_color2"),
275
color_mapper=self._cmap(image_value_range),
279
self.cross_plot2.index_range = self.polyplot.index_range.y_range
283
# Create a container and add components
284
self.container = HPlotContainer(padding=40, fill_padding=True,
285
bgcolor = "white", use_backbuffer=False)
286
inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
287
inner_cont.add(self.cross_plot)
288
inner_cont.add(contour_container)
289
self.container.add(self.colorbar)
290
self.container.add(inner_cont)
291
self.container.add(self.cross_plot2)
294
def update(self, model):
295
self.minz = model.minz
296
self.maxz = model.maxz
297
self.colorbar.index_mapper.range.low = self.minz
298
self.colorbar.index_mapper.range.high = self.maxz
299
self._image_index.set_data(model.xs, model.ys)
300
self._image_value.data = model.zs
301
self.pd.set_data("line_index", model.xs)
302
self.pd.set_data("line_index2", model.ys)
303
self.container.invalidate_draw()
304
self.container.request_redraw()
307
#---------------------------------------------------------------------------
309
#---------------------------------------------------------------------------
311
def _metadata_changed(self, old, new):
312
""" This function takes out a cross section from the image data, based
313
on the line inspector selections, and updates the line and scatter
316
self.cross_plot.value_range.low = self.minz
317
self.cross_plot.value_range.high = self.maxz
318
self.cross_plot2.value_range.low = self.minz
319
self.cross_plot2.value_range.high = self.maxz
320
if self._image_index.metadata.has_key("selections"):
321
x_ndx, y_ndx = self._image_index.metadata["selections"]
323
self.pd.set_data("line_value",
324
self._image_value.data[y_ndx,:])
325
self.pd.set_data("line_value2",
326
self._image_value.data[:,x_ndx])
327
xdata, ydata = self._image_index.get_data()
328
xdata, ydata = xdata.get_data(), ydata.get_data()
329
self.pd.set_data("scatter_index", array([xdata[x_ndx]]))
330
self.pd.set_data("scatter_index2", array([ydata[y_ndx]]))
331
self.pd.set_data("scatter_value",
332
array([self._image_value.data[y_ndx, x_ndx]]))
333
self.pd.set_data("scatter_value2",
334
array([self._image_value.data[y_ndx, x_ndx]]))
335
self.pd.set_data("scatter_color",
336
array([self._image_value.data[y_ndx, x_ndx]]))
337
self.pd.set_data("scatter_color2",
338
array([self._image_value.data[y_ndx, x_ndx]]))
340
self.pd.set_data("scatter_value", array([]))
341
self.pd.set_data("scatter_value2", array([]))
342
self.pd.set_data("line_value", array([]))
343
self.pd.set_data("line_value2", array([]))
345
def _colormap_changed(self):
346
self._cmap = default_colormaps.color_map_name_dict[self.colormap]
347
if self.polyplot is not None:
348
value_range = self.polyplot.color_mapper.range
349
self.polyplot.color_mapper = self._cmap(value_range)
350
value_range = self.cross_plot.color_mapper.range
351
self.cross_plot.color_mapper = self._cmap(value_range)
352
# FIXME: change when we decide how best to update plots using
353
# the shared colormap in plot object
354
self.cross_plot.plots["dot"][0].color_mapper = self._cmap(value_range)
355
self.cross_plot2.plots["dot"][0].color_mapper = self._cmap(value_range)
356
self.container.request_redraw()
358
def _num_levels_changed(self):
359
if self.num_levels > 3:
360
self.polyplot.levels = self.num_levels
361
self.lineplot.levels = self.num_levels
366
# HasTraits class that supplies the callable for the timer event.
367
class TimerController(HasTraits):
369
# The plot view which will be affected by timed animation
370
view = Instance(PlotUI)
372
# The ModelView instance that contains the animation options:
373
model_view = Instance('ModelView')
375
# Whether the view is animated:
376
animated = DelegatesTo('model_view')
378
# whether color change is animated on each boundary:
379
animate_left = DelegatesTo('model_view')
380
animate_right = DelegatesTo('model_view')
381
animate_top = DelegatesTo('model_view')
382
animate_bottom = DelegatesTo('model_view')
384
# current increments of selected point, for animation
388
# Possible directions for 2D animated motion.
389
# One tuple will be selected randomly from these on each bounce.
390
# In each tuple, the first integer is the absolute value of
391
# the new delta of the component that reached a boundary.
392
# The second integer is the new delta of the other component.
393
motions = ((1,1), (1,2), (1,3), (2,1), (3,1), (3,2), (2,3),
394
(1,-1),(1,-2),(1,-3),(2,-1),(3,-1),(3,-2),(2,-2)
397
def onTimer(self, *args):
399
Callback function which responds to each timer tick
400
and animates the moving selection point and colors.
403
def randomize(new_direction=1, color_change=False):
405
Randomize 2D motion, and colors if desired.
408
* new_direction is the sign of the new motion delta for
409
the component that reached the boundary (the primary bounce
412
* color_change is whether to change the colormap if allowed.
414
Returns a pair of integers, which are the new motion deltas,
415
respectively, for primary bounce direction and the other.
419
self.view.colormap = random.choice(colormaps)
420
result0, result1 = random.choice(self.motions)
421
return result0 * new_direction, result1
425
metadata = self.view._image_index.metadata
426
indices = metadata.get("selections", ())
427
if len(indices) == 2:
428
# Indices are (x,y) but limits are (y,x)
430
ylim, xlim = self.view._image_value.data.shape
434
self.y_delta, self.x_delta = randomize(1,
438
self.y_delta, self.x_delta = randomize(-1,
444
self.x_delta, self.y_delta = randomize(1,
448
self.x_delta, self.y_delta = randomize(-1,
453
self.x_delta, self.y_delta = random.choice(self.motions)
455
metadata['selections'] = x,y
458
class ModelView(HasTraits):
460
model = Instance(Model)
461
view = Instance(PlotUI)
462
timer = Instance(Timer)
463
timer_controller = Instance(TimerController, ())
467
animated = Bool(False)
469
# Whether to animate colors on a bounce of each side:
470
animate_left = Bool(False)
471
animate_right = Bool(False)
472
animate_top = Bool(False)
473
animate_bottom = Bool(False)
475
traits_view = View(UItem('@view'),
476
HGroup(UItem('edit_model', enabled_when='not animated'),
479
Item('animate_left', enabled_when='animated',
480
label='Change colors at: Left'),
481
Item('animate_right', enabled_when='animated',
483
Item('animate_top', enabled_when='animated',
485
Item('animate_bottom', enabled_when='animated',
488
title = "Function Inspector",
491
@on_trait_change('model, view')
492
def update_view(self):
493
if self.model is not None and self.view is not None:
494
self.view.update(self.model)
496
def _edit_model_fired(self):
497
self.model.configure_traits()
499
def _edit_view_fired(self):
500
self.view.configure_traits(view="plot_edit_view")
502
def _model_changed(self):
503
if self.view is not None:
504
self.view.update(self.model)
507
def _start_timer(self):
508
# Start up the timer! We should do this only when the demo actually
509
# starts and not when the demo object is created.
510
# FIXME: close timer on exit.
511
self.timer_controller.view = self.view
512
self.timer_controller.model_view = self
513
self.timer = Timer(40, self.timer_controller.onTimer)
515
def edit_traits(self, *args, **kws):
517
return super(ModelView, self).edit_traits(*args, **kws)
519
def configure_traits(self, *args, **kws):
521
return super(ModelView, self).configure_traits(*args, **kws)
524
options_dict = {'colormap' : "jet",
526
'function' : "tanh(x**2+y)*cos(y)*jn(0,x+y*2)"}
527
model=Model(**options_dict)
528
view=PlotUI(**options_dict)
529
popup = ModelView(model=model, view=view)
531
def show_plot(**kwargs):
532
model = Model(**kwargs)
533
view = PlotUI(**kwargs)
534
modelview=ModelView(model=model, view=view)
535
modelview.configure_traits()
542
usage = "usage: %prog [options]"
543
parser = OptionParser(usage=usage, version="%prog 1.0")
545
parser.add_option("-c", "--colormap",
546
action="store", type="string", dest="colormap", default="jet",
547
metavar="CMAP", help="choose a default colormapper")
549
parser.add_option("-n", "--nlevels",
550
action="store", type="int", dest="num_levels", default=15,
551
help="number countour levels to plot [default: %default]")
553
parser.add_option("-f", "--function",
554
action="store", type="string", dest="function",
555
default="tanh(x**2+y)*cos(y)*jn(0,x+y*2)",
556
help="function of x and y [default: %default]")
558
opts, args = parser.parse_args(argv[1:])
561
parser.error("Incorrect number of arguments")
563
show_plot(colormap=opts.colormap, num_levels=opts.num_levels,
564
function=opts.function)
566
if __name__ == "__main__":