~ubuntu-branches/ubuntu/precise/python-chaco/precise

« back to all changes in this revision

Viewing changes to examples/advanced/data_cube.py

  • Committer: Bazaar Package Importer
  • Author(s): Varun Hiremath
  • Date: 2008-12-29 02:34:05 UTC
  • Revision ID: james.westby@ubuntu.com-20081229023405-x7i4kp9mdxzmdnvu
Tags: upstream-3.0.1
ImportĀ upstreamĀ versionĀ 3.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
"""
 
2
Allows isometric viewing of a 3D data cube.
 
3
 
 
4
Click or click-drag in any data window to set the slice to view.
 
5
"""
 
6
 
 
7
# Outstanding TODOs:
 
8
#  - need to add line inspectors to side and bottom plots, and synchronize
 
9
#    with center plot
 
10
#  - need to set the various image plots to use the same colormap instance,
 
11
#    and that colormap's range needs to be set to min/max of the entire cube
 
12
#  - refactor create_window() so there is less code duplication
 
13
#  - try to eliminate the use of model.xs, ys, zs in favor of bounds tuples
 
14
from numpy import amin, amax, zeros, fromfile, transpose, uint8
 
15
 
 
16
# Standard library imports
 
17
import os, sys, shutil
 
18
 
 
19
# Major library imports
 
20
from numpy import arange, linspace, nanmin, nanmax, newaxis, pi, sin, cos
 
21
 
 
22
# Enthought library imports
 
23
from enthought.chaco.api import ArrayPlotData, Plot, GridPlotContainer, \
 
24
                                 BaseTool, DataRange1D
 
25
from enthought.chaco.default_colormaps import *
 
26
from enthought.chaco.tools.api import LineInspector, SimpleZoom
 
27
from enthought.enable.example_support import DemoFrame, demo_main
 
28
from enthought.enable.api import Window
 
29
from enthought.traits.api import Any, Array, Bool, Callable, CFloat, CInt, \
 
30
        Event, Float, HasTraits, Int, Trait, on_trait_change
 
31
        
 
32
# Will hold the path that the user chooses to download to. Will be an empty
 
33
# string if the user decides to download to the current directory.
 
34
dl_path = ''
 
35
 
 
36
# Determines if the script should ask the user if they would like to remove the
 
37
# downloaded data.  This defaults to False, because data deletion is 
 
38
# irreversible, and in the worst case, the user will have to remove it
 
39
# manually themselves. 
 
40
run_cleanup = False
 
41
 
 
42
class Model(HasTraits):
 
43
    npts_x = CInt(256)
 
44
    npts_y = CInt(256)
 
45
    npts_z = CInt(109)
 
46
 
 
47
    min_x = CFloat(-2*pi)
 
48
    max_x = CFloat(2*pi)
 
49
    min_y = CFloat(-2*pi)
 
50
    max_y = CFloat(2*pi)
 
51
    min_z = CFloat(-pi)
 
52
    max_z = CFloat(pi)
 
53
 
 
54
    xs = Array
 
55
    ys = Array
 
56
    vals = Array
 
57
 
 
58
    minval = Float
 
59
    maxval = Float
 
60
 
 
61
    model_changed = Event
 
62
 
 
63
    def __init__(self, *args, **kwargs):
 
64
        super(Model, self).__init__(*args, **kwargs)
 
65
        self.compute_model()
 
66
 
 
67
    @on_trait_change("npts_+", "min_+", "max_+")
 
68
    def compute_model(self):
 
69
        def vfunc(x, y, z):
 
70
            return sin(x*z) * cos(y)*sin(z) + sin(0.5*z)
 
71
 
 
72
        # Create the axes
 
73
        self.xs = linspace(self.min_x, self.max_x, self.npts_x)
 
74
        self.ys = linspace(self.min_y, self.max_y, self.npts_y)
 
75
        self.zs = linspace(self.min_z, self.max_z, self.npts_z)
 
76
 
 
77
        # Generate a cube of values by using newaxis to span new dimensions
 
78
        self.vals = vfunc(self.xs[:, newaxis, newaxis],
 
79
                          self.ys[newaxis, :, newaxis],
 
80
                          self.zs[newaxis, newaxis, :])
 
81
 
 
82
        self.minval = nanmin(self.vals)
 
83
        self.maxval = nanmax(self.vals)
 
84
        self.model_changed = True
 
85
 
 
86
 
 
87
class BrainModel(Model):
 
88
    def __init__(self, *args, **kwargs):
 
89
        download_data()
 
90
        super(BrainModel, self).__init__(*args, **kwargs)
 
91
 
 
92
    def compute_model(self):
 
93
        global dl_path
 
94
        mrbrain_path = os.path.join(dl_path, 'voldata', 'MRbrain.')
 
95
        nx = 256
 
96
        ny = 256
 
97
        nz = 109
 
98
        full_arr = zeros((nx, ny, nz), dtype='f')
 
99
        for i in range(1, 110):
 
100
            arr = fromfile(mrbrain_path + str(i), dtype='>u2')
 
101
            arr.shape = (256,256)
 
102
            full_arr[:,:,i-1] = arr
 
103
        self.vals = full_arr
 
104
 
 
105
        # Create the axes
 
106
        self.xs = arange(nx)
 
107
        self.ys = arange(ny)
 
108
        self.zs = arange(nz)
 
109
 
 
110
        # Generate a cube of values by using newaxis to span new dimensions
 
111
        self.minval = nanmin(self.vals)
 
112
        self.maxval = nanmax(self.vals)
 
113
        self.model_changed = True
 
114
 
 
115
 
 
116
class ImageIndexTool(BaseTool):
 
117
    """ A tool to set the slice of a cube based on the user's mouse movements
 
118
    or clicks.
 
119
    """
 
120
 
 
121
    # This callback will be called with the index into self.component's
 
122
    # index and value:
 
123
    #     callback(tool, x_index, y_index)
 
124
    # where *tool* is a reference to this tool instance.  The callback
 
125
    # can then use tool.token.
 
126
    callback = Any()
 
127
 
 
128
    # This callback (if it exists) will be called with the integer number
 
129
    # of mousewheel clicks
 
130
    wheel_cb = Any()
 
131
 
 
132
    # This token can be used by the callback to decide how to process
 
133
    # the event.
 
134
    token  = Any()
 
135
 
 
136
    # Whether or not to update the slice info; we enter select mode when
 
137
    # the left mouse button is pressed and exit it when the mouse button
 
138
    # is released
 
139
    # FIXME: This is not used right now.
 
140
    select_mode = Bool(False)
 
141
 
 
142
    def normal_left_down(self, event):
 
143
        self._update_slices(event)
 
144
 
 
145
    def normal_right_down(self, event):
 
146
        self._update_slices(event)
 
147
 
 
148
    def normal_mouse_move(self, event):
 
149
        if event.left_down or event.right_down:
 
150
            self._update_slices(event)
 
151
 
 
152
    def _update_slices(self, event):
 
153
            plot = self.component
 
154
            ndx = plot.map_index((event.x, event.y), 
 
155
                                 threshold=5.0, index_only=True)
 
156
            if ndx:
 
157
                self.callback(self, *ndx)
 
158
 
 
159
    def normal_mouse_wheel(self, event):
 
160
        if self.wheel_cb is not None:
 
161
            self.wheel_cb(self, event.mouse_wheel)
 
162
 
 
163
 
 
164
class PlotFrame(DemoFrame):
 
165
 
 
166
    # These are the indices into the cube that each of the image plot views
 
167
    # will show; the default values are non-zero just to make it a little
 
168
    # interesting.
 
169
    slice_x = 10
 
170
    slice_y = 10
 
171
    slice_z = 10
 
172
 
 
173
    num_levels = Int(15)
 
174
    colormap = Any
 
175
    colorcube = Any
 
176
 
 
177
    #---------------------------------------------------------------------------
 
178
    # Private Traits
 
179
    #---------------------------------------------------------------------------
 
180
        
 
181
    _cmap = Trait(jet, Callable)
 
182
 
 
183
    def _index_callback(self, tool, x_index, y_index):
 
184
        plane = tool.token
 
185
        if plane == "xy":
 
186
            self.slice_x = x_index
 
187
            self.slice_y = y_index
 
188
        elif plane == "yz":
 
189
            # transposed because the plot is oriented vertically
 
190
            self.slice_z = x_index
 
191
            self.slice_y = y_index
 
192
        elif plane == "xz":
 
193
            self.slice_x = x_index
 
194
            self.slice_z = y_index
 
195
        else:
 
196
            warnings.warn("Unrecognized plane for _index_callback: %s" % plane)
 
197
        self._update_images()
 
198
        self.center.invalidate_and_redraw()
 
199
        self.right.invalidate_and_redraw()
 
200
        self.bottom.invalidate_and_redraw()
 
201
        return
 
202
 
 
203
    def _wheel_callback(self, tool, wheelamt):
 
204
        plane_slice_dict = {"xy": ("slice_z", 2), 
 
205
                            "yz": ("slice_x", 0),
 
206
                            "xz": ("slice_y", 1)}
 
207
        attr, shape_ndx = plane_slice_dict[tool.token]
 
208
        val = getattr(self, attr)
 
209
        max = self.model.vals.shape[shape_ndx]
 
210
        if val + wheelamt > max:
 
211
            setattr(self, attr, max-1)
 
212
        elif val + wheelamt < 0:
 
213
            setattr(self, attr, 0)
 
214
        else:
 
215
            setattr(self, attr, val + wheelamt)
 
216
 
 
217
        self._update_images()
 
218
        self.center.invalidate_and_redraw()
 
219
        self.right.invalidate_and_redraw()
 
220
        self.bottom.invalidate_and_redraw()
 
221
        return
 
222
    
 
223
    def _create_window(self):
 
224
        # Create the model
 
225
        try:
 
226
            self.model = model = BrainModel()
 
227
            cmap = bone
 
228
        except SystemExit:
 
229
            sys.exit()
 
230
        except:
 
231
            print "Unable to load BrainModel, using generated data cube."
 
232
            self.model = model = Model()
 
233
            cmap = jet
 
234
        self._update_model(cmap)
 
235
 
 
236
        datacube = self.colorcube
 
237
 
 
238
        # Create the plot
 
239
        self.plotdata = ArrayPlotData()
 
240
        self._update_images()
 
241
 
 
242
        # Center Plot
 
243
        centerplot = Plot(self.plotdata, padding=0)
 
244
        imgplot = centerplot.img_plot("xy", xbounds=model.xs, ybounds=model.ys, 
 
245
                            colormap=cmap)[0]
 
246
        self._add_plot_tools(imgplot, "xy")
 
247
        self.center = imgplot
 
248
 
 
249
        # Right Plot
 
250
        rightplot = Plot(self.plotdata, width=150, resizable="v", padding=0)
 
251
        rightplot.value_range = centerplot.value_range
 
252
        imgplot = rightplot.img_plot("yz", xbounds=model.zs, ybounds=model.ys,
 
253
                                     colormap=cmap)[0]
 
254
        self._add_plot_tools(imgplot, "yz")
 
255
        self.right = imgplot
 
256
 
 
257
        # Bottom Plot
 
258
        bottomplot = Plot(self.plotdata, height=150, resizable="h", padding=0)
 
259
        bottomplot.index_range = centerplot.index_range
 
260
        imgplot = bottomplot.img_plot("xz", xbounds=model.xs, ybounds=model.zs,
 
261
                                      colormap=cmap)[0]
 
262
        self._add_plot_tools(imgplot, "xz")
 
263
        self.bottom = imgplot
 
264
 
 
265
        # Create Container and add all Plots
 
266
        container = GridPlotContainer(padding=20, fill_padding=True,
 
267
                                      bgcolor="white", use_backbuffer=True,
 
268
                                      shape=(2,2), spacing=(12,12))
 
269
        container.add(centerplot)
 
270
        container.add(rightplot)
 
271
        container.add(bottomplot)
 
272
 
 
273
        self.container = container
 
274
        return Window(self, -1, component=container)
 
275
    
 
276
    def _add_plot_tools(self, imgplot, token):
 
277
        """ Add LineInspectors, ImageIndexTool, and SimpleZoom to the image plots. """
 
278
        
 
279
        imgplot.overlays.append(SimpleZoom(component=imgplot, tool_mode="box",
 
280
                                           enable_wheel=False, always_on=False))
 
281
        imgplot.overlays.append(LineInspector(imgplot, axis="index_y", color="white",
 
282
            inspect_mode="indexed", write_metadata=True, is_listener=True))
 
283
        imgplot.overlays.append(LineInspector(imgplot, axis="index_x", color="white",
 
284
            inspect_mode="indexed", write_metadata=True, is_listener=True))
 
285
        imgplot.tools.append(ImageIndexTool(imgplot, token=token, 
 
286
            callback=self._index_callback, wheel_cb=self._wheel_callback))
 
287
 
 
288
    def _update_model(self, cmap):
 
289
        range = DataRange1D(low=amin(self.model.vals), 
 
290
                            high=amax(self.model.vals))
 
291
        self.colormap = cmap(range)
 
292
        self.colorcube = (self.colormap.map_screen(self.model.vals) * 255).astype(uint8)
 
293
        
 
294
    def _update_images(self):
 
295
        """ Updates the image data in self.plotdata to correspond to the 
 
296
        slices given.
 
297
        """
 
298
        cube = self.colorcube
 
299
        pd = self.plotdata
 
300
        # These are transposed because img_plot() expects its data to be in 
 
301
        # row-major order
 
302
        pd.set_data("xy", transpose(cube[:, :, self.slice_z], (1,0,2)))
 
303
        pd.set_data("xz", transpose(cube[:, self.slice_y, :], (1,0,2)))
 
304
        pd.set_data("yz", cube[self.slice_x, :, :])
 
305
 
 
306
def download_data():
 
307
    global dl_path, run_cleanup
 
308
    
 
309
    print 'Please enter the location of the "voldata" subdirectory containing'
 
310
    print 'the data files for this demo, or enter a path to download to (7.8MB).'
 
311
    print 'Press <ENTER> to download to the current directory.'
 
312
    dl_path = raw_input('Path: ').strip().rstrip("/").rstrip("\\")
 
313
    
 
314
    if not dl_path.endswith("voldata"):
 
315
        voldata_path = os.path.join(dl_path, 'voldata')
 
316
    else:
 
317
        voldata_path = dl_path
 
318
    tar_path = os.path.join(dl_path, 'MRbrain.tar.gz')
 
319
    
 
320
    data_good = True
 
321
    try:
 
322
        for i in range(1,110):
 
323
            if not os.path.isfile(os.path.join(voldata_path, "MRbrain.%d" % i)):
 
324
                data_good = False
 
325
                break
 
326
        else:
 
327
            data_good = True
 
328
    except:
 
329
        data_good = False
 
330
    
 
331
    if not data_good:
 
332
        import urllib
 
333
        import tarfile
 
334
 
 
335
        if len(dl_path) > 0 and not os.path.exists(dl_path):
 
336
            print 'The given path does not exist.'
 
337
            run_cleanup = False
 
338
            sys.exit()
 
339
 
 
340
        if not os.path.isabs(dl_path):
 
341
            print 'Downloading to: ' + os.path.join(os.getcwd(), dl_path)
 
342
        else:
 
343
            print 'Downloading to: ' + dl_path
 
344
        
 
345
        try:
 
346
            # download and extract the file
 
347
            print "Downloading data, Please Wait (7.8MB)"
 
348
            opener = urllib.urlopen('http://www-graphics.stanford.edu/data/voldata/MRbrain.tar.gz')
 
349
        except:
 
350
            print 'Download error. Opening backup data.'
 
351
            run_cleanup = False
 
352
            raise
 
353
        
 
354
        try:
 
355
            open(tar_path, 'wb').write(opener.read())
 
356
        except:
 
357
            print 'Cannot write to the destination directory specified. ' \
 
358
                  'Opening backup data.'
 
359
            run_cleanup = False
 
360
            raise
 
361
        
 
362
        tar_file = tarfile.open(tar_path)
 
363
        try:
 
364
            os.mkdir(voldata_path)
 
365
        except:
 
366
            pass
 
367
        tar_file.extractall(voldata_path)
 
368
        tar_file.close()
 
369
        os.unlink(tar_path)
 
370
    else:
 
371
        print 'Previously downloaded data detected.'
 
372
        
 
373
def cleanup_data():
 
374
    global dl_path
 
375
    
 
376
    answer = raw_input('Remove downloaded files? [Y/N]: ')
 
377
    if answer.lower() == 'y':
 
378
        try:
 
379
            shutil.rmtree(os.path.join(dl_path, 'voldata'))
 
380
        except:
 
381
            pass
 
382
        
 
383
if __name__ == "__main__":
 
384
    demo_main(PlotFrame, size=(800,700), title="Cube analyzer")
 
385
    if run_cleanup:
 
386
        cleanup_data()
 
387