~ubuntu-branches/ubuntu/precise/python-chaco/precise

« back to all changes in this revision

Viewing changes to enthought/chaco/barplot.py

  • Committer: Bazaar Package Importer
  • Author(s): Varun Hiremath
  • Date: 2008-12-29 02:34:05 UTC
  • Revision ID: james.westby@ubuntu.com-20081229023405-x7i4kp9mdxzmdnvu
Tags: upstream-3.0.1
ImportĀ upstreamĀ versionĀ 3.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
""" Defines the BarPlot class.
 
2
"""
 
3
import logging
 
4
 
 
5
from numpy import array, compress, column_stack, invert, isnan, transpose, zeros
 
6
from enthought.traits.api import Any, Bool, Enum, Float, Instance, Property
 
7
from enthought.enable.api import black_color_trait
 
8
from enthought.kiva import FILL_STROKE
 
9
 
 
10
# Local relative imports
 
11
from enthought.chaco.abstract_plot_renderer import AbstractPlotRenderer
 
12
from abstract_mapper import AbstractMapper
 
13
from array_data_source import ArrayDataSource
 
14
from base import reverse_map_1d
 
15
 
 
16
 
 
17
logger = logging.getLogger(__name__)
 
18
 
 
19
 
 
20
class BarPlot(AbstractPlotRenderer):
 
21
    """
 
22
    A renderer for bar charts.
 
23
    """
 
24
    # The data source to use for the index coordinate.
 
25
    index = Instance(ArrayDataSource)
 
26
 
 
27
    # The data source to use as value points.
 
28
    value = Instance(ArrayDataSource)
 
29
 
 
30
    # The data source to use as "starting" values for the bars.
 
31
    starting_value = Instance(ArrayDataSource)
 
32
 
 
33
    # Labels for the indices.
 
34
    index_mapper = Instance(AbstractMapper)
 
35
    # Labels for the values.
 
36
    value_mapper = Instance(AbstractMapper)
 
37
 
 
38
    # The orientation of the index axis.
 
39
    orientation = Enum("h", "v")
 
40
 
 
41
    # The direction of the index axis with respect to the graphics context's 
 
42
    # direction.
 
43
    index_direction = Enum("normal", "flipped")
 
44
 
 
45
    # The direction of the value axis with respect to the graphics context's 
 
46
    # direction.
 
47
    value_direction = Enum("normal", "flipped")
 
48
 
 
49
    # Type of width used for bars:
 
50
    #
 
51
    # 'data' 
 
52
    #     The width is in the units along the x-dimension of the data space.  
 
53
    # 'screen' 
 
54
    #     The width uses a fixed width of pixels.
 
55
    bar_width_type = Enum("data", "screen")
 
56
 
 
57
    # Width of the bars, in data or screen space (determined by 
 
58
    # **bar_width_type**).
 
59
    bar_width = Float(10)
 
60
 
 
61
    # Round on rectangle dimensions? This is not strictly an "antialias", but
 
62
    # it has the same effect through exact pixel drawing.
 
63
    antialias = Bool(True)
 
64
 
 
65
    # Width of the border of the bars.
 
66
    line_width = Float(1.0)
 
67
    # Color of the border of the bars.
 
68
    line_color = black_color_trait
 
69
    # Color to fill the bars.
 
70
    fill_color = black_color_trait
 
71
 
 
72
    #use_draw_order = False
 
73
 
 
74
    # Convenience properties that correspond to either index_mapper or
 
75
    # value_mapper, depending on the orientation of the plot.
 
76
 
 
77
    # Corresponds to either **index_mapper** or **value_mapper**, depending on
 
78
    # the orientation of the plot.
 
79
    x_mapper = Property
 
80
    # Corresponds to either **value_mapper** or **index_mapper**, depending on
 
81
    # the orientation of the plot.
 
82
    y_mapper = Property
 
83
 
 
84
    # Corresponds to either **index_direction** or **value_direction**, 
 
85
    # depending on the orientation of the plot.
 
86
    x_direction = Property
 
87
    # Corresponds to either **value_direction** or **index_direction**, 
 
88
    # depending on the orientation of the plot
 
89
    y_direction = Property
 
90
 
 
91
    # Convenience property for accessing the index data range.
 
92
    index_range = Property
 
93
    # Convenience property for accessing the value data range.
 
94
    value_range = Property
 
95
 
 
96
 
 
97
    #------------------------------------------------------------------------
 
98
    # Private traits
 
99
    #------------------------------------------------------------------------
 
100
 
 
101
    # Indicates whether or not the data cache is valid
 
102
    _cache_valid = Bool(False)
 
103
 
 
104
    # Cached data values from the datasources.  If **bar_width_type** is "data",
 
105
    # then this is an Nx4 array of (bar_left, bar_right, start, end) for a
 
106
    # bar plot in normal orientation.  If **bar_width_type** is "screen", then
 
107
    # this is an Nx3 array of (bar_center, start, end).
 
108
    _cached_data_pts = Any
 
109
 
 
110
 
 
111
    #------------------------------------------------------------------------
 
112
    # AbstractPlotRenderer interface
 
113
    #------------------------------------------------------------------------
 
114
 
 
115
    def map_screen(self, data_array):
 
116
        """ Maps an array of data points into screen space and returns it as
 
117
        an array. 
 
118
        
 
119
        Implements the AbstractPlotRenderer interface.
 
120
        """
 
121
        # data_array is Nx2 array
 
122
        if len(data_array) == 0:
 
123
            return []
 
124
        x_ary, y_ary = transpose(data_array)
 
125
        sx = self.index_mapper.map_screen(x_ary)
 
126
        sy = self.value_mapper.map_screen(y_ary)
 
127
 
 
128
        # reverse the directions as indicated
 
129
        if self.index_direction == "flipped":
 
130
            x_sign = -1.0
 
131
            x_offset = self.bounds[0]
 
132
        else:
 
133
            x_sign = 1.0
 
134
            x_offset = 0.0
 
135
 
 
136
        if self.value_direction == "flipped":
 
137
            y_sign = -1.0
 
138
            y_offset = self.bounds[1]
 
139
        else:
 
140
            y_sign = 1.0
 
141
            y_offset = 0.0
 
142
 
 
143
        # now return map based on orientation
 
144
        if self.orientation == "h":
 
145
            return transpose(array((x_sign*sx+x_offset,y_sign*sy+y_offset)))
 
146
        else:
 
147
            return transpose(array((y_sign*sy+y_offset,x_sign*sx+x_offset)))
 
148
 
 
149
    def map_data(self, screen_pt):
 
150
        """ Maps a screen space point into the "index" space of the plot.
 
151
        
 
152
        Implements the AbstractPlotRenderer interface.
 
153
        """
 
154
        if self.orientation == "h":
 
155
            screen_coord = screen_pt[0]
 
156
        else:
 
157
            screen_coord = screen_pt[1]
 
158
        return self.index_mapper.map_data(screen_coord)
 
159
 
 
160
    def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, \
 
161
                  index_only=False):
 
162
        """ Maps a screen space point to an index into the plot's index array(s).
 
163
        
 
164
        Implements the AbstractPlotRenderer interface.
 
165
        """
 
166
        data_pt = self.map_data(screen_pt)
 
167
        if ((data_pt < self.index_mapper.range.low) or \
 
168
            (data_pt > self.index_mapper.range.high)) and outside_returns_none:
 
169
            return None
 
170
        half = threshold / 2.0
 
171
        index_data = self.index.get_data()
 
172
        value_data = self.value.get_data()
 
173
 
 
174
        if len(value_data) == 0 or len(index_data) == 0:
 
175
            return None
 
176
 
 
177
        try:
 
178
            ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
 
179
        except IndexError:
 
180
            return None
 
181
 
 
182
        x = index_data[ndx]
 
183
        y = value_data[ndx]
 
184
        
 
185
        result = self.map_screen(array([[x,y]]))
 
186
        if result is None:
 
187
            return None
 
188
 
 
189
        sx, sy = result[0]
 
190
        if index_only and ((screen_pt[0]-sx) < threshold):
 
191
            return ndx
 
192
        elif ((screen_pt[0]-sx)**2 + (screen_pt[1]-sy)**2 < threshold*threshold):
 
193
            return ndx
 
194
        else:
 
195
            return None
 
196
 
 
197
    #------------------------------------------------------------------------
 
198
    # PlotComponent interface
 
199
    #------------------------------------------------------------------------
 
200
 
 
201
    def _gather_points(self):
 
202
        """ Collects data points that are within the range of the plot, and
 
203
        caches them in **_cached_data_pts**.
 
204
        """
 
205
        index, index_mask = self.index.get_data_mask()
 
206
        value, value_mask = self.value.get_data_mask()
 
207
 
 
208
        if not self.index or not self.value:
 
209
            return
 
210
 
 
211
        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
 
212
            logger.warn("Chaco: using empty dataset; index_len=%d, value_len=%d." \
 
213
                                % (len(index), len(value)))
 
214
            self._cached_data_pts = array([])
 
215
            self._cache_valid = True
 
216
            return
 
217
 
 
218
        # TODO: Until we code up a better handling of value-based culling that
 
219
        # takes into account starting_value and dataspace bar widths, just use
 
220
        # the index culling for now.
 
221
#        value_range_mask = self.value_mapper.range.mask_data(value)
 
222
#        nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
 
223
#        point_mask = index_mask & value_mask & nan_mask & \
 
224
#                     index_range_mask & value_range_mask
 
225
 
 
226
        index_range_mask = self.index_mapper.range.mask_data(index)
 
227
        nan_mask = invert(isnan(index_mask))
 
228
        point_mask = index_mask & nan_mask & index_range_mask
 
229
 
 
230
        if self.starting_value is None:
 
231
            starting_values = zeros(len(index))
 
232
        else:
 
233
            starting_values = self.starting_value.get_data()
 
234
 
 
235
        if self.bar_width_type == "data":
 
236
            half_width = self.bar_width / 2.0
 
237
            points = column_stack((index-half_width, index+half_width,
 
238
                                   starting_values, value))
 
239
        else:
 
240
            points = column_stack((index, starting_values, value))
 
241
        self._cached_data_pts = compress(point_mask, points, axis=0)
 
242
 
 
243
        self._cache_valid = True
 
244
        return
 
245
 
 
246
    def _draw_plot(self, gc, view_bounds=None, mode="normal"):
 
247
        """ Draws the 'plot' layer.
 
248
        """
 
249
        if not self._cache_valid:
 
250
            self._gather_points()
 
251
 
 
252
        data = self._cached_data_pts
 
253
        if data.size == 0:
 
254
            # Nothing to draw.
 
255
            return
 
256
 
 
257
        gc.save_state()
 
258
        gc.clip_to_rect(self.x, self.y, self.width, self.height)
 
259
        gc.set_antialias(self.antialias)
 
260
        gc.set_stroke_color(self.line_color_)
 
261
        gc.set_fill_color(self.fill_color_)
 
262
        gc.set_line_width(self.line_width)
 
263
 
 
264
        if self.bar_width_type == "data":
 
265
            # map the bar start and stop locations into screen space
 
266
            lower_left_pts = self.map_screen(data[:,(0,2)])
 
267
            upper_right_pts = self.map_screen(data[:,(1,3)])
 
268
        else:
 
269
            half_width = self.bar_width / 2.0
 
270
            # map the bar centers into screen space and then compute the bar
 
271
            # start and end positions
 
272
            lower_left_pts = self.map_screen(data[:,(0,1)])
 
273
            upper_right_pts = self.map_screen(data[:,(0,2)])
 
274
            lower_left_pts[:,0] -= half_width
 
275
            upper_right_pts[:,0] += half_width
 
276
 
 
277
        bounds = upper_right_pts - lower_left_pts
 
278
        gc.rects(column_stack((lower_left_pts, bounds)))
 
279
        gc.draw_path()
 
280
        gc.restore_state()
 
281
 
 
282
 
 
283
    def _draw_default_axes(self, gc):
 
284
        if not self.origin_axis_visible:
 
285
            return
 
286
        gc.save_state()
 
287
        gc.set_stroke_color(self.origin_axis_color_)
 
288
        gc.set_line_width(self.origin_axis_width)
 
289
        gc.set_line_dash(None)
 
290
 
 
291
        for range in (self.index_mapper.range, self.value_mapper.range):
 
292
            if (range.low < 0) and (range.high > 0):
 
293
                if range == self.index_mapper.range:
 
294
                    dual = self.value_mapper.range
 
295
                    data_pts = array([[0.0,dual.low], [0.0, dual.high]])
 
296
                else:
 
297
                    dual = self.index_mapper.range
 
298
                    data_pts = array([[dual.low,0.0], [dual.high,0.0]])
 
299
                start,end = self.map_screen(data_pts)
 
300
                gc.move_to(int(start[0])+0.5, int(start[1])+0.5)
 
301
                gc.line_to(int(end[0])+0.5, int(end[1])+0.5)
 
302
                gc.stroke_path()
 
303
        gc.restore_state()
 
304
        return
 
305
 
 
306
    def _render_icon(self, gc, x, y, width, height):
 
307
        gc.save_state()
 
308
        gc.set_fill_color(self.fill_color_)
 
309
        gc.set_stroke_color(self.line_color_)
 
310
        gc.rect(x+width/4, y+height/4, width/2, height/2)
 
311
        gc.draw_path(FILL_STROKE)
 
312
        gc.restore_state()
 
313
 
 
314
    def _post_load(self):
 
315
        super(BarPlot, self)._post_load()
 
316
        return
 
317
 
 
318
 
 
319
    #------------------------------------------------------------------------
 
320
    # Properties
 
321
    #------------------------------------------------------------------------
 
322
 
 
323
    def _get_index_range(self):
 
324
        return self.index_mapper.range
 
325
 
 
326
    def _set_index_range(self, val):
 
327
        self.index_mapper.range = val
 
328
 
 
329
    def _get_value_range(self):
 
330
        return self.value_mapper.range
 
331
 
 
332
    def _set_value_range(self, val):
 
333
        self.value_mapper.range = val
 
334
 
 
335
    def _get_x_mapper(self):
 
336
        if self.orientation == "h":
 
337
            return self.index_mapper
 
338
        else:
 
339
            return self.value_mapper
 
340
 
 
341
    def _get_y_mapper(self):
 
342
        if self.orientation == "h":
 
343
            return self.value_mapper
 
344
        else:
 
345
            return self.index_mapper
 
346
 
 
347
    def _get_x_direction(self):
 
348
        if self.orientation == "h":
 
349
            return self.index_direction
 
350
        else:
 
351
            return self.value_direction
 
352
 
 
353
    def _get_y_direction(self):
 
354
        if self.orientation == "h":
 
355
            return self.value_direction
 
356
        else:
 
357
            return self.index_direction
 
358
 
 
359
    #------------------------------------------------------------------------
 
360
    # Event handlers - these are mostly copied from BaseXYPlot
 
361
    #------------------------------------------------------------------------
 
362
 
 
363
    def _update_mappers(self):
 
364
        """ Updates the index and value mappers. Called by trait change handlers
 
365
        for various traits.
 
366
        """
 
367
        x_mapper = self.index_mapper
 
368
        y_mapper = self.value_mapper
 
369
        x_dir = self.index_direction
 
370
        y_dir = self.value_direction
 
371
 
 
372
        if self.orientation == "v":
 
373
            x_mapper, y_mapper = y_mapper, x_mapper
 
374
            x_dir, y_dir = y_dir, x_dir
 
375
 
 
376
        x = self.x
 
377
        x2 = self.x2
 
378
        y = self.y
 
379
        y2 = self.y2
 
380
 
 
381
        if x_mapper is not None:
 
382
            if x_dir =="normal":
 
383
                x_mapper.low_pos = x
 
384
                x_mapper.high_pos = x2
 
385
            else:
 
386
                x_mapper.low_pos = x2
 
387
                x_mapper.high_pos = x
 
388
 
 
389
        if y_mapper is not None:
 
390
            if y_dir == "normal":
 
391
                y_mapper.low_pos = y
 
392
                y_mapper.high_pos = y2
 
393
            else:
 
394
                y_mapper.low_pos = y2
 
395
                y_mapper.high_pos = y
 
396
 
 
397
        self.invalidate_draw()
 
398
        self._cache_valid = False
 
399
 
 
400
    def _bounds_changed(self, old, new):
 
401
        super(BarPlot, self)._bounds_changed(old, new)
 
402
        self._update_mappers()
 
403
 
 
404
    def _bounds_items_changed(self, event):
 
405
        super(BarPlot, self)._bounds_items_changed(event)
 
406
        self._update_mappers()
 
407
 
 
408
    def _orientation_changed(self):
 
409
        self._update_mappers()
 
410
 
 
411
    def _index_changed(self, old, new):
 
412
        if old is not None:
 
413
            old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
 
414
        if new is not None:
 
415
            new.on_trait_change(self._either_data_changed, "data_changed")
 
416
        self._either_data_changed()
 
417
 
 
418
    def _index_direction_changed(self):
 
419
        m = self.index_mapper
 
420
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
 
421
        self.invalidate_draw()
 
422
 
 
423
    def _value_direction_changed(self):
 
424
        m = self.value_mapper
 
425
        m.low_pos, m.high_pos = m.high_pos, m.low_pos
 
426
        self.invalidate_draw()
 
427
 
 
428
    def _either_data_changed(self):
 
429
        self.invalidate_draw()
 
430
        self._cache_valid = False
 
431
        self.request_redraw()
 
432
 
 
433
    def _value_changed(self, old, new):
 
434
        if old is not None:
 
435
            old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
 
436
        if new is not None:
 
437
            new.on_trait_change(self._either_data_changed, "data_changed")
 
438
        self._either_data_changed()
 
439
 
 
440
    def _index_mapper_changed(self, old, new):
 
441
        return self._either_mapper_changed(old, new)
 
442
 
 
443
    def _value_mapper_changed(self, old, new):
 
444
        return self._either_mapper_changed(old, new)
 
445
 
 
446
    def _either_mapper_changed(self, old, new):
 
447
        if old is not None:
 
448
            old.on_trait_change(self._mapper_updated_handler, "updated", remove=True)
 
449
        if new is not None:
 
450
            new.on_trait_change(self._mapper_updated_handler, "updated")
 
451
        self.invalidate_draw()
 
452
 
 
453
    def _mapper_updated_handler(self):
 
454
        self._cache_valid = False
 
455
        self.invalidate_draw()
 
456
        self.request_redraw()
 
457
 
 
458
    def _bar_width_changed(self):
 
459
        self._cache_valid = False
 
460
        self.invalidate_draw()
 
461
        self.request_redraw()
 
462
 
 
463
    def _bar_width_type_changed(self):
 
464
        self._cache_valid = False
 
465
        self.invalidate_draw()
 
466
        self.request_redraw()
 
467
 
 
468
 
 
469
 
 
470
### EOF ####################################################################