2
* Automatically Tuned Linear Algebra Software v3.10.1
3
* Copyright (C) 2010 R. Clint Whaley
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions
8
* 1. Redistributions of source code must retain the above copyright
9
* notice, this list of conditions and the following disclaimer.
10
* 2. Redistributions in binary form must reproduce the above copyright
11
* notice, this list of conditions, and the following disclaimer in the
12
* documentation and/or other materials provided with the distribution.
13
* 3. The name of the ATLAS group or the names of its contributers may
14
* not be used to endorse or promote products derived from this
15
* software without specific written permission.
17
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
19
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
21
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
* POSSIBILITY OF SUCH DAMAGE.
33
#include "atlas_misc.h"
34
#include "atlas_mmtesttime.h"
37
int GetGoodLat(int MULADD, int kb, int mu, int nu, int ku, int lat)
39
int slat, blat, i, ii = mu*nu*ku;
40
if (MULADD) return(lat);
41
if ( (lat > 1) && (kb > ku) && ((ii/lat)*lat != ii) ) /* lat won't work */
43
for (i=lat; i; i--) if ( (ii/i) * i == ii ) break;
45
for (i=lat; i < MAXLAT; i++) if ( (ii/i) * i == ii ) break;
47
if ( (ii/blat)*blat != ii ) blat = slat;
48
if (slat < 2) lat = blat;
49
else if (lat-slat < blat-lat) lat = slat;
55
int GetUniqueMuNus(int nregs, int muladd, int lat, int *mus, int *nus)
57
* RETURNS: number of unique MU,NU combos; always allow at least 1x1
62
for (j=1; j <= nregs; j++)
64
for (i=1; i <= nregs; i++)
66
k = (muladd) ? 0 : lat;
67
if ((i != 1 || j != 1) && i*j+i+1+k > nregs) continue;
80
void PrintMUNUs(int N, int *mus, int *nus, double *fpls)
86
printf("%3d. MU=%d, NU=%d, fpl=%.3f\n", i, mus[i], nus[i], fpls[i]);
88
printf("%3d. MU=%d, NU=%d\n", i, mus[i], nus[i]);
93
void SortByFlpLd(int N, int *mus, int *nus, double *FPL)
95
* Simple selection sort, sorting from best (greatest) flops/load to worst
96
* ties in mflop are broken by taking the most square one, and if they
97
* are equally square, then take the one with the bigger mu.
100
int i, j, imax, mindim, mindimB;
104
printf("\nUNSORTED:\n");
105
PrintMUNUs(N, mus, nus, NULL);
107
for (i=0; i < N-1; i++)
110
mindimB = (mus[i] <= nus[i]) ? mus[i] : nus[i];
111
fplB = (2.0 * mus[i] * nus[i]) / (mus[i] + nus[i]);
112
for (j=i+1; j < N; j++)
114
fpl = (2.0 * mus[j] * nus[j]) / (mus[j] + nus[j]);
119
mindimB = (mus[j] <= nus[j]) ? mus[j] : nus[j];
121
else if (fpl == fplB)
123
mindim = (mus[j] <= nus[j]) ? mus[j] : nus[j];
124
if (mindim > mindimB)
130
* For symmetric shapes, choose the one with a bigger mu
132
else if (mindim == mindimB)
151
FPL[i] = (2.0 * mus[i] * nus[i]) / (mus[i] + nus[i]);
153
printf("\n\nSORTED:\n");
154
PrintMUNUs(N, mus, nus, FPL);
158
#define LOWBOUND 0.6857
159
void GetMuNus(int nregs, int muladd, int lat, int *NGOOD, int *N0,
160
int **mus, int **nus, double **fpls)
165
N = GetUniqueMuNus(nregs, muladd, lat, NULL, NULL);
166
*mus = malloc(N*sizeof(int));
167
*nus = malloc(N*sizeof(int));
168
*fpls = malloc(N*sizeof(double));
169
assert(*mus && *nus && *fpls);
170
GetUniqueMuNus(nregs, muladd, lat, *mus, *nus);
171
SortByFlpLd(N, *mus, *nus, *fpls);
173
fplB = LOWBOUND * f[0];
174
for (i=1; i < N && f[i] >= fplB; i++);
179
int GetSafeGoodMuNu(int nreg, int muladd, int lat,
180
int N, int *mus, int *nus, double *fpls)
182
* Find the good value to compare agains the "bad" ones; should be safe on not
183
* overflowing registers
184
* NOTE : assumes mus/nus already sorted by flops/load
188
k = (muladd) ? 0 : lat;
189
for (i=0; i < N; i++)
190
if (mus[i]*nus[i]+mus[i]+nus[i]+k+4 <= nreg)
195
void GetSafeMUNU(int nreg, int muladd, int lat, int *MU, int *NU)
201
GetMuNus(nreg, muladd, lat, &Ng, &N, &mus, &nus, &fpls);
202
i = GetSafeGoodMuNu(nreg, muladd, lat, N, mus, nus, fpls);
209
void GetMulAdd(char pre, int *MULADD, int *lat)
211
char nam[64], ln[128];
214
sprintf(nam, "res/%cMULADD", pre);
215
if (!FileExists(nam))
217
sprintf(ln, "make RunMulAdd pre=%c maxlat=%d mflop=%d\n", pre, 6, 200);
218
assert(system(ln) == 0);
220
fp = fopen(nam, "r");
222
fscanf(fp, "%d", MULADD);
223
fscanf(fp, "%d", lat);
227
void PrintUsage(char *name, int ierr, char *flag)
230
"%s searches for the best kernel that emit_mm.c can produce\n",
232
fprintf(stderr, "For all gemm parameters (eg., nb) if they are not specified or\nspecified as 0, then the search determines them,\notherwise they are forced to the commandline specification.\n\n");
235
fprintf(stderr, "Bad argument #%d: '%s'\n",
236
ierr, flag ? flag : "Not enough arguments");
238
fprintf(stderr, "ERROR: %s\n", flag);
239
fprintf(stderr, "USAGE: %s [flags]:\n", name);
240
fprintf(stderr, " -v # : higher numbers print out more\n");
241
fprintf(stderr, " -p [s,d,c,z]: set precision prefix \n");
242
fprintf(stderr, " -b <nb> : blocking factor \n");
243
fprintf(stderr, " -r <nreg> : number of registers to assume\n");
244
fprintf(stderr, " -k <ku> : K unrolling factor \n");
245
fprintf(stderr, " -l <lat> : multiply latency to assume\n");
246
fprintf(stderr, " -M <muladd> : -1: search 0: separate mul&add : else MACC\n");
247
fprintf(stderr, " -o <outfile> : defaults to res/<pre>gMMRES.sum\n");
248
exit(ierr ? ierr : -1);
251
char GetFlags(int nargs, char **args, int *verb, int *nregs, int *nb,
252
int *ku, int *MACC, int *lat, char **outfile)
260
*lat = *nregs = *nb = *ku = 0;
262
for (i=1; i < nargs; i++)
264
if (args[i][0] != '-')
265
PrintUsage(args[0], i, args[i]);
270
PrintUsage(args[0], i, NULL);
271
*outfile = DupString(args[i]);
273
case 'p': /* -p <pre> */
275
PrintUsage(args[0], i, NULL);
277
ch = tolower(args[i][0]);
278
assert(ch == 's' || ch == 'd' || ch == 'c' || ch == 'z');
283
PrintUsage(args[0], i, NULL);
284
*MACC = atoi(args[i]);
288
PrintUsage(args[0], i, NULL);
289
*verb = atoi(args[i]);
293
PrintUsage(args[0], i, NULL);
298
PrintUsage(args[0], i, NULL);
299
*lat = atoi(args[i]);
303
PrintUsage(args[0], i, NULL);
304
*nregs = atoi(args[i]);
307
PrintUsage(args[0], i, args[i]);
311
if (*outfile == NULL)
313
*outfile = DupString("res/dgMMRES.sum");
322
char pre, /* precision */
323
int verb, /* verbosity level */
324
int MACC, /* 0 : separate mult&add, else MACC */
325
int lat0, /* multiply latency */
326
int beta, /* 0,1 beta, else beta=X */
327
int nb, /* blocking factor */
328
int mu, int nu, int ku, /* unrolling factors */
329
int fftch, /* do bogus fetch of C at top of loop? */
330
int iftch, /* # of initial fetches to do */
331
int nftch, /* # of fetches to do thereafter */
332
int LDTOP, /* 1: load C at top, else at bottom */
333
int pf /* prefetch strategy */
336
* If ku is set, times only that value, else tries both ku=1 & ku=nb
337
* RETURNS: best performance of timed problems, with ku set correctly,
338
* but the generator flags may be bad!
344
assert(mu > 0 && nu > 0);
348
lat = GetGoodLat(MACC, nb, mu, nu, ku, lat0);
349
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, ku,
350
fftch, iftch, nftch, LDTOP, pf, -1, -1);
354
lat = GetGoodLat(MACC, nb, mu, nu, 1, lat0);
355
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, nb,
356
fftch, iftch, nftch, LDTOP, pf, -1, -1);
357
mf1 = TimeGMMKernel(verb, 0, pre, MACC, lat, beta, nb, mu, nu, 1,
358
fftch, iftch, nftch, LDTOP, pf, -1, -1);
371
char pre, /* precision */
372
int verb, /* verbosity level */
373
int MACC, /* 0 : separate mult&add, else MACC */
374
int lat, /* multiply latency */
375
int beta, /* 0,1 beta, else beta=X */
376
int nb, /* blocking factor */
377
int mu, int nu, int ku, /* unrolling factors */
378
int fftch, /* do bogus fetch of C at top of loop? */
379
int iftch, /* # of initial fetches to do */
380
int nftch, /* # of fetches to do thereafter */
381
int LDTOP /* 1: load C at top, else at bottom */
386
mf0 = TryKUs(mmp, pre, verb, MACC, lat, beta, nb, mu, nu, ku, fftch, iftch,
388
mf1 = TryKUs(mmp, pre, verb, MACC, lat, beta, nb, mu, nu, ku, fftch, iftch,
390
mmp->pref = (mf1 > mf0);
391
return((mmp->pref) ? mf1 : mf0);
394
void ConfirmMACC(char pre, int verb, int nregs, int nb, int ku,
398
int maccB, latB, latR, mu, nu;
404
if (pre == 'd' || pre == 's')
410
GetMulAdd(upr, MACC, lat);
412
printf("\nCONFIRMING MACC=%d AND LAT=%d WITH FORCED NREGS=%d\n",
415
* Find performance of present MACC setting
419
GetSafeMUNU(nregs, maccB, latB, &mu, &nu);
420
mfB = TryKUs(mmp, pre, verb, maccB, latB, 1, nb, mu, nu, ku, 0, mu+nu, 1,
423
printf(" MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
424
maccB, latB, mu, nu, mfB);
426
* If no MACC, see if latency = 1 is just as good (dynamically scheduled mach)
430
GetSafeMUNU(nregs, 0, 1, &mu, &nu);
431
mf = TryKUs(mmp, pre, verb, 0, 1, 1, nb, mu, nu, ku, 0, mu+nu, 1, 0, 0);
433
printf(" MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
442
* Find setting of reverse MACC setting, same latency
444
GetSafeMUNU(nregs, !maccB, *lat, &mu, &nu);
445
mf = TryKUs(mmp, pre, verb, maccB, *lat, 1, nb, mu, nu, ku, 0, mu+nu, 1,
448
printf(" MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
449
!maccB, *lat, mu, nu, mf);
450
if (mf > mfB || (!maccB && mf == mfB))
457
* Try to reverse MACC to 0, lat=1 (dynamically scheduled machines)
461
GetSafeMUNU(nregs, 0, 1, &mu, &nu);
462
mf = TryKUs(mmp, pre, verb, 0, 1, 1, nb, mu, nu, ku, 0, mu+nu, 1, 0, 0);
464
printf(" MACC=%1d, lat=%2d, mu=%2d, nu=%2d, mf=%.2f\n",
474
printf("CHOSE MACC=%d LAT=%d (%.2f)\n", maccB, latB, mfB);
479
double FindNumRegsByMACC(char pre, int verb, int nb, int ku, int MACC, int lat0,
480
int *NREGS, int *MU, int *NU)
482
int i, mu, nu, muB, nuB, ForceMACC, lat;
488
for (i=8; i < 1024; i *= 2)
492
lat = i>>1; /* don't allow lat to take up */
493
lat = (lat > lat0) ? lat0 : lat; /* more than 1/2 the registers */
497
GetSafeMUNU(i, MACC, lat, &mu, &nu);
498
mf = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, mu, nu, ku, 0, mu+nu, 1,
502
" nreg=%3d: nb=%2d, mu=%2d, nu=%2d, ku=%2d, MACC=%1d, lat=%2d, mf=%.2f\n",
503
i, nb, mu, nu, mmp->ku, MACC, lat, mf);
511
* Call a 8% decline in performance evidence of register overflow
514
else if (1.08*mf < mfB)
523
int FindNumRegs(char pre, int verb, int nb, int ku, int *MACC, int *lat)
525
* Finds an estimate for the number of registers the compiler will let
526
* you use in a generated matmul
529
int nregs, nr, ForceMACC, mu, nu, lat0;
530
double mf, mf1, mfmacc;
534
sprintf(ln, "res/%cfpuMM", pre);
538
fgets(ln, 128, fp); /* skip header */
539
if (fscanf(fp, " %d %d %d", &nregs, MACC, lat) == 3)
543
printf("READ IN NUMBER OF GEMM REGISTERS = %d, MACC=%d, lat=%d:\n",
545
sprintf(ln, "res/%cnreg", pre);
547
fprintf(fp, "%d\n", nregs);
554
ForceMACC = (*MACC >= 0);
555
if (ForceMACC && *MACC && !(*lat))
558
"If you force no MACC, then you must also force latency!\n");
562
return(FindNumRegs('d', verb, nb, ku, MACC, lat));
564
return(FindNumRegs('s', verb, nb, ku, MACC, lat));
566
printf("\nESTIMATING THE NUMBER OF USEABLE REGISTERS FOR GEMM:\n");
568
GetMulAdd(pre, MACC, lat);
570
mf = FindNumRegsByMACC(pre, verb, nb, ku, *MACC, *lat, &nregs, &mu, &nu);
572
* Using separate multiply and add is expensive in terms of registers,
573
* and is often messed up by compilers, so let's try lat=1 (for dynamically
574
* scheduled machines), and using a MACC, and see what happens
576
if (!ForceMACC && *MACC == 0)
581
mf1 = FindNumRegsByMACC(pre, verb, nb, ku, 0, 1, &nr, &mu, &nu);
582
if (mf1 >= mf) /* latency of 1 just as good as longer latency */
589
mfmacc = FindNumRegsByMACC(pre, verb, nb, ku, 1, lat0, &nr, &mu, &nu);
590
if (mfmacc > mf && mfmacc >= mf1) /* MACC is better */
600
fprintf(fp, "NREG MACC LAT\n%4d %5d %4d\n", nregs, *MACC, *lat);
603
* Write # of registers to <pre>nreg for use by sysinfo tuning
605
sprintf(ln, "res/%cnreg", pre);
607
fprintf(fp, "%d\n", nregs);
611
printf("NUMBER OF ESTIMATED GEMM REGISTERS = %d, MACC=%d, lat=%d:\n",
616
int GetBigNB(char pre)
619
if (pre == 'd' || pre == 'z')
623
L1Elts *= GetL1CacheSize();
624
for (i=16; i*i < L1Elts; i += 4);
632
int GetSmallNB(char pre)
635
const int imul = (pre == 'c' || pre == 'z') ? 6 : 3;
636
if (pre == 'd' || pre == 'z')
640
L1Elts *= GetL1CacheSize();
641
for (i=16; imul*i*i < L1Elts; i += 4);
642
if (imul*i*i > L1Elts)
649
ATL_mmnode_t *FindBestNB
651
char pre, /* precision, one of s,d,c,z */
652
int verb, /* verbosity */
653
ATL_mmnode_t *mmp, /* input/output struct for best case found so far */
654
int ku /* 0: tune ku, else we must use this ku */
657
* This function tries to find the NB to use. It varies NB, prefetch,
658
* and ku (if allowed, but only between 1 and full unrolling)
659
* RETURNS: matmul struct of best found case
662
int bN, b0, binc, nbB, muB, nuB, pfB, MACC, lat, KUISKB=0, i;
673
* Find largest block factor to tune; Since L1 estimate may be wrong,
674
* make sure that larger block factors aren't competitive, but max
675
* NB will be 80 regardless to avoid cleanup nightmare
678
printf("\nFINDING UPPER BOUND ON NB:\n");
679
bN = GetBigNB(pre); /* our guess for largest useful NB */
682
mf = TryKUs(mmp, pre, verb, MACC, lat, 1, bN+4, muB, nuB, ku,
683
0, muB+nuB, 1, 0, 0);
684
printf(" nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
685
bN+4, muB, nuB, mmp->ku, MACC, lat, mf);
698
printf("NB UPPER BOUND CHOSEN AS : %d (%.2f)\n", bN, mfB);
700
* See if lowering NB past when all matrices should fit is useful
701
* (again, L1 detection could be wrong)
704
printf("\nFINDING LOWER BOUND ON NB:\n");
705
b0 = GetSmallNB(pre);
706
mf1 = TryKUs(mmp, pre, verb, MACC, lat, 1, b0, muB, nuB, ku,
707
0, muB+nuB, 1, 0, 0);
708
printf(" nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
709
b0, muB, nuB, mmp->ku, MACC, lat, mf1);
712
mf = TryKUs(mmp, pre, verb, MACC, lat, 1, b0-4, muB, nuB, ku,
713
0, muB+nuB, 1, 0, 0);
714
printf(" nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
715
b0-4, muB, nuB, mmp->ku, MACC, lat, mf);
726
printf("NB LOWER BOUND CHOSEN AS : %d\n", b0);
729
* Now try all NBs with varying prefetch
731
binc = (pre == 's' || pre == 'c') ? 4 : 2;
732
KUISKB = (!ku && mmp->ku == mmp->nbB);
736
printf("\nFINDING BEST NB AND PREFETCH SETTING IN RANGE [%d,%d,%d]:\n",
739
for (i=b0; i <= bN; i += binc)
741
mf = TryPFs(mmp, pre, verb, MACC, lat, 1, i, muB, nuB, KUISKB ? i:ku,
744
" nb=%3d, pf=%d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
745
i, mmp->pref, muB, nuB, mmp->ku, MACC, lat, mf);
754
printf("BEST NB=%d, BEST PREFETCH=%d (%.2f)\n", nbB, pfB, mfB);
756
mmp->mbB = mmp->nbB = mmp->kbB = nbB;
760
ATL_mmnode_t *FindBestKU
762
char pre, /* precision, one of s,d,c,z */
763
int verb, /* verbosity */
764
ATL_mmnode_t *mmp /* input/output struct for best case found so far */
767
* Find best K unrolling. There is no data cache dependence here, so time
768
* with in-cache operands for increases speed and accuracy
771
int k, kuB, latB, kN, incK, lat;
772
int nb, LAT, MACC, mu, nu, pf;
775
LAT = mmp->lat; /* canonical latency */
782
printf("TRYING KUs FOR NB=%d, PF=%d, MU=%d, NU=%d MACC=%d, LAT=%d:\n",
783
nb, pf, mu, nu, MACC, LAT);
785
* Try ku=1 as default
788
latB = lat = GetGoodLat(MACC, nb, mu, nu, 1, LAT);
789
mfB = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, 1,
790
0, mu+nu, 1, 0, pf, -1, -1);
792
printf(" KU=%d, lat=%d, mf=%.2f\n", 1, lat, mfB);
794
* Try NB/2 as maximal unrolling that actually has a loop
797
lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
798
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
799
0, mu+nu, 1, 0, pf, -1, -1);
801
printf(" KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
809
* Try fully unrolled, give it .5% bonus since it is easier on lat, etc.
811
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, nb,
812
0, mu+nu, 1, 0, pf, -1, -1);
815
printf(" KU=%d, lat=%d, mf=%.2f\n", nb, LAT, mf);
823
* Have already tried 1 & KB, so now try 2, 4, 6, 8
825
for (k=2; k <= 8; k += 2)
827
lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
828
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
829
0, mu+nu, 1, 0, pf, -1, -1);
831
printf(" KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
840
* Try all cases in range [8,nb/2,4]
843
if (!mmp->muladd && mmp->lat > 2)
846
k = (incK >= 8) ? incK : (8/incK)*incK;
853
for (; k < kN; k += incK)
855
lat = GetGoodLat(MACC, nb, mu, nu, k, LAT);
856
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, k,
857
0, mu+nu, 1, 0, pf, -1, -1);
859
printf(" KU=%d, lat=%d, mf=%.2f\n", k, LAT, mf);
868
* Time the best found case out-of-cache so we it can be compared to others
870
lat = GetGoodLat(MACC, nb, mu, nu, kuB, LAT);
871
mfB = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, kuB,
872
0, mu+nu, 1, 0, pf, -1, -1);
877
printf("BEST KU=%d, lat=%d (%.2f)\n", kuB, latB, mfB);
881
ATL_mmnode_t *FindBestRest
883
char pre, /* precision, one of s,d,c,z */
884
int verb, /* verbosity */
885
ATL_mmnode_t *mmp /* input/output struct for best case found so far */
886
) /* tunes iftch, nftch, fftch, LDTOP, tries opposite muladd */
888
int i, j, n, nelts, nb, mu, nu, pf, MACC, lat, ku;
889
int ifB, nfB, ffB, ldtopB;
897
ifB = nelts = mu + nu;
905
printf( "FINDING BEST FETCH PATTERN FOR nb=%d, mu=%d, nu=%d, ku=%d\n",
907
for (i=2; i <= nelts; i++)
913
for (j=1; j <= n; j++)
915
mf = TimeGMMKernel(verb, 1, pre, MACC, lat, 1, nb, mu, nu, ku,
916
0, i, j, 0, pf, -1, -1);
918
printf (" ifetch=%2d, nfetch=%2d, mf=%.2f\n", i, j, mf);
928
* overwrite bad ifetch value output file with selected one
930
mfB = TimeGMMKernel(verb, 1, pre, MACC, lat, 1, nb, mu, nu, ku,
931
0, ifB, nfB, 0, pf, -1, -1);
933
printf(" best case retimed as mf=%.2f\n", mfB);
938
printf("BEST IFETCH=%d, NFETCH=%d (%.2f)\n", ifB, nfB, mfB);
940
* Try force fetch for beta=0
943
printf("TRYING FALSE FETCH FOR BETA=0 CASES:\n");
944
mf0 = TimeGMMKernel(verb, 1, pre, MACC, lat, 0, nb, mu, nu, ku,
945
0, ifB, nfB, 0, pf, -1, -1);
947
printf(" noFF=%.2f\n", mf0);
948
mf = TimeGMMKernel(verb, 1, pre, MACC, lat, 0, nb, mu, nu, ku,
949
1, ifB, nfB, 0, pf, -1, -1);
951
printf(" yesFF=%.2f\n", mf);
956
* If loading C at top is 2% faster, take it despite error bound hit
959
printf("TRYING LOAD-AT-TOP (load-at-bottom %.2f)\n", mfB);
960
mf = TimeGMMKernel(verb, 0, pre, MACC, lat, 1, nb, mu, nu, ku,
961
ffB, ifB, nfB, 1, pf, -1, -1);
963
printf(" load-at-top, mf=%.2f\n", mf);
967
printf("STICKING WITH LOAD-AT-BOTTOM\n");
973
printf("SWITCHING TO LOAD-AT-TOP\n");
974
mmp->flag |= (1<<MMF_LDCTOP);
979
* See if reversing muladd setting is helpful
982
printf("TRYING SWAP OF MACC (present, madd=%d, lat=%d, mf=%.2f)\n",
986
i = Mmax(mmp->lat, 4);
987
i = GetGoodLat(0, nb, mu, nu, ku, i);
991
mf = TimeGMMKernel(verb, 0, pre, !MACC, i, 1, nb, mu, nu, ku,
992
ffB, ifB, nfB, ldtopB, pf, -1, -1);
994
printf(" macc=%d, lat=%d, mf=%.2f\n", !MACC, i, mf);
1001
printf("SWITCHING TO NEW MACC SETTING!\n");
1004
printf("KEEPING MACC SETTING.\n");
1005
mmp->mflop[0] = mfB;
1009
ATL_mmnode_t *FindBestGenGemm
1011
char pre, /* precision, one of s,d,c,z */
1012
int verb, /* verbosity */
1013
int nregs, /* max # of registers to use */
1014
int MACC, /* 1: machine has multiply & accumulate, else separate */
1015
int lat, /* latency on multiply */
1016
int FNB, /* is it required to use NB, or can we tune? */
1017
int NB, /* suggested nb */
1018
int ku /* 0: tune ku, else we must use this ku */
1021
* This routine finds the best copy matmul case that can be generated by
1022
* emit_mm.c. It will search over the following parameters:
1023
* (nb,pf), (mu,nu), ku, nftch, iftch, fftch, LDTOP
1025
* pf is currently 1 or 0, and it controls whether the next block of A is
1026
* prefetched or not.
1028
* LDTOP determines if we load C values before entering the K loop (TOP)
1029
* or after. After gives better error bound, so give it slight advantage
1031
* nftch,iftch are crude load scheduling parameters, and they tend to
1032
* have little affect on most machines (the compiler usually reschedules
1033
* the loads on its own).
1035
* fftch causes the generator to load C at the top of the loop even
1036
* when we are don't need the values there, so that C is in cache at
1037
* the bottom of the loop when we need it.
1039
* RETURNS: filled structure with best gemm case found
1043
int nb, N, Ng, i, j, mu, nu, nbB, muB, nuB;
1044
int *mus, *nus, *ip;
1045
double mf, mfB, mf1;
1047
#ifdef ATL_GAS_x8664
1049
int exmu[NEXMU] = {4, 6, 8, 10, 12};
1050
int exnu[NEXMU] = {1, 1, 1, 1, 1};
1051
#elif defined(ATL_GAS_x8632)
1053
int exmu[NEXMU] = {3, 4, 6, 2};
1054
int exnu[NEXMU] = {1, 1, 1, 2};
1060
* Use either required nb, or one that is a multiple of a lot of our
1061
* unrolling factors; Use a big block factor so that our register blocking
1062
* matters more (cache is covering less of costs)
1068
nb = (GetBigNB(pre)/12)*12;
1072
if (pre == 'd' || pre == 's')
1075
FillInGMMNode(verb, mmp, pre, MACC, lat, 1, nb, 1, 1, 1, 0, 2, 1, 0, 0);
1078
* Get all MU/NU unrollings, Ng of them are competitive on flops/load ratio.
1079
* For x86, always include extra 1-D blockings in mix, even if they
1080
* are not judged competive (because if reg-reg moves aren't free, which
1081
* is true for older x86 machines, 2-D register blocks don't really work
1082
* due to 2-operand assembly)
1084
GetMuNus(nregs, MACC, lat, &Ng, &N, &mus, &nus, &fpls);
1087
for (j=0; j < NEXMU; j++)
1091
for (i=0; i < Ng; i++)
1092
if (mus[i] == mu && nus[i] == nu) break;
1097
ip = malloc((Ng+1)*sizeof(int));
1099
for (i=0; i < Ng; i++)
1103
ip = malloc((Ng+1)*sizeof(int));
1105
for (i=0; i < Ng; i++)
1117
printf("PROBING FOR M AND N UNROLLING FACTORS:\n");
1119
* Try all competitive unrolling factors
1123
for (i=0; i < Ng; i++)
1125
mf = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, mus[i], nus[i], ku,
1126
0, mus[i]+nus[i], 1, 0, 0);
1128
printf(" nb=%3d, mu=%3d, nu=%3d, ku=%3d, MACC=%d, lat=%d, mf=%.2f\n",
1129
nb, mus[i], nus[i], mmp->ku, MACC, lat, mf);
1139
mmp->iftch = muB+nuB;
1140
mmp->mflop[0] = mfB;
1141
printf("SELECTED MU=%d, NU=%d (%.2f)\n", muB, nuB, mfB);
1146
else /* complex types gets their MU & NU from real cases */
1148
upr = (pre == 'z') ? 'd' : 's';
1149
mmp = ReadMMFileWithPath(upr, "res", "gMMRES.sum");
1152
sprintf(ln, "make res/%cgMMRES.sum > /dev/null 2>&1", upr);
1153
assert(system(ln) == 0);
1154
mmp = ReadMMFileWithPath(upr, "res", "gMMRES.sum");
1160
mfB = TryKUs(mmp, pre, verb, MACC, lat, 1, nb, muB, nuB, ku,
1161
0, muB+nuB, 1, 0, 0);
1162
mmp->mflop[0] = mfB;
1163
printf("READ IN MU=%d, NU=%d FROM REAL, nb=%d, mf=%.2f\n",
1167
* If we are allowed, try to tune NB
1170
mmp = FindBestNB(pre, verb, mmp, ku);
1171
else /* still need to scope prefetch settings with required NB */
1173
mmp->nbB = mmp->mbB = mmp->kbB = nb;
1174
mf = TryPFs(mmp, pre, verb, MACC, lat, 1, nb, muB, nuB, ku,
1178
* If we are allowed, tune ku
1181
mmp = FindBestKU(pre, verb, mmp); /* tunes ku */
1182
mmp = FindBestRest(pre, verb, mmp); /* tunes iftch, nftch, fftch, LDTOP */
1186
int main(int nargs, char **args)
1189
int verb, nregs, FNB, nb, ku, MACC, lat, mu, nu;
1190
ATL_mmnode_t *mmp, *mm;
1192
pre = GetFlags(nargs, args, &verb, &nregs, &nb, &ku, &MACC, &lat, &outfile);
1193
if (nregs == -1) /* run # register probe only */
1196
nregs = FindNumRegs(pre, verb, nb, ku, &MACC, &lat);
1199
mmp = ReadMMFile(outfile);
1202
if (mmp->mflop[0] <= 0) /* need to retime */
1204
for (mm=mmp; mm; mm = mm->next)
1206
mm->mflop[0] = TimeGMMKernel(verb, 0, pre, mm->muladd, mm->lat,
1207
1, mm->nbB, mm->mu, mm->nu, mm->ku,
1208
mm->fftch, mm->iftch, mm->nftch,
1209
FLAG_IS_SET(mm->flag, MMF_LDCTOP),
1212
WriteMMFile(outfile, mmp);
1214
printf("\nSEARCH OUTPUT READ IN AS:\n");
1215
PrintMMNodes(stdout, mmp);
1226
nregs = FindNumRegs(pre, verb, nb, ku, &MACC, &lat);
1228
ConfirmMACC(pre, verb, nregs, nb, ku, &MACC, &lat);
1229
mmp = FindBestGenGemm(pre, verb, nregs, MACC, lat, FNB, nb, ku);
1230
WriteMMFile(outfile, mmp);
1231
printf("\nSELECTED GENERATED KERNEL:\n");
1232
PrintMMNodes(stdout, mmp);