~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to tune/blas/gemv/emit_rmvT.c

  • Committer: Bazaar Package Importer
  • Author(s): Camm Maguire
  • Date: 2002-04-13 10:07:52 UTC
  • Revision ID: james.westby@ubuntu.com-20020413100752-va9zm0rd4gpurdkq
Tags: upstream-3.2.1ln
ImportĀ upstreamĀ versionĀ 3.2.1ln

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *             Automatically Tuned Linear Algebra Software v3.2
 
3
 *                    (C) Copyright 2000 R. Clint Whaley                     
 
4
 *
 
5
 * Redistribution and use in source and binary forms, with or without
 
6
 * modification, are permitted provided that the following conditions
 
7
 * are met:
 
8
 *   1. Redistributions of source code must retain the above copyright
 
9
 *      notice, this list of conditions and the following disclaimer.
 
10
 *   2. Redistributions in binary form must reproduce the above copyright
 
11
 *      notice, this list of conditions, and the following disclaimer in the
 
12
 *      documentation and/or other materials provided with the distribution.
 
13
 *   3. The name of the University of Tennessee, the ATLAS group,
 
14
 *      or the names of its contributers may not be used to endorse
 
15
 *      or promote products derived from this software without specific
 
16
 *      written permission.
 
17
 *
 
18
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 
19
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 
20
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 
21
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 
22
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 
23
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 
24
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 
25
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 
26
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 
27
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 
28
 * POSSIBILITY OF SUCH DAMAGE. 
 
29
 *
 
30
 */
 
31
 
 
32
 
 
33
#include <stdio.h>
 
34
#include <stdlib.h>
 
35
#include <string.h>
 
36
#include <assert.h>
 
37
#include <ctype.h>
 
38
 
 
39
#define Mlowcase(C) ( ((C) > 64 && (C) < 91) ? (C) | 32 : (C) )
 
40
#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )
 
41
#define Mmax(x, y) ( (x) > (y) ? (x) : (y) )
 
42
int GetPower2(int n)
 
43
{
 
44
   int pwr2, i;
 
45
 
 
46
   if (n == 1) return(0);
 
47
   for (pwr2=0, i=1; i < n; i <<= 1, pwr2++);
 
48
   if (i != n) pwr2 = 0;
 
49
   return(pwr2);
 
50
}
 
51
 
 
52
#define ShiftThresh 2
 
53
char *GetDiv(int N, char *inc)
 
54
{
 
55
   static char ln[256];
 
56
   int pwr2 = GetPower2(N);
 
57
   if (N == 1) sprintf(ln, "%s", inc);
 
58
   else if (pwr2) sprintf(ln, "((%s) >> %d)", inc, pwr2);
 
59
   else sprintf(ln, "((%s) / %d)", inc, N);
 
60
   return(ln);
 
61
}
 
62
 
 
63
char *GetInc(int N, char *inc)
 
64
{
 
65
   static char ln0[256];
 
66
   char ln[256];
 
67
   char *p=ln;
 
68
   int i, n=N, iPLUS=0;
 
69
 
 
70
   if (n == 0)
 
71
   {
 
72
      ln[0] = '0';
 
73
      ln[1] = '\0';
 
74
   }
 
75
   while(n > 1)
 
76
   {
 
77
      for (i=0; n >= (1<<i); i++);
 
78
      if ( (1 << i) > n) i--;
 
79
      if (iPLUS++) *p++ = '+';
 
80
      sprintf(p, "((%s) << %d)", inc, i);
 
81
      p += strlen(p);
 
82
      n -= (1 << i);
 
83
   }
 
84
   if (n == 1)
 
85
   {
 
86
      if (iPLUS++) *p++ = '+';
 
87
      sprintf(p, "%s", inc);
 
88
   }
 
89
   if (iPLUS > ShiftThresh) sprintf(ln0, "(%d*(%s))", N, inc);
 
90
   else if (iPLUS) sprintf(ln0, "(%s)", ln);
 
91
   else sprintf(ln0, "%s", ln);
 
92
   return(ln0);
 
93
}
 
94
 
 
95
static void GetAij(int nY, int ngap, int gap, int gapmul, int ia,
 
96
                   int *icol, int *igap, int *gapi, int *imul)
 
97
 
98
   ia = ia % (nY * ngap * gap * gapmul);
 
99
   *imul = ia / (nY * ngap * gap);
 
100
   ia -= *imul * nY * ngap * gap;
 
101
   *gapi = ia / (nY * ngap);
 
102
   ia -= *gapi * nY * ngap;
 
103
   *igap = ia / nY;
 
104
   ia -= *igap * nY;
 
105
   *icol = ia;
 
106
}
 
107
 
 
108
static int GetAi(int ngap, int gap, int igap, int gapi, int imul)
 
109
{
 
110
   return(imul*gap*ngap+igap*gap+gapi);
 
111
}
 
112
 
 
113
static void mvTXbody(FILE *fpout, char *spc, int lat, int nY, int ngap, int gap,
 
114
                     int gapmul, int pfA, int pfX, int Yregs, int CLEANUP)
 
115
{
 
116
   int i, j, k, nops, xelts, op=0;
 
117
   int FETCHA=1, FETCHX=1;
 
118
   int ia=pfA, ix=pfX, iy=0, myregs;
 
119
   int icol, iacc, igap, imul, Acol, Aacc, agap, Amul, xcol, xacc, xgap, xmul;
 
120
 
 
121
   nops = gapmul * gap * ngap * nY;
 
122
   xelts = gapmul * gap * ngap;
 
123
   if (lat)
 
124
   {
 
125
      op = lat;
 
126
      ix += lat/nY;
 
127
      if (ix > xelts) ix = ix % xelts;
 
128
   }
 
129
 
 
130
   for (i=0; i < nops; i++, ia++, iy++)
 
131
   {
 
132
      if (ia == nops)
 
133
      {
 
134
         ia = 0;
 
135
         for (j=0; j < nY; j++)
 
136
         {
 
137
            if (xelts > 1) fprintf(fpout, "%s   pA%d += %d;\n", spc, j, xelts);
 
138
            else fprintf(fpout, "%s   pA%d++;\n", spc, j);
 
139
         }
 
140
         if (CLEANUP) FETCHA = 0;
 
141
      }
 
142
      if (ix == xelts)
 
143
      {
 
144
         ix = 0;
 
145
         if (xelts > 1) fprintf(fpout, "%s   X += %d;\n", spc, xelts);
 
146
         else fprintf(fpout, "%s   X++;\n", spc);
 
147
         if (CLEANUP) FETCHX = 0;
 
148
      }
 
149
      GetAij(nY, ngap, gap, gapmul, i, &icol, &iacc, &igap, &imul);
 
150
      GetAij(nY, ngap, gap, gapmul, ia, &Acol, &Aacc, &agap, &Amul);
 
151
      myregs = Yregs / nY;
 
152
      if (icol < Yregs % nY) myregs++;
 
153
      if (lat)  /* seperate multiply & add code */
 
154
      {
 
155
         fprintf(fpout, "%s   rY%d += rA%d;\n", spc, 
 
156
                 icol+((iy/nY)%myregs)*nY, i %pfA);
 
157
         if (!CLEANUP || op < nops)
 
158
            fprintf(fpout, "%s   rA%d *= rX%d;\n", spc, op % pfA,
 
159
                    ((i+lat)/nY)%pfX);
 
160
      }
 
161
      else fprintf(fpout, "%s   rY%d += rA%d * rX%d;\n", 
 
162
                   spc, icol+((iy/nY)%myregs)*nY, i%pfA, (i/nY)%pfX);
 
163
      op++;
 
164
      if (FETCHA)
 
165
      {
 
166
         k = GetAi(ngap, gap, Aacc, agap, Amul);
 
167
         if (k) fprintf(fpout, "%s   rA%d = pA%d[%d];\n", spc, i%pfA, Acol, k);
 
168
         else fprintf(fpout, "%s   rA%d = *pA%d;\n", spc, i%pfA, Acol);
 
169
      }
 
170
      if (op%nY==0 && FETCHX)
 
171
      {
 
172
         GetAij(1, ngap, gap, gapmul, ix, &xcol, &xacc, &xgap, &xmul);
 
173
         k = GetAi(ngap, gap, xacc, xgap, xmul);
 
174
         if (k) fprintf(fpout, "%s   rX%d = X[%d];\n",spc, ((i+lat)/nY)%pfX, k);
 
175
         else fprintf(fpout, "%s   rX%d = *X;\n", spc, ((i+lat)/nY)%pfX);
 
176
         ix++;
 
177
      }
 
178
   }
 
179
}
 
180
 
 
181
static void FetchY(FILE *fpout, char *spc, int pre, int beta, int nY, int Yregs,
 
182
                   char *breg)
 
183
{
 
184
   int i, j;
 
185
 
 
186
   if (beta != 0)
 
187
   {
 
188
      fprintf(fpout, "%s   rY0 = *Y;\n", spc);
 
189
      for (j=1; j < nY; j++) fprintf(fpout, "%s   rY%d = Y[%d];\n", spc, j, j);
 
190
   }
 
191
   if (beta != 0 && beta != 1)
 
192
   {
 
193
      fprintf(fpout, "%s   %s = beta;\n", spc, breg);
 
194
      for (j=0; j < nY; j++) fprintf(fpout, "%s   rY%d *= %s;\n", spc,j, breg);
 
195
   }
 
196
   if (Yregs > nY || beta == 0)
 
197
   {
 
198
      fprintf(fpout, "%s   ", spc);
 
199
      for (i=(beta!=0)*nY; i < Yregs; i++) fprintf(fpout, "rY%d = ", i);
 
200
      if (pre == 's' || pre == 'c') fprintf(fpout, "0.0f;\n");
 
201
      else fprintf(fpout, "0.0;\n");
 
202
   }
 
203
}
 
204
 
 
205
static void FetchAX(FILE *fpout, char *spc, int nY, int ngap, int gap,
 
206
                    int gapmul, int pfA, int pfX)
 
207
{
 
208
   int ix, ia, icol, iacc, igap, imul, k;
 
209
 
 
210
   for (ix=0; ix < pfX; ix++) 
 
211
   {
 
212
      GetAij(1, ngap, gap, gapmul, ix, &icol, &iacc, &igap, &imul);
 
213
      k = GetAi(ngap, gap, iacc, igap, imul);
 
214
      if (k) fprintf(fpout, "%s   rX%d = X[%d];\n", spc, ix, k);
 
215
      else fprintf(fpout, "%s   rX%d = *X;\n", spc, ix);
 
216
   }
 
217
   for (ia = 0; ia < pfA; ia++)
 
218
   {
 
219
      GetAij(nY, ngap, gap, gapmul, ia, &icol, &iacc, &igap, &imul);
 
220
      k = GetAi(ngap, gap, iacc, igap, imul);
 
221
      if (k) fprintf(fpout, "%s   rA%d = pA%d[%d];\n", spc, ia, icol, k);
 
222
      else fprintf(fpout, "%s   rA%d = *pA%d;\n", spc, ia, icol);
 
223
   }
 
224
}
 
225
 
 
226
static void StartPipe(FILE *fpout, char *spc, int lat, int nY,
 
227
                      int ngap, int gap, int gapmul, int pfX)
 
228
{
 
229
   int i, k, icol, wgap, gapi, imul;
 
230
 
 
231
   for (i=0; i < lat; i++)
 
232
   {
 
233
      fprintf(fpout, "%s   rA%d *= rX%d;\n", spc, i, (i/nY)%pfX);
 
234
      if ((i+1) % nY == 0)
 
235
      {
 
236
         GetAij(1, ngap, gap, gapmul, pfX+i/nY, &icol, &wgap, &gapi, &imul);
 
237
         k = GetAi(ngap, gap, wgap, gapi, imul);
 
238
         if (k) fprintf(fpout, "%s   rX%d = X[%d];\n", spc, (i/nY)%pfX, k);
 
239
         else fprintf(fpout, "%s   rX%d = *X;\n", spc, (i/nY)%pfX);
 
240
      }
 
241
   }
 
242
}
 
243
 
 
244
static void CombY(FILE *fpout, char *spc, int nY, int Yregs)
 
245
/*
 
246
 * use binary tree for adding up multiple accumulators
 
247
 */
 
248
{
 
249
   int i, j, d, n=(Yregs+nY-1)/nY;
 
250
   for (d=1; d < n; d <<= 1)
 
251
      for (i=0; i < n; i += (d<<1))
 
252
         for (j=0; j < nY; j++)
 
253
            if (j+(i+d)*nY < Yregs)
 
254
               fprintf(fpout, "%s   rY%d += rY%d;\n", spc, j+i*nY, j+(i+d)*nY);
 
255
}
 
256
 
 
257
static void XCleanup(FILE *fpout, char *spc, char pre, int lat, int nY,
 
258
                     int ngap, int gap, int gapmul, int pfA, int pfX, int Yregs)
 
259
{
 
260
#if 1
 
261
   int mingap, j, ia, ix, iy, il, FirstTime=1;
 
262
 
 
263
   fprintf(fpout, "%s   if (X != stXN)\n%s   {\n", spc, spc);
 
264
   spc -= 3;
 
265
   if (pre == 's') mingap = 8;
 
266
   else mingap = 4;
 
267
   for (j=1; j < ngap*gap*gapmul; j <<= 1);
 
268
   if (j > ngap*gap*gapmul) /* not power of two */
 
269
   {
 
270
      if (GetPower2(gap))
 
271
      {
 
272
         if (!GetPower2(ngap)) ngap = 2;
 
273
         gapmul = j / (ngap*gap);
 
274
      }
 
275
      else
 
276
      {
 
277
         ngap = gapmul = 1;
 
278
         gap = j;
 
279
      }
 
280
   }
 
281
   if (j <= 2) FirstTime=0;
 
282
   for (; j; j >>= 1)
 
283
   {
 
284
      if (FirstTime)
 
285
         fprintf(fpout, "%s   if ( (ptrdiff_t)(stXN-X) < %d ) goto cu%d;\n", 
 
286
                 spc, j, j>>1);
 
287
      else if (j == 1)
 
288
         fprintf(fpout, "%s   if ( (ptrdiff_t)(stXN-X) == 1 )\n", spc);
 
289
      else fprintf(fpout, "%s   if ( (ptrdiff_t)(stXN-X) >= %d )\n", spc, j);
 
290
      for (ia=pfA; (nY*ngap*gap*gapmul) % ia; ia--);
 
291
      for (ix=pfX; (ngap * gap * gapmul)% ix; ix--);
 
292
      if (lat)
 
293
      {
 
294
         for (il=Mmin(ia,lat); (nY*ngap*gap*gapmul) % il; il--);
 
295
         if (lat - il > 2 || il < 2 || il < pfA) il = 0;
 
296
      }
 
297
      else il = 0;
 
298
      iy = Mmin(Yregs, nY*ngap*gap*gapmul);
 
299
      fprintf(fpout, 
 
300
"%s   { /* lat=%d, ngap=%d, gap=%d, gapmul=%d, pfA=%d, pfX=%d, yregs=%d */\n", 
 
301
              spc, il, ngap, gap, gapmul, ia, ix, iy);
 
302
      spc -= 3;
 
303
      FetchAX(fpout, spc, nY, ngap, gap, gapmul, ia, ix);
 
304
      if (il) StartPipe(fpout, spc, il, nY, ngap, gap, gapmul, pfX);
 
305
      mvTXbody(fpout, spc, il, nY, ngap, gap, gapmul, ia, ix, iy, 1);
 
306
      spc += 3;
 
307
      fprintf(fpout, "%s   }\n", spc);
 
308
      if (FirstTime) fprintf(fpout, "cu%d:\n", j>>1);
 
309
      if (gapmul > 1) gapmul >>= 1;
 
310
      else if (ngap > 2) ngap >>= 1;
 
311
      else if (gap > mingap) gap >>= 1;
 
312
      else if (ngap > 1) ngap >>= 1;
 
313
      else if (gap > 1) gap >>= 1;
 
314
      FirstTime = 0;
 
315
   }
 
316
   spc += 3;
 
317
   fprintf(fpout, "%s   } /* done X cleanup */\n\n", spc);
 
318
#else
 
319
   int i;
 
320
 
 
321
   fprintf(fpout, "%s   if (X != stXN)\n%s   {\n", spc, spc);
 
322
   spc -= 3;
 
323
   if (nY == 1) i = 1;
 
324
   else for (i=pfA; nY % i; i--);
 
325
   FetchAX(fpout, spc, nY, 1, 1, 1, i, 1);
 
326
 
 
327
   fprintf(fpout, "%s   if (X != stXN_1)\n%s   {\n", spc, spc);
 
328
   spc -= 3;
 
329
   fprintf(fpout, "%s   do /* while (X != stXN_1) */\n%s   {\n", spc, spc);
 
330
   spc -= 3;
 
331
   mvTXbody(fpout, spc, lat, nY, 1, 1, 1, i, 1, nY, 0);
 
332
   spc += 3;
 
333
   fprintf(fpout, "%s   }\n%s   while(X != stXN_1);\n", spc, spc);
 
334
   spc += 3;
 
335
   fprintf(fpout, "%s   }\n", spc);
 
336
 
 
337
   mvTXbody(fpout, spc, lat, nY, 1, 1, 1, i, 1, nY, 1);
 
338
   spc += 3;
 
339
   fprintf(fpout, "%s   } /* finish cleanup */\n\n", spc);
 
340
#endif
 
341
}
 
342
 
 
343
static void Xloop(FILE *fpout, char *spc, char pre, int lat, int nY,
 
344
                  int ngap, int gap, int gapmul, int pfA, int pfX, int Yregs)
 
345
{
 
346
   int i, nu = ngap * gap * gapmul;
 
347
 
 
348
   fprintf(fpout, "%s   if (N >= %d)\n%s   {\n", spc, 2*nu, spc);
 
349
   spc -= 3;
 
350
   FetchAX(fpout, spc, nY, ngap, gap, gapmul, pfA, pfX);
 
351
   if (lat) StartPipe(fpout, spc, lat, nY, ngap, gap, gapmul, pfX);
 
352
   fprintf(fpout, "%s   do /* while (X != stX) */\n%s   {\n", spc, spc);
 
353
   spc -= 3;
 
354
   mvTXbody(fpout, spc, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs, 0);
 
355
   spc += 3;
 
356
   fprintf(fpout, "%s   }\n%s   while(X != stX);\n", spc, spc);
 
357
   mvTXbody(fpout, spc, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs, 1);
 
358
   spc += 3;
 
359
   fprintf(fpout, "%s   }\n\n", spc);
 
360
 
 
361
   XCleanup(fpout, spc, pre, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs);
 
362
   CombY(fpout, spc, nY, Yregs);
 
363
}
 
364
 
 
365
static void emit_mvT
 
366
   (FILE *fpout, char *rout, char pre, char *styp, char *typ, int lat,
 
367
    int beta, int nY, int ngap, int gap, int gapmul, int pfA, int pfX,
 
368
    int Yregs)
 
369
{
 
370
   char *spcs = "                                        ";
 
371
   char *spc = spcs+40;
 
372
   char ln[128];
 
373
   char *bnam[3] = {"b0", "b1", "bX"};
 
374
   int i, j, nu = ngap * gap * gapmul;
 
375
   if (beta != 0 && beta != 1) beta = 2;
 
376
 
 
377
   if (nY > 1) fprintf(fpout, "#include \"atlas_level1.h\"\n");
 
378
   fprintf(fpout, "#include <stddef.h>\n");
 
379
   if (rout == NULL)
 
380
   {
 
381
      rout = ln;
 
382
      sprintf(ln, "ATL_%cgemvT_a1_x1_%s_y1\n", pre, bnam[beta]);
 
383
   }
 
384
   fprintf(fpout, "void %s", rout);
 
385
   fprintf(fpout, "/*\n * lat=%d, beta=%s, nY=%d, ngap=%d, gap=%d, gapmul=%d, pfA=%d, pfX=%d\n */\n",
 
386
           lat, bnam[beta], nY, ngap, gap, gapmul, pfA, pfX);
 
387
   fprintf(fpout, "   (const int M, const int N, const %s alpha,\n", styp);
 
388
   fprintf(fpout, 
 
389
           "   const %s *A, const int lda, const %s *X, const int incX,\n", 
 
390
           typ, typ);
 
391
   fprintf(fpout, "   const %s beta, %s *Y, const int incY)\n", styp, typ);
 
392
   fprintf(fpout, "{\n");
 
393
   fprintf(fpout, "   register %s rX0", typ);
 
394
   for (i=1; i < pfX; i++) fprintf(fpout, ", rX%d", i);
 
395
   for (i=0; i < pfA; i++) fprintf(fpout, ", rA%d", i);
 
396
   for (i=0; i < Yregs; i++) fprintf(fpout, ", rY%d", i);
 
397
   fprintf(fpout, ";\n");
 
398
   fprintf(fpout, "   const %s *pA0=A", typ);
 
399
   for (i=1; i < nY; i++) fprintf(fpout, ", *pA%d=pA%d+lda", i, i-1);
 
400
   fprintf(fpout, ";\n");
 
401
   fprintf(fpout, "   const int n = %s;\n", GetInc(nu,GetDiv(nu, "N")));
 
402
   fprintf(fpout, "   const int m = %s;\n", GetInc(nY,GetDiv(nY, "M")));
 
403
   fprintf(fpout, 
 
404
           "   const %s *stX = X + n - %d, *stXN = X + N, *stXN_1 = stXN-1;\n",
 
405
           typ, nu);
 
406
   fprintf(fpout, "   %s *stY = Y + m;\n", typ);
 
407
   fprintf(fpout, "%s   const int incA = %s - N;\n", spc, GetInc(nY, "lda"));
 
408
   fprintf(fpout, "\n");
 
409
 
 
410
   if (nY > 1)
 
411
   {
 
412
      fprintf(fpout, "%s   if (m)\n%s   {\n", spc, spc);
 
413
      spc -= 3;
 
414
   }
 
415
   fprintf(fpout, "%s   do /* while (Y != stY) */\n%s   {\n", spc, spc);
 
416
   spc -= 3;
 
417
 
 
418
   FetchY(fpout, spc, pre, beta, nY, Yregs, "rX0");
 
419
 
 
420
   Xloop(fpout, spc, pre, lat, nY, ngap, gap, gapmul, pfA, pfX, Yregs);
 
421
 
 
422
   fprintf(fpout, "%s   X -= N;\n", spc);
 
423
   for (j=0; j < nY; j++)
 
424
   {
 
425
      fprintf(fpout, "%s   pA%d += incA;\n", spc, j);
 
426
      fprintf(fpout, "%s   Y[%d] = rY%d;\n", spc, j, j);
 
427
   }
 
428
   if (nY > 1) fprintf(fpout, "%s   Y += %d;\n", spc, nY);
 
429
   spc += 3;
 
430
   if (nY == 1) fprintf(fpout, "%s   }\n%s   while(++Y != stY);\n", spc, spc);
 
431
   else fprintf(fpout, "%s   }\n%s   while(Y != stY);\n", spc, spc);
 
432
   if (nY > 1)
 
433
   {
 
434
      spc += 3;
 
435
      fprintf(fpout, "%s   }\n", spc);
 
436
      fprintf(fpout, "%s   if (m != M)\n%s   {\n", spc, spc);
 
437
      spc -= 3;
 
438
      if (beta != 0 && beta != 1) fprintf(fpout, "%s   rX0 = beta;\n", spc);
 
439
      if (nY > 2)
 
440
      {
 
441
         fprintf(fpout, "%s   stY += M-m;\n", spc);
 
442
         fprintf(fpout, "%s   do /* while (Y != stY) */\n%s   {\n", spc, spc);
 
443
         spc -= 3;
 
444
      }
 
445
      if (beta != 0) fprintf(fpout, "%s   rY0 = *Y;\n", spc);
 
446
      if (beta != 0 && beta != 1) fprintf(fpout, "%s   rY0 *= rX0;\n", spc);
 
447
      if (beta == 0) fprintf(fpout, "%s   rY0 = ", spc);
 
448
      else fprintf(fpout, "%s   rY0 += ", spc);
 
449
      fprintf(fpout, "ATL_%cdot(N, pA0, 1, X, 1);\n", pre);
 
450
      fprintf(fpout, "%s   *Y = rY0;\n", spc);
 
451
      if (nY > 2)
 
452
      {
 
453
         fprintf(fpout, "%s   pA0 += lda;\n", spc);
 
454
         spc += 3;
 
455
         fprintf(fpout, "%s   }\n%s   while (++Y != stY);\n", spc, spc);
 
456
      }
 
457
      spc += 3;
 
458
      fprintf(fpout, "%s   } /* end Y cleanup */;\n", spc);
 
459
   }
 
460
   fprintf(fpout, "}\n");
 
461
}
 
462
 
 
463
void PrintUsage(char *nam)
 
464
{
 
465
   fprintf(stderr, "USAGE: %s -l <latency> -Y <nY> -G <ngap> -g <gap> -M <gapmul> -A <# of regs for A> -X <# of regs for X> -y <# regs for Y> -f <file> -R <rout> -b <beta>\n", nam);
 
466
   exit(-1);
 
467
}
 
468
 
 
469
void GetFlags(int nargs, char **args, int *lat, int *beta, int *nY,
 
470
              int *ngap, int *gap, int *gapmul, int *pfA, int *pfX, 
 
471
              int *Yregs, char *pre, char *styp, char *typ, 
 
472
              FILE **fpout, char **rout)
 
473
{
 
474
   int i;
 
475
   *beta = -3;
 
476
   *lat = 0;
 
477
   *nY = *Yregs = *ngap = *gap = *gapmul = *pfA = *pfX = 1;
 
478
   *pre = 'd'; 
 
479
   *rout = NULL;
 
480
   *fpout = stdout;
 
481
 
 
482
   for (i=1; i < nargs; i++)
 
483
   {
 
484
      if (args[i][0] != '-') PrintUsage(args[0]);
 
485
      switch(args[i][1])
 
486
      {
 
487
      case 'R':
 
488
         *rout = args[++i];
 
489
         break;
 
490
      case 'f':
 
491
         *fpout = fopen(args[++i], "w");
 
492
         assert(*fpout);
 
493
         break;
 
494
      case 'b':
 
495
         *beta = atoi(args[++i]);
 
496
         break;
 
497
      case 'l':
 
498
         *lat = atoi(args[++i]);
 
499
         break;
 
500
      case 'y':
 
501
         *nY = atoi(args[++i]);
 
502
         break;
 
503
      case 'G':
 
504
         *ngap = atoi(args[++i]);
 
505
         break;
 
506
      case 'g':
 
507
         *gap = atoi(args[++i]);
 
508
         break;
 
509
      case 'M':
 
510
         *gapmul = atoi(args[++i]);
 
511
         break;
 
512
      case 'A':
 
513
         *pfA = atoi(args[++i]);
 
514
         break;
 
515
      case 'X':
 
516
         *pfX = atoi(args[++i]);
 
517
         break;
 
518
      case 'Y':
 
519
         *Yregs = atoi(args[++i]);
 
520
         break;
 
521
      case 'p':
 
522
         i++;
 
523
         *pre = Mlowcase(args[i][0]);
 
524
         break;
 
525
      default:
 
526
         PrintUsage(args[0]);
 
527
      }
 
528
   }
 
529
   i = *nY * *ngap * *gap * *gapmul;
 
530
   assert(i % *pfA == 0);
 
531
   if (*lat)
 
532
   {
 
533
      assert(i % *lat == 0);
 
534
      assert(*pfX + *lat / *nY <= i);
 
535
      assert(*pfA > *lat);
 
536
   }
 
537
   assert((*ngap * *gap * *gapmul)% (*pfX)  == 0);
 
538
   assert(*Yregs >= *nY);
 
539
 
 
540
   switch (*pre)
 
541
   {
 
542
   case 'z':
 
543
      sprintf(styp, "double*");
 
544
      sprintf(typ, "double");
 
545
      break;
 
546
   case 'd':
 
547
      sprintf(styp, "double");
 
548
      sprintf(typ, "double");
 
549
      break;
 
550
   case 'c':
 
551
      sprintf(styp, "float*");
 
552
      sprintf(typ, "float");
 
553
      break;
 
554
   case 's':
 
555
      sprintf(styp, "float");
 
556
      sprintf(typ, "float");
 
557
      break;
 
558
   default:
 
559
      PrintUsage(args[0]);
 
560
   }
 
561
}
 
562
 
 
563
main(int nargs, char **args)
 
564
{
 
565
   char pre, styp[32], typ[32], *rout;
 
566
   int lat;     /* 0: combined muladd inst; X: seperate multiply & add, lat=X */
 
567
   int nY;      /* number of columns of A to operate on (# of dot products) */
 
568
   int Yregs;   /* number of registers to use for Y */
 
569
   int ngap;    /* number of accumulaters to use within each dot prod */
 
570
   int gap;     /* gap size in column */
 
571
   int gapmul;  /* number of unrollings of whole mess to do */
 
572
   int pfA;     /* number of registers to use in prefetching A */
 
573
   int pfX;     /* number of registers to use in prefetching X */
 
574
   int i, beta;
 
575
   FILE *fpout;
 
576
 
 
577
   GetFlags(nargs, args, &lat, &beta, &nY, &ngap, &gap, &gapmul, &pfA, &pfX,
 
578
            &Yregs, &pre, styp, typ, &fpout, &rout);
 
579
 
 
580
   if (beta == -3) /* generate all beta cases */
 
581
   {
 
582
      for (i=0; i < 3; i++)
 
583
      {
 
584
         if (i == 0 || i == 1) fprintf(fpout, "#ifdef BETA%d\n\n", i);
 
585
         else fprintf(fpout, "#ifdef BETAX\n\n");
 
586
         emit_mvT(fpout, rout, pre, styp, typ, lat, i, nY, ngap, gap, gapmul,
 
587
                  pfA, pfX, Yregs);
 
588
         fprintf(fpout, "\n\n#endif\n");
 
589
      }
 
590
   }
 
591
   else emit_mvT(fpout, rout, pre, styp, typ, lat, beta, nY, ngap, gap,
 
592
                 gapmul, pfA, pfX, Yregs);
 
593
   exit(0);
 
594
}