~ubuntu-branches/ubuntu/precise/python-chaco/precise

« back to all changes in this revision

Viewing changes to enthought/chaco/axis.py

  • Committer: Bazaar Package Importer
  • Author(s): Varun Hiremath
  • Date: 2008-12-29 02:34:05 UTC
  • Revision ID: james.westby@ubuntu.com-20081229023405-x7i4kp9mdxzmdnvu
Tags: upstream-3.0.1
ImportĀ upstreamĀ versionĀ 3.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
""" Defines the PlotAxis class, and associated validator and UI.
 
2
"""
 
3
# Major library import
 
4
from numpy import array, around, absolute, cos, dot, float64, inf, pi, \
 
5
                  sqrt, sin, transpose
 
6
 
 
7
# Enthought Library imports
 
8
from enthought.enable.api import ColorTrait, LineStyle
 
9
from enthought.kiva.traits.kiva_font_trait import KivaFont
 
10
from enthought.traits.api import Any, Float, Int, Str, Trait, Unicode, \
 
11
     Bool, Event, List, Array, Instance, Enum, Callable
 
12
 
 
13
# Local relative imports
 
14
from ticks import AbstractTickGenerator, DefaultTickGenerator
 
15
from abstract_mapper import AbstractMapper
 
16
from abstract_overlay import AbstractOverlay
 
17
from label import Label
 
18
from log_mapper import LogMapper
 
19
 
 
20
 
 
21
def DEFAULT_TICK_FORMATTER(val):
 
22
    return ("%f"%val).rstrip("0").rstrip(".")
 
23
 
 
24
class PlotAxis(AbstractOverlay):
 
25
    """
 
26
    The PlotAxis is a visual component that can be rendered on its own as
 
27
    a standalone component or attached as an overlay to another component.
 
28
    (To attach it as an overlay, set its **component** attribute.)
 
29
 
 
30
    When it is attached as an overlay, it draws into the padding around
 
31
    the component.
 
32
    """
 
33
 
 
34
    # The mapper that drives this axis.
 
35
    mapper = Instance(AbstractMapper)
 
36
 
 
37
    # The text of the axis title.
 
38
    title = Trait('', Str, Unicode) #May want to add PlotLabel option
 
39
 
 
40
    # The font of the title.
 
41
    title_font = KivaFont('modern 12')
 
42
 
 
43
    # The spacing between the axis line and the title
 
44
    title_spacing = Trait('auto', 'auto', Float)
 
45
 
 
46
    # The color of the title.
 
47
    title_color = ColorTrait("black")
 
48
 
 
49
    # Not used right now.
 
50
    markers = Any     # TODO: Implement this
 
51
 
 
52
    # The thickness (in pixels) of each tick.
 
53
    tick_weight = Float(1.0)
 
54
 
 
55
    # The color of the ticks.
 
56
    tick_color = ColorTrait("black")
 
57
 
 
58
    # The font of the tick labels.
 
59
    tick_label_font = KivaFont('modern 10')
 
60
 
 
61
    # The color of the tick labels.
 
62
    tick_label_color = ColorTrait("black")
 
63
 
 
64
    # A callable that is passed the numerical value of each tick label and
 
65
    # that returns a string.
 
66
    tick_label_formatter = Callable(DEFAULT_TICK_FORMATTER)
 
67
 
 
68
    # The number of pixels by which the ticks extend into the plot area.
 
69
    tick_in = Int(5)
 
70
 
 
71
    # The number of pixels by which the ticks extend into the label area.
 
72
    tick_out = Int(5)
 
73
 
 
74
    # Are ticks visible at all?
 
75
    tick_visible = Bool(True)
 
76
 
 
77
    # The dataspace interval between ticks.
 
78
    tick_interval = Trait('auto', 'auto', Float)
 
79
 
 
80
    # A callable that implements the AbstractTickGenerator interface.
 
81
    tick_generator = Instance(AbstractTickGenerator)
 
82
 
 
83
    # The location of the axis relative to the plot.  This determines where
 
84
    # the axis title is located relative to the axis line.
 
85
    orientation = Enum("top", "bottom", "left", "right")
 
86
 
 
87
    # Is the axis line visible?
 
88
    axis_line_visible = Bool(True)
 
89
 
 
90
    # The color of the axis line.
 
91
    axis_line_color = ColorTrait("black")
 
92
 
 
93
    # The line thickness (in pixels) of the axis line.
 
94
    axis_line_weight = Float(1.0)
 
95
 
 
96
    # The dash style of the axis line.
 
97
    axis_line_style = LineStyle('solid')
 
98
 
 
99
    # A special version of the axis line that is more useful for geophysical
 
100
    # plots.
 
101
    small_haxis_style = Bool(False)     # TODO: MOVE THIS OUT OF HERE!
 
102
 
 
103
    # Does the axis ensure that its end labels fall within its bounding area?
 
104
    ensure_labels_bounded = Bool(False)
 
105
 
 
106
    # Does the axis prevent the ticks from being rendered outside its bounds?
 
107
    # This flag is off by default because the standard axis *does* render ticks
 
108
    # that encroach on the plot area.
 
109
    ensure_ticks_bounded = Bool(False)
 
110
 
 
111
    # Fired when the axis's range bounds change.
 
112
    updated = Event
 
113
 
 
114
    #------------------------------------------------------------------------
 
115
    # Override default values of inherited traits
 
116
    #------------------------------------------------------------------------
 
117
 
 
118
    # Background color (overrides AbstractOverlay). Axes usually let the color of
 
119
    # the container show through.
 
120
    bgcolor = ColorTrait("transparent")
 
121
 
 
122
    # Dimensions that the axis is resizable in (overrides PlotComponent). 
 
123
    # Typically, axes are resizable in both dimensions.
 
124
    resizable = "hv"
 
125
 
 
126
    #------------------------------------------------------------------------
 
127
    # Private Traits
 
128
    #------------------------------------------------------------------------
 
129
 
 
130
    # Cached position calculations
 
131
 
 
132
    _tick_list = List  # These are caches of their respective positions
 
133
    _tick_positions = Any #List
 
134
    _tick_label_list = Any
 
135
    _tick_label_positions = Any
 
136
    _tick_label_bounding_boxes = List
 
137
    _major_axis_size = Float
 
138
    _minor_axis_size = Float
 
139
    _major_axis = Array
 
140
    _title_orientation = Array
 
141
    _title_angle = Float
 
142
    _origin_point = Array
 
143
    _inside_vector = Array
 
144
    _axis_vector = Array
 
145
    _axis_pixel_vector = Array
 
146
    _end_axis_point = Array
 
147
 
 
148
 
 
149
    ticklabel_cache = List
 
150
    _cache_valid = Bool(False)
 
151
 
 
152
 
 
153
    #------------------------------------------------------------------------
 
154
    # Public methods
 
155
    #------------------------------------------------------------------------
 
156
 
 
157
    def __init__(self, component=None, **kwargs):
 
158
        # TODO: change this back to a factory in the instance trait some day
 
159
        self.tick_generator = DefaultTickGenerator()
 
160
        # Override init so that our component gets set last.  We want the
 
161
        # _component_changed() event handler to get run last.
 
162
        super(PlotAxis, self).__init__(**kwargs)
 
163
        if component is not None:
 
164
            self.component = component
 
165
 
 
166
    def invalidate(self):
 
167
        """ Invalidates the pre-computed layout and scaling data.
 
168
        """
 
169
        self._reset_cache()
 
170
        self.invalidate_draw()
 
171
        return
 
172
 
 
173
    def traits_view(self):
 
174
        """ Returns a View instance for use with Traits UI.  This method is
 
175
        called automatically be the Traits framework when .edit_traits() is
 
176
        invoked.
 
177
        """
 
178
        from axis_view import AxisView
 
179
        return AxisView
 
180
 
 
181
 
 
182
    #------------------------------------------------------------------------
 
183
    # PlotComponent and AbstractOverlay interface
 
184
    #------------------------------------------------------------------------
 
185
 
 
186
    def _do_layout(self, *args, **kw):
 
187
        """ Tells this component to do layout at a given size.
 
188
        
 
189
        Overrides PlotComponent.
 
190
        """
 
191
        if self.use_draw_order and self.component is not None:
 
192
            self._layout_as_overlay(*args, **kw)
 
193
        else:
 
194
            super(PlotAxis, self)._do_layout(*args, **kw)
 
195
        return
 
196
 
 
197
    def overlay(self, component, gc, view_bounds=None, mode='normal'):
 
198
        """ Draws this component overlaid on another component.
 
199
        
 
200
        Overrides AbstractOverlay.
 
201
        """
 
202
        if not self.visible:
 
203
            return
 
204
        self._draw_component(gc, view_bounds, mode, component)
 
205
        return
 
206
 
 
207
    def _draw_overlay(self, gc, view_bounds=None, mode='normal'):
 
208
        """ Draws the overlay layer of a component.
 
209
        
 
210
        Overrides PlotComponent.
 
211
        """
 
212
        self._draw_component(gc, view_bounds, mode)
 
213
        return
 
214
 
 
215
    def _draw_component(self, gc, view_bounds=None, mode='normal', component=None):
 
216
        """ Draws the component.
 
217
 
 
218
        This method is preserved for backwards compatibility. Overrides 
 
219
        PlotComponent.
 
220
        """
 
221
        if not self.visible:
 
222
            return
 
223
 
 
224
        if not self._cache_valid:
 
225
            if component is not None:
 
226
                self._calculate_geometry(component)
 
227
            else:
 
228
                self._old_calculate_geometry()
 
229
            self._compute_tick_positions(gc, component)
 
230
            self._compute_labels(gc)
 
231
 
 
232
        try:
 
233
            gc.save_state()
 
234
 
 
235
            # slight optimization: if we set the font correctly on the
 
236
            # base gc before handing it in to our title and tick labels,
 
237
            # their set_font() won't have to do any work.
 
238
            gc.set_font(self.tick_label_font)
 
239
 
 
240
            if self.axis_line_visible:
 
241
                self._draw_axis_line(gc, self._origin_point, self._end_axis_point)
 
242
            if self.title:
 
243
                self._draw_title(gc)
 
244
 
 
245
            self._draw_ticks(gc)
 
246
            self._draw_labels(gc)
 
247
        finally:
 
248
            gc.restore_state()
 
249
 
 
250
        self._cache_valid = True
 
251
        return
 
252
 
 
253
 
 
254
    #------------------------------------------------------------------------
 
255
    # Private draw routines
 
256
    #------------------------------------------------------------------------
 
257
 
 
258
    def _layout_as_overlay(self, size=None, force=False):
 
259
        """ Lays out the axis as an overlay on another component.
 
260
        """
 
261
        if self.component is not None:
 
262
            if self.orientation in ("left", "right"):
 
263
                self.y = self.component.y
 
264
                self.height = self.component.height
 
265
                if self.orientation == "left":
 
266
                    self.width = self.component.padding_left
 
267
                    self.x = self.component.outer_x
 
268
                elif self.orientation == "right":
 
269
                    self.width = self.component.padding_right
 
270
                    self.x = self.component.x2 + 1
 
271
            else:
 
272
                self.x = self.component.x
 
273
                self.width = self.component.width
 
274
                if self.orientation == "bottom":
 
275
                    self.height = self.component.padding_bottom
 
276
                    self.y = self.component.outer_y
 
277
                elif self.orientation == "top":
 
278
                    self.height = self.component.padding_top
 
279
                    self.y = self.component.y2 + 1
 
280
        return
 
281
 
 
282
    def _draw_axis_line(self, gc, startpoint, endpoint):
 
283
        """ Draws the line for the axis.
 
284
        """
 
285
        gc.save_state()
 
286
        try:
 
287
            gc.set_antialias(0)
 
288
            gc.set_line_width(self.axis_line_weight)
 
289
            gc.set_stroke_color(self.axis_line_color_)
 
290
            gc.set_line_dash(self.axis_line_style_)
 
291
            gc.move_to(*around(startpoint))
 
292
            gc.line_to(*around(endpoint))
 
293
            gc.stroke_path()
 
294
        finally:
 
295
            gc.restore_state()
 
296
        return
 
297
 
 
298
    def _draw_title_old(self, gc, label=None, v_offset=20):
 
299
        """ Draws the title for the axis.
 
300
        """
 
301
        #put in rotation code for right side
 
302
 
 
303
        if label is None:
 
304
            title_label = Label(text=self.title,
 
305
                                font=self.title_font,
 
306
                                color=self.title_color,
 
307
                                rotate_angle=self.title_angle)
 
308
        else:
 
309
            title_label = label
 
310
        tl_bounds = array(title_label.get_width_height(gc), float64)
 
311
 
 
312
        if self.title_angle == 0:
 
313
            text_center_to_corner = -tl_bounds/2.0
 
314
            v_offset = max([l._bounding_box[1] for l in self.ticklabel_cache]) * 1.3
 
315
        else:
 
316
            v_offset = max([l._bounding_box[0] for l in self.ticklabel_cache]) * 1.3
 
317
            corner_vec = transpose(-tl_bounds/2.0)
 
318
            rotmatrix = self._rotmatrix(-self.title_angle*pi/180.0)
 
319
            text_center_to_corner = transpose(dot(rotmatrix, corner_vec))[0]
 
320
 
 
321
        offset = (self._origin_point+self._end_axis_point)/2
 
322
        center_dist = self._center_dist(-self._inside_vector, tl_bounds[0], tl_bounds[1], rotation=self.title_angle)
 
323
        offset -= self._inside_vector * (center_dist + v_offset)
 
324
        offset += text_center_to_corner
 
325
 
 
326
        if self.title_angle == 90.0:
 
327
            # Horrible hack to adjust for the fact that the generic math above isn't
 
328
            # actually putting the label in the right place...
 
329
            offset[1] = offset[1] - tl_bounds[0]/2.0
 
330
 
 
331
        gc.translate_ctm(*offset)
 
332
        title_label.draw(gc)
 
333
        gc.translate_ctm(*(-offset))
 
334
 
 
335
        return
 
336
 
 
337
 
 
338
    def _draw_title(self, gc, label=None, v_offset=None):
 
339
        """ Draws the title for the axis.
 
340
        """
 
341
        #put in rotation code for right side
 
342
 
 
343
 
 
344
        if label is None:
 
345
            title_label = Label(text=self.title,
 
346
                                font=self.title_font,
 
347
                                color=self.title_color,
 
348
                                rotate_angle=self.title_angle)
 
349
        else:
 
350
            title_label = label
 
351
        tl_bounds = array(title_label.get_width_height(gc), float64)
 
352
 
 
353
        if self.title_spacing != 'auto':
 
354
            v_offset = self.title_spacing
 
355
        calculate_v_offset = (self.title_spacing) and (v_offset is None )
 
356
        
 
357
        if self.title_angle == 0:
 
358
            text_center_to_corner = -tl_bounds/2.0
 
359
            if calculate_v_offset:
 
360
                if not self.ticklabel_cache:
 
361
                    v_offset = 25
 
362
                else:
 
363
                    v_offset = max([l._bounding_box[1] for l in self.ticklabel_cache]) * 1.3
 
364
                           
 
365
            offset = (self._origin_point+self._end_axis_point)/2
 
366
            center_dist = self._center_dist(-self._inside_vector, tl_bounds[0], tl_bounds[1], rotation=self.title_angle)
 
367
            offset -= self._inside_vector * (center_dist + v_offset)
 
368
            offset += text_center_to_corner
 
369
        
 
370
        elif self.title_angle == 90:
 
371
            # Center the text vertically
 
372
            if calculate_v_offset:
 
373
                if not self.ticklabel_cache:
 
374
                    v_offset = 25
 
375
                else:
 
376
                    v_offset = (self._end_axis_point[1] - self._origin_point[1] - tl_bounds[0])/2.0
 
377
            h_offset = self.tick_out + tl_bounds[1] + 8
 
378
            if len(self.ticklabel_cache) > 0:
 
379
                h_offset += max([l._bounding_box[0] for l in self.ticklabel_cache]) 
 
380
            offset = array([self._origin_point[0] - h_offset, self._origin_point[1] + v_offset])
 
381
 
 
382
        elif self.title_angle == 270:
 
383
            # Center the text vertically
 
384
            if calculate_v_offset:
 
385
                if not self.ticklabel_cache:
 
386
                    v_offset = 25
 
387
                else:
 
388
                    v_offset = (self._end_axis_point[1] - self._origin_point[1] + tl_bounds[0])/2.0
 
389
            h_offset = self.tick_out + tl_bounds[1] + 8
 
390
            if len(self.ticklabel_cache) > 0:
 
391
                h_offset += max([l._bounding_box[0] for l in self.ticklabel_cache]) 
 
392
            offset = array([self._origin_point[0] + h_offset, self._origin_point[1] + v_offset])
 
393
 
 
394
        else:
 
395
            if calculate_v_offset:
 
396
                if not self.ticklabel_cache:
 
397
                    v_offset = 25
 
398
                else:
 
399
                    v_offset = max([l._bounding_box[0] for l in self.ticklabel_cache]) * 1.3
 
400
            corner_vec = transpose(-tl_bounds/2.0)
 
401
            rotmatrix = self._rotmatrix(-self.title_angle*pi/180.0)
 
402
            text_center_to_corner = transpose(dot(rotmatrix, corner_vec))[0]
 
403
            offset = (self._origin_point+self._end_axis_point)/2
 
404
            center_dist = self._center_dist(-self._inside_vector, tl_bounds[0], tl_bounds[1], rotation=self.title_angle)
 
405
            offset -= self._inside_vector * (center_dist + v_offset)
 
406
            offset += text_center_to_corner
 
407
 
 
408
        gc.translate_ctm(*offset)
 
409
        title_label.draw(gc)
 
410
        gc.translate_ctm(*(-offset))
 
411
 
 
412
        return
 
413
 
 
414
 
 
415
    def _draw_ticks(self, gc):
 
416
        """ Draws the tick marks for the axis.
 
417
        """
 
418
        if not self.tick_visible:
 
419
            return
 
420
        gc.set_stroke_color(self.tick_color_)
 
421
        gc.set_line_width(self.tick_weight)
 
422
        gc.set_antialias(False)
 
423
        gc.begin_path()
 
424
        tick_in_vector = self._inside_vector*self.tick_in
 
425
        tick_out_vector = self._inside_vector*self.tick_out
 
426
        for tick_pos in self._tick_positions:
 
427
            gc.move_to(*(tick_pos + tick_in_vector))
 
428
            gc.line_to(*(tick_pos - tick_out_vector))
 
429
        gc.stroke_path()
 
430
        return
 
431
 
 
432
    def _draw_labels(self, gc):
 
433
        """ Draws the tick labels for the axis.
 
434
        """
 
435
        for i in range(len(self._tick_label_positions)):
 
436
            #We want a more sophisticated scheme than just 2 decimals all the time
 
437
            ticklabel = self.ticklabel_cache[i]
 
438
            tl_bounds = self._tick_label_bounding_boxes[i]
 
439
 
 
440
            #base_position puts the tick label at a point where the vector
 
441
            #extending from the tick mark inside 8 units
 
442
            #just touches the rectangular bounding box of the tick label.
 
443
            #Note: This is not necessarily optimal for non
 
444
            #horizontal/vertical axes.  More work could be done on this.
 
445
 
 
446
            base_position = (self._center_dist(-self._inside_vector, *tl_bounds)+8) \
 
447
                                * -self._inside_vector \
 
448
                                - tl_bounds/2.0 + self._tick_label_positions[i]
 
449
 
 
450
            if self.ensure_labels_bounded:
 
451
                pushdir = 0
 
452
                if i == 0:
 
453
                    pushdir = 1
 
454
                elif i == len(self._tick_label_positions)-1:
 
455
                    pushdir = -1
 
456
                push_pixel_vector = self._axis_pixel_vector * pushdir
 
457
                tlpos = around((self._center_dist(push_pixel_vector,*tl_bounds)+4) \
 
458
                                          * push_pixel_vector + base_position)
 
459
 
 
460
            else:
 
461
                tlpos = around(base_position)
 
462
 
 
463
            gc.translate_ctm(*tlpos)
 
464
            ticklabel.draw(gc)
 
465
            gc.translate_ctm(*(-tlpos))
 
466
        return
 
467
 
 
468
 
 
469
    #------------------------------------------------------------------------
 
470
    # Private methods for computing positions and layout
 
471
    #------------------------------------------------------------------------
 
472
 
 
473
    def _reset_cache(self):
 
474
        """ Clears the cached tick positions, labels, and label positions.
 
475
        """
 
476
        self._tick_positions = []
 
477
        self._tick_label_list = []
 
478
        self._tick_label_positions = []
 
479
        return
 
480
 
 
481
    def _compute_tick_positions(self, gc, overlay_component=None):
 
482
        """ Calculates the positions for the tick marks.
 
483
        """
 
484
        if (self.mapper is None):
 
485
            self._reset_cache()
 
486
            self._cache_valid = True
 
487
            return
 
488
 
 
489
        datalow = self.mapper.range.low
 
490
        datahigh = self.mapper.range.high
 
491
        screenhigh = self.mapper.high_pos
 
492
        screenlow = self.mapper.low_pos
 
493
        if overlay_component is not None:
 
494
            origin = getattr(overlay_component, 'origin', 'bottom left')
 
495
            if self.orientation in ("top", "bottom"):
 
496
                if "right" in origin: 
 
497
                    flip_from_gc = True
 
498
                else: 
 
499
                    flip_from_gc = False
 
500
            elif self.orientation in ("left", "right"):
 
501
                if "top" in origin: 
 
502
                    flip_from_gc = True
 
503
                else: 
 
504
                    flip_from_gc = False
 
505
 
 
506
            if flip_from_gc:
 
507
                screenlow, screenhigh = screenhigh, screenlow
 
508
 
 
509
 
 
510
        if (datalow == datahigh) or (screenlow == screenhigh) or \
 
511
           (datalow in [inf, -inf]) or (datahigh in [inf, -inf]):
 
512
            self._reset_cache()
 
513
            self._cache_valid = True
 
514
            return
 
515
 
 
516
        if datalow > datahigh:
 
517
            raise RuntimeError, "DataRange low is greater than high; unable to compute axis ticks."
 
518
 
 
519
        if not self.tick_generator:
 
520
            return
 
521
 
 
522
        if isinstance(self.mapper, LogMapper):
 
523
            scale = 'log'
 
524
        else:
 
525
            scale = 'linear'
 
526
 
 
527
        tick_list = array(self.tick_generator.get_ticks(datalow, datahigh,
 
528
                                                        datalow, datahigh,
 
529
                                                        self.tick_interval,
 
530
                                                        use_endpoints=False,
 
531
                                                        scale=scale), float64)
 
532
 
 
533
        mapped_tick_positions = (array(self.mapper.map_screen(tick_list))-screenlow) / \
 
534
                                            (screenhigh-screenlow)
 
535
        self._tick_positions = around(array([self._axis_vector*tickpos + self._origin_point \
 
536
                                for tickpos in mapped_tick_positions]))
 
537
 
 
538
        if self.small_haxis_style:
 
539
            # If we're a small axis, we want the endpoints to be the labels regardless of
 
540
            # where the ticks are, as the labels represent the bounds, not where the tick
 
541
            # marks are.
 
542
            self._tick_label_list = array([datalow, datahigh])
 
543
            mapped_label_positions = (array(self.mapper.map_screen(self._tick_label_list))-screenlow) / \
 
544
                                     (screenhigh-screenlow)
 
545
            self._tick_label_positions = [self._axis_vector*tickpos + self._origin_point \
 
546
                                          for tickpos in mapped_label_positions]
 
547
        else:
 
548
            self._tick_label_list = tick_list
 
549
            self._tick_label_positions = self._tick_positions
 
550
        return
 
551
 
 
552
 
 
553
    def _compute_labels(self, gc):
 
554
        """Generates the labels for tick marks.  
 
555
        
 
556
        Waits for the cache to become invalid.
 
557
        """
 
558
        self.ticklabel_cache = []
 
559
        formatter = self.tick_label_formatter
 
560
        for i in range(len(self._tick_label_positions)):
 
561
            val = self._tick_label_list[i]
 
562
            if formatter is not None:
 
563
                tickstring = formatter(val)
 
564
            else:
 
565
                tickstring = str(val)
 
566
            ticklabel = Label(text=tickstring,
 
567
                              font=self.tick_label_font,
 
568
                              color=self.tick_label_color)
 
569
            self.ticklabel_cache.append(ticklabel)
 
570
 
 
571
        # TODO: Right now we are hardcoding this handling of a scaled CTM,
 
572
        # eventually it would be nice if we didn't have to do this.
 
573
        ctm = gc.get_ctm()
 
574
        if len(ctm) == 6:
 
575
            # AffineMatrix class
 
576
            scale = array((ctm[0], ctm[3]))
 
577
        elif len(ctm) == 3:
 
578
            # Mac GC
 
579
            scale = array((ctm[0][0], ctm[1][1]))
 
580
        else:
 
581
            scale = array((1.0, 1.0))
 
582
        self._tick_label_bounding_boxes = [array(ticklabel.get_bounding_box(gc), float) / scale
 
583
                                               for ticklabel in self.ticklabel_cache]
 
584
        return
 
585
 
 
586
 
 
587
    def _calculate_geometry(self, overlay_component=None):
 
588
        if overlay_component is not None:
 
589
            if self.orientation == "top":
 
590
                new_origin = [overlay_component.x, overlay_component.y2]
 
591
                inside_vec = [0.0, -1.0]
 
592
            elif self.orientation == "bottom":
 
593
                new_origin = [overlay_component.x, overlay_component.y]
 
594
                inside_vec = [0.0, 1.0]
 
595
            elif self.orientation == "left":
 
596
                new_origin = [overlay_component.x, overlay_component.y]
 
597
                inside_vec = [1.0, 0.0]
 
598
            else:  # self.orientation == "right":
 
599
                new_origin = [overlay_component.x2, overlay_component.y]
 
600
                inside_vec = [-1.0, 0.0]
 
601
            self._origin_point = array(new_origin)
 
602
            self._inside_vector = array(inside_vec)
 
603
        else:
 
604
            #FIXME: Why aren't we setting self._inside_vector here?
 
605
            overlay_component = self
 
606
            new_origin = array(self.position)
 
607
 
 
608
        origin = getattr(overlay_component, "origin", 'bottom left')
 
609
        if self.orientation in ('top', 'bottom'):
 
610
            self._major_axis_size = overlay_component.bounds[0]
 
611
            self._minor_axis_size = overlay_component.bounds[1]
 
612
            self._major_axis = array([1., 0.])
 
613
            self._title_orientation = array([0.,1.])
 
614
            if "right" in origin: 
 
615
                flip_from_gc = True
 
616
            else: 
 
617
                flip_from_gc = False
 
618
            #this could be calculated...
 
619
            self.title_angle = 0.0
 
620
        elif self.orientation in ('left', 'right'):
 
621
            self._major_axis_size = overlay_component.bounds[1]
 
622
            self._minor_axis_size = overlay_component.bounds[0]
 
623
            self._major_axis = array([0., 1.])
 
624
            self._title_orientation = array([-1., 0])
 
625
            origin = getattr(overlay_component, "origin", 'bottom left')
 
626
            if "top" in origin: 
 
627
                flip_from_gc = True
 
628
            else: 
 
629
                flip_from_gc = False
 
630
            if self.orientation == 'left':
 
631
                self.title_angle = 90.0
 
632
            else:
 
633
                self.title_angle = 270.0                
 
634
 
 
635
        if self.ensure_ticks_bounded:
 
636
            self._origin_point -= self._inside_vector*self.tick_in
 
637
 
 
638
        screenhigh = self.mapper.high_pos
 
639
        screenlow = self.mapper.low_pos
 
640
        # TODO: should this be here, or not?
 
641
        if flip_from_gc:
 
642
            screenlow, screenhigh = screenhigh, screenlow
 
643
        
 
644
        self._end_axis_point = (screenhigh-screenlow)*self._major_axis + self._origin_point
 
645
        self._axis_vector = self._end_axis_point - self._origin_point
 
646
        # This is the vector that represents one unit of data space in terms of screen space.
 
647
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
 
648
        return
 
649
 
 
650
 
 
651
    def _old_calculate_geometry(self):
 
652
        if hasattr(self, 'mapper') and self.mapper is not None:
 
653
            screenhigh = self.mapper.high_pos
 
654
            screenlow = self.mapper.low_pos
 
655
        else:
 
656
            # fixme: this should take into account axis orientation
 
657
            screenhigh = self.x2
 
658
            screenlow = self.x
 
659
 
 
660
        if self.orientation in ('top', 'bottom'):
 
661
            self._major_axis_size = self.bounds[0]
 
662
            self._minor_axis_size = self.bounds[1]
 
663
            self._major_axis = array([1., 0.])
 
664
            self._title_orientation = array([0.,1.])
 
665
            #this could be calculated...
 
666
            self.title_angle = 0.0
 
667
            if self.orientation == 'top':
 
668
                self._origin_point = array(self.position) + self._major_axis * screenlow
 
669
                self._inside_vector = array([0.,-1.])
 
670
            else: #self.oriention == 'bottom'
 
671
                self._origin_point = array(self.position) + array([0., self.bounds[1]]) + self._major_axis*screenlow
 
672
                self._inside_vector = array([0., 1.])
 
673
        elif self.orientation in ('left', 'right'):
 
674
            self._major_axis_size = self.bounds[1]
 
675
            self._minor_axis_size = self.bounds[0]
 
676
            self._major_axis = array([0., 1.])
 
677
            self._title_orientation = array([-1., 0])
 
678
            self.title_angle = 90.0
 
679
            if self.orientation == 'left':
 
680
                self._origin_point = array(self.position) + array([self.bounds[0], 0.]) + self._major_axis*screenlow
 
681
                self._inside_vector = array([1., 0.])
 
682
            else: #self.orientation == 'right'
 
683
                self._origin_point = array(self.position) + self._major_axis*screenlow
 
684
                self._inside_vector = array([-1., 0.])
 
685
 
 
686
#        if self.mapper.high_pos<self.mapper.low_pos:
 
687
#            self._origin_point = self._origin_point + self._axis_
 
688
 
 
689
        if self.ensure_ticks_bounded:
 
690
            self._origin_point -= self._inside_vector*self.tick_in
 
691
 
 
692
        self._end_axis_point = (screenhigh-screenlow)*self._major_axis + self._origin_point
 
693
        self._axis_vector = self._end_axis_point - self._origin_point
 
694
        # This is the vector that represents one unit of data space in terms of screen space.
 
695
        self._axis_pixel_vector = self._axis_vector/sqrt(dot(self._axis_vector,self._axis_vector))
 
696
        return
 
697
 
 
698
 
 
699
    #------------------------------------------------------------------------
 
700
    # Private helper methods
 
701
    #------------------------------------------------------------------------
 
702
 
 
703
    def _rotmatrix(self, theta):
 
704
        """Returns a 2x2 rotation matrix for angle *theta*.
 
705
        """
 
706
        return array([[cos(theta), sin(theta)], [-sin(theta), cos(theta)]], float64)
 
707
 
 
708
    def _center_dist(self, vect, width, height, rotation=0.0):
 
709
        """Given a width and height of a rectangle, this method finds the
 
710
        distance in units of the vector, in the direction of the vector, from
 
711
        the center of the rectangle, to wherever the vector leaves the
 
712
        rectangle. This method is useful for determining where to place text so
 
713
        it doesn't run into other components. """
 
714
        rotvec = transpose(dot(self._rotmatrix(rotation*pi/180.0), transpose(array([vect], float64))))[0]
 
715
        absvec = absolute(rotvec)
 
716
        if absvec[1] != 0:
 
717
            heightdist = (float(height)/2)/float(absvec[1])
 
718
        else:
 
719
            heightdist = 9999999
 
720
        if absvec[0] != 0:
 
721
            widthdist = (float(width)/2)/float(absvec[0])
 
722
        else:
 
723
            widthdist = 99999999
 
724
 
 
725
        return min(heightdist, widthdist)
 
726
 
 
727
 
 
728
    #------------------------------------------------------------------------
 
729
    # Event handlers
 
730
    #------------------------------------------------------------------------
 
731
 
 
732
    def _bounds_changed(self, old, new):
 
733
        super(PlotAxis, self)._bounds_changed(old, new)
 
734
        self._layout_needed = True
 
735
        self._invalidate()
 
736
 
 
737
    def _bounds_items_changed(self, event):
 
738
        super(PlotAxis, self)._bounds_items_changed(event)
 
739
        self._layout_needed = True
 
740
        self._invalidate()
 
741
 
 
742
    def _mapper_changed(self, old, new):
 
743
        if old is not None:
 
744
            old.on_trait_change(self.mapper_updated, "updated", remove=True)
 
745
        if new is not None:
 
746
            new.on_trait_change(self.mapper_updated, "updated")
 
747
        self._invalidate()
 
748
 
 
749
    def mapper_updated(self):
 
750
        """
 
751
        Event handler that is bound to this axis's mapper's **updated** event
 
752
        """
 
753
        self._invalidate()
 
754
 
 
755
    def _position_changed(self, old, new):
 
756
        super(PlotAxis, self)._position_changed(old, new)
 
757
        self._cache_valid = False
 
758
 
 
759
    def _position_items_changed(self, event):
 
760
        super(PlotAxis, self)._position_items_changed(event)
 
761
        self._cache_valid = False
 
762
    
 
763
    def _position_changed_for_component(self):
 
764
        self._cache_valid = False
 
765
 
 
766
    def _position_items_changed_for_component(self):
 
767
        self._cache_valid = False
 
768
 
 
769
    def _bounds_changed_for_component(self):
 
770
        self._cache_valid = False
 
771
        self._layout_needed = True
 
772
 
 
773
    def _bounds_items_changed_for_component(self):
 
774
        self._cache_valid = False
 
775
        self._layout_needed = True
 
776
 
 
777
    def _origin_changed_for_component(self):
 
778
        self._invalidate()
 
779
    
 
780
    def _updated_fired(self):
 
781
        """If the axis bounds changed, redraw."""
 
782
        self._cache_valid = False
 
783
        return
 
784
 
 
785
    def _invalidate(self):
 
786
        self._cache_valid = False
 
787
        self.invalidate_draw()
 
788
        if self.component:
 
789
            self.component.invalidate_draw()
 
790
#            self.component.request_redraw()
 
791
#        else:
 
792
#            self.request_redraw()
 
793
        return
 
794
 
 
795
    def _component_changed(self):
 
796
        if self.mapper is not None:
 
797
            # If there is a mapper set, just leave it be.
 
798
            return
 
799
 
 
800
        # Try to pick the most appropriate mapper for our orientation 
 
801
        # and what information we can glean from our component.
 
802
        attrmap = { "left": ("ymapper", "y_mapper", "value_mapper"),
 
803
                    "bottom": ("xmapper", "x_mapper", "index_mapper"), }
 
804
        attrmap["right"] = attrmap["left"]
 
805
        attrmap["top"] = attrmap["bottom"]
 
806
 
 
807
        component = self.component
 
808
        attr1, attr2, attr3 = attrmap[self.orientation]
 
809
        for attr in attrmap[self.orientation]:
 
810
            if hasattr(component, attr):
 
811
                self.mapper = getattr(component, attr)
 
812
                break
 
813
        return
 
814
 
 
815
 
 
816
    #------------------------------------------------------------------------
 
817
    # The following event handlers just invalidate our previously computed
 
818
    # Label instances and backbuffer if any of our visual attributes change.
 
819
    # TODO: refactor this stuff and the caching of contained objects (e.g. Label)
 
820
    #------------------------------------------------------------------------
 
821
 
 
822
    def _title_changed(self):
 
823
        self.invalidate_draw()
 
824
        if self.component:
 
825
            self.component.invalidate_draw()
 
826
#            self.component.request_redraw()
 
827
#        else:
 
828
#            self.request_redraw()
 
829
 
 
830
    def _title_color_changed(self):
 
831
        return self._invalidate()
 
832
 
 
833
    def _title_font_changed(self):
 
834
        return self._invalidate()
 
835
 
 
836
    def _tick_weight_changed(self):
 
837
        return self._invalidate()
 
838
 
 
839
    def _tick_color_changed(self):
 
840
        return self._invalidate()
 
841
 
 
842
    def _tick_font_changed(self):
 
843
        return self._invalidate()
 
844
 
 
845
    def _tick_label_font_changed(self):
 
846
        return self._invalidate()
 
847
 
 
848
    def _tick_label_color_changed(self):
 
849
        return self._invalidate()
 
850
 
 
851
    def _tick_in_changed(self):
 
852
        return self._invalidate()
 
853
 
 
854
    def _tick_out_changed(self):
 
855
        return self._invalidate()
 
856
 
 
857
    def _tick_visible_changed(self):
 
858
        return self._invalidate()
 
859
 
 
860
    def _tick_interval_changed(self):
 
861
        return self._invalidate()
 
862
 
 
863
    def _axis_line_color_changed(self):
 
864
        return self._invalidate()
 
865
 
 
866
    def _axis_line_weight_changed(self):
 
867
        return self._invalidate()
 
868
 
 
869
    def _axis_line_style_changed(self):
 
870
        return self._invalidate()
 
871
 
 
872
    def _orientation_changed(self):
 
873
        return self._invalidate()
 
874
 
 
875
    #------------------------------------------------------------------------
 
876
    # Persistence-related methods
 
877
    #------------------------------------------------------------------------
 
878
 
 
879
    def __getstate__(self):
 
880
        dont_pickle = [
 
881
            '_tick_list',
 
882
            '_tick_positions',
 
883
            '_tick_label_list',
 
884
            '_tick_label_positions',
 
885
            '_tick_label_bounding_boxes',
 
886
            '_major_axis_size',
 
887
            '_minor_axis_size',
 
888
            '_major_axis',
 
889
            '_title_orientation',
 
890
            '_title_angle',
 
891
            '_origin_point',
 
892
            '_inside_vector',
 
893
            '_axis_vector',
 
894
            '_axis_pixel_vector',
 
895
            '_end_axis_point',
 
896
            '_ticklabel_cache',
 
897
            '_cache_valid'
 
898
           ]
 
899
 
 
900
        state = super(PlotAxis,self).__getstate__()
 
901
        for key in dont_pickle:
 
902
            if state.has_key(key):
 
903
                del state[key]
 
904
 
 
905
        return state
 
906
 
 
907
    def __setstate__(self, state):
 
908
        super(PlotAxis,self).__setstate__(state)
 
909
        self._mapper_changed(None, self.mapper)
 
910
        self._reset_cache()
 
911
        self._cache_valid = False
 
912
        return
 
913
 
 
914
 
 
915
# EOF ########################################################################