1
""" Defines the BarPlot class.
5
from numpy import array, compress, column_stack, invert, isnan, transpose, zeros
6
from enthought.traits.api import Any, Bool, Enum, Float, Instance, Property
7
from enthought.enable.api import black_color_trait
8
from enthought.kiva import FILL_STROKE
10
# Local relative imports
11
from enthought.chaco.abstract_plot_renderer import AbstractPlotRenderer
12
from abstract_mapper import AbstractMapper
13
from array_data_source import ArrayDataSource
14
from base import reverse_map_1d
17
logger = logging.getLogger(__name__)
20
class BarPlot(AbstractPlotRenderer):
22
A renderer for bar charts.
24
# The data source to use for the index coordinate.
25
index = Instance(ArrayDataSource)
27
# The data source to use as value points.
28
value = Instance(ArrayDataSource)
30
# The data source to use as "starting" values for the bars.
31
starting_value = Instance(ArrayDataSource)
33
# Labels for the indices.
34
index_mapper = Instance(AbstractMapper)
35
# Labels for the values.
36
value_mapper = Instance(AbstractMapper)
38
# The orientation of the index axis.
39
orientation = Enum("h", "v")
41
# The direction of the index axis with respect to the graphics context's
43
index_direction = Enum("normal", "flipped")
45
# The direction of the value axis with respect to the graphics context's
47
value_direction = Enum("normal", "flipped")
49
# Type of width used for bars:
52
# The width is in the units along the x-dimension of the data space.
54
# The width uses a fixed width of pixels.
55
bar_width_type = Enum("data", "screen")
57
# Width of the bars, in data or screen space (determined by
58
# **bar_width_type**).
61
# Round on rectangle dimensions? This is not strictly an "antialias", but
62
# it has the same effect through exact pixel drawing.
63
antialias = Bool(True)
65
# Width of the border of the bars.
66
line_width = Float(1.0)
67
# Color of the border of the bars.
68
line_color = black_color_trait
69
# Color to fill the bars.
70
fill_color = black_color_trait
72
#use_draw_order = False
74
# Convenience properties that correspond to either index_mapper or
75
# value_mapper, depending on the orientation of the plot.
77
# Corresponds to either **index_mapper** or **value_mapper**, depending on
78
# the orientation of the plot.
80
# Corresponds to either **value_mapper** or **index_mapper**, depending on
81
# the orientation of the plot.
84
# Corresponds to either **index_direction** or **value_direction**,
85
# depending on the orientation of the plot.
86
x_direction = Property
87
# Corresponds to either **value_direction** or **index_direction**,
88
# depending on the orientation of the plot
89
y_direction = Property
91
# Convenience property for accessing the index data range.
92
index_range = Property
93
# Convenience property for accessing the value data range.
94
value_range = Property
97
#------------------------------------------------------------------------
99
#------------------------------------------------------------------------
101
# Indicates whether or not the data cache is valid
102
_cache_valid = Bool(False)
104
# Cached data values from the datasources. If **bar_width_type** is "data",
105
# then this is an Nx4 array of (bar_left, bar_right, start, end) for a
106
# bar plot in normal orientation. If **bar_width_type** is "screen", then
107
# this is an Nx3 array of (bar_center, start, end).
108
_cached_data_pts = Any
111
#------------------------------------------------------------------------
112
# AbstractPlotRenderer interface
113
#------------------------------------------------------------------------
115
def map_screen(self, data_array):
116
""" Maps an array of data points into screen space and returns it as
119
Implements the AbstractPlotRenderer interface.
121
# data_array is Nx2 array
122
if len(data_array) == 0:
124
x_ary, y_ary = transpose(data_array)
125
sx = self.index_mapper.map_screen(x_ary)
126
sy = self.value_mapper.map_screen(y_ary)
128
# reverse the directions as indicated
129
if self.index_direction == "flipped":
131
x_offset = self.bounds[0]
136
if self.value_direction == "flipped":
138
y_offset = self.bounds[1]
143
# now return map based on orientation
144
if self.orientation == "h":
145
return transpose(array((x_sign*sx+x_offset,y_sign*sy+y_offset)))
147
return transpose(array((y_sign*sy+y_offset,x_sign*sx+x_offset)))
149
def map_data(self, screen_pt):
150
""" Maps a screen space point into the "index" space of the plot.
152
Implements the AbstractPlotRenderer interface.
154
if self.orientation == "h":
155
screen_coord = screen_pt[0]
157
screen_coord = screen_pt[1]
158
return self.index_mapper.map_data(screen_coord)
160
def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, \
162
""" Maps a screen space point to an index into the plot's index array(s).
164
Implements the AbstractPlotRenderer interface.
166
data_pt = self.map_data(screen_pt)
167
if ((data_pt < self.index_mapper.range.low) or \
168
(data_pt > self.index_mapper.range.high)) and outside_returns_none:
170
half = threshold / 2.0
171
index_data = self.index.get_data()
172
value_data = self.value.get_data()
174
if len(value_data) == 0 or len(index_data) == 0:
178
ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order)
185
result = self.map_screen(array([[x,y]]))
190
if index_only and ((screen_pt[0]-sx) < threshold):
192
elif ((screen_pt[0]-sx)**2 + (screen_pt[1]-sy)**2 < threshold*threshold):
197
#------------------------------------------------------------------------
198
# PlotComponent interface
199
#------------------------------------------------------------------------
201
def _gather_points(self):
202
""" Collects data points that are within the range of the plot, and
203
caches them in **_cached_data_pts**.
205
index, index_mask = self.index.get_data_mask()
206
value, value_mask = self.value.get_data_mask()
208
if not self.index or not self.value:
211
if len(index) == 0 or len(value) == 0 or len(index) != len(value):
212
logger.warn("Chaco: using empty dataset; index_len=%d, value_len=%d." \
213
% (len(index), len(value)))
214
self._cached_data_pts = array([])
215
self._cache_valid = True
218
# TODO: Until we code up a better handling of value-based culling that
219
# takes into account starting_value and dataspace bar widths, just use
220
# the index culling for now.
221
# value_range_mask = self.value_mapper.range.mask_data(value)
222
# nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask))
223
# point_mask = index_mask & value_mask & nan_mask & \
224
# index_range_mask & value_range_mask
226
index_range_mask = self.index_mapper.range.mask_data(index)
227
nan_mask = invert(isnan(index_mask))
228
point_mask = index_mask & nan_mask & index_range_mask
230
if self.starting_value is None:
231
starting_values = zeros(len(index))
233
starting_values = self.starting_value.get_data()
235
if self.bar_width_type == "data":
236
half_width = self.bar_width / 2.0
237
points = column_stack((index-half_width, index+half_width,
238
starting_values, value))
240
points = column_stack((index, starting_values, value))
241
self._cached_data_pts = compress(point_mask, points, axis=0)
243
self._cache_valid = True
246
def _draw_plot(self, gc, view_bounds=None, mode="normal"):
247
""" Draws the 'plot' layer.
249
if not self._cache_valid:
250
self._gather_points()
252
data = self._cached_data_pts
258
gc.clip_to_rect(self.x, self.y, self.width, self.height)
259
gc.set_antialias(self.antialias)
260
gc.set_stroke_color(self.line_color_)
261
gc.set_fill_color(self.fill_color_)
262
gc.set_line_width(self.line_width)
264
if self.bar_width_type == "data":
265
# map the bar start and stop locations into screen space
266
lower_left_pts = self.map_screen(data[:,(0,2)])
267
upper_right_pts = self.map_screen(data[:,(1,3)])
269
half_width = self.bar_width / 2.0
270
# map the bar centers into screen space and then compute the bar
271
# start and end positions
272
lower_left_pts = self.map_screen(data[:,(0,1)])
273
upper_right_pts = self.map_screen(data[:,(0,2)])
274
lower_left_pts[:,0] -= half_width
275
upper_right_pts[:,0] += half_width
277
bounds = upper_right_pts - lower_left_pts
278
gc.rects(column_stack((lower_left_pts, bounds)))
283
def _draw_default_axes(self, gc):
284
if not self.origin_axis_visible:
287
gc.set_stroke_color(self.origin_axis_color_)
288
gc.set_line_width(self.origin_axis_width)
289
gc.set_line_dash(None)
291
for range in (self.index_mapper.range, self.value_mapper.range):
292
if (range.low < 0) and (range.high > 0):
293
if range == self.index_mapper.range:
294
dual = self.value_mapper.range
295
data_pts = array([[0.0,dual.low], [0.0, dual.high]])
297
dual = self.index_mapper.range
298
data_pts = array([[dual.low,0.0], [dual.high,0.0]])
299
start,end = self.map_screen(data_pts)
300
gc.move_to(int(start[0])+0.5, int(start[1])+0.5)
301
gc.line_to(int(end[0])+0.5, int(end[1])+0.5)
306
def _render_icon(self, gc, x, y, width, height):
308
gc.set_fill_color(self.fill_color_)
309
gc.set_stroke_color(self.line_color_)
310
gc.rect(x+width/4, y+height/4, width/2, height/2)
311
gc.draw_path(FILL_STROKE)
314
def _post_load(self):
315
super(BarPlot, self)._post_load()
319
#------------------------------------------------------------------------
321
#------------------------------------------------------------------------
323
def _get_index_range(self):
324
return self.index_mapper.range
326
def _set_index_range(self, val):
327
self.index_mapper.range = val
329
def _get_value_range(self):
330
return self.value_mapper.range
332
def _set_value_range(self, val):
333
self.value_mapper.range = val
335
def _get_x_mapper(self):
336
if self.orientation == "h":
337
return self.index_mapper
339
return self.value_mapper
341
def _get_y_mapper(self):
342
if self.orientation == "h":
343
return self.value_mapper
345
return self.index_mapper
347
def _get_x_direction(self):
348
if self.orientation == "h":
349
return self.index_direction
351
return self.value_direction
353
def _get_y_direction(self):
354
if self.orientation == "h":
355
return self.value_direction
357
return self.index_direction
359
#------------------------------------------------------------------------
360
# Event handlers - these are mostly copied from BaseXYPlot
361
#------------------------------------------------------------------------
363
def _update_mappers(self):
364
""" Updates the index and value mappers. Called by trait change handlers
367
x_mapper = self.index_mapper
368
y_mapper = self.value_mapper
369
x_dir = self.index_direction
370
y_dir = self.value_direction
372
if self.orientation == "v":
373
x_mapper, y_mapper = y_mapper, x_mapper
374
x_dir, y_dir = y_dir, x_dir
381
if x_mapper is not None:
384
x_mapper.high_pos = x2
386
x_mapper.low_pos = x2
387
x_mapper.high_pos = x
389
if y_mapper is not None:
390
if y_dir == "normal":
392
y_mapper.high_pos = y2
394
y_mapper.low_pos = y2
395
y_mapper.high_pos = y
397
self.invalidate_draw()
398
self._cache_valid = False
400
def _bounds_changed(self, old, new):
401
super(BarPlot, self)._bounds_changed(old, new)
402
self._update_mappers()
404
def _bounds_items_changed(self, event):
405
super(BarPlot, self)._bounds_items_changed(event)
406
self._update_mappers()
408
def _orientation_changed(self):
409
self._update_mappers()
411
def _index_changed(self, old, new):
413
old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
415
new.on_trait_change(self._either_data_changed, "data_changed")
416
self._either_data_changed()
418
def _index_direction_changed(self):
419
m = self.index_mapper
420
m.low_pos, m.high_pos = m.high_pos, m.low_pos
421
self.invalidate_draw()
423
def _value_direction_changed(self):
424
m = self.value_mapper
425
m.low_pos, m.high_pos = m.high_pos, m.low_pos
426
self.invalidate_draw()
428
def _either_data_changed(self):
429
self.invalidate_draw()
430
self._cache_valid = False
431
self.request_redraw()
433
def _value_changed(self, old, new):
435
old.on_trait_change(self._either_data_changed, "data_changed", remove=True)
437
new.on_trait_change(self._either_data_changed, "data_changed")
438
self._either_data_changed()
440
def _index_mapper_changed(self, old, new):
441
return self._either_mapper_changed(old, new)
443
def _value_mapper_changed(self, old, new):
444
return self._either_mapper_changed(old, new)
446
def _either_mapper_changed(self, old, new):
448
old.on_trait_change(self._mapper_updated_handler, "updated", remove=True)
450
new.on_trait_change(self._mapper_updated_handler, "updated")
451
self.invalidate_draw()
453
def _mapper_updated_handler(self):
454
self._cache_valid = False
455
self.invalidate_draw()
456
self.request_redraw()
458
def _bar_width_changed(self):
459
self._cache_valid = False
460
self.invalidate_draw()
461
self.request_redraw()
463
def _bar_width_type_changed(self):
464
self._cache_valid = False
465
self.invalidate_draw()
466
self.request_redraw()
470
### EOF ####################################################################