~ubuntu-branches/ubuntu/utopic/python-chaco/utopic

« back to all changes in this revision

Viewing changes to examples/demo/nonlinear_color_mapping.py

  • Committer: Package Import Robot
  • Author(s): Andrew Starr-Bochicchio
  • Date: 2014-06-01 17:04:08 UTC
  • mfrom: (7.2.5 sid)
  • Revision ID: package-import@ubuntu.com-20140601170408-m86xvdjd83a4qon0
Tags: 4.4.1-1ubuntu1
* Merge from Debian unstable. Remaining Ubuntu changes:
 - Let the binary-predeb target work on the usr/lib/python* directory
   as we don't have usr/share/pyshared anymore.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#!/usr/bin/env python
2
 
"""
3
 
    Demonstrates usage of the TransformColorMapper class.
4
 
    - The colorbar is zoomable and panable.
5
 
"""
6
 
 
7
 
# Major library imports
8
 
from numpy import linspace, meshgrid, pi, cos, sin, log10
9
 
 
10
 
# Enthought library imports
11
 
from enable.api import Component, ComponentEditor
12
 
from traits.api import HasTraits, Instance, Property, Float, \
13
 
            Enum, Array, Tuple, Int, Callable, cached_property
14
 
from traitsui.api import Item, Group, HGroup, VGroup, View, RangeEditor
15
 
 
16
 
# Chaco imports
17
 
from chaco.api import ArrayPlotData, Plot, ColorBar, HPlotContainer, \
18
 
                      LinearMapper, LogMapper, CMapImagePlot, Spectral, \
19
 
                      TransformColorMapper, jet, hot
20
 
from chaco.tools.api import PanTool, ZoomTool
21
 
 
22
 
 
23
 
class DataGrid(HasTraits):
24
 
    """Holds a grid of 2D data that represents a function z = f(x,y)."""
25
 
 
26
 
    #------------------------------------------------------
27
 
    # Primary Traits
28
 
    #------------------------------------------------------
29
 
 
30
 
    # (xmin, ymin xmax, ymax)
31
 
    domain_bounds = Tuple(Float, Float, Float, Float)
32
 
    
33
 
    # grid dimensions: (Nx, Ny)
34
 
    grid_size = Tuple(Int, Int)
35
 
 
36
 
    # The function to evaluate on the grid.
37
 
    func = Callable
38
 
 
39
 
    #------------------------------------------------------
40
 
    # Properties
41
 
    #------------------------------------------------------
42
 
 
43
 
    # 1D array of x coordinates.
44
 
    x_array = Property(Array, depends_on=['domain_bounds, grid_size'])
45
 
    
46
 
    # 1D array of y coordinates.
47
 
    y_array = Property(Array, depends_on=['domain_bounds, grid_size'])
48
 
 
49
 
    # 2D array of function values, z = f(x,y)
50
 
    data = Property(Array, depends_on=['func, x_array, y_array'])
51
 
 
52
 
    data_min = Property(Float, depends_on=['data'])
53
 
    data_max = Property(Float, depends_on=['data'])
54
 
 
55
 
    #------------------------------------------------------
56
 
    # Trait handlers
57
 
    #------------------------------------------------------
58
 
       
59
 
    @cached_property
60
 
    def _get_x_array(self):
61
 
        xmin = self.domain_bounds[0]
62
 
        xmax = self.domain_bounds[2]
63
 
        nx = self.grid_size[0]
64
 
        x_array = linspace(xmin, xmax, nx)
65
 
        return x_array
66
 
 
67
 
    @cached_property
68
 
    def _get_y_array(self):
69
 
        ymin = self.domain_bounds[1]
70
 
        ymax = self.domain_bounds[3]
71
 
        ny = self.grid_size[1]
72
 
        y_array = linspace(ymin, ymax, ny)
73
 
        return y_array
74
 
    
75
 
    @cached_property
76
 
    def _get_data(self):
77
 
        # This might be called with func == None during initialization.
78
 
        if self.func is None:
79
 
            return None
80
 
        # Create a scalar field to colormap.
81
 
        xs = self.x_array
82
 
        ys = self.y_array
83
 
        x, y = meshgrid(xs,ys)
84
 
        z = self.func(x,y)[:-1,:-1]
85
 
        return z
86
 
 
87
 
    @cached_property
88
 
    def _get_data_min(self):
89
 
        return self.data.min()
90
 
 
91
 
    @cached_property
92
 
    def _get_data_max(self):
93
 
        return self.data.max()    
94
 
 
95
 
 
96
 
def _create_plot_component(model):
97
 
 
98
 
    # Create a plot data object and give it the model's data array.
99
 
    pd = ArrayPlotData()
100
 
    pd.set_data("imagedata", model.data)
101
 
 
102
 
    # Create the "main" Plot.
103
 
    plot = Plot(pd, padding=50)
104
 
 
105
 
    # Use a TransformColorMapper for the color map.
106
 
    tcm = TransformColorMapper.from_color_map(jet)
107
 
 
108
 
    # Create the image plot renderer in the main plot.
109
 
    renderer = plot.img_plot("imagedata", 
110
 
                    xbounds=model.x_array,
111
 
                    ybounds=model.y_array,
112
 
                    colormap=tcm)[0]
113
 
 
114
 
    # Create the colorbar.
115
 
    lm = LinearMapper(range=renderer.value_range,
116
 
                      domain_limits=(renderer.value_range.low,
117
 
                                     renderer.value_range.high))
118
 
    colorbar = ColorBar(index_mapper=lm,
119
 
                        plot=plot,
120
 
                        orientation='v',
121
 
                        resizable='v',
122
 
                        width=30,
123
 
                        padding=20)
124
 
 
125
 
    colorbar.padding_top = plot.padding_top
126
 
    colorbar.padding_bottom = plot.padding_bottom
127
 
 
128
 
    # Add pan and zoom tools to the colorbar.
129
 
    colorbar.tools.append(PanTool(colorbar,
130
 
                                  constrain_direction="y",
131
 
                                  constrain=True))
132
 
    zoom_overlay = ZoomTool(colorbar, axis="index", tool_mode="range",
133
 
                            always_on=True, drag_button="right")
134
 
    colorbar.overlays.append(zoom_overlay)
135
 
 
136
 
    # Create a container to position the plot and the colorbar side-by-side
137
 
    container = HPlotContainer(use_backbuffer = True)
138
 
    container.add(plot)
139
 
    container.add(colorbar)
140
 
 
141
 
    return container
142
 
 
143
 
 
144
 
class DataGridView(HasTraits):
145
 
 
146
 
    # The DataGrid instance plotted by this view.
147
 
    model = Instance(DataGrid)
148
 
 
149
 
    colormap_scale = Enum('linear [default]', 'log [data_func]',
150
 
                            'power [data_func]', 'power [unit_func]',
151
 
                            'cos [unit_func]', 'sin [unit_func]')
152
 
    
153
 
    power = Float(1.0)
154
 
    
155
 
    colorbar_scale = Enum('linear', 'log')
156
 
 
157
 
    plot = Instance(Component)
158
 
    
159
 
    img_renderer = Property(Instance(CMapImagePlot), depends_on=['plot'])
160
 
    
161
 
    colorbar = Property(Instance(ColorBar), depends_on=['plot'])
162
 
 
163
 
 
164
 
    traits_view = View(
165
 
                    VGroup(
166
 
                        HGroup(
167
 
                            Item('colormap_scale'),
168
 
                            Item('power',
169
 
                                 editor=RangeEditor(low=0.1,
170
 
                                                    high=3.0,
171
 
                                                    format="%4.2f"),
172
 
                                 visible_when='colormap_scale.startswith("power")',
173
 
                                 springy=True),
174
 
                            Item('colorbar_scale'),
175
 
                            springy=True),
176
 
                        Group(
177
 
                            Item('plot',
178
 
                                 editor=ComponentEditor(size=(750,500)), 
179
 
                            show_label=False)),
180
 
                        ),
181
 
                    resizable=True, title="TransformColorMapper Demo",
182
 
                    )
183
 
 
184
 
 
185
 
    def _plot_default(self):
186
 
        return _create_plot_component(self.model)
187
 
 
188
 
    def _get_main_plot(self):
189
 
        return self.plot.components[0]
190
 
 
191
 
    def _get_img_renderer(self):
192
 
        return self.plot.components[0].components[0]
193
 
 
194
 
    def _get_colorbar(self):
195
 
        return self.plot.components[1]
196
 
 
197
 
    def _colormap_scale_changed(self):
198
 
        if self.colormap_scale == 'linear [default]':
199
 
            self.img_renderer.color_mapper.data_func = None
200
 
            self.img_renderer.color_mapper.unit_func = None
201
 
        elif self.colormap_scale == 'log [data_func]':
202
 
            self.img_renderer.color_mapper.data_func = log10
203
 
            self.img_renderer.color_mapper.unit_func = None
204
 
        elif self.colormap_scale == 'power [data_func]':
205
 
            self.img_renderer.color_mapper.data_func = lambda x: x**self.power
206
 
            self.img_renderer.color_mapper.unit_func = None
207
 
        elif self.colormap_scale == 'power [unit_func]':
208
 
            self.img_renderer.color_mapper.data_func = None
209
 
            self.img_renderer.color_mapper.unit_func = lambda x: x**self.power            
210
 
        elif self.colormap_scale == 'cos [unit_func]':
211
 
            self.img_renderer.color_mapper.data_func = None
212
 
            self.img_renderer.color_mapper.unit_func = lambda x: cos(0.5*pi*x) 
213
 
        elif self.colormap_scale == 'sin [unit_func]':
214
 
            self.img_renderer.color_mapper.data_func = None
215
 
            self.img_renderer.color_mapper.unit_func = lambda x: sin(0.5*pi*x)
216
 
        # FIXME: This call to request_redraw() should not be necessary.
217
 
        self.img_renderer.request_redraw()
218
 
 
219
 
    def _power_changed(self):
220
 
        if self.colormap_scale == 'power [data_func]':
221
 
            self.img_renderer.color_mapper.data_func = lambda x: x**self.power
222
 
        elif self.colormap_scale == 'power [unit_func]':
223
 
            self.img_renderer.color_mapper.unit_func = lambda x: x**self.power
224
 
        self.img_renderer.request_redraw()
225
 
 
226
 
    def _colorbar_scale_changed(self):
227
 
        rng = self.colorbar.index_mapper.range
228
 
        dlim = self.colorbar.index_mapper.domain_limits
229
 
        if self.colorbar_scale == 'linear':
230
 
            new_mapper = LinearMapper(range=rng, domain_limits=dlim)
231
 
        else:  # 'log'
232
 
            new_mapper = LogMapper(range=rng, domain_limits=dlim)
233
 
        self.colorbar.index_mapper = new_mapper
234
 
 
235
 
 
236
 
 
237
 
if __name__ == "__main__":
238
 
    grid = DataGrid(
239
 
                func = lambda x,y: 3.0**(x**2 + 2*(cos(2*pi*y)-1)),
240
 
                domain_bounds=(0.0,0.0, 2.0,2.0),
241
 
                grid_size=(200, 200))
242
 
    #print "data bounds: ", grid.data_min, grid.data_max
243
 
    demo = DataGridView(model=grid)
244
 
    demo.configure_traits()