2
* Automatically Tuned Linear Algebra Software v3.2
3
* (C) Copyright 2000 R. Clint Whaley
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions
8
* 1. Redistributions of source code must retain the above copyright
9
* notice, this list of conditions and the following disclaimer.
10
* 2. Redistributions in binary form must reproduce the above copyright
11
* notice, this list of conditions, and the following disclaimer in the
12
* documentation and/or other materials provided with the distribution.
13
* 3. The name of the University of Tennessee, the ATLAS group,
14
* or the names of its contributers may not be used to endorse
15
* or promote products derived from this software without specific
18
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
22
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
* POSSIBILITY OF SUCH DAMAGE.
39
#define Mlowcase(C) ( ((C) > 64 && (C) < 91) ? (C) | 32 : (C) )
40
#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )
41
#define Mmax(x, y) ( (x) > (y) ? (x) : (y) )
46
if (n == 1) return(0);
47
for (pwr2=0, i=1; i < n; i <<= 1, pwr2++);
53
char *GetDiv(int N, char *inc)
56
int pwr2 = GetPower2(N);
57
if (N == 1) sprintf(ln, "%s", inc);
58
else if (pwr2) sprintf(ln, "((%s) >> %d)", inc, pwr2);
59
else sprintf(ln, "((%s) / %d)", inc, N);
63
char *GetInc(int N, char *inc)
77
for (i=0; n >= (1<<i); i++);
78
if ( (1 << i) > n) i--;
79
if (iPLUS++) *p++ = '+';
80
sprintf(p, "((%s) << %d)", inc, i);
86
if (iPLUS++) *p++ = '+';
87
sprintf(p, "%s", inc);
89
if (iPLUS > ShiftThresh) sprintf(ln0, "(%d*(%s))", N, inc);
90
else if (iPLUS) sprintf(ln0, "(%s)", ln);
91
else sprintf(ln0, "%s", ln);
95
static void GetAij(int nY, int ngap, int gap, int gapmul, int ia,
96
int *icol, int *igap, int *gapi, int *imul)
98
ia = ia % (nY * ngap * gap * gapmul);
99
*imul = ia / (nY * ngap * gap);
100
ia -= *imul * nY * ngap * gap;
101
*gapi = ia / (nY * ngap);
102
ia -= *gapi * nY * ngap;
108
static int GetAi(int ngap, int gap, int igap, int gapi, int imul)
110
return(imul*gap*ngap+igap*gap+gapi);
113
static void mvTXbody(FILE *fpout, char *spc, int lat, int nY, int ngap, int gap,
114
int gapmul, int pfA, int pfX, int Yregs, int CLEANUP)
116
int i, j, k, nops, xelts, op=0;
117
int FETCHA=1, FETCHX=1;
118
int ia=pfA, ix=pfX, iy=0, myregs;
119
int icol, iacc, igap, imul, Acol, Aacc, agap, Amul, xcol, xacc, xgap, xmul;
121
nops = gapmul * gap * ngap * nY;
122
xelts = gapmul * gap * ngap;
127
if (ix > xelts) ix = ix % xelts;
130
for (i=0; i < nops; i++, ia++, iy++)
135
for (j=0; j < nY; j++)
137
if (xelts > 1) fprintf(fpout, "%s pA%d += %d;\n", spc, j, xelts);
138
else fprintf(fpout, "%s pA%d++;\n", spc, j);
140
if (CLEANUP) FETCHA = 0;
145
if (xelts > 1) fprintf(fpout, "%s X += %d;\n", spc, xelts);
146
else fprintf(fpout, "%s X++;\n", spc);
147
if (CLEANUP) FETCHX = 0;
149
GetAij(nY, ngap, gap, gapmul, i, &icol, &iacc, &igap, &imul);
150
GetAij(nY, ngap, gap, gapmul, ia, &Acol, &Aacc, &agap, &Amul);
152
if (icol < Yregs % nY) myregs++;
153
if (lat) /* seperate multiply & add code */
155
fprintf(fpout, "%s rY%d += rA%d;\n", spc,
156
icol+((iy/nY)%myregs)*nY, i %pfA);
157
if (!CLEANUP || op < nops)
158
fprintf(fpout, "%s rA%d *= rX%d;\n", spc, op % pfA,
161
else fprintf(fpout, "%s rY%d += rA%d * rX%d;\n",
162
spc, icol+((iy/nY)%myregs)*nY, i%pfA, (i/nY)%pfX);
166
k = GetAi(ngap, gap, Aacc, agap, Amul);
167
if (k) fprintf(fpout, "%s rA%d = pA%d[%d];\n", spc, i%pfA, Acol, k);
168
else fprintf(fpout, "%s rA%d = *pA%d;\n", spc, i%pfA, Acol);
170
if (op%nY==0 && FETCHX)
172
GetAij(1, ngap, gap, gapmul, ix, &xcol, &xacc, &xgap, &xmul);
173
k = GetAi(ngap, gap, xacc, xgap, xmul);
174
if (k) fprintf(fpout, "%s rX%d = X[%d];\n",spc, ((i+lat)/nY)%pfX, k);
175
else fprintf(fpout, "%s rX%d = *X;\n", spc, ((i+lat)/nY)%pfX);
181
static void FetchY(FILE *fpout, char *spc, int pre, int beta, int nY, int Yregs,
188
fprintf(fpout, "%s rY0 = *Y;\n", spc);
189
for (j=1; j < nY; j++) fprintf(fpout, "%s rY%d = Y[%d];\n", spc, j, j);
191
if (beta != 0 && beta != 1)
193
fprintf(fpout, "%s %s = beta;\n", spc, breg);
194
for (j=0; j < nY; j++) fprintf(fpout, "%s rY%d *= %s;\n", spc,j, breg);
196
if (Yregs > nY || beta == 0)
198
fprintf(fpout, "%s ", spc);
199
for (i=(beta!=0)*nY; i < Yregs; i++) fprintf(fpout, "rY%d = ", i);
200
if (pre == 's' || pre == 'c') fprintf(fpout, "0.0f;\n");
201
else fprintf(fpout, "0.0;\n");
205
static void FetchAX(FILE *fpout, char *spc, int nY, int ngap, int gap,
206
int gapmul, int pfA, int pfX)
208
int ix, ia, icol, iacc, igap, imul, k;
210
for (ix=0; ix < pfX; ix++)
212
GetAij(1, ngap, gap, gapmul, ix, &icol, &iacc, &igap, &imul);
213
k = GetAi(ngap, gap, iacc, igap, imul);
214
if (k) fprintf(fpout, "%s rX%d = X[%d];\n", spc, ix, k);
215
else fprintf(fpout, "%s rX%d = *X;\n", spc, ix);
217
for (ia = 0; ia < pfA; ia++)
219
GetAij(nY, ngap, gap, gapmul, ia, &icol, &iacc, &igap, &imul);
220
k = GetAi(ngap, gap, iacc, igap, imul);
221
if (k) fprintf(fpout, "%s rA%d = pA%d[%d];\n", spc, ia, icol, k);
222
else fprintf(fpout, "%s rA%d = *pA%d;\n", spc, ia, icol);
226
static void StartPipe(FILE *fpout, char *spc, int lat, int nY,
227
int ngap, int gap, int gapmul, int pfX)
229
int i, k, icol, wgap, gapi, imul;
231
for (i=0; i < lat; i++)
233
fprintf(fpout, "%s rA%d *= rX%d;\n", spc, i, (i/nY)%pfX);
236
GetAij(1, ngap, gap, gapmul, pfX+i/nY, &icol, &wgap, &gapi, &imul);
237
k = GetAi(ngap, gap, wgap, gapi, imul);
238
if (k) fprintf(fpout, "%s rX%d = X[%d];\n", spc, (i/nY)%pfX, k);
239
else fprintf(fpout, "%s rX%d = *X;\n", spc, (i/nY)%pfX);
244
static void CombY(FILE *fpout, char *spc, int nY, int Yregs)
246
* use binary tree for adding up multiple accumulators
249
int i, j, d, n=(Yregs+nY-1)/nY;
250
for (d=1; d < n; d <<= 1)
251
for (i=0; i < n; i += (d<<1))
252
for (j=0; j < nY; j++)
253
if (j+(i+d)*nY < Yregs)
254
fprintf(fpout, "%s rY%d += rY%d;\n", spc, j+i*nY, j+(i+d)*nY);
257
static void XCleanup(FILE *fpout, char *spc, char pre, int lat, int nY,
258
int ngap, int gap, int gapmul, int pfA, int pfX, int Yregs)
261
int mingap, j, ia, ix, iy, il, FirstTime=1;
263
fprintf(fpout, "%s if (X != stXN)\n%s {\n", spc, spc);
265
if (pre == 's') mingap = 8;
267
for (j=1; j < ngap*gap*gapmul; j <<= 1);
268
if (j > ngap*gap*gapmul) /* not power of two */
272
if (!GetPower2(ngap)) ngap = 2;
273
gapmul = j / (ngap*gap);
281
if (j <= 2) FirstTime=0;
285
fprintf(fpout, "%s if ( (ptrdiff_t)(stXN-X) < %d ) goto cu%d;\n",
288
fprintf(fpout, "%s if ( (ptrdiff_t)(stXN-X) == 1 )\n", spc);
289
else fprintf(fpout, "%s if ( (ptrdiff_t)(stXN-X) >= %d )\n", spc, j);
290
for (ia=pfA; (nY*ngap*gap*gapmul) % ia; ia--);
291
for (ix=pfX; (ngap * gap * gapmul)% ix; ix--);
294
for (il=Mmin(ia,lat); (nY*ngap*gap*gapmul) % il; il--);
295
if (lat - il > 2 || il < 2 || il < pfA) il = 0;
298
iy = Mmin(Yregs, nY*ngap*gap*gapmul);
300
"%s { /* lat=%d, ngap=%d, gap=%d, gapmul=%d, pfA=%d, pfX=%d, yregs=%d */\n",
301
spc, il, ngap, gap, gapmul, ia, ix, iy);
303
FetchAX(fpout, spc, nY, ngap, gap, gapmul, ia, ix);
304
if (il) StartPipe(fpout, spc, il, nY, ngap, gap, gapmul, pfX);
305
mvTXbody(fpout, spc, il, nY, ngap, gap, gapmul, ia, ix, iy, 1);
307
fprintf(fpout, "%s }\n", spc);
308
if (FirstTime) fprintf(fpout, "cu%d:\n", j>>1);
309
if (gapmul > 1) gapmul >>= 1;
310
else if (ngap > 2) ngap >>= 1;
311
else if (gap > mingap) gap >>= 1;
312
else if (ngap > 1) ngap >>= 1;
313
else if (gap > 1) gap >>= 1;
317
fprintf(fpout, "%s } /* done X cleanup */\n\n", spc);
321
fprintf(fpout, "%s if (X != stXN)\n%s {\n", spc, spc);
324
else for (i=pfA; nY % i; i--);
325
FetchAX(fpout, spc, nY, 1, 1, 1, i, 1);
327
fprintf(fpout, "%s if (X != stXN_1)\n%s {\n", spc, spc);
329
fprintf(fpout, "%s do /* while (X != stXN_1) */\n%s {\n", spc, spc);
331
mvTXbody(fpout, spc, lat, nY, 1, 1, 1, i, 1, nY, 0);
333
fprintf(fpout, "%s }\n%s while(X != stXN_1);\n", spc, spc);
335
fprintf(fpout, "%s }\n", spc);
337
mvTXbody(fpout, spc, lat, nY, 1, 1, 1, i, 1, nY, 1);
339
fprintf(fpout, "%s } /* finish cleanup */\n\n", spc);
343
static void Xloop(FILE *fpout, char *spc, char pre, int lat, int nY,
344
int ngap, int gap, int gapmul, int pfA, int pfX, int Yregs)
346
int i, nu = ngap * gap * gapmul;
348
fprintf(fpout, "%s if (N >= %d)\n%s {\n", spc, 2*nu, spc);
350
FetchAX(fpout, spc, nY, ngap, gap, gapmul, pfA, pfX);
351
if (lat) StartPipe(fpout, spc, lat, nY, ngap, gap, gapmul, pfX);
352
fprintf(fpout, "%s do /* while (X != stX) */\n%s {\n", spc, spc);
354
mvTXbody(fpout, spc, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs, 0);
356
fprintf(fpout, "%s }\n%s while(X != stX);\n", spc, spc);
357
mvTXbody(fpout, spc, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs, 1);
359
fprintf(fpout, "%s }\n\n", spc);
361
XCleanup(fpout, spc, pre, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs);
362
CombY(fpout, spc, nY, Yregs);
366
(FILE *fpout, char *rout, char pre, char *styp, char *typ, int lat,
367
int beta, int nY, int ngap, int gap, int gapmul, int pfA, int pfX,
373
char *bnam[3] = {"b0", "b1", "bX"};
374
int i, j, nu = ngap * gap * gapmul;
375
if (beta != 0 && beta != 1) beta = 2;
377
if (nY > 1) fprintf(fpout, "#include \"atlas_level1.h\"\n");
378
fprintf(fpout, "#include <stddef.h>\n");
382
sprintf(ln, "ATL_%cgemvT_a1_x1_%s_y1\n", pre, bnam[beta]);
384
fprintf(fpout, "void %s", rout);
385
fprintf(fpout, "/*\n * lat=%d, beta=%s, nY=%d, ngap=%d, gap=%d, gapmul=%d, pfA=%d, pfX=%d\n */\n",
386
lat, bnam[beta], nY, ngap, gap, gapmul, pfA, pfX);
387
fprintf(fpout, " (const int M, const int N, const %s alpha,\n", styp);
389
" const %s *A, const int lda, const %s *X, const int incX,\n",
391
fprintf(fpout, " const %s beta, %s *Y, const int incY)\n", styp, typ);
392
fprintf(fpout, "{\n");
393
fprintf(fpout, " register %s rX0", typ);
394
for (i=1; i < pfX; i++) fprintf(fpout, ", rX%d", i);
395
for (i=0; i < pfA; i++) fprintf(fpout, ", rA%d", i);
396
for (i=0; i < Yregs; i++) fprintf(fpout, ", rY%d", i);
397
fprintf(fpout, ";\n");
398
fprintf(fpout, " const %s *pA0=A", typ);
399
for (i=1; i < nY; i++) fprintf(fpout, ", *pA%d=pA%d+lda", i, i-1);
400
fprintf(fpout, ";\n");
401
fprintf(fpout, " const int n = %s;\n", GetInc(nu,GetDiv(nu, "N")));
402
fprintf(fpout, " const int m = %s;\n", GetInc(nY,GetDiv(nY, "M")));
404
" const %s *stX = X + n - %d, *stXN = X + N, *stXN_1 = stXN-1;\n",
406
fprintf(fpout, " %s *stY = Y + m;\n", typ);
407
fprintf(fpout, "%s const int incA = %s - N;\n", spc, GetInc(nY, "lda"));
408
fprintf(fpout, "\n");
412
fprintf(fpout, "%s if (m)\n%s {\n", spc, spc);
415
fprintf(fpout, "%s do /* while (Y != stY) */\n%s {\n", spc, spc);
418
FetchY(fpout, spc, pre, beta, nY, Yregs, "rX0");
420
Xloop(fpout, spc, pre, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs);
422
fprintf(fpout, "%s X -= N;\n", spc);
423
for (j=0; j < nY; j++)
425
fprintf(fpout, "%s pA%d += incA;\n", spc, j);
426
fprintf(fpout, "%s Y[%d] = rY%d;\n", spc, j, j);
428
if (nY > 1) fprintf(fpout, "%s Y += %d;\n", spc, nY);
430
if (nY == 1) fprintf(fpout, "%s }\n%s while(++Y != stY);\n", spc, spc);
431
else fprintf(fpout, "%s }\n%s while(Y != stY);\n", spc, spc);
435
fprintf(fpout, "%s }\n", spc);
436
fprintf(fpout, "%s if (m != M)\n%s {\n", spc, spc);
438
if (beta != 0 && beta != 1) fprintf(fpout, "%s rX0 = beta;\n", spc);
441
fprintf(fpout, "%s stY += M-m;\n", spc);
442
fprintf(fpout, "%s do /* while (Y != stY) */\n%s {\n", spc, spc);
445
if (beta != 0) fprintf(fpout, "%s rY0 = *Y;\n", spc);
446
if (beta != 0 && beta != 1) fprintf(fpout, "%s rY0 *= rX0;\n", spc);
447
if (beta == 0) fprintf(fpout, "%s rY0 = ", spc);
448
else fprintf(fpout, "%s rY0 += ", spc);
449
fprintf(fpout, "ATL_%cdot(N, pA0, 1, X, 1);\n", pre);
450
fprintf(fpout, "%s *Y = rY0;\n", spc);
453
fprintf(fpout, "%s pA0 += lda;\n", spc);
455
fprintf(fpout, "%s }\n%s while (++Y != stY);\n", spc, spc);
458
fprintf(fpout, "%s } /* end Y cleanup */;\n", spc);
460
fprintf(fpout, "}\n");
463
void PrintUsage(char *nam)
465
fprintf(stderr, "USAGE: %s -l <latency> -Y <nY> -G <ngap> -g <gap> -M <gapmul> -A <# of regs for A> -X <# of regs for X> -y <# regs for Y> -f <file> -R <rout> -b <beta>\n", nam);
469
void GetFlags(int nargs, char **args, int *lat, int *beta, int *nY,
470
int *ngap, int *gap, int *gapmul, int *pfA, int *pfX,
471
int *Yregs, char *pre, char *styp, char *typ,
472
FILE **fpout, char **rout)
477
*nY = *Yregs = *ngap = *gap = *gapmul = *pfA = *pfX = 1;
482
for (i=1; i < nargs; i++)
484
if (args[i][0] != '-') PrintUsage(args[0]);
491
*fpout = fopen(args[++i], "w");
495
*beta = atoi(args[++i]);
498
*lat = atoi(args[++i]);
501
*nY = atoi(args[++i]);
504
*ngap = atoi(args[++i]);
507
*gap = atoi(args[++i]);
510
*gapmul = atoi(args[++i]);
513
*pfA = atoi(args[++i]);
516
*pfX = atoi(args[++i]);
519
*Yregs = atoi(args[++i]);
523
*pre = Mlowcase(args[i][0]);
529
i = *nY * *ngap * *gap * *gapmul;
530
assert(i % *pfA == 0);
533
assert(i % *lat == 0);
534
assert(*pfX + *lat / *nY <= i);
537
assert((*ngap * *gap * *gapmul)% (*pfX) == 0);
538
assert(*Yregs >= *nY);
543
sprintf(styp, "double*");
544
sprintf(typ, "double");
547
sprintf(styp, "double");
548
sprintf(typ, "double");
551
sprintf(styp, "float*");
552
sprintf(typ, "float");
555
sprintf(styp, "float");
556
sprintf(typ, "float");
563
main(int nargs, char **args)
565
char pre, styp[32], typ[32], *rout;
566
int lat; /* 0: combined muladd inst; X: seperate multiply & add, lat=X */
567
int nY; /* number of columns of A to operate on (# of dot products) */
568
int Yregs; /* number of registers to use for Y */
569
int ngap; /* number of accumulaters to use within each dot prod */
570
int gap; /* gap size in column */
571
int gapmul; /* number of unrollings of whole mess to do */
572
int pfA; /* number of registers to use in prefetching A */
573
int pfX; /* number of registers to use in prefetching X */
577
GetFlags(nargs, args, &lat, &beta, &nY, &ngap, &gap, &gapmul, &pfA, &pfX,
578
&Yregs, &pre, styp, typ, &fpout, &rout);
580
if (beta == -3) /* generate all beta cases */
582
for (i=0; i < 3; i++)
584
if (i == 0 || i == 1) fprintf(fpout, "#ifdef BETA%d\n\n", i);
585
else fprintf(fpout, "#ifdef BETAX\n\n");
586
emit_mvT(fpout, rout, pre, styp, typ, lat, i, nY, ngap, gap, gapmul,
588
fprintf(fpout, "\n\n#endif\n");
591
else emit_mvT(fpout, rout, pre, styp, typ, lat, beta, nY, ngap, gap,
592
gapmul, pfA, pfX, Yregs);