~ubuntu-branches/debian/squeeze/pyopencl/squeeze

« back to all changes in this revision

Viewing changes to examples/demo_mandelbrot.py

  • Committer: Bazaar Package Importer
  • Author(s): Tomasz Rybak
  • Date: 2010-05-31 19:29:00 UTC
  • Revision ID: james.westby@ubuntu.com-20100531192900-ll7guuro37nntr4y
Tags: upstream-0.92~beta+git20100709
ImportĀ upstreamĀ versionĀ 0.92~beta+git20100709

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# I found this example for PyCuda here:
 
2
# http://wiki.tiker.net/PyCuda/Examples/Mandelbrot
 
3
#
 
4
# I adapted it for PyOpenCL. Hopefully it is useful to someone.
 
5
# July 2010, HolgerRapp@gmx.net
 
6
#
 
7
# Original readme below these lines. 
 
8
 
 
9
# Mandelbrot calculate using GPU, Serial numpy and faster numpy
 
10
# Use to show the speed difference between CPU and GPU calculations
 
11
# ian@ianozsvald.com March 2010
 
12
 
 
13
# Based on vegaseat's TKinter/numpy example code from 2006
 
14
# http://www.daniweb.com/code/snippet216851.html#
 
15
# with minor changes to move to numpy from the obsolete Numeric
 
16
 
 
17
import numpy as np
 
18
import time
 
19
 
 
20
import numpy
 
21
import numpy.linalg as la
 
22
 
 
23
import pyopencl as cl
 
24
 
 
25
# You can choose a calculation routine below (calc_fractal), uncomment
 
26
# one of the three lines to test the three variations
 
27
# Speed notes are listed in the same place
 
28
 
 
29
# set width and height of window, more pixels take longer to calculate
 
30
w = 400
 
31
h = 400
 
32
 
 
33
def calc_fractal_opencl(q, maxiter):
 
34
    ctx = cl.Context(cl.get_platforms()[0].get_devices())
 
35
    queue = cl.CommandQueue(ctx)
 
36
 
 
37
    output = np.empty(q.shape, dtype=np.uint64)# resize(np.array(0,), q.shape)
 
38
 
 
39
    mf = cl.mem_flags
 
40
    q_opencl = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=q)
 
41
    output_opencl = cl.Buffer(ctx, mf.WRITE_ONLY, output.nbytes)
 
42
 
 
43
    prg = cl.Program(ctx, """
 
44
    __kernel void mandelbrot(__global float2 *q,
 
45
                     __global long *output, long const maxiter)
 
46
    {
 
47
        int gid = get_global_id(0);
 
48
        float nreal, real = 0;
 
49
        float imag = 0;
 
50
        for(int curiter = 0; curiter < maxiter; curiter++) {
 
51
            nreal = real*real - imag*imag + q[gid][0];
 
52
            imag = 2* real*imag + q[gid][1];
 
53
            real = nreal;
 
54
 
 
55
            if (real*real + imag*imag > 4.) {
 
56
                 output[gid] = curiter;
 
57
                 break;
 
58
            }
 
59
        }
 
60
    }
 
61
    """).build()
 
62
 
 
63
    prg.mandelbrot(queue, output.shape, None, q_opencl,
 
64
            output_opencl, np.int32(maxiter))
 
65
 
 
66
    cl.enqueue_read_buffer(queue, output_opencl, output).wait()
 
67
 
 
68
    return output
 
69
 
 
70
 
 
71
 
 
72
def calc_fractal_serial(q, maxiter):
 
73
    # calculate z using numpy
 
74
    # this routine unrolls calc_fractal_numpy as an intermediate
 
75
    # step to the creation of calc_fractal_opencl
 
76
    # it runs slower than calc_fractal_numpy
 
77
    z = np.zeros(q.shape, np.complex64)
 
78
    output = np.resize(np.array(0,), q.shape)
 
79
    for i in range(len(q)):
 
80
        for iter in range(maxiter):
 
81
            z[i] = z[i]*z[i] + q[i]
 
82
            if abs(z[i]) > 2.0:
 
83
                q[i] = 0+0j
 
84
                z[i] = 0+0j
 
85
                output[i] = iter
 
86
    return output
 
87
 
 
88
def calc_fractal_numpy(q, maxiter):
 
89
    # calculate z using numpy, this is the original
 
90
    # routine from vegaseat's URL
 
91
    output = np.resize(np.array(0,), q.shape)
 
92
    z = np.zeros(q.shape, np.complex64)
 
93
 
 
94
    for iter in range(maxiter):
 
95
        z = z*z + q
 
96
        done = np.greater(abs(z), 2.0)
 
97
        q = np.where(done,0+0j, q)
 
98
        z = np.where(done,0+0j, z)
 
99
        output = np.where(done, iter, output)
 
100
    return output
 
101
 
 
102
# choose your calculation routine here by uncommenting one of the options
 
103
calc_fractal = calc_fractal_opencl
 
104
# calc_fractal = calc_fractal_serial
 
105
# calc_fractal = calc_fractal_numpy
 
106
 
 
107
if __name__ == '__main__':
 
108
    import Tkinter as tk
 
109
    import Image          # PIL
 
110
    import ImageTk        # PIL
 
111
 
 
112
 
 
113
    class Mandelbrot(object):
 
114
        def __init__(self):
 
115
            # create window
 
116
            self.root = tk.Tk()
 
117
            self.root.title("Mandelbrot Set")
 
118
            self.create_image()
 
119
            self.create_label()
 
120
            # start event loop
 
121
            self.root.mainloop()
 
122
 
 
123
 
 
124
        def draw(self, x1, x2, y1, y2, maxiter=300):
 
125
            # draw the Mandelbrot set, from numpy example
 
126
            xx = np.arange(x1, x2, (x2-x1)/w*2)
 
127
            yy = np.arange(y2, y1, (y1-y2)/h*2) * 1j
 
128
            q = np.ravel(xx+yy[:, np.newaxis]).astype(np.complex64)
 
129
 
 
130
            start_main = time.time()
 
131
            output = calc_fractal(q, maxiter)
 
132
            end_main = time.time()
 
133
 
 
134
            secs = end_main - start_main
 
135
            print "Main took", secs
 
136
 
 
137
            output = (output + (256*output) + (256**2)*output) * 8
 
138
            # convert output to a string
 
139
            self.mandel = output.tostring()
 
140
 
 
141
        def create_image(self):
 
142
            """"
 
143
            create the image from the draw() string
 
144
            """
 
145
            self.im = Image.new("RGB", (w/2, h/2))
 
146
            # you can experiment with these x and y ranges
 
147
            self.draw(-2.13, 0.77, -1.3, 1.3)
 
148
            self.im.fromstring(self.mandel, "raw", "RGBX", 0, -1)
 
149
 
 
150
        def create_label(self):
 
151
            # put the image on a label widget
 
152
            self.image = ImageTk.PhotoImage(self.im)
 
153
            self.label = tk.Label(self.root, image=self.image)
 
154
            self.label.pack()
 
155
 
 
156
    # test the class
 
157
    test = Mandelbrot()
 
158