~ubuntu-branches/ubuntu/precise/python-chaco/precise

« back to all changes in this revision

Viewing changes to enthought/chaco/plot_factory.py

  • Committer: Bazaar Package Importer
  • Author(s): Varun Hiremath
  • Date: 2008-12-29 02:34:05 UTC
  • Revision ID: james.westby@ubuntu.com-20081229023405-x7i4kp9mdxzmdnvu
Tags: upstream-3.0.1
ImportĀ upstreamĀ versionĀ 3.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
"""
 
2
Contains convenience functions to create ready-made PlotRenderer
 
3
and PlotFrame instances of various types.
 
4
"""
 
5
 
 
6
from numpy import array, ndarray, transpose, cos, sin
 
7
 
 
8
# Local relative imports
 
9
from abstract_data_source import AbstractDataSource
 
10
from array_data_source import ArrayDataSource
 
11
from axis import PlotAxis
 
12
from barplot import BarPlot
 
13
from data_range_1d import DataRange1D
 
14
from grid import PlotGrid
 
15
from linear_mapper import LinearMapper
 
16
from scatterplot import ScatterPlot
 
17
from polar_mapper import PolarMapper
 
18
from lineplot import LinePlot
 
19
from polar_line_renderer import PolarLineRenderer
 
20
 
 
21
def _create_data_sources(data, index_sort="none"):
 
22
    """
 
23
    Returns datasources for index and value based on the inputs.  Assumes that
 
24
    the index data is unsorted unless otherwise specified.
 
25
    """
 
26
    if (type(data) == ndarray) or (len(data) == 2):
 
27
        index, value = data
 
28
        if type(index) in (list, tuple, ndarray):
 
29
            index = ArrayDataSource(array(index), sort_order=index_sort)
 
30
        elif not isinstance(index, AbstractDataSource):
 
31
            raise RuntimeError, "Need an array or list of values or a DataSource, got %s instead." % type(index)
 
32
        
 
33
        if type(value) in (list, tuple, ndarray):
 
34
            value = ArrayDataSource(array(value))
 
35
        elif not isinstance(value, AbstractDataSource):
 
36
            raise RuntimeError, "Need an array or list of values or a DataSource, got %s instead." % type(index)
 
37
        
 
38
        return index, value
 
39
    else:
 
40
        raise RuntimeError, "Unable to create datasources."
 
41
    
 
42
 
 
43
def create_scatter_plot(data=[], index_bounds=None, value_bounds=None,
 
44
                        orientation="h", color="green", marker="square",
 
45
                        marker_size=4,
 
46
                        bgcolor="transparent", outline_color="black",
 
47
                        border_visible=True,
 
48
                        add_grid=False, add_axis=False,
 
49
                        index_sort="none"):
 
50
    """
 
51
    Creates a ScatterPlot from a single Nx2 data array or a tuple of
 
52
    two length-N 1-D arrays.  The data must be sorted on the index if any
 
53
    reverse-mapping tools are to be used.
 
54
    
 
55
    Pre-existing "index" and "value" datasources can be passed in.
 
56
    """
 
57
    
 
58
    index, value = _create_data_sources(data)
 
59
    
 
60
    if index_bounds is not None:
 
61
        index_range = DataRange1D(low=index_bounds[0], high=index_bounds[1])
 
62
    else:
 
63
        index_range = DataRange1D()
 
64
    index_range.add(index)
 
65
    index_mapper = LinearMapper(range=index_range)
 
66
 
 
67
    if value_bounds is not None:
 
68
        value_range = DataRange1D(low=value_bounds[0], high=value_bounds[1])
 
69
    else:
 
70
        value_range = DataRange1D()
 
71
    value_range.add(value)
 
72
    value_mapper = LinearMapper(range=value_range)
 
73
    
 
74
    plot = ScatterPlot(index=index, value=value,
 
75
                         index_mapper=index_mapper,
 
76
                         value_mapper=value_mapper,
 
77
                         orientation=orientation,
 
78
                         marker=marker,
 
79
                         marker_size=marker_size,
 
80
                         color=color,
 
81
                         bgcolor=bgcolor,
 
82
                         outline_color=outline_color,
 
83
                         border_visible=border_visible,)
 
84
    
 
85
    if add_grid:
 
86
        add_default_grids(plot, orientation)
 
87
    if add_axis:
 
88
        add_default_axes(plot, orientation)
 
89
    return plot
 
90
 
 
91
 
 
92
def create_line_plot(data=[], index_bounds=None, value_bounds=None,
 
93
                     orientation="h", color="red", width=1.0,
 
94
                     dash="solid", value_mapper_class=LinearMapper,
 
95
                     bgcolor="transparent", border_visible=False,
 
96
                     add_grid=False, add_axis=False,
 
97
                     index_sort="none"):
 
98
    
 
99
    index, value = _create_data_sources(data, index_sort)
 
100
    
 
101
    if index_bounds is not None:
 
102
        index_range = DataRange1D(low=index_bounds[0], high=index_bounds[1])
 
103
    else:
 
104
        index_range = DataRange1D()
 
105
    index_range.add(index)
 
106
    index_mapper = LinearMapper(range=index_range)
 
107
 
 
108
    if value_bounds is not None:
 
109
        value_range = DataRange1D(low=value_bounds[0], high=value_bounds[1])
 
110
    else:
 
111
        value_range = DataRange1D()
 
112
    value_range.add(value)
 
113
    value_mapper = value_mapper_class(range=value_range)
 
114
    
 
115
    plot = LinePlot(index=index, value=value,
 
116
                    index_mapper = index_mapper,
 
117
                    value_mapper = value_mapper,
 
118
                    orientation = orientation,
 
119
                    color = color,
 
120
                    bgcolor = bgcolor,
 
121
                    line_width = width,
 
122
                    line_style = dash,
 
123
                    border_visible=border_visible)
 
124
    
 
125
    if add_grid:
 
126
        add_default_grids(plot, orientation)
 
127
    if add_axis:
 
128
        add_default_axes(plot, orientation)
 
129
    return plot
 
130
 
 
131
 
 
132
def create_bar_plot(data=[], index_bounds=None, value_bounds=None,
 
133
                     orientation="h", color="red", bar_width=10.0,
 
134
                     value_mapper_class=LinearMapper,
 
135
                     line_color="black", 
 
136
                     fill_color="red", line_width=1,
 
137
                     bgcolor="transparent", border_visible=False,
 
138
                     antialias=True,
 
139
                     add_grid=False, add_axis=False):
 
140
    
 
141
    index, value = _create_data_sources(data)
 
142
    
 
143
    if index_bounds is not None:
 
144
        index_range = DataRange1D(low=index_bounds[0], high=index_bounds[1])
 
145
    else:
 
146
        index_range = DataRange1D()
 
147
    index_range.add(index)
 
148
    index_mapper = LinearMapper(range=index_range)
 
149
 
 
150
    if value_bounds is not None:
 
151
        value_range = DataRange1D(low=value_bounds[0], high=value_bounds[1])
 
152
    else:
 
153
        value_range = DataRange1D()
 
154
    value_range.add(value)
 
155
    value_mapper = value_mapper_class(range=value_range)
 
156
    
 
157
    # Create the plot
 
158
    plot = BarPlot(index=index,
 
159
                    value=value,
 
160
                    value_mapper=value_mapper,
 
161
                    index_mapper=index_mapper,
 
162
                    orientation=orientation,
 
163
                    line_color=line_color,
 
164
                    fill_color=fill_color,
 
165
                    line_width=line_width,
 
166
                    bar_width=bar_width,
 
167
                    antialias=antialias,)
 
168
    
 
169
    if add_grid:
 
170
        add_default_grids(plot, orientation)
 
171
    if add_axis:
 
172
        add_default_axes(plot, orientation)
 
173
    return plot
 
174
 
 
175
 
 
176
def create_polar_plot(data, orientation='h', color='black', width=1.0,
 
177
                      dash="solid", grid="dot", value_mapper_class=PolarMapper):
 
178
    if (type(data) != ndarray) and (len(data) == 2):
 
179
        data = transpose(array(data))
 
180
    
 
181
    r_data, t_data = transpose(data)
 
182
    index_data= r_data*cos(t_data)
 
183
    value_data= r_data*sin(t_data)
 
184
    
 
185
    index = ArrayDataSource(index_data, sort_order='ascending')
 
186
    # Typically the value data is unsorted
 
187
    value = ArrayDataSource(value_data)
 
188
 
 
189
    index_range = DataRange1D()
 
190
    index_range.add(index)
 
191
    index_mapper = PolarMapper(range=index_range)
 
192
    
 
193
    value_range = DataRange1D()
 
194
    value_range.add(value)
 
195
    value_mapper = value_mapper_class(range=value_range)
 
196
    
 
197
    plot = PolarLineRenderer(index=index, value=value,
 
198
                    index_mapper = index_mapper,
 
199
                    value_mapper = value_mapper,
 
200
                    orientation = orientation,
 
201
                    color = color,
 
202
                    line_width = width,
 
203
                    line_style = dash,
 
204
                    grid_style = grid)
 
205
    
 
206
    return plot
 
207
 
 
208
 
 
209
def add_default_axes(plot, orientation="normal", vtitle="",htitle=""):
 
210
    """
 
211
    Creates left and bottom axes for a plot.  Assumes that the index is
 
212
    horizontal and value is vertical by default; set *orientation* to
 
213
    something other than "normal" if they are flipped.
 
214
    """
 
215
    if orientation in ("normal", "h"):
 
216
        v_mapper = plot.value_mapper
 
217
        h_mapper = plot.index_mapper
 
218
    else:
 
219
        v_mapper = plot.index_mapper
 
220
        h_mapper = plot.value_mapper
 
221
    
 
222
    left = PlotAxis(orientation='left',
 
223
                    title= vtitle,
 
224
                    mapper=v_mapper,
 
225
                    component=plot)
 
226
    
 
227
    bottom = PlotAxis(orientation='bottom',
 
228
                      title= htitle,
 
229
                      mapper=h_mapper,
 
230
                      component=plot)
 
231
 
 
232
    plot.underlays.append(left)
 
233
    plot.underlays.append(bottom)
 
234
    return left, bottom
 
235
 
 
236
 
 
237
def add_default_grids(plot, orientation="normal"):
 
238
    """
 
239
    Creates horizontal and vertical gridlines for a plot.  Assumes that the
 
240
    index is horizontal and value is vertical by default; set orientation to
 
241
    something other than "normal" if they are flipped.
 
242
    """
 
243
    if orientation in ("normal", "h"):
 
244
        v_mapper = plot.index_mapper
 
245
        h_mapper = plot.value_mapper
 
246
    else:
 
247
        v_mapper = plot.value_mapper
 
248
        h_mapper = plot.index_mapper
 
249
    
 
250
    vgrid = PlotGrid(mapper=v_mapper, orientation='vertical',
 
251
                     component=plot,
 
252
                     line_color="lightgray", line_style="dot")
 
253
 
 
254
    hgrid = PlotGrid(mapper=h_mapper, orientation='horizontal',
 
255
                     component=plot,
 
256
                     line_color="lightgray", line_style="dot")
 
257
    
 
258
    plot.underlays.append(vgrid)
 
259
    plot.underlays.append(hgrid)
 
260
    return hgrid, vgrid
 
261
 
 
262
 
 
263
 
 
264
# EOF