1
from __future__ import with_statement
3
from numpy import array, transpose, ndarray, empty
4
from traits.api import Instance, DelegatesTo, Bool, Int
6
from enable.api import transparent_color_trait
7
from chaco.color_mapper import ColorMapper
8
from chaco.base_xy_plot import BaseXYPlot
9
from chaco.linear_mapper import LinearMapper
11
class BandedMapper(LinearMapper):
14
def map_screen(self, data_array):
17
if self._null_data_range:
18
if isinstance(data_array, (tuple, list, ndarray)):
19
x = empty(data_array.shape)
23
return array([self.low_pos])
25
# Scale the data by the number of bands
26
return (data_array*self.bands - self.range.low) * self._scale + self.low_pos
28
class HorizonPlot(BaseXYPlot):
30
bands = DelegatesTo('value_mapper')
31
color_mapper = Instance(ColorMapper)
35
# FIXME There should be a way to automatically detect whether the data has
37
negative_bands = Bool(True)
39
# Override parent traits
43
def _color_mapper_changed(self, new):
44
# change the number of steps to match the number of bands
45
if not self.negative_bands:
46
new.steps = self.bands+1
48
new.steps = self.bands*2+1
50
def _gather_points(self):
51
""" Collects the data points that are within the bounds of the plot and
57
index = self.index.get_data()
58
value = self.value.get_data()
60
if not self.index or not self.value:
63
if len(index) == 0 or len(value) == 0 or len(index) != len(value):
64
self._cached_data_pts = []
65
self._cache_valid = True
68
points = transpose(array((index,value)))
69
self._cached_data_pts = points
71
self._cache_valid = True
73
def _render(self, gc, points):
77
ox, oy = self.map_screen([[0,0]])[0]
78
ylow, yhigh = self.value_mapper.screen_bounds
80
y_plus_height = yhigh - oy
83
bands = array(self.color_mapper._get_color_bands())
86
gc.clip_to_rect(self.x, self.y, self.width, self.height)
88
inc = -1 * array([0, y_plus_height])
89
if self.negative_bands: render_bands = bands[self.bands+1:]
90
else: render_bands = bands[1:]
91
for i, col in enumerate(render_bands):
92
self._render_fill(gc, col, points+i*inc, ox, oy)
95
if self.negative_bands:
97
points[:,1] = oy - points[:,1]
100
points[:,1] += y_plus_height
102
zeroy = int(yhigh) + 2
103
for i, col in enumerate(bands[self.bands-1::-1]):
104
self._render_fill(gc, col, points+i*inc, ox, zeroy)
106
gc.set_stroke_color((.75, .75, .75))
109
gc.move_to(self.x, self.y)
110
gc.line_to(self.x+self.width, self.y)
113
def _render_fill(self, gc, face_col, points, ox, oy):
114
gc.set_fill_color(tuple(face_col))
116
startx, starty = points[0]
117
gc.move_to(startx, oy)
118
gc.line_to(startx, starty)
122
endx, endy = points[-1]
124
gc.line_to(startx, oy)