1
"""A way over-simplified SRUMMA matrix multiplication implementation.
3
Assumes square matrices with the shape as a multiple of the block size.
12
N = CHUNK_SIZE*MULTIPLIER
14
### assign to 'me' this processor's ID
15
### assign to 'nproc' how many processors in the world
18
def __init__(self, alo, ahi, blo, bhi, clo, chi):
26
return "Task(%s,%s,%s,%s,%s,%s)" % (
27
self.alo, self.ahi, self.blo, self.bhi, self.clo, self.chi)
29
def get_task_list(chunk_size, multiplier):
31
task_list = [None]*multiplier**3
32
for row_chunk in range(multiplier):
33
for col_chunk in range(multiplier):
34
clo = [ row_chunk *chunk_size, col_chunk *chunk_size]
35
chi = [(row_chunk+1)*chunk_size, (col_chunk+1)*chunk_size]
36
for i in range(multiplier):
37
alo = [ row_chunk *chunk_size, i *chunk_size]
38
ahi = [(row_chunk+1)*chunk_size, (i+1)*chunk_size]
39
blo = [ i *chunk_size, col_chunk *chunk_size]
40
bhi = [(i+1)*chunk_size, (col_chunk+1)*chunk_size]
41
task_list[count] = Task(alo, ahi, blo, bhi, clo, chi)
45
def srumma(g_a, g_b, g_c, chunk_size, multiplier):
46
# statically partition the task list among nprocs
47
task_list = get_task_list(chunk_size, multiplier)
48
ntasks = multiplier**3 // nproc
52
stop += multiplier**3 % nproc
53
# the srumma algorithm, more or less
54
task_prev = task_list[start]
55
### use a nonblocking get to request first block and nb handle from 'g_a'
56
### and assign to 'a_prev' and 'a_nb_prev'
57
### use a nonblocking get to request first block and nb handle from 'g_b'
58
### and assign to 'b_prev' and 'b_nb_prev'
59
for i in range(start+1,stop):
60
task_next = task_list[i]
61
### use a nonblocking get to request next block and nb handle from 'g_a'
62
### and assign to 'a_next' and 'a_nb_next'
63
### use a nonblocking get to request next block and nb handle from 'g_b'
64
### and assign to 'b_next' and 'b_nb_next'
65
### wait on the previoius nb handle for 'g_a'
66
### wait on the previoius nb handle for 'g_b'
67
result = np.dot(a_prev,b_prev)
68
### accumulate the result into 'g_c' at the previous block location
70
a_prev,a_nb_prev = a_next,a_nb_next
71
b_prev,b_nb_prev = b_next,b_nb_next
72
### wait on the previoius nb handle for 'g_a'
73
### wait on the previoius nb handle for 'g_b'
76
result = np.dot(a_prev,b_prev)
77
### accumulate the result into 'g_c' at the previous block location
80
def verify_using_ga(g_a, g_b, g_c):
81
g_v = ga.duplicate(g_c)
82
ga.gemm(False,False,N,N,N,1,g_a,g_b,0,g_v)
86
val = int(np.abs(np.sum(c-v))>0.0001)
93
def verify_using_np(g_a, g_b, g_c):
98
val = int(np.abs(np.sum(c-v))>0.0001)
102
if __name__ == '__main__':
103
if nproc > MULTIPLIER**3:
105
print "You must use less than %s processors" % (MULTIPLIER**3+1)
107
g_a = ga.create(ga.C_DBL, [N,N])
108
g_b = ga.create(ga.C_DBL, [N,N])
109
g_c = ga.create(ga.C_DBL, [N,N])
110
# put some fake data into input arrays A and B
112
ga.put(g_a, np.random.random(N*N))
113
ga.put(g_b, np.random.random(N*N))
117
srumma(g_a, g_b, g_c, CHUNK_SIZE, MULTIPLIER)
121
print "verifying using ga.gemm...",
122
ok = verify_using_ga(g_a, g_b, g_c)
129
print "verifying using np.dot...",
130
ok = verify_using_np(g_a, g_b, g_c)