~ubuntu-branches/ubuntu/utopic/python-chaco/utopic

« back to all changes in this revision

Viewing changes to examples/demo/advanced/data_cube.py

  • Committer: Package Import Robot
  • Author(s): Andrew Starr-Bochicchio
  • Date: 2014-06-01 17:04:08 UTC
  • mfrom: (7.2.5 sid)
  • Revision ID: package-import@ubuntu.com-20140601170408-m86xvdjd83a4qon0
Tags: 4.4.1-1ubuntu1
* Merge from Debian unstable. Remaining Ubuntu changes:
 - Let the binary-predeb target work on the usr/lib/python* directory
   as we don't have usr/share/pyshared anymore.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
"""
2
 
Allows isometric viewing of a 3D data cube.
3
 
 
4
 
Click or click-drag in any data window to set the slice to view.
5
 
"""
6
 
 
7
 
# Outstanding TODOs:
8
 
#  - need to add line inspectors to side and bottom plots, and synchronize
9
 
#    with center plot
10
 
#  - need to set the various image plots to use the same colormap instance,
11
 
#    and that colormap's range needs to be set to min/max of the entire cube
12
 
#  - refactor create_window() so there is less code duplication
13
 
#  - try to eliminate the use of model.xs, ys, zs in favor of bounds tuples
14
 
from numpy import amin, amax, zeros, fromfile, transpose, uint8
15
 
 
16
 
# Standard library imports
17
 
import os, sys, shutil
18
 
 
19
 
# Major library imports
20
 
from numpy import arange, linspace, nanmin, nanmax, newaxis, pi, sin, cos
21
 
 
22
 
# Enthought library imports
23
 
from chaco.api import ArrayPlotData, Plot, GridPlotContainer, \
24
 
                                 BaseTool, DataRange1D
25
 
from chaco.default_colormaps import *
26
 
from chaco.tools.api import LineInspector, ZoomTool
27
 
from enable.example_support import DemoFrame, demo_main
28
 
from enable.api import Window
29
 
from traits.api import Any, Array, Bool, Callable, CFloat, CInt, \
30
 
        Event, Float, HasTraits, Int, Trait, on_trait_change
31
 
 
32
 
# Will hold the path that the user chooses to download to. Will be an empty
33
 
# string if the user decides to download to the current directory.
34
 
dl_path = ''
35
 
 
36
 
# Determines if the script should ask the user if they would like to remove the
37
 
# downloaded data.  This defaults to False, because data deletion is
38
 
# irreversible, and in the worst case, the user will have to remove it
39
 
# manually themselves.
40
 
run_cleanup = False
41
 
 
42
 
class Model(HasTraits):
43
 
    npts_x = CInt(256)
44
 
    npts_y = CInt(256)
45
 
    npts_z = CInt(109)
46
 
 
47
 
    min_x = CFloat(-2*pi)
48
 
    max_x = CFloat(2*pi)
49
 
    min_y = CFloat(-2*pi)
50
 
    max_y = CFloat(2*pi)
51
 
    min_z = CFloat(-pi)
52
 
    max_z = CFloat(pi)
53
 
 
54
 
    xs = Array
55
 
    ys = Array
56
 
    vals = Array
57
 
 
58
 
    minval = Float
59
 
    maxval = Float
60
 
 
61
 
    model_changed = Event
62
 
 
63
 
    def __init__(self, *args, **kwargs):
64
 
        super(Model, self).__init__(*args, **kwargs)
65
 
        self.compute_model()
66
 
 
67
 
    @on_trait_change("npts_+", "min_+", "max_+")
68
 
    def compute_model(self):
69
 
        def vfunc(x, y, z):
70
 
            return sin(x*z) * cos(y)*sin(z) + sin(0.5*z)
71
 
 
72
 
        # Create the axes
73
 
        self.xs = linspace(self.min_x, self.max_x, self.npts_x)
74
 
        self.ys = linspace(self.min_y, self.max_y, self.npts_y)
75
 
        self.zs = linspace(self.min_z, self.max_z, self.npts_z)
76
 
 
77
 
        # Generate a cube of values by using newaxis to span new dimensions
78
 
        self.vals = vfunc(self.xs[:, newaxis, newaxis],
79
 
                          self.ys[newaxis, :, newaxis],
80
 
                          self.zs[newaxis, newaxis, :])
81
 
 
82
 
        self.minval = nanmin(self.vals)
83
 
        self.maxval = nanmax(self.vals)
84
 
        self.model_changed = True
85
 
 
86
 
 
87
 
class BrainModel(Model):
88
 
    def __init__(self, *args, **kwargs):
89
 
        download_data()
90
 
        super(BrainModel, self).__init__(*args, **kwargs)
91
 
 
92
 
    def compute_model(self):
93
 
        global dl_path
94
 
        mrbrain_path = os.path.join(dl_path, 'voldata', 'MRbrain.')
95
 
        nx = 256
96
 
        ny = 256
97
 
        nz = 109
98
 
        full_arr = zeros((nx, ny, nz), dtype='f')
99
 
        for i in range(1, 110):
100
 
            arr = fromfile(mrbrain_path + str(i), dtype='>u2')
101
 
            arr.shape = (256,256)
102
 
            full_arr[:,:,i-1] = arr
103
 
        self.vals = full_arr
104
 
 
105
 
        # Create the axes
106
 
        self.xs = arange(nx)
107
 
        self.ys = arange(ny)
108
 
        self.zs = arange(nz)
109
 
 
110
 
        # Generate a cube of values by using newaxis to span new dimensions
111
 
        self.minval = nanmin(self.vals)
112
 
        self.maxval = nanmax(self.vals)
113
 
        self.model_changed = True
114
 
 
115
 
 
116
 
class ImageIndexTool(BaseTool):
117
 
    """ A tool to set the slice of a cube based on the user's mouse movements
118
 
    or clicks.
119
 
    """
120
 
 
121
 
    # This callback will be called with the index into self.component's
122
 
    # index and value:
123
 
    #     callback(tool, x_index, y_index)
124
 
    # where *tool* is a reference to this tool instance.  The callback
125
 
    # can then use tool.token.
126
 
    callback = Any()
127
 
 
128
 
    # This callback (if it exists) will be called with the integer number
129
 
    # of mousewheel clicks
130
 
    wheel_cb = Any()
131
 
 
132
 
    # This token can be used by the callback to decide how to process
133
 
    # the event.
134
 
    token  = Any()
135
 
 
136
 
    # Whether or not to update the slice info; we enter select mode when
137
 
    # the left mouse button is pressed and exit it when the mouse button
138
 
    # is released
139
 
    # FIXME: This is not used right now.
140
 
    select_mode = Bool(False)
141
 
 
142
 
    def normal_left_down(self, event):
143
 
        self._update_slices(event)
144
 
 
145
 
    def normal_right_down(self, event):
146
 
        self._update_slices(event)
147
 
 
148
 
    def normal_mouse_move(self, event):
149
 
        if event.left_down or event.right_down:
150
 
            self._update_slices(event)
151
 
 
152
 
    def _update_slices(self, event):
153
 
            plot = self.component
154
 
            ndx = plot.map_index((event.x, event.y),
155
 
                                 threshold=5.0, index_only=True)
156
 
            if ndx:
157
 
                self.callback(self, *ndx)
158
 
 
159
 
    def normal_mouse_wheel(self, event):
160
 
        if self.wheel_cb is not None:
161
 
            self.wheel_cb(self, event.mouse_wheel)
162
 
 
163
 
 
164
 
class PlotFrame(DemoFrame):
165
 
 
166
 
    # These are the indices into the cube that each of the image plot views
167
 
    # will show; the default values are non-zero just to make it a little
168
 
    # interesting.
169
 
    slice_x = 10
170
 
    slice_y = 10
171
 
    slice_z = 10
172
 
 
173
 
    num_levels = Int(15)
174
 
    colormap = Any
175
 
    colorcube = Any
176
 
 
177
 
    #---------------------------------------------------------------------------
178
 
    # Private Traits
179
 
    #---------------------------------------------------------------------------
180
 
 
181
 
    _cmap = Trait(jet, Callable)
182
 
 
183
 
    def _index_callback(self, tool, x_index, y_index):
184
 
        plane = tool.token
185
 
        if plane == "xy":
186
 
            self.slice_x = x_index
187
 
            self.slice_y = y_index
188
 
        elif plane == "yz":
189
 
            # transposed because the plot is oriented vertically
190
 
            self.slice_z = x_index
191
 
            self.slice_y = y_index
192
 
        elif plane == "xz":
193
 
            self.slice_x = x_index
194
 
            self.slice_z = y_index
195
 
        else:
196
 
            warnings.warn("Unrecognized plane for _index_callback: %s" % plane)
197
 
        self._update_images()
198
 
        self.center.invalidate_and_redraw()
199
 
        self.right.invalidate_and_redraw()
200
 
        self.bottom.invalidate_and_redraw()
201
 
        return
202
 
 
203
 
    def _wheel_callback(self, tool, wheelamt):
204
 
        plane_slice_dict = {"xy": ("slice_z", 2),
205
 
                            "yz": ("slice_x", 0),
206
 
                            "xz": ("slice_y", 1)}
207
 
        attr, shape_ndx = plane_slice_dict[tool.token]
208
 
        val = getattr(self, attr)
209
 
        max = self.model.vals.shape[shape_ndx]
210
 
        if val + wheelamt > max:
211
 
            setattr(self, attr, max-1)
212
 
        elif val + wheelamt < 0:
213
 
            setattr(self, attr, 0)
214
 
        else:
215
 
            setattr(self, attr, val + wheelamt)
216
 
 
217
 
        self._update_images()
218
 
        self.center.invalidate_and_redraw()
219
 
        self.right.invalidate_and_redraw()
220
 
        self.bottom.invalidate_and_redraw()
221
 
        return
222
 
 
223
 
    def _create_window(self):
224
 
        # Create the model
225
 
        #try:
226
 
        #    self.model = model = BrainModel()
227
 
        #    cmap = bone
228
 
        #except SystemExit:
229
 
        #    sys.exit()
230
 
        #except:
231
 
        #    print "Unable to load BrainModel, using generated data cube."
232
 
        self.model = model = Model()
233
 
        cmap = jet
234
 
        self._update_model(cmap)
235
 
 
236
 
        datacube = self.colorcube
237
 
 
238
 
        # Create the plot
239
 
        self.plotdata = ArrayPlotData()
240
 
        self._update_images()
241
 
 
242
 
        # Center Plot
243
 
        centerplot = Plot(self.plotdata, padding=0)
244
 
        imgplot = centerplot.img_plot("xy",
245
 
                                xbounds=(model.xs[0], model.xs[-1]),
246
 
                                ybounds=(model.ys[0], model.ys[-1]),
247
 
                                colormap=cmap)[0]
248
 
        self._add_plot_tools(imgplot, "xy")
249
 
        self.center = imgplot
250
 
 
251
 
        # Right Plot
252
 
        rightplot = Plot(self.plotdata, width=150, resizable="v", padding=0)
253
 
        rightplot.value_range = centerplot.value_range
254
 
        imgplot = rightplot.img_plot("yz",
255
 
                                xbounds=(model.zs[0], model.zs[-1]),
256
 
                                ybounds=(model.ys[0], model.ys[-1]),
257
 
                                colormap=cmap)[0]
258
 
        self._add_plot_tools(imgplot, "yz")
259
 
        self.right = imgplot
260
 
 
261
 
        # Bottom Plot
262
 
        bottomplot = Plot(self.plotdata, height=150, resizable="h", padding=0)
263
 
        bottomplot.index_range = centerplot.index_range
264
 
        imgplot = bottomplot.img_plot("xz",
265
 
                                xbounds=(model.xs[0], model.xs[-1]),
266
 
                                ybounds=(model.zs[0], model.zs[-1]),
267
 
                                colormap=cmap)[0]
268
 
        self._add_plot_tools(imgplot, "xz")
269
 
        self.bottom = imgplot
270
 
 
271
 
        # Create Container and add all Plots
272
 
        container = GridPlotContainer(padding=20, fill_padding=True,
273
 
                                      bgcolor="white", use_backbuffer=True,
274
 
                                      shape=(2,2), spacing=(12,12))
275
 
        container.add(centerplot)
276
 
        container.add(rightplot)
277
 
        container.add(bottomplot)
278
 
 
279
 
        self.container = container
280
 
        return Window(self, -1, component=container)
281
 
 
282
 
    def _add_plot_tools(self, imgplot, token):
283
 
        """ Add LineInspectors, ImageIndexTool, and ZoomTool to the image plots. """
284
 
 
285
 
        imgplot.overlays.append(ZoomTool(component=imgplot, tool_mode="box",
286
 
                                           enable_wheel=False, always_on=False))
287
 
        imgplot.overlays.append(LineInspector(imgplot, axis="index_y", color="white",
288
 
            inspect_mode="indexed", write_metadata=True, is_listener=True))
289
 
        imgplot.overlays.append(LineInspector(imgplot, axis="index_x", color="white",
290
 
            inspect_mode="indexed", write_metadata=True, is_listener=True))
291
 
        imgplot.tools.append(ImageIndexTool(imgplot, token=token,
292
 
            callback=self._index_callback, wheel_cb=self._wheel_callback))
293
 
 
294
 
    def _update_model(self, cmap):
295
 
        range = DataRange1D(low=amin(self.model.vals),
296
 
                            high=amax(self.model.vals))
297
 
        self.colormap = cmap(range)
298
 
        self.colorcube = (self.colormap.map_screen(self.model.vals) * 255).astype(uint8)
299
 
 
300
 
    def _update_images(self):
301
 
        """ Updates the image data in self.plotdata to correspond to the
302
 
        slices given.
303
 
        """
304
 
        cube = self.colorcube
305
 
        pd = self.plotdata
306
 
        # These are transposed because img_plot() expects its data to be in
307
 
        # row-major order
308
 
        pd.set_data("xy", transpose(cube[:, :, self.slice_z], (1,0,2)))
309
 
        pd.set_data("xz", transpose(cube[:, self.slice_y, :], (1,0,2)))
310
 
        pd.set_data("yz", cube[self.slice_x, :, :])
311
 
 
312
 
def download_data():
313
 
    global dl_path, run_cleanup
314
 
 
315
 
    print 'Please enter the location of the "voldata" subdirectory containing'
316
 
    print 'the data files for this demo, or enter a path to download to (7.8MB).'
317
 
    print 'Press <ENTER> to download to the current directory.'
318
 
    dl_path = raw_input('Path: ').strip().rstrip("/").rstrip("\\")
319
 
 
320
 
    if not dl_path.endswith("voldata"):
321
 
        voldata_path = os.path.join(dl_path, 'voldata')
322
 
    else:
323
 
        voldata_path = dl_path
324
 
    tar_path = os.path.join(dl_path, 'MRbrain.tar.gz')
325
 
 
326
 
    data_good = True
327
 
    try:
328
 
        for i in range(1,110):
329
 
            if not os.path.isfile(os.path.join(voldata_path, "MRbrain.%d" % i)):
330
 
                data_good = False
331
 
                break
332
 
        else:
333
 
            data_good = True
334
 
    except:
335
 
        data_good = False
336
 
 
337
 
    if not data_good:
338
 
        import urllib
339
 
        import tarfile
340
 
 
341
 
        if len(dl_path) > 0 and not os.path.exists(dl_path):
342
 
            print 'The given path does not exist.'
343
 
            run_cleanup = False
344
 
            sys.exit()
345
 
 
346
 
        if not os.path.isabs(dl_path):
347
 
            print 'Downloading to: ' + os.path.join(os.getcwd(), dl_path)
348
 
        else:
349
 
            print 'Downloading to: ' + dl_path
350
 
 
351
 
        try:
352
 
            # download and extract the file
353
 
            print "Downloading data, Please Wait (7.8MB)"
354
 
            opener = urllib.urlopen('http://www-graphics.stanford.edu/data/voldata/MRbrain.tar.gz')
355
 
        except:
356
 
            print 'Download error. Opening backup data.'
357
 
            run_cleanup = False
358
 
            raise
359
 
 
360
 
        try:
361
 
            open(tar_path, 'wb').write(opener.read())
362
 
        except:
363
 
            print 'Cannot write to the destination directory specified. ' \
364
 
                  'Opening backup data.'
365
 
            run_cleanup = False
366
 
            raise
367
 
 
368
 
        tar_file = tarfile.open(tar_path)
369
 
        try:
370
 
            os.mkdir(voldata_path)
371
 
        except:
372
 
            pass
373
 
        tar_file.extractall(voldata_path)
374
 
        tar_file.close()
375
 
        os.unlink(tar_path)
376
 
    else:
377
 
        print 'Previously downloaded data detected.'
378
 
 
379
 
def cleanup_data():
380
 
    global dl_path
381
 
 
382
 
    answer = raw_input('Remove downloaded files? [Y/N]: ')
383
 
    if answer.lower() == 'y':
384
 
        try:
385
 
            shutil.rmtree(os.path.join(dl_path, 'voldata'))
386
 
        except:
387
 
            pass
388
 
 
389
 
if __name__ == "__main__":
390
 
    demo_main(PlotFrame, size=(800,700), title="Cube analyzer")
391
 
    if run_cleanup:
392
 
        cleanup_data()
393