~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to tune/blas/gemm/gmmsearch.c

  • Committer: Package Import Robot
  • Author(s): Sébastien Villemot
  • Date: 2013-06-11 15:58:16 UTC
  • mfrom: (1.1.3 upstream)
  • mto: (2.2.21 experimental)
  • mto: This revision was merged to the branch mainline in revision 26.
  • Revision ID: package-import@ubuntu.com-20130611155816-b72z8f621tuhbzn0
Tags: upstream-3.10.1
Import upstream version 3.10.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *             Automatically Tuned Linear Algebra Software v3.10.1
 
3
 * Copyright (C) 2010 R. Clint Whaley
 
4
 *
 
5
 * Redistribution and use in source and binary forms, with or without
 
6
 * modification, are permitted provided that the following conditions
 
7
 * are met:
 
8
 *   1. Redistributions of source code must retain the above copyright
 
9
 *      notice, this list of conditions and the following disclaimer.
 
10
 *   2. Redistributions in binary form must reproduce the above copyright
 
11
 *      notice, this list of conditions, and the following disclaimer in the
 
12
 *      documentation and/or other materials provided with the distribution.
 
13
 *   3. The name of the ATLAS group or the names of its contributers may
 
14
 *      not be used to endorse or promote products derived from this
 
15
 *      software without specific written permission.
 
16
 *
 
17
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 
18
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 
19
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 
20
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
 
21
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 
22
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 
23
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 
24
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 
25
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 
26
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 
27
 * POSSIBILITY OF SUCH DAMAGE.
 
28
 *
 
29
 */
 
30
#include <stdio.h>
 
31
#include <stdlib.h>
 
32
#include <assert.h>
 
33
#include "atlas_misc.h"
 
34
#include "atlas_mmtesttime.h"
 
35
 
 
36
#define MAXLAT 6
 
37
int GetGoodLat(int MULADD, int kb, int mu, int nu, int ku, int lat)
 
38
{
 
39
   int slat, blat, i, ii = mu*nu*ku;
 
40
   if (MULADD) return(lat);
 
41
   if ( (lat > 1) && (kb > ku) && ((ii/lat)*lat != ii) )  /* lat won't work */
 
42
   {
 
43
      for (i=lat; i; i--) if ( (ii/i) * i == ii ) break;
 
44
      slat = i;
 
45
      for (i=lat; i < MAXLAT; i++) if ( (ii/i) * i == ii ) break;
 
46
      blat = i;
 
47
      if ( (ii/blat)*blat != ii ) blat = slat;
 
48
      if (slat < 2) lat = blat;
 
49
      else if (lat-slat < blat-lat) lat = slat;
 
50
      else lat = blat;
 
51
   }
 
52
   return(lat);
 
53
}
 
54
 
 
55
int GetUniqueMuNus(int nregs, int muladd, int lat, int *mus, int *nus)
 
56
/*
 
57
 * RETURNS: number of unique MU,NU combos; always allow at least 1x1
 
58
 */
 
59
{
 
60
   int i, j, k, n=0;
 
61
 
 
62
   for (j=1; j <= nregs; j++)
 
63
   {
 
64
      for (i=1; i <= nregs; i++)
 
65
      {
 
66
         k = (muladd) ? 0 : lat;
 
67
         if ((i != 1 || j != 1) && i*j+i+1+k > nregs) continue;
 
68
         if (mus)
 
69
         {
 
70
            mus[n] = i;
 
71
            nus[n] = j;
 
72
         }
 
73
         n++;
 
74
      }
 
75
   }
 
76
   return(n);
 
77
}
 
78
 
 
79
#ifdef DEBUG
 
80
void PrintMUNUs(int N, int *mus, int *nus, double *fpls)
 
81
{
 
82
   int i;
 
83
   for (i=0; i < N; i++)
 
84
   {
 
85
      if (fpls)
 
86
         printf("%3d. MU=%d, NU=%d, fpl=%.3f\n", i, mus[i], nus[i], fpls[i]);
 
87
      else
 
88
         printf("%3d. MU=%d, NU=%d\n", i, mus[i], nus[i]);
 
89
   }
 
90
}
 
91
#endif
 
92
 
 
93
void SortByFlpLd(int N, int *mus, int *nus, double *FPL)
 
94
/*
 
95
 * Simple selection sort, sorting from best (greatest) flops/load to worst
 
96
 * ties in mflop are broken by taking the most square one, and if they
 
97
 * are equally square, then take the one with the bigger mu.
 
98
 */
 
99
{
 
100
   int i, j, imax, mindim, mindimB;
 
101
   double fpl, fplB;
 
102
 
 
103
   #ifdef DEBUG
 
104
      printf("\nUNSORTED:\n");
 
105
      PrintMUNUs(N, mus, nus, NULL);
 
106
   #endif
 
107
   for (i=0; i < N-1; i++)
 
108
   {
 
109
      imax = i;
 
110
      mindimB = (mus[i] <= nus[i]) ? mus[i] : nus[i];
 
111
      fplB = (2.0 * mus[i] * nus[i]) / (mus[i] + nus[i]);
 
112
      for (j=i+1; j < N; j++)
 
113
      {
 
114
          fpl = (2.0 * mus[j] * nus[j]) / (mus[j] + nus[j]);
 
115
          if (fpl > fplB)
 
116
          {
 
117
             imax = j;
 
118
             fplB = fpl;
 
119
             mindimB = (mus[j] <= nus[j]) ? mus[j] : nus[j];
 
120
          }
 
121
          else if (fpl == fplB)
 
122
          {
 
123
             mindim = (mus[j] <= nus[j]) ? mus[j] : nus[j];
 
124
             if (mindim > mindimB)
 
125
             {
 
126
                imax = j;
 
127
                mindimB = mindim;
 
128
             }
 
129
/*
 
130
 *           For symmetric shapes, choose the one with a bigger mu
 
131
 */
 
132
             else if (mindim == mindimB)
 
133
             {
 
134
                if (mus[j] > mindim)
 
135
                   imax = j;
 
136
             }
 
137
          }
 
138
      }
 
139
      if (imax != i)
 
140
      {
 
141
          j = mus[i];
 
142
          mus[i] = mus[imax];
 
143
          mus[imax] = j;
 
144
          j = nus[i];
 
145
          nus[i] = nus[imax];
 
146
          nus[imax] = j;
 
147
      }
 
148
      if (FPL)
 
149
         FPL[i] = fplB;
 
150
   }
 
151
   FPL[i] = (2.0 * mus[i] * nus[i]) / (mus[i] + nus[i]);
 
152
   #ifdef DEBUG
 
153
      printf("\n\nSORTED:\n");
 
154
      PrintMUNUs(N, mus, nus, FPL);
 
155
   #endif
 
156
}
 
157
 
 
158
#define LOWBOUND 0.6857
 
159
void GetMuNus(int nregs, int muladd, int lat, int *NGOOD, int *N0,
 
160
              int **mus, int **nus, double **fpls)
 
161
{
 
162
   int N, i;
 
163
   double fplB, *f;
 
164
 
 
165
   N = GetUniqueMuNus(nregs, muladd, lat, NULL, NULL);
 
166
   *mus = malloc(N*sizeof(int));
 
167
   *nus = malloc(N*sizeof(int));
 
168
   *fpls = malloc(N*sizeof(double));
 
169
   assert(*mus && *nus && *fpls);
 
170
   GetUniqueMuNus(nregs, muladd, lat, *mus, *nus);
 
171
   SortByFlpLd(N, *mus, *nus, *fpls);
 
172
   f = *fpls;
 
173
   fplB = LOWBOUND * f[0];
 
174
   for (i=1; i < N && f[i] >= fplB; i++);
 
175
   *NGOOD = i;
 
176
   *N0 = N;
 
177
}
 
178
 
 
179
int GetSafeGoodMuNu(int nreg, int muladd, int lat,
 
180
                    int N, int *mus, int *nus, double *fpls)
 
181
/*
 
182
 * Find the good value to compare agains the "bad" ones; should be safe on not
 
183
 * overflowing registers
 
184
 * NOTE : assumes mus/nus already sorted by flops/load
 
185
 */
 
186
{
 
187
   int k, i;
 
188
   k = (muladd) ? 0 : lat;
 
189
   for (i=0; i < N; i++)
 
190
      if (mus[i]*nus[i]+mus[i]+nus[i]+k+4 <= nreg)
 
191
         return(i);
 
192
   return(0);
 
193
}
 
194
 
 
195
void GetSafeMUNU(int nreg, int muladd, int lat, int *MU, int *NU)
 
196
{
 
197
   int N, Ng, i;
 
198
   int *mus, *nus;
 
199
   double *fpls;
 
200
 
 
201
   GetMuNus(nreg, muladd, lat, &Ng, &N, &mus, &nus, &fpls);
 
202
   i = GetSafeGoodMuNu(nreg, muladd, lat, N, mus, nus, fpls);
 
203
   *MU = mus[i];
 
204
   *NU = nus[i];
 
205
   free(mus);
 
206
   free(nus);
 
207
   free(fpls);
 
208
}
 
209
void GetMulAdd(char pre, int *MULADD, int *lat)
 
210
{
 
211
   char nam[64], ln[128];
 
212
   FILE *fp;
 
213
 
 
214
   sprintf(nam, "res/%cMULADD", pre);
 
215
   if (!FileExists(nam))
 
216
   {
 
217
      sprintf(ln, "make RunMulAdd pre=%c maxlat=%d mflop=%d\n", pre, 6, 200);
 
218
      assert(system(ln) == 0);
 
219
   }
 
220
   fp = fopen(nam, "r");
 
221
   assert(fp != NULL);
 
222
   fscanf(fp, "%d", MULADD);
 
223
   fscanf(fp, "%d", lat);
 
224
   fclose(fp);
 
225
}
 
226
 
 
227
void PrintUsage(char *name, int ierr, char *flag)
 
228
{
 
229
   fprintf(stderr,
 
230
           "%s searches for the best kernel that emit_mm.c can produce\n",
 
231
           name);
 
232
   fprintf(stderr, "For all gemm parameters (eg., nb) if they are not specified or\nspecified as 0, then the search determines them,\notherwise they are forced to the commandline specification.\n\n");
 
233
 
 
234
   if (ierr > 0)
 
235
      fprintf(stderr, "Bad argument #%d: '%s'\n",
 
236
              ierr, flag ? flag : "Not enough arguments");
 
237
   else if (ierr < 0)
 
238
      fprintf(stderr, "ERROR: %s\n", flag);
 
239
   fprintf(stderr, "USAGE: %s [flags]:\n", name);
 
240
   fprintf(stderr, "   -v # : higher numbers print out more\n");
 
241
   fprintf(stderr, "   -p [s,d,c,z]: set precision prefix \n");
 
242
   fprintf(stderr, "   -b <nb> : blocking factor \n");
 
243
   fprintf(stderr, "   -r <nreg> : number of registers to assume\n");
 
244
   fprintf(stderr, "   -k <ku> : K unrolling factor \n");
 
245
   fprintf(stderr, "   -l <lat> : multiply latency to assume\n");
 
246
   fprintf(stderr, "   -M <muladd> : -1: search 0: separate mul&add : else MACC\n");
 
247
   fprintf(stderr, "   -o <outfile> : defaults to res/<pre>gMMRES.sum\n");
 
248
   exit(ierr ? ierr : -1);
 
249
}
 
250
 
 
251
char GetFlags(int nargs, char **args, int *verb, int *nregs, int *nb,
 
252
              int *ku, int *MACC, int *lat, char **outfile)
 
253
{
 
254
   char pre, ch;
 
255
   int i;
 
256
 
 
257
   *outfile = NULL;
 
258
   *verb = 1;
 
259
   *MACC = -1;
 
260
   *lat = *nregs = *nb = *ku = 0;
 
261
   pre = 'd';
 
262
   for (i=1; i < nargs; i++)
 
263
   {
 
264
      if (args[i][0] != '-')
 
265
         PrintUsage(args[0], i, args[i]);
 
266
      switch(args[i][1])
 
267
      {
 
268
      case 'o':
 
269
         if (++i >= nargs)
 
270
            PrintUsage(args[0], i, NULL);
 
271
         *outfile = DupString(args[i]);
 
272
         break;
 
273
      case 'p':  /* -p <pre> */
 
274
         if (++i >= nargs)
 
275
            PrintUsage(args[0], i, NULL);
 
276
 
 
277
         ch = tolower(args[i][0]);
 
278
         assert(ch == 's' || ch == 'd' || ch == 'c' || ch == 'z');
 
279
         pre = ch;
 
280
         break;
 
281
      case 'M':
 
282
         if (++i >= nargs)
 
283
            PrintUsage(args[0], i, NULL);
 
284
         *MACC = atoi(args[i]);
 
285
         break;
 
286
      case 'v':
 
287
         if (++i >= nargs)
 
288
            PrintUsage(args[0], i, NULL);
 
289
         *verb = atoi(args[i]);
 
290
         break;
 
291
      case 'b':
 
292
         if (++i >= nargs)
 
293
            PrintUsage(args[0], i, NULL);
 
294
         *nb = atoi(args[i]);
 
295
         break;
 
296
      case 'l':
 
297
         if (++i >= nargs)
 
298
            PrintUsage(args[0], i, NULL);
 
299
         *lat = atoi(args[i]);
 
300
         break;
 
301
      case 'r':
 
302
         if (++i >= nargs)
 
303
            PrintUsage(args[0], i, NULL);
 
304
         *nregs = atoi(args[i]);
 
305
         break;
 
306
      default:
 
307
         PrintUsage(args[0], i, args[i]);
 
308
      }
 
309
   }
 
310
   assert(*nb >= 0);
 
311
   if (*outfile == NULL)
 
312
   {
 
313
      *outfile = DupString("res/dgMMRES.sum");
 
314
      (*outfile)[4] = pre;
 
315
   }
 
316
   return(pre);
 
317
}
 
318
 
 
319
double TryKUs
 
320
(
 
321
   ATL_mmnode_t *mmp,
 
322
   char pre,                    /* precision */
 
323
   int verb,                    /* verbosity level */
 
324
   int MACC,                    /* 0 : separate mult&add, else MACC */
 
325
   int lat0,                    /* multiply latency */
 
326
   int beta,                    /* 0,1 beta, else beta=X */
 
327
   int nb,                      /* blocking factor */
 
328
   int mu, int nu, int ku,      /* unrolling factors */
 
329
   int fftch,                   /* do bogus fetch of C at top of loop? */
 
330
   int iftch,                   /* # of initial fetches to do */
 
331
   int nftch,                   /* # of fetches to do thereafter */
 
332
   int LDTOP,                   /* 1: load C at top, else at bottom */
 
333
   int pf                       /* prefetch strategy */
 
334
)
 
335
/*
 
336
 * If ku is set, times only that value, else tries both ku=1 & ku=nb
 
337
 * RETURNS: best performance of timed problems, with ku set correctly,
 
338
 *          but the generator flags may be bad!
 
339
 */
 
340
{
 
341
   double mf, mf1;
 
342
   int lat;
 
343
 
 
344
   assert(mu > 0 && nu > 0);
 
345
   if (ku)
 
346
   {
 
347
      mmp->ku = ku;
 
348
      lat = GetGoodLat(MACC, nb, mu, nu, ku, lat0);
 
349
      mf = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, ku,
 
350
                         fftch, iftch, nftch, LDTOP, pf, -1, -1);
 
351
   }
 
352
   else
 
353
   {
 
354
      lat = GetGoodLat(MACC, nb, mu, nu, 1, lat0);
 
355
      mf = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, nb,
 
356
                         fftch, iftch, nftch, LDTOP, pf, -1, -1);
 
357
      mf1 = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, 1,
 
358
                          fftch, iftch, nftch, LDTOP, pf, -1, -1);
 
359
      if (mf1 >= mf)
 
360
      {
 
361
         mmp->ku = 1;
 
362
         mf = mf1;
 
363
      }
 
364
   }
 
365
   return(mf);
 
366
}
 
367
 
 
368
double TryPFs
 
369
(
 
370
   ATL_mmnode_t *mmp,
 
371
   char pre,                    /* precision */
 
372
   int verb,                    /* verbosity level */
 
373
   int MACC,                    /* 0 : separate mult&add, else MACC */
 
374
   int lat,                     /* multiply latency */
 
375
   int beta,                    /* 0,1 beta, else beta=X */
 
376
   int nb,                      /* blocking factor */
 
377
   int mu, int nu, int ku,      /* unrolling factors */
 
378
   int fftch,                   /* do bogus fetch of C at top of loop? */
 
379
   int iftch,                   /* # of initial fetches to do */
 
380
   int nftch,                   /* # of fetches to do thereafter */
 
381
   int LDTOP                    /* 1: load C at top, else at bottom */
 
382
)
 
383
{
 
384
   double mf0, mf1;
 
385
 
 
386
   mf0 = TryKUs(mmp, pre, verb, MACC, lat, beta, nb, mu, nu, ku, fftch, iftch,
 
387
                nftch, LDTOP, 0);
 
388
   mf1 = TryKUs(mmp, pre, verb, MACC, lat, beta, nb, mu, nu, ku, fftch, iftch,
 
389
                nftch, LDTOP, 1);
 
390
   mmp->pref = (mf1 > mf0);
 
391
   return((mmp->pref) ? mf1 : mf0);
 
392
}
 
393
 
 
394
void ConfirmMACC(char pre, int verb, int nregs, int nb, int ku,
 
395
                 int *MACC, int *lat)
 
396
{
 
397
   char upr;
 
398
   int maccB, latB, latR, mu, nu;
 
399
   double mf, mfB;
 
400
   ATL_mmnode_t *mmp;
 
401
 
 
402
   mmp = GetMMNode();
 
403
 
 
404
   if (pre == 'd' || pre == 's')
 
405
      upr = pre;
 
406
   else if (pre == 'z')
 
407
      upr = 'd';
 
408
   else
 
409
      upr = 's';
 
410
   GetMulAdd(upr, MACC, lat);
 
411
   if (verb)
 
412
      printf("\nCONFIRMING MACC=%d AND LAT=%d WITH FORCED NREGS=%d\n",
 
413
             *MACC, *lat, nregs);
 
414
/*
 
415
 * Find performance of present MACC setting
 
416
 */
 
417
   maccB = *MACC;
 
418
   latB = *lat;
 
419
   GetSafeMUNU(nregs, maccB, latB, &mu, &nu);
 
420
   mfB = TryKUs(mmp, pre, verb, maccB, latB, 1, nb, mu, nu, ku, 0, mu+nu, 1,
 
421
                0, 0);
 
422
   if (verb)
 
423
      printf("   MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
 
424
             maccB, latB, mu, nu, mfB);
 
425
/*
 
426
 * If no MACC, see if latency = 1 is just as good (dynamically scheduled mach)
 
427
 */
 
428
   if (!maccB)
 
429
   {
 
430
      GetSafeMUNU(nregs, 0, 1, &mu, &nu);
 
431
      mf = TryKUs(mmp, pre, verb, 0, 1, 1, nb, mu, nu, ku, 0, mu+nu, 1, 0, 0);
 
432
      if (verb)
 
433
         printf("   MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
 
434
                0, 1, mu, nu, mf);
 
435
      if (mf > mfB)
 
436
      {
 
437
         mfB = mf;
 
438
         latB = 1;
 
439
      }
 
440
   }
 
441
/*
 
442
 * Find setting of reverse MACC setting, same latency
 
443
 */
 
444
   GetSafeMUNU(nregs, !maccB, *lat, &mu, &nu);
 
445
   mf = TryKUs(mmp, pre, verb, maccB, *lat, 1, nb, mu, nu, ku, 0, mu+nu, 1,
 
446
               0, 0);
 
447
   if (verb)
 
448
      printf("   MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
 
449
             !maccB, *lat, mu, nu, mf);
 
450
   if (mf > mfB || (!maccB && mf == mfB))
 
451
   {
 
452
      maccB = !maccB;
 
453
      latB = *lat;
 
454
      mfB = mf;
 
455
   }
 
456
/*
 
457
 * Try to reverse MACC to 0, lat=1 (dynamically scheduled machines)
 
458
 */
 
459
   if (*MACC == 1)
 
460
   {
 
461
      GetSafeMUNU(nregs, 0, 1, &mu, &nu);
 
462
      mf = TryKUs(mmp, pre, verb, 0, 1, 1, nb, mu, nu, ku, 0, mu+nu, 1, 0, 0);
 
463
      if (verb)
 
464
         printf("   MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
 
465
                0, 1, mu, nu, mf);
 
466
      if (mf > mfB)
 
467
      {
 
468
         mfB = mf;
 
469
         latB = 1;
 
470
         maccB = 0;
 
471
      }
 
472
   }
 
473
   if (verb)
 
474
      printf("CHOSE MACC=%d LAT=%d (%.2f)\n", maccB, latB, mfB);
 
475
   *MACC = maccB;
 
476
   *lat = latB;
 
477
}
 
478
 
 
479
double FindNumRegsByMACC(char pre, int verb, int nb, int ku, int MACC, int lat0,
 
480
                         int *NREGS, int *MU, int *NU)
 
481
{
 
482
   int i, mu, nu, muB, nuB, ForceMACC, lat;
 
483
   double mf, mfB;
 
484
   ATL_mmnode_t *mmp;
 
485
 
 
486
   mmp = GetMMNode();
 
487
   mfB = 0.0;
 
488
   for (i=8; i < 1024; i *= 2)
 
489
   {
 
490
      if (!MACC)
 
491
      {
 
492
         lat = i>>1;                       /* don't allow lat to take up */
 
493
         lat = (lat > lat0) ? lat0 : lat;  /* more than 1/2 the registers */
 
494
      }
 
495
      else
 
496
         lat = lat0;
 
497
      GetSafeMUNU(i, MACC, lat, &mu, &nu);
 
498
      mf = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, mu, nu, ku, 0, mu+nu, 1,
 
499
                  0, 0);
 
500
      if (verb)
 
501
         printf(
 
502
 "   nreg=%3d: nb=%2d, mu=%2d, nu=%2d, ku=%2d, MACC=%1d, lat=%2d, mf=%.2f\n",
 
503
                i, nb, mu, nu, mmp->ku, MACC, lat, mf);
 
504
      if (mf > mfB)
 
505
      {
 
506
         mfB = mf;
 
507
         muB = mu;
 
508
         nuB = nu;
 
509
      }
 
510
/*
 
511
 *    Call a 8% decline in performance evidence of register overflow
 
512
 */
 
513
 
 
514
      else if (1.08*mf < mfB)
 
515
         break;
 
516
   }
 
517
   *NREGS = i>>1;
 
518
   *MU = muB;
 
519
   *NU = nuB;
 
520
   return(mfB);
 
521
}
 
522
 
 
523
int FindNumRegs(char pre, int verb, int nb, int ku, int *MACC, int *lat)
 
524
/*
 
525
 * Finds an estimate for the number of registers the compiler will let
 
526
 * you use in a generated matmul
 
527
 */
 
528
{
 
529
   int nregs, nr, ForceMACC, mu, nu, lat0;
 
530
   double mf, mf1, mfmacc;
 
531
   FILE *fp;
 
532
   char ln[128];
 
533
 
 
534
   sprintf(ln, "res/%cfpuMM", pre);
 
535
   fp = fopen(ln, "r");
 
536
   if (fp)
 
537
   {
 
538
      fgets(ln, 128, fp);  /* skip header */
 
539
      if (fscanf(fp, " %d %d %d", &nregs, MACC, lat) == 3)
 
540
      {
 
541
         fclose(fp);
 
542
         if (verb)
 
543
            printf("READ IN NUMBER OF GEMM REGISTERS = %d, MACC=%d, lat=%d:\n",
 
544
             nregs, *MACC, *lat);
 
545
         sprintf(ln, "res/%cnreg", pre);
 
546
         fp = fopen(ln, "w");
 
547
         fprintf(fp, "%d\n", nregs);
 
548
         fclose(fp);
 
549
         return(nregs);
 
550
      }
 
551
      fclose(fp);
 
552
   }
 
553
 
 
554
   ForceMACC = (*MACC >= 0);
 
555
   if (ForceMACC && *MACC && !(*lat))
 
556
   {
 
557
      fprintf(stderr,
 
558
              "If you force no MACC, then you must also force latency!\n");
 
559
      exit(-1);
 
560
   }
 
561
   if (pre == 'z')
 
562
      return(FindNumRegs('d', verb, nb, ku, MACC, lat));
 
563
   else if (pre == 'c')
 
564
      return(FindNumRegs('s', verb, nb, ku, MACC, lat));
 
565
   if (verb)
 
566
      printf("\nESTIMATING THE NUMBER OF USEABLE REGISTERS FOR GEMM:\n");
 
567
   if (!ForceMACC)
 
568
      GetMulAdd(pre, MACC, lat);
 
569
   lat0 = *lat;
 
570
   mf = FindNumRegsByMACC(pre, verb, nb, ku, *MACC, *lat, &nregs, &mu, &nu);
 
571
/*
 
572
 * Using separate multiply and add is expensive in terms of registers,
 
573
 * and is often messed up by compilers, so let's try lat=1 (for dynamically
 
574
 * scheduled machines), and using a MACC, and see what happens
 
575
 */
 
576
   if (!ForceMACC && *MACC == 0)
 
577
   {
 
578
      if (*lat > 1)
 
579
      {
 
580
         printf("\n");
 
581
         mf1 = FindNumRegsByMACC(pre, verb, nb, ku, 0, 1, &nr, &mu, &nu);
 
582
         if (mf1 >= mf)  /* latency of 1 just as good as longer latency */
 
583
         {
 
584
            nregs = nr;
 
585
            *lat = 1;
 
586
         }
 
587
      }
 
588
      printf("\n");
 
589
      mfmacc = FindNumRegsByMACC(pre, verb, nb, ku, 1, lat0, &nr, &mu, &nu);
 
590
      if (mfmacc > mf && mfmacc >= mf1) /* MACC is better */
 
591
      {
 
592
         nregs = nr;
 
593
         *MACC = 1;
 
594
         *lat = lat0;
 
595
      }
 
596
   }
 
597
 
 
598
   fp = fopen(ln, "w");
 
599
   assert(fp);
 
600
   fprintf(fp, "NREG  MACC  LAT\n%4d %5d %4d\n", nregs, *MACC, *lat);
 
601
   fclose(fp);
 
602
/*
 
603
 * Write # of registers to <pre>nreg for use by sysinfo tuning
 
604
 */
 
605
   sprintf(ln, "res/%cnreg", pre);
 
606
   fp = fopen(ln, "w");
 
607
   fprintf(fp, "%d\n", nregs);
 
608
   fclose(fp);
 
609
 
 
610
   if (verb)
 
611
      printf("NUMBER OF ESTIMATED GEMM REGISTERS = %d, MACC=%d, lat=%d:\n",
 
612
             nregs, *MACC, *lat);
 
613
   return(nregs);
 
614
}
 
615
 
 
616
int GetBigNB(char pre)
 
617
{
 
618
   int i, L1Elts;
 
619
   if (pre == 'd' || pre == 'z')
 
620
      L1Elts = 1024/8;
 
621
   else
 
622
      L1Elts = 1024/4;
 
623
   L1Elts *= GetL1CacheSize();
 
624
   for (i=16; i*i < L1Elts; i += 4);
 
625
   if (i*i > L1Elts)
 
626
      i -= 4;
 
627
   if (i > 80)
 
628
      i = 80;
 
629
   return(i);
 
630
}
 
631
 
 
632
int GetSmallNB(char pre)
 
633
{
 
634
   int i, L1Elts;
 
635
   const int imul = (pre == 'c' || pre == 'z') ? 6 : 3;
 
636
   if (pre == 'd' || pre == 'z')
 
637
      L1Elts = 1024/8;
 
638
   else
 
639
      L1Elts = 1024/4;
 
640
   L1Elts *= GetL1CacheSize();
 
641
   for (i=16; imul*i*i < L1Elts; i += 4);
 
642
   if (imul*i*i > L1Elts)
 
643
      i -= 4;
 
644
   i = Mmin(i,80);
 
645
   i = Mmax(i,16);
 
646
   return(i);
 
647
}
 
648
 
 
649
ATL_mmnode_t *FindBestNB
 
650
(
 
651
   char pre,   /* precision, one of s,d,c,z */
 
652
   int verb,   /* verbosity */
 
653
   ATL_mmnode_t *mmp,  /* input/output struct for best case found so far */
 
654
   int ku      /* 0: tune ku, else we must use this ku */
 
655
)
 
656
/*
 
657
 * This function tries to find the NB to use.  It varies NB, prefetch,
 
658
 * and ku (if allowed, but only between 1 and full unrolling)
 
659
 * RETURNS: matmul struct of best found case
 
660
 */
 
661
{
 
662
   int bN, b0, binc, nbB, muB, nuB, pfB, MACC, lat, KUISKB=0, i;
 
663
   double mf, mfB, mf1;
 
664
 
 
665
   nbB = mmp->nbB;
 
666
   mfB = mmp->mflop[0];
 
667
   muB = mmp->mu;
 
668
   nuB = mmp->nu;
 
669
   pfB = mmp->pref;
 
670
   MACC = mmp->muladd;
 
671
   lat = mmp->lat;
 
672
/*
 
673
 * Find largest block factor to tune; Since L1 estimate may be wrong,
 
674
 * make sure that larger block factors aren't competitive, but max
 
675
 * NB will be 80 regardless to avoid cleanup nightmare
 
676
 */
 
677
   if (verb)
 
678
      printf("\nFINDING UPPER BOUND ON NB:\n");
 
679
   bN = GetBigNB(pre);  /* our guess for largest useful NB */
 
680
   while (bN < 80)
 
681
   {
 
682
      mf = TryKUs(mmp, pre, verb, MACC, lat, 1, bN+4, muB, nuB, ku,
 
683
                  0, muB+nuB, 1, 0, 0);
 
684
      printf("   nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
 
685
             bN+4, muB, nuB, mmp->ku, MACC, lat, mf);
 
686
      if (mf > mfB)
 
687
      {
 
688
         mfB = mf;
 
689
         nbB = bN+4;
 
690
      }
 
691
      else
 
692
         break;
 
693
      bN += 4;
 
694
   }
 
695
   if (bN > 80)
 
696
      bN = 80;
 
697
   if (verb)
 
698
      printf("NB UPPER BOUND CHOSEN AS : %d (%.2f)\n", bN, mfB);
 
699
/*
 
700
 * See if lowering NB past when all matrices should fit is useful
 
701
 * (again, L1 detection could be wrong)
 
702
 */
 
703
   if (verb)
 
704
      printf("\nFINDING LOWER BOUND ON NB:\n");
 
705
   b0 = GetSmallNB(pre);
 
706
   mf1 = TryKUs(mmp, pre, verb, MACC, lat, 1, b0, muB, nuB, ku,
 
707
                0, muB+nuB, 1, 0, 0);
 
708
   printf("   nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
 
709
          b0, muB, nuB, mmp->ku, MACC, lat, mf1);
 
710
   while(b0 > 20)
 
711
   {
 
712
      mf = TryKUs(mmp, pre, verb, MACC, lat, 1, b0-4, muB, nuB, ku,
 
713
                  0, muB+nuB, 1, 0, 0);
 
714
      printf("   nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
 
715
             b0-4, muB, nuB, mmp->ku, MACC, lat, mf);
 
716
      if (mf < mf1)
 
717
         break;
 
718
      else if (mf > mfB)
 
719
      {
 
720
         mfB = mf;
 
721
         nbB = b0-4;
 
722
      }
 
723
      b0 -= 4;
 
724
   }
 
725
   if (verb)
 
726
      printf("NB LOWER BOUND CHOSEN AS : %d\n", b0);
 
727
 
 
728
/*
 
729
 * Now try all NBs with varying prefetch
 
730
 */
 
731
   binc = (pre == 's' || pre == 'c') ? 4 : 2;
 
732
   KUISKB = (!ku && mmp->ku == mmp->nbB);
 
733
   b0 = (b0/binc)*binc;
 
734
   bN = (bN/binc)*binc;
 
735
   if (verb)
 
736
      printf("\nFINDING BEST NB AND PREFETCH SETTING IN RANGE [%d,%d,%d]:\n",
 
737
             b0, bN, binc);
 
738
 
 
739
   for (i=b0; i <= bN; i += binc)
 
740
   {
 
741
      mf = TryPFs(mmp, pre, verb, MACC, lat, 1, i, muB, nuB, KUISKB ? i:ku,
 
742
                  0, muB+nuB, 1, 0);
 
743
      printf(
 
744
      "   nb=%3d, pf=%d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
 
745
             i, mmp->pref, muB, nuB, mmp->ku, MACC, lat, mf);
 
746
      if (mf > mfB)
 
747
      {
 
748
         mfB = mf;
 
749
         nbB = i;
 
750
         pfB = mmp->pref;
 
751
      }
 
752
   }
 
753
   if (verb)
 
754
      printf("BEST NB=%d, BEST PREFETCH=%d (%.2f)\n", nbB, pfB, mfB);
 
755
   mmp->mflop[0] = mfB;
 
756
   mmp->mbB = mmp->nbB = mmp->kbB = nbB;
 
757
   mmp->pref = pfB;
 
758
   return(mmp);
 
759
}
 
760
ATL_mmnode_t *FindBestKU
 
761
(
 
762
   char pre,   /* precision, one of s,d,c,z */
 
763
   int verb,   /* verbosity */
 
764
   ATL_mmnode_t *mmp   /* input/output struct for best case found so far */
 
765
)
 
766
/*
 
767
 * Find best K unrolling.  There is no data cache dependence here, so time
 
768
 * with in-cache operands for increases speed and accuracy
 
769
 */
 
770
{
 
771
   int k, kuB, latB, kN, incK, lat;
 
772
   int nb, LAT, MACC, mu, nu, pf;
 
773
   double mf, mfB;
 
774
 
 
775
   LAT = mmp->lat;  /* canonical latency */
 
776
   MACC = mmp->muladd;
 
777
   mu = mmp->mu;
 
778
   nu = mmp->nu;
 
779
   nb = mmp->nbB;
 
780
   pf = mmp->pref;
 
781
   if (verb)
 
782
      printf("TRYING KUs FOR NB=%d, PF=%d, MU=%d, NU=%d MACC=%d, LAT=%d:\n",
 
783
             nb, pf, mu, nu, MACC, LAT);
 
784
/*
 
785
 * Try ku=1 as default
 
786
 */
 
787
   kuB = 1;
 
788
   latB = lat = GetGoodLat(MACC, nb, mu, nu, 1, LAT);
 
789
   mfB = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, 1,
 
790
                       0, mu+nu, 1, 0, pf, -1, -1);
 
791
   if (verb)
 
792
      printf("   KU=%d, lat=%d, mf=%.2f\n", 1, lat, mfB);
 
793
/*
 
794
 * Try NB/2 as maximal unrolling that actually has a loop
 
795
 */
 
796
   k = nb>>1;
 
797
   lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
 
798
   mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
 
799
                      0, mu+nu, 1, 0, pf, -1, -1);
 
800
   if (verb)
 
801
      printf("   KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
 
802
   if (mf > mfB)
 
803
   {
 
804
      mfB = mf;
 
805
      kuB = nb;
 
806
      latB = lat;
 
807
   }
 
808
/*
 
809
 * Try fully unrolled, give it .5% bonus since it is easier on lat, etc.
 
810
 */
 
811
   mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, nb,
 
812
                      0, mu+nu, 1, 0, pf, -1, -1);
 
813
   mf *= 1.005;
 
814
   if (verb)
 
815
      printf("   KU=%d, lat=%d, mf=%.2f\n", nb, LAT, mf);
 
816
   if (mf > mfB)
 
817
   {
 
818
      mfB = mf;
 
819
      kuB = nb;
 
820
      latB = LAT;
 
821
   }
 
822
/*
 
823
 * Have already tried 1 & KB, so now try 2, 4, 6, 8
 
824
 */
 
825
   for (k=2; k <= 8; k += 2)
 
826
   {
 
827
      lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
 
828
      mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
 
829
                         0, mu+nu, 1, 0, pf, -1, -1);
 
830
      if (verb)
 
831
         printf("   KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
 
832
      if (mf > mfB)
 
833
      {
 
834
         mfB = mf;
 
835
         kuB = k;
 
836
         latB = lat;
 
837
      }
 
838
   }
 
839
/*
 
840
 * Try all cases in range [8,nb/2,4]
 
841
 */
 
842
   kN = nb>>1;
 
843
   if (!mmp->muladd && mmp->lat > 2)
 
844
   {
 
845
      incK = mmp->lat;
 
846
      k = (incK >= 8) ? incK : (8/incK)*incK;
 
847
   }
 
848
   else
 
849
   {
 
850
      incK = 4;
 
851
      k = 8;
 
852
   }
 
853
   for (; k < kN; k += incK)
 
854
   {
 
855
      lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
 
856
      mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
 
857
                         0, mu+nu, 1, 0, pf, -1, -1);
 
858
      if (verb)
 
859
         printf("   KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
 
860
      if (mf > mfB)
 
861
      {
 
862
         mfB = mf;
 
863
         kuB = k;
 
864
         latB = LAT;
 
865
      }
 
866
   }
 
867
/*
 
868
 * Time the best found case out-of-cache so we it can be compared to others
 
869
 */
 
870
   lat = GetGoodLat(MACC, nb, mu, nu, kuB, LAT);
 
871
   mfB = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, kuB,
 
872
                       0, mu+nu, 1, 0, pf, -1, -1);
 
873
   mmp->ku = kuB;
 
874
   mmp->lat = latB;
 
875
   mmp->mflop[0] = mfB;
 
876
   if (verb)
 
877
      printf("BEST KU=%d, lat=%d (%.2f)\n", kuB, latB, mfB);
 
878
   return(mmp);
 
879
}
 
880
 
 
881
ATL_mmnode_t *FindBestRest
 
882
(
 
883
   char pre,   /* precision, one of s,d,c,z */
 
884
   int verb,   /* verbosity */
 
885
   ATL_mmnode_t *mmp   /* input/output struct for best case found so far */
 
886
)  /* tunes iftch, nftch, fftch, LDTOP, tries opposite muladd  */
 
887
{
 
888
   int i, j, n, nelts, nb, mu, nu, pf, MACC, lat, ku;
 
889
   int ifB, nfB, ffB, ldtopB;
 
890
   double mfB, mf, mf0;
 
891
 
 
892
   nb = mmp->nbB;
 
893
   pf = mmp->pref;
 
894
   mu = mmp->mu;
 
895
   nu = mmp->nu;
 
896
   ldtopB = 0;
 
897
   ifB = nelts = mu + nu;
 
898
   nfB = 1;
 
899
   ffB = 0;
 
900
   mfB = 0;
 
901
   MACC = mmp->muladd;
 
902
   lat = mmp->lat;
 
903
   ku = mmp->ku;
 
904
   if (verb)
 
905
      printf( "FINDING BEST FETCH PATTERN FOR nb=%d, mu=%d, nu=%d, ku=%d\n",
 
906
             nb, mu , nu, ku);
 
907
   for (i=2; i <= nelts; i++)
 
908
   {
 
909
      n = nelts - i;
 
910
      n = Mmin(i, n);
 
911
      if (!n)
 
912
         n = 1;
 
913
      for (j=1; j <= n; j++)
 
914
      {
 
915
         mf = TimeGMMKernel(verb, 1, pre, MACC, lat, 1, nb, mu, nu, ku,
 
916
                            0, i, j, 0, pf, -1, -1);
 
917
         if (verb)
 
918
            printf ("   ifetch=%2d, nfetch=%2d, mf=%.2f\n", i, j, mf);
 
919
         if (mf > mfB)
 
920
         {
 
921
            ifB = i;
 
922
            nfB = j;
 
923
            mfB = mf;
 
924
         }
 
925
      }
 
926
   }
 
927
/*
 
928
 * overwrite bad ifetch value output file with selected one
 
929
 */
 
930
   mfB = TimeGMMKernel(verb, 1, pre, MACC, lat, 1, nb, mu, nu, ku,
 
931
                       0, ifB, nfB, 0, pf, -1, -1);
 
932
   if (verb)
 
933
      printf("   best case retimed as mf=%.2f\n", mfB);
 
934
   mmp->mflop[0] = mfB;
 
935
   mmp->iftch = ifB;
 
936
   mmp->nftch = nfB;
 
937
   if (verb)
 
938
      printf("BEST IFETCH=%d, NFETCH=%d (%.2f)\n", ifB, nfB, mfB);
 
939
/*
 
940
 * Try force fetch for beta=0
 
941
 */
 
942
   if (verb > 1)
 
943
      printf("TRYING FALSE FETCH FOR BETA=0 CASES:\n");
 
944
   mf0 = TimeGMMKernel(verb, 1, pre, MACC, lat, 0, nb, mu, nu, ku,
 
945
                       0, ifB, nfB, 0, pf, -1, -1);
 
946
   if (verb > 1)
 
947
      printf("   noFF=%.2f\n", mf0);
 
948
   mf = TimeGMMKernel(verb, 1, pre, MACC, lat, 0, nb, mu, nu, ku,
 
949
                      1, ifB, nfB, 0, pf, -1, -1);
 
950
   if (verb > 1)
 
951
      printf("   yesFF=%.2f\n", mf);
 
952
   if (mf > mf0)
 
953
      ffB = 1;
 
954
   mmp->fftch = ffB;
 
955
/*
 
956
 * If loading C at top is 2% faster, take it despite error bound hit
 
957
 */
 
958
   if (verb)
 
959
      printf("TRYING LOAD-AT-TOP (load-at-bottom %.2f)\n", mfB);
 
960
   mf  = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, ku,
 
961
                       ffB, ifB, nfB, 1, pf, -1, -1);
 
962
   if (verb)
 
963
      printf("   load-at-top, mf=%.2f\n", mf);
 
964
   if (mfB*1.02 > mf)
 
965
   {
 
966
      if (verb)
 
967
         printf("STICKING WITH LOAD-AT-BOTTOM\n");
 
968
   }
 
969
   else
 
970
   {
 
971
      ldtopB = 1;
 
972
      if (verb)
 
973
         printf("SWITCHING TO LOAD-AT-TOP\n");
 
974
      mmp->flag |= (1<<MMF_LDCTOP);
 
975
      mfB = mf;
 
976
   }
 
977
 
 
978
/*
 
979
 * See if reversing muladd setting is helpful
 
980
 */
 
981
   if (verb)
 
982
      printf("TRYING SWAP OF MACC (present, madd=%d, lat=%d, mf=%.2f)\n",
 
983
             MACC, lat, mfB);
 
984
   if (MACC)
 
985
   {
 
986
      i = Mmax(mmp->lat, 4);
 
987
      i = GetGoodLat(0, nb, mu, nu, ku, i);
 
988
   }
 
989
   else
 
990
      i = mmp->lat;
 
991
   mf  = TimeGMMKernel(verb, 0, pre, !MACC, i, 1, nb, mu, nu, ku,
 
992
                       ffB, ifB, nfB, ldtopB, pf, -1, -1);
 
993
   if (verb)
 
994
      printf("   macc=%d, lat=%d, mf=%.2f\n", !MACC, i, mf);
 
995
   if (mf > mfB)
 
996
   {
 
997
      mmp->muladd = !MACC;
 
998
      mmp->lat = i;
 
999
      mfB = mf;
 
1000
      if (verb)
 
1001
         printf("SWITCHING TO NEW MACC SETTING!\n");
 
1002
   }
 
1003
   else if (verb)
 
1004
      printf("KEEPING MACC SETTING.\n");
 
1005
   mmp->mflop[0] = mfB;
 
1006
   return(mmp);
 
1007
}
 
1008
 
 
1009
ATL_mmnode_t *FindBestGenGemm
 
1010
(
 
1011
   char pre,   /* precision, one of s,d,c,z */
 
1012
   int verb,   /* verbosity */
 
1013
   int nregs,  /* max # of registers to use */
 
1014
   int MACC,   /* 1: machine has multiply & accumulate, else separate */
 
1015
   int lat,    /* latency on multiply */
 
1016
   int FNB,    /* is it required to use NB, or can we tune? */
 
1017
   int NB,     /* suggested nb */
 
1018
   int ku      /* 0: tune ku, else we must use this ku */
 
1019
)
 
1020
/*
 
1021
 * This routine finds the best copy matmul case that can be generated by
 
1022
 * emit_mm.c.  It will search over the following parameters:
 
1023
 *    (nb,pf), (mu,nu), ku, nftch, iftch, fftch, LDTOP
 
1024
 *
 
1025
 * pf is currently 1 or 0, and it controls whether the next block of A is
 
1026
 * prefetched or not.
 
1027
 *
 
1028
 * LDTOP determines if we load C values before entering the K loop (TOP)
 
1029
 * or after.  After gives better error bound, so give it slight advantage
 
1030
 *
 
1031
 * nftch,iftch are crude load scheduling parameters, and they tend to
 
1032
 * have little affect on most machines (the compiler usually reschedules
 
1033
 * the loads on its own).
 
1034
 *
 
1035
 * fftch causes the generator to load C at the top of the loop even
 
1036
 * when we are don't need the values there, so that C is in cache at
 
1037
 * the bottom of the loop when we need it.
 
1038
 *
 
1039
 * RETURNS: filled structure with best gemm case found
 
1040
 */
 
1041
{
 
1042
   ATL_mmnode_t *mmp;
 
1043
   int nb, N, Ng, i, j, mu, nu, nbB, muB, nuB;
 
1044
   int *mus, *nus, *ip;
 
1045
   double mf, mfB, mf1;
 
1046
   double *fpls;
 
1047
   #ifdef ATL_GAS_x8664
 
1048
      #define NEXMU 5
 
1049
      int exmu[NEXMU] = {4, 6, 8, 10, 12};
 
1050
      int exnu[NEXMU] = {1, 1, 1, 1,  1};
 
1051
   #elif defined(ATL_GAS_x8632)
 
1052
      #define NEXMU 4
 
1053
      int exmu[NEXMU] = {3, 4, 6, 2};
 
1054
      int exnu[NEXMU] = {1, 1, 1, 2};
 
1055
   #endif
 
1056
   char upr;
 
1057
   char ln[128];
 
1058
 
 
1059
/*
 
1060
 * Use either required nb, or one that is a multiple of a lot of our
 
1061
 * unrolling factors;  Use a big block factor so that our register blocking
 
1062
 * matters more (cache is covering less of costs)
 
1063
 */
 
1064
   if (FNB)
 
1065
      nb = NB;
 
1066
   else
 
1067
   {
 
1068
      nb = (GetBigNB(pre)/12)*12;
 
1069
      if (nb < 24)
 
1070
         nb = 24;
 
1071
   }
 
1072
   if (pre == 'd' || pre == 's')
 
1073
   {
 
1074
      mmp = GetMMNode();
 
1075
      FillInGMMNode(verb, mmp, pre, MACC, lat, 1, nb, 1, 1, 1, 0, 2, 1, 0, 0);
 
1076
 
 
1077
/*
 
1078
 *    Get all MU/NU unrollings, Ng of them are competitive on flops/load ratio.
 
1079
 *    For x86, always include extra 1-D blockings in mix, even if they
 
1080
 *    are not judged competive (because if reg-reg moves aren't free, which
 
1081
 *    is true for older x86 machines, 2-D register blocks don't really work
 
1082
 *    due to 2-operand assembly)
 
1083
 */
 
1084
      GetMuNus(nregs, MACC, lat, &Ng, &N, &mus, &nus, &fpls);
 
1085
      free(fpls);
 
1086
      #ifdef NEXMU
 
1087
         for (j=0; j < NEXMU; j++)
 
1088
         {
 
1089
            mu = exmu[j];
 
1090
            nu = exnu[j];
 
1091
            for (i=0; i < Ng; i++)
 
1092
               if (mus[i] == mu && nus[i] == nu) break;
 
1093
            if (i == Ng)
 
1094
            {
 
1095
               if (Ng >= N)
 
1096
               {
 
1097
                  ip = malloc((Ng+1)*sizeof(int));
 
1098
                  assert(ip);
 
1099
                  for (i=0; i < Ng; i++)
 
1100
                     ip[i] = mus[i];
 
1101
                  free(mus);
 
1102
                  mus = ip;
 
1103
                  ip = malloc((Ng+1)*sizeof(int));
 
1104
                  assert(ip);
 
1105
                  for (i=0; i < Ng; i++)
 
1106
                     ip[i] = nus[i];
 
1107
                  free(nus);
 
1108
                  nus = ip;
 
1109
               }
 
1110
               mus[Ng] = mu;
 
1111
               nus[Ng] = nu;
 
1112
               Ng++;
 
1113
            }
 
1114
         }
 
1115
      #endif
 
1116
      if (verb)
 
1117
         printf("PROBING FOR M AND N UNROLLING FACTORS:\n");
 
1118
/*
 
1119
 *    Try all competitive unrolling factors
 
1120
 */
 
1121
      mfB = 0;
 
1122
      muB = nuB = 1;
 
1123
      for (i=0; i < Ng; i++)
 
1124
      {
 
1125
         mf = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, mus[i], nus[i], ku,
 
1126
                     0, mus[i]+nus[i], 1, 0, 0);
 
1127
 
 
1128
         printf("   nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
 
1129
                nb, mus[i], nus[i], mmp->ku, MACC, lat, mf);
 
1130
         if (mf > mfB)
 
1131
         {
 
1132
            muB = mus[i];
 
1133
            nuB = nus[i];
 
1134
            mfB = mf;
 
1135
         }
 
1136
      }
 
1137
      mmp->mu = muB;
 
1138
      mmp->nu = nuB;
 
1139
      mmp->iftch = muB+nuB;
 
1140
      mmp->mflop[0] = mfB;
 
1141
      printf("SELECTED MU=%d, NU=%d (%.2f)\n", muB, nuB, mfB);
 
1142
      free(mus);
 
1143
      free(nus);
 
1144
      nbB = nb;
 
1145
   }
 
1146
   else /* complex types gets their MU & NU from real cases */
 
1147
   {
 
1148
      upr = (pre == 'z') ? 'd' : 's';
 
1149
      mmp = ReadMMFileWithPath(upr, "res", "gMMRES.sum");
 
1150
      if (!mmp)
 
1151
      {
 
1152
         sprintf(ln, "make res/%cgMMRES.sum > /dev/null 2>&1", upr);
 
1153
         assert(system(ln) == 0);
 
1154
         mmp = ReadMMFileWithPath(upr, "res", "gMMRES.sum");
 
1155
         assert(mmp);
 
1156
      }
 
1157
      muB = mmp->mu;
 
1158
      nuB = mmp->nu;
 
1159
      nbB = nb;
 
1160
      mfB = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, muB, nuB, ku,
 
1161
                   0, muB+nuB, 1, 0, 0);
 
1162
      mmp->mflop[0] = mfB;
 
1163
      printf("READ IN MU=%d, NU=%d FROM REAL, nb=%d, mf=%.2f\n",
 
1164
             muB, nuB, nb, mfB);
 
1165
   }
 
1166
/*
 
1167
 * If we are allowed, try to tune NB
 
1168
 */
 
1169
   if (!FNB)
 
1170
      mmp = FindBestNB(pre, verb, mmp, ku);
 
1171
   else  /* still need to scope prefetch settings with required NB */
 
1172
   {
 
1173
      mmp->nbB = mmp->mbB = mmp->kbB = nb;
 
1174
      mf = TryPFs(mmp, pre, verb, MACC, lat, 1, nb, muB, nuB, ku,
 
1175
                  0, muB+nuB, 1, 0);
 
1176
   }
 
1177
/*
 
1178
 * If we are allowed, tune ku
 
1179
 */
 
1180
   if (!ku)
 
1181
      mmp = FindBestKU(pre, verb, mmp); /* tunes ku */
 
1182
   mmp = FindBestRest(pre, verb, mmp);  /* tunes iftch, nftch, fftch, LDTOP */
 
1183
   return(mmp);
 
1184
}
 
1185
 
 
1186
int main(int nargs, char **args)
 
1187
{
 
1188
   char pre, *outfile;
 
1189
   int verb, nregs, FNB, nb, ku, MACC, lat, mu, nu;
 
1190
   ATL_mmnode_t *mmp, *mm;
 
1191
 
 
1192
   pre = GetFlags(nargs, args, &verb, &nregs, &nb, &ku, &MACC, &lat, &outfile);
 
1193
   if (nregs == -1)  /* run # register probe only */
 
1194
   {
 
1195
      nb = GetBigNB(pre);
 
1196
      nregs = FindNumRegs(pre, verb, nb, ku, &MACC, &lat);
 
1197
      exit(0);
 
1198
   }
 
1199
   mmp = ReadMMFile(outfile);
 
1200
   if (mmp)
 
1201
   {
 
1202
      if (mmp->mflop[0] <= 0)  /* need to retime */
 
1203
      {
 
1204
         for (mm=mmp; mm; mm = mm->next)
 
1205
         {
 
1206
            mm->mflop[0] = TimeGMMKernel(verb, 0, pre, mm->muladd, mm->lat,
 
1207
                                         1, mm->nbB, mm->mu, mm->nu, mm->ku,
 
1208
                                         mm->fftch, mm->iftch, mm->nftch,
 
1209
                                         FLAG_IS_SET(mm->flag, MMF_LDCTOP),
 
1210
                                         mm->pref, -1, -1);
 
1211
         }
 
1212
         WriteMMFile(outfile, mmp);
 
1213
      }
 
1214
      printf("\nSEARCH OUTPUT READ IN AS:\n");
 
1215
      PrintMMNodes(stdout, mmp);
 
1216
      exit(0);
 
1217
   }
 
1218
   if (nb > 0)
 
1219
      FNB = 1;
 
1220
   else
 
1221
   {
 
1222
      nb = GetBigNB(pre);
 
1223
      FNB = 0;
 
1224
   }
 
1225
   if (!nregs)
 
1226
      nregs = FindNumRegs(pre, verb, nb, ku, &MACC, &lat);
 
1227
   else if (MACC < 0)
 
1228
      ConfirmMACC(pre, verb, nregs, nb, ku, &MACC, &lat);
 
1229
   mmp = FindBestGenGemm(pre, verb, nregs, MACC, lat, FNB, nb, ku);
 
1230
   WriteMMFile(outfile, mmp);
 
1231
   printf("\nSELECTED GENERATED KERNEL:\n");
 
1232
   PrintMMNodes(stdout, mmp);
 
1233
   KillMMNode(mmp);
 
1234
   exit(0);
 
1235
}