2
* Automatically Tuned Linear Algebra Software v3.2
3
* (C) Copyright 1997 R. Clint Whaley
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions
8
* 1. Redistributions of source code must retain the above copyright
9
* notice, this list of conditions and the following disclaimer.
10
* 2. Redistributions in binary form must reproduce the above copyright
11
* notice, this list of conditions, and the following disclaimer in the
12
* documentation and/or other materials provided with the distribution.
13
* 3. The name of the University of Tennessee, the ATLAS group,
14
* or the names of its contributers may not be used to endorse
15
* or promote products derived from this software without specific
18
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
22
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
* POSSIBILITY OF SUCH DAMAGE.
36
#include "atlas_misc.h"
37
#include "atlas_fopen.h"
39
#define Mmin(x, y) ( (x) > (y) ? (y) : (x) )
43
#define L1FNAME "L1CacheSize"
49
void PrintUsage(char *xnam)
51
fprintf(stderr, "\n\nUsage: %s [-r #][-h][-f][-l #][-p s/d/c/z][-m #]\n",
53
fprintf(stderr, "-h : Print this help screen\n");
54
fprintf(stderr, "-f : Force complete search over given parameters\n");
55
fprintf(stderr, "-p s/d/c/z : set the precision to search for\n");
56
fprintf(stderr, "-r # : Set max number of registers to use to # (default 32)\n");
57
fprintf(stderr, "-m # : Set max L1 cache size (kilobytes) to #\n");
58
fprintf(stderr, "-L <c/f> : Select what language to use (C or Fortran77)\n");
59
fprintf(stderr, "-K # : Set K-loop unrolling to # (-1 = K).\n");
60
fprintf(stderr, "-l # : Use latency factor #. If set to 0,\n");
62
" do not do latency checking. By default, latency checking is\n");
64
" done only if initial timings show it is a win.\n");
68
void GetSettings(int nargs, char *args[], char *pre, char *lang, int *ku,
69
int *LAT, int *FRC, int *nreg, int *MaxL1Size, int *ROUT)
81
for (i=1; i < nargs; i++)
83
if (*args[i] != '-') PrintUsage(args[0]);
87
*ku = atoi(args[++i]);
91
if ( (*args[i] == 'F') || (*args[i] == 'f') ) *lang = 'F';
94
*MaxL1Size = atoi(args[++i]);
97
*nreg = atoi(args[++i]);
100
*FRC = atoi(args[++i]);
103
*LAT = atoi(args[++i]);
110
*ROUT = atoi(args[++i]);
118
int L1Elts(char pre, int MaxL1Size)
124
if (!FileExists("res/L1CacheSize"))
126
sprintf(ln, "make RunL1 MaxL1=%d\n",MaxL1Size);
129
remove("res/L1CacheSize");
130
fprintf(stderr, "Error in command: %s", ln);
134
L1f = fopen("res/L1CacheSize", "r");
136
fscanf(L1f, "%d", &L1Size);
141
tsize = sizeof(float);
144
tsize = sizeof(double);
147
tsize = sizeof(long double);
150
tsize = sizeof(float);
153
tsize = sizeof(double);
156
return( (L1Size*1024) / tsize);
159
int GetCacheSize(int MaxL1Size)
161
* Returns L1 size in kilobytes
168
if (!FileExists("res/L1CacheSize"))
170
sprintf(ln, "make RunL1 MaxL1=%d\n",MaxL1Size);
173
remove("res/L1CacheSize");
174
fprintf(stderr, "Error in command: %s", ln);
178
L1f = fopen("res/L1CacheSize", "r");
180
fscanf(L1f, "%d", &L1Size);
182
fprintf(stderr, "\n Read in L1 Cache size as = %dKB.\n",L1Size);
186
int GetTypeSize(char pre)
189
if (pre == 'c' || pre == 's') tsize = ATL_ssize;
190
else tsize = ATL_dsize;
193
void findNBs(char prec, char *NBnam, int MaxL1Size)
197
int i, L1Size, tmp, tsize, tL1Size, CL, nNB;
200
fprintf(stderr, "NB setting not supplied; calculating:\n");
202
L1Size = GetCacheSize(MaxL1Size);
203
tsize = GetTypeSize(prec);
205
tL1Size = L1Size * (1024 / tsize);
206
tmp = CL = ATL_Cachelen / tsize;
209
fprintf(stderr, "tmp=%d, tL1size=%d\n",tmp, tL1Size);
210
while (tmp*tmp <= tL1Size)
212
if (tmp >= 16) /* no block sizes smaller than 16 */
214
if (tmp >= 80) break; /* no block sizes bigger than 80 */
217
if (!nNB) /* this should never happen */
224
else if (nNB > 2) /* put second biggest blocking factor first in list */
231
NBf = fopen(NBnam, "w");
232
fprintf(NBf, "%d\n", nNB);
233
for (i=0; i != nNB; i++) fprintf(NBf, "%d\n", NB[i]);
237
int GetSafeNB(char pre, int MaxL1)
239
int i, L1, tsize, inc;
241
tsize = GetTypeSize(pre);
242
inc = ATL_MinMMAlign / tsize;
243
if (inc < 4) inc = 4;
244
L1 = (GetCacheSize(MaxL1) * 1024) / tsize;
245
for (i=inc; i*i < L1; i += inc);
246
if (i*i > L1) i -= inc;
247
if (pre == 'd' || pre == 's')
249
if (i*i == L1) i -= inc;
253
if (i*i == L1) i -= 2*inc;
261
double GetAvg(int n, double tolerance, double *mflop)
266
* Sort results, largest first
268
for (i=0; i != n; i++)
270
for (j=i+1; j < n; j++)
272
if (mflop[i] < mflop[j])
282
* Throw out result if it is outside tolerance; rerun if two mflop not within
283
* tolerance; this code assumes n == 3
285
if (tolerance*mflop[1] < mflop[0]) /* too big a range in results */
287
if (tolerance*mflop[2] < mflop[1]) return(-1.0);
288
tavg = (mflop[1] + mflop[2]) / 2.0;
290
else if (tolerance*mflop[2] < mflop[0]) tavg = (mflop[0] + mflop[1]) / 2.0;
291
else tavg = (mflop[0] + mflop[1] + mflop[2]) / 3.0;
296
double mms_case(char pre, int MULADD, int NB, int mu, int nu, int ku, int lat)
298
char fnam[128], ln[256];
300
double mflop[NTIM], t0;
303
if (ku > NB) ku = NB;
304
else if (ku == -1) ku = NB;
305
sprintf(fnam, "res/%c%smm%c%c%d_%dx%dx%d_%dx%dx%d_%dx%dx%d%s%s_%dx%d_%d",
306
pre, "JIK", 'T', 'N', NB, NB, NB, NB, NB, NB, 0, mu, nu, ku,
307
"_a1", "_b1", MULADD, lat, 1);
309
sprintf(fnam, "res/%c%cNB%d_%dx%dx%d_%d-%d.mflop", LANG, pre, NB, mu, nu,
312
if (!FileExists(fnam))
314
if (pre == 'c' || pre == 'z')
316
" make mmcase pre=%c loopO=%s ta=%c tb=%c mb=%d nb=%d kb=%d lda=%d ldb=%d ldc=%d mu=%d nu=%d ku=%d alpha=%d beta=%d muladd=%d lat=%d csA=1 csB=1 csC=2 cleanup=%d\n",
317
pre, "JIK", 'T', 'N', NB, NB, NB, NB, NB, 0, mu, nu, ku,
318
1, 1, MULADD, lat, 1);
320
" make mmcase pre=%c loopO=%s ta=%c tb=%c mb=%d nb=%d kb=%d lda=%d ldb=%d ldc=%d mu=%d nu=%d ku=%d alpha=%d beta=%d muladd=%d lat=%d cleanup=%d\n",
321
pre, "JIK", 'T', 'N', NB, NB, NB, NB, NB, 0, mu, nu, ku,
322
1, 1, MULADD, lat, 1);
323
fprintf(stderr, "%s:\n",ln);
326
fprintf(stderr, "Error in command: %s", ln);
327
sprintf(ln, "rm -f %s\n", fnam);
332
assert( (fp = fopen(fnam, "r")) != NULL );
333
for (i=0; i != NTIM; i++)
335
assert( fscanf(fp, "%lf", &mflop[i]) == 1 );
339
t0 = GetAvg(NTIM, TOLERANCE, mflop);
342
fprintf(stderr, "NB=%d, MU=%d, NU=%d, KU=%d: rerun with higher reps; variation exceeds tolerence\n", NB, mu, nu, ku);
343
sprintf(ln, "rm -f res/%s\n", fnam);
348
"\npre=%c, muladd=%d, lat=%d, nb=%d, mu=%d, nu=%d, ku=%d, mflop=%.2f\n",
349
pre, MULADD, lat, NB, mu, nu, ku, t0);
353
double mmcase0(char *nam, char pre, char *loopO, char ta, char tb,
354
int M, int N, int K, int mb, int nb, int kb,
355
int lda, int ldb, int ldc, int mu, int nu, int ku,
356
int muladd, int lat, int beta, int csA, int csB, int csC,
357
int FFetch, int ifetch, int nfetch, char *mmnam)
359
char fnam[128], ln[512], bnam[16], casnam[128], mmcase[128];
360
int i, N0, lda2=lda, ldb2=ldb, ldc2=ldc;
361
double mflop[NTIM], t0;
364
if (lda < 0) { lda2 = -lda; lda = 0; }
365
if (ldb < 0) { ldb2 = -ldb; ldb = 0; }
366
if (ldc < 0) { ldc2 = -ldc; ldc = 0; }
367
if (mmnam) sprintf(mmcase, "mmucase mmrout=%s", mmnam);
368
else sprintf(mmcase, "mmcase");
369
if (ifetch == -1 || nfetch == -1) { ifetch = mu+nu; nfetch = 1; }
370
if (beta == 1) sprintf(bnam, "_b1");
371
else if (beta == -1) sprintf(bnam, "_bn1");
372
else if (beta == 0) sprintf(bnam, "_b0");
373
else sprintf(bnam, "_bX");
377
else if (ku == -1) ku = K;
381
sprintf(casnam, "casnam=%s", nam);
385
sprintf(fnam, "res/%c%smm%c%c%d_%dx%dx%d_%dx%dx%d_%dx%dx%d%s%s_%dx%d_%d",
386
pre, loopO, ta, tb, N0, mb, nb, kb, lda, ldb, ldc, mu, nu, ku,
387
"_a1", bnam, muladd, lat, 1);
390
if (!FileExists(fnam))
392
if (pre == 'c' || pre == 'z')
394
" make %s pre=%c loopO=%s ta=%c tb=%c M=%d N=%d K=%d mb=%d nb=%d kb=%d lda=%d ldb=%d ldc=%d lda2=%d ldb2=%d ldc2=%d mu=%d nu=%d ku=%d alpha=%d beta=%d muladd=%d lat=%d cleanup=%d csA=%d csB=%d csC=%d ff=%d if=%d nf=%d %s\n",
395
mmcase,pre, loopO, ta, tb, M, N, K, mb, nb, kb, lda, ldb, ldc,
396
lda2, ldb2, ldc2, mu, nu, ku, 1, beta, muladd, lat, 1,
397
csA, csB, csC, FFetch, ifetch, nfetch, casnam);
399
" make %s pre=%c loopO=%s ta=%c tb=%c M=%d N=%d K=%d mb=%d nb=%d kb=%d lda=%d ldb=%d ldc=%d lda2=%d ldb2=%d ldc2=%d mu=%d nu=%d ku=%d alpha=%d beta=%d muladd=%d lat=%d cleanup=%d ff=%d if=%d nf=%d %s\n",
400
mmcase, pre, loopO, ta, tb, M, N, K, mb, nb, kb, lda, ldb,
401
ldc, lda2, ldb2, ldc2, mu, nu, ku, 1, beta, muladd, lat, 1,
402
FFetch, ifetch, nfetch, casnam);
403
fprintf(stderr, "%s:\n",ln);
407
* User cases, and large leading dimensions can fail to run
409
if (mmnam) return(-1.0); /* user cases can fail to compile */
410
if (lda2 != lda || ldb2 != ldb || ldc2 != ldc) return(-1);
411
fprintf(stderr, "Error in command: %s", ln);
412
sprintf(ln, "rm -f %s\n", fnam);
417
fp = fopen(fnam, "r");
418
if (!fp) fprintf(stderr, "ERROR: can't find file=%s\n", fnam);
420
for (i=0; i != NTIM; i++)
422
assert(fscanf(fp, "%lf", &mflop[i]) == 1);
426
t0 = GetAvg(NTIM, TOLERANCE, mflop);
430
"case=%s: rerun with higher reps; variation exceeds tolerence\n", fnam);
431
sprintf(ln, "rm -f %s\n", fnam);
436
"\n pre=%c, loopO=%s, ta=%c tb=%c, mb=%d, nb=%d, kb=%d, lda=%d, ldb=%d, ldc=%d\n",
437
pre, loopO, ta, tb, mb, nb, kb, lda, ldb, ldc);
438
fprintf(stdout, " mu=%d, nu=%d, ku=%d, muladd=%d, lat=%d ====> mflop=%f\n",
439
mu, nu, ku, muladd, lat, t0);
443
double mmucase(int ifile, char pre, int nb, int muladd, int lat,
444
int mu, int nu, int ku, char *fnam)
449
sprintf(fout, "res/%cuser%d", pre, ifile);
450
if (mu == 1 && nu == 1) iff = 1;
452
return(mmcase0(fout, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
453
nb, nb, 0, mu, nu, ku, muladd, lat, 1, 1, 1, 2, 0, iff, 1,
457
enum CW {CleanM=0, CleanN=1, CleanK=2, CleanNot=3};
458
double mmclean(char pre, enum CW which, char *loopO, char ta, char tb,
459
int M, int N, int K, int mb, int nb, int kb,
460
int lda, int ldb, int ldc, int mu, int nu, int ku,
461
int muladd, int lat, int beta, int csA, int csB, int csC,
462
int FFetch, int ifetch, int nfetch)
465
char cwh[3] = {'M', 'N', 'K'};
466
sprintf(nam, "res/%cClean%c_%dx%dx%d", pre, cwh[which], M, N, K);
467
return(mmcase0(nam, pre, loopO, ta, tb, M, N, K, mb, nb, kb, lda, ldb, ldc,
468
mu, nu, ku, muladd, lat, beta, csA, csB, csC,
469
FFetch, ifetch, nfetch, NULL));
472
double mmcase(char *nam, char pre, char *loopO, char ta, char tb,
473
int M, int N, int K, int mb, int nb, int kb,
474
int lda, int ldb, int ldc, int mu, int nu, int ku,
475
int muladd, int lat, int beta, int csA, int csB, int csC,
476
int FFetch, int ifetch, int nfetch)
478
return(mmcase0(nam, pre, loopO, ta, tb, M, N, K, mb, nb, kb, lda, ldb, ldc,
479
mu, nu, ku, muladd, lat, beta, csA, csB, csC,
480
FFetch, ifetch, nfetch, NULL));
483
int GetGoodLat(int MULADD, int kb, int mu, int nu, int ku, int lat)
485
int slat, blat, i, ii = mu*nu*ku;
486
if (MULADD) return(lat);
487
if ( (lat > 1) && (kb > ku) && ((ii/lat)*lat != ii) ) /* lat won't work */
489
for (i=lat; i; i--) if ( (ii/i) * i == ii ) break;
491
for (i=lat; i < MAXLAT; i++) if ( (ii/i) * i == ii ) break;
493
if ( (ii/blat)*blat != ii ) blat = slat;
494
if (slat < 2) lat = blat;
495
else if (lat-slat < blat-lat) lat = slat;
501
void FindMUNU(int muladd, int lat, int nr, int *MU, int *NU)
503
* Find near-square muxnu using nr registers or less
516
if (j < 3) mu = nu = 1;
520
for (nu=1; nu*nu < mu; nu++);
521
if (nu*nu > mu) nu -= 2;
523
if (nu < 1) mu = nu = 1;
526
mu = (nr-nu) / (1+nu);
540
void PutInstLogLine(FILE *fp, int muladd, int lat, int nb,
541
int mu, int nu, int ku, int ForceFetch,
542
int ifetch, int nfetch, double mflop)
544
fprintf(fp, "%6d %3d %3d %3d %3d %3d %5d %5d %5d %7.2lf\n",
545
muladd, lat, nb, mu, nu, ku, ForceFetch, ifetch, nfetch, mflop);
547
void PutInstLogFile(FILE *fp, int muladd, int lat, int nb,
548
int mu, int nu, int ku, int ForceFetch,
549
int ifetch, int nfetch, double mflop)
551
fprintf(fp, "MULADD LAT NB MU NU KU FFTCH IFTCH NFTCH MFLOP\n");
552
PutInstLogLine(fp, muladd, lat, nb, mu, nu, ku, ForceFetch,
553
ifetch, nfetch, mflop);
555
void PutInstLogFile1(char *fnam, char pre, int muladd, int lat, int nb,
556
int mu, int nu, int ku, int ForceFetch,
557
int ifetch, int nfetch, double mflop)
561
fp = fopen(fnam, "w");
563
PutInstLogFile(fp, muladd, lat, nb, mu, nu, ku, ForceFetch, ifetch, nfetch,
568
void GetInstLogLine(FILE *fp, int *muladd, int *lat, int *nb,
569
int *mu, int *nu, int *ku, int *ForceFetch,
570
int *ifetch, int *nfetch, double *mflop)
572
assert(fscanf(fp, " %d %d %d %d %d %d %d %d %d %lf\n",
573
muladd, lat, nb, mu, nu, ku, ForceFetch,
574
ifetch, nfetch, mflop) == 10);
577
void GetInstLogFile(char *nam, char pre, int *muladd, int *lat, int *nb,
578
int *mu, int *nu, int *ku, int *ForceFetch,
579
int *ifetch, int *nfetch, double *mflop)
584
fp = fopen(nam, "r");
585
if (fp == NULL) fprintf(stderr, "file %s not found!!\n\n", nam);
588
GetInstLogLine(fp, muladd, lat, nb, mu, nu, ku, ForceFetch,
589
ifetch, nfetch, mflop);
595
void CreateFinalSumm(char pre, int muladd, int lat, int nb, int mu, int nu,
596
int ku, int Ff, int If, int Nf, double gmf)
598
char ln[64], auth[65];
603
sprintf(ln, "res/%cMMRES", pre);
605
PutInstLogFile(fp, muladd, lat, nb, mu, nu, ku, Ff, If, Nf, gmf);
606
sprintf(ln, "res/%cuMMRES", pre);
607
fp0 = fopen(ln, "r");
609
assert(fgets(ln, 64, fp0));
610
assert(fscanf(fp0, " %d %d %lf \"%[^\"]\" \"%[^\"]", &icase, &unb, &umf,
613
fprintf(fp, "\nICASE NB MFLOP ROUT AUTHOR\n");
614
fprintf(fp, "%5d %3d %8.2f \"%.63s\" \"%.63s\"\n", icase, unb, umf,
620
void FindFetch(char ta, char tb, char pre, int mb, int nb, int kb,
621
int mu, int nu, int ku, int muladd, int lat,
622
int *FFetch0, int *ifetch0, int *nfetch0)
624
* See what fetch patterns are appropriate
628
const int nelts = mu+nu;
629
int csA=1, csB=1, csC=1, nleft, i, j;
630
int ifetch = mu+nu, nfetch = 1;
633
if (pre == 'c' || pre == 'z') csC = 2;
635
mf0 = mmcase(NULL, pre, "JIK", ta, tb, mb, nb, kb, mb, nb, kb,
636
kb, kb, 0, mu, nu, ku, muladd, lat, 0, csA, csB, csC,
639
for (i=2; i < nelts; i++)
642
for (j=1; j <= nleft; j++)
644
sprintf(fnam, "res/%cMMfetch%d_%d", pre, i, j);
645
mf = mmcase(fnam, pre, "JIK", ta, tb, mb, nb, kb, mb, nb, kb,
646
kb, kb, 0, mu, nu, ku, muladd, lat, 0, csA, csB, csC,
657
* See if prefetching good idea for beta=0 case
659
sprintf(fnam, "res/%cMM_b0", pre);
660
mf0 = mmcase(fnam, pre, "JIK", ta, tb, mb, nb, kb, mb, nb, kb,
661
kb, kb, 0, mu, nu, ku, muladd, lat, 0, csA, csB, csC,
664
sprintf(fnam, "res/%cMM_b0_pref", pre);
665
mf = mmcase(fnam, pre, "JIK", ta, tb, mb, nb, kb, mb, nb, kb,
666
kb, kb, 0, mu, nu, ku, muladd, lat, 0, csA, csB, csC,
669
*FFetch0 = (mf > mf0);
672
fprintf(stdout, "\n\nFORCEFETCH=%d, IFETCH = %d, NFETCH = %d\n\n",
673
*FFetch0, *ifetch0, *nfetch0);
677
void searchmu_nu(char pre, int nb, int maxreg, int Fku, int muladd, int LAT,
678
int NO1D, double *mfB, int *nbB, int *muB, int *nuB, int *kuB,
681
int i, j, lat, ku, nr2, nreg=Mmin(nb, maxreg);
684
for (i=1; i <= nreg; i++)
687
if (nr2 > nb) nr2 = nb;
688
for (j=1; j <= nreg; j++)
690
if ( (((i==1) && (j > 4)) || ((j==1) && (i > 4))) && NO1D) continue;
691
if (Fku == -1 || (!Fku) ) ku = nb;
692
else if (Fku) ku = Fku;
693
if (ku != nb) lat = GetGoodLat(muladd, nb, i, j, ku, LAT);
695
if (j*i+j+i+(!muladd)*lat > maxreg) continue; /* not enough regs */
696
mf = mms_case(pre, muladd, nb, i, j, ku, lat);
708
lat = GetGoodLat(muladd, nb, i, j, 1, LAT);
709
mf = mms_case(pre, muladd, nb, i, j, 1, lat);
724
void FindKU(char pre, int muladd, int LAT, int nb, int mu, int nu,
725
double *mfB, int *kuB, int *latB)
727
* For best case, try various ku's
733
fprintf(stderr, "Confirming K-loop unrollings for chosen NB:\n");
734
mf = mms_case(pre, muladd, nb, mu, nu, nb, LAT);
741
for (k=1; k < nb; k += 4)
744
if (k > nb/2) k = nb;
745
lat = GetGoodLat(muladd, nb, mu, nu, k, *latB);
746
mf = mms_case(pre, muladd, nb, mu, nu, k, lat);
756
void FindLAT(char pre, int maxlat, int nb, int muladd, int mu, int nu, int ku,
757
double *mfB, int *latB)
763
fprintf(stderr, "\nConfirming latency factors for chosen parameters:\n");
764
for (i=1; i <= maxlat; i++)
766
lat = GetGoodLat(muladd, nb, mu, nu, ku, i);
769
mf = mms_case(pre, muladd, nb, mu, nu, ku, lat);
777
fprintf(stderr, "\n\n Best latency factor=%d\n\n", *latB);
780
int CheckUser(char pre, double adv, double gmf, int gnb, double *umf)
782
* Checks if user case is better than generated, and if so, return umb
789
sprintf(fnam, "res/%cuMMRES", pre);
790
if (!FileExists(fnam)) /* need to run user search */
792
sprintf(fnam, "make RunUMMSearch pre=%c nb=%d\n", pre, gnb);
793
assert(system(fnam) == 0);
794
sprintf(fnam, "res/%cuMMRES", pre);
796
fp = fopen(fnam, "r");
798
assert(fgets(fnam, 128, fp));
799
assert(fgets(fnam, 128, fp));
801
sscanf(fnam, " %d %d %lf", &i, &unb, &umflop);
802
if (i >= 0 && umflop < 0.0) /* need to retime */
804
sprintf(fnam, "make RunUMMSearch pre=%c nb=0\n", pre);
805
assert(system(fnam) == 0);
806
sprintf(fnam, "res/%cuMMRES", pre);
807
fp = fopen(fnam, "r");
809
assert(fgets(fnam, 128, fp));
810
assert(fgets(fnam, 128, fp));
812
sscanf(fnam, " %d %d %lf", &i, &unb, &umflop);
814
fprintf(stdout, "\nBEST USER CASE: NB=%d, MFLOP=%.2f\n", unb, umflop);
815
if (umf) *umf = umflop;
816
if (adv*gmf > umflop) return(gnb);
820
int GetNO1D(char pre, int nreg, int nb, int MULADD, int LAT)
825
if (pre == 'z') pre = 'd';
826
else if (pre == 'c') pre = 's';
828
lat = GetGoodLat(MULADD, nb, 3, 3, 1, LAT);
829
if (nreg >= 15+(!MULADD)*Mmax(LAT,lat))
831
mf0 = mms_case(pre, MULADD, nb, 3, 3, 1, lat);
832
mf1 = mms_case(pre, MULADD, nb, 3, 3, nb, LAT);
834
mf0 = mms_case(pre, MULADD, nb, 9, 1, 1, lat);
835
if (mf0 > mf) NO1D = 0;
836
else if (mms_case(pre, MULADD, nb, 9, 1, nb, LAT) > mf) NO1D = 0;
837
else if (mms_case(pre, MULADD, nb, 1, 9, nb, LAT) > mf) NO1D = 0;
838
else if (mms_case(pre, MULADD, nb, 1, 9, 1, lat) > mf) NO1D = 0;
844
void gmmsearch(char pre, int MULADD, int Fku, int nNBs, int *NBs, int nreg,
847
* Does real generated mmsearch
850
int latB, muB, nuB, kuB, nbB;
851
int i, j, k, NB, TEST_MU, TEST_NU, ku, nb, lat=LAT, nNB=nNBs, NO1D=0;
852
int FFetch, ifetch, nfetch, muladd;
857
sprintf(ln, "res/%cgMMRES", pre);
858
if (FileExists(ln)) /* already have needed result */
860
GetInstLogFile(ln, pre, &muladd, &lat, &nb, &muB, &nuB, &kuB, &FFetch,
861
&ifetch, &nfetch, &mf);
864
mf = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
865
nb, nb, 0, muB, nuB, kuB, muladd, lat, 1, 1, 1, 2,
866
FFetch, ifetch, nfetch);
867
PutInstLogFile1(ln, pre, muladd, lat, nb, muB, nuB, kuB, FFetch,
873
* Try not to tempt fate by using all registers
875
if (nreg > 16) i = nreg-2;
877
FindMUNU(MULADD, lat, i, &TEST_MU, &TEST_NU);
879
* First, find a good NB
882
fprintf(stderr, "Doing initial NB search:\n");
884
for (k=0; k != nNBs; k++)
888
mf = mms_case(pre, MULADD, NB, TEST_MU, TEST_NU, ku, lat);
899
if (Fku == 0) /* try no K-loop unrolling */
901
lat = GetGoodLat(MULADD, NB, TEST_MU, TEST_NU, 1, LAT);
902
mf = mms_case(pre, MULADD, NB, TEST_MU, TEST_NU, 1, lat);
920
fprintf(stderr, "NB=%d selected:\n", NBs[0]);
925
fprintf(stderr, "\nCombined multiply add, latency factor=%d, NB=%d ku=%d, chosen; initial MFLOP=%f. Beginning unroll search:\n", latB, NBs[0], kuB, mfB);
927
fprintf(stderr, "\nSeparate multiply and add, latency factor=%d, NB=%d ku=%d, chosen; initial MFLOP=%f. Beginning unroll search:\n", latB, NBs[0], kuB, mfB);
929
NO1D = GetNO1D(pre, nreg, NBs[0], MULADD, LAT);
930
if (NO1D) fprintf(stderr, "\n\nSkipping most 1D cases\n\n");
931
else fprintf(stderr, "\n\nTiming 1D cases\n\n");
932
for (k=0; k != nNB; k++)
935
searchmu_nu(pre, NB, nreg, Fku, MULADD, LAT, NO1D,
936
&mfB, &nbB, &muB, &nuB, &kuB, &latB);
938
fprintf(stderr, "\n\nBest case so far: nb=%d, mu=%d, nu=%d, ku=%d, lat=%d; MFLOPS=%f.\n",
939
nbB, muB, nuB, kuB, latB, mfB);
940
fprintf(stderr, "Trying various other NB and KU settings:\n\n");
942
* If we haven't checked all permutations, try other blocking factors
947
if (nNBs > 1) fprintf(stderr, "Trying various blocking factors:\n");
948
mf = mms_case(pre, MULADD, NBs[0], muB, nuB, kuB, latB);
949
for (k=0; k < nNBs; k++)
952
if (Fku == -1) ku = NB;
953
else if (Fku) ku = Fku;
954
else if (kuB == nbB) ku = NB;
956
if (ku != NB) lat = GetGoodLat(MULADD, NB, muB, nuB, ku, latB);
958
mf = mms_case(pre, MULADD, NB, muB, nuB, ku, lat);
968
if (nb != nbB) fprintf(stderr, "\nNew block factor of %d chosen!!\n\n", nbB);
972
* Try all ku's, and then valid latencies
974
FindKU(pre, MULADD, LAT, nbB, muB, nuB, &mfB, &kuB, &latB);
975
FindLAT(pre, MAXLAT, nbB, MULADD, muB, nuB, kuB, &mfB, &latB);
978
* Make sure MULADD is correct
980
lat = GetGoodLat(!MULADD, nbB, muB, nuB, kuB, latB);
981
mf = mms_case(pre, !MULADD, nbB, muB, nuB, kuB, lat);
984
fprintf(stderr, "\n\nMULADD MAY BE WRONG!!, old=%f, new=%f\n", mfB, mf);
987
* Try various fetch patterns
989
FindFetch('T', 'N', pre, nbB, nbB, nbB, muB, nuB, kuB, MULADD, latB,
990
&FFetch, &ifetch, &nfetch);
992
"BEST GENERATED CASE: nb=%d, ma=%d, lat=%d mu=%d, nu=%d, ku=%d -- %.2f\n",
993
nbB, MULADD, latB, muB, nuB, kuB, mfB);
994
sprintf(ln, "res/%cgMMRES", pre);
995
PutInstLogFile1(ln, pre, MULADD, latB, nbB, muB, nuB, kuB,
996
FFetch, ifetch, nfetch, mfB);
999
void mmsearch(char pre, int MULADD, int Fku, int nNBs, int *NBs, int nreg,
1002
int latB, muB, nuB, kuB, nbB;
1003
int muladd, nb, ifetch, nfetch, FFetch;
1006
int umb, unb, ukb, ma;
1011
sprintf(fnam, "res/%cMMRES", pre);
1012
if (FileExists(fnam)) /* already have result */
1014
GetInstLogFile(fnam, pre, &muladd, &latB, &nb, &muB, &nuB, &kuB, &FFetch,
1015
&ifetch, &nfetch, &mfB);
1018
mfB = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
1019
nb, nb, 0, muB, nuB, kuB, muladd, latB, 1, 1, 1, 2,
1020
FFetch, ifetch, nfetch);
1021
gmmsearch(pre, muladd, Fku, nNBs, NBs, nreg, latB, Fnb);
1022
nb = CheckUser(pre, 1.02, mfB, nb, NULL);
1023
CreateFinalSumm(pre, muladd, latB, nb, muB, nuB, kuB, FFetch, ifetch,
1026
sprintf(fnam, "res/%cNB", pre);
1027
fp = fopen(fnam, "w");
1028
fprintf(fp, "%d\n%d\n", 1, nbB);
1032
gmmsearch(pre, MULADD, Fku, nNBs, NBs, nreg, LAT, Fnb);
1033
sprintf(fnam, "res/%cgMMRES", pre);
1034
GetInstLogFile(fnam, pre, &muladd, &latB, &nbB, &muB, &nuB, &kuB, &FFetch,
1035
&ifetch, &nfetch, &mfB);
1037
nb = CheckUser(pre, 1.02, mfB, nbB, NULL);
1040
if (kuB == nbB) kuB = nb;
1042
if (nb % muB || nb % nuB)
1044
NO1D = GetNO1D(pre, nreg, nb, MULADD, LAT);
1045
searchmu_nu(pre, nb, nreg, Fku, MULADD, LAT, NO1D,
1046
&mfB, &nb, &muB, &nuB, &kuB, &latB);
1048
FindKU(pre, MULADD, LAT, nbB, muB, nuB, &mfB, &kuB, &latB);
1049
FindLAT(pre, MAXLAT, nbB, MULADD, muB, nuB, kuB, &mfB, &latB);
1050
FindFetch('T', 'N', pre, nbB, nbB, nbB, muB, nuB, kuB, MULADD, latB,
1051
&FFetch, &ifetch, &nfetch);
1054
* Save NB we've found
1056
sprintf(fnam, "res/%cNB", pre);
1057
fp = fopen(fnam, "w");
1058
fprintf(fp, "%d\n%d\n", 1, nbB);
1061
* Save best case parameters we have found
1063
CreateFinalSumm(pre, MULADD, latB, nbB, muB, nuB, kuB, FFetch, ifetch,
1067
void FindNC_0(char ta, char tb, char pre, int N, int mb, int nb, int kb,
1068
int mu, int nu, int ku, int muladd, int lat,
1069
int FFetch, int ifetch, int nfetch)
1071
int kuB=ku, latB=lat, lat0=lat, kb0=kb;
1072
int i, j, k, csA=1, csB=1, csC=1, kmax;
1077
sprintf(fnam, "res/%cbest%c%c_%dx%dx%d", pre, ta, tb, mb, nb, kb);
1078
if (FileExists(fnam)) /* default already exists */
1080
GetInstLogFile(fnam, pre, &muladd, &lat, &nb, &mu, &nu, &ku,
1081
&FFetch, &ifetch, &nfetch, &mf);
1082
if (mf < 0.0) /* need to retime */
1084
mf = mmcase(NULL, pre, "JIK", ta, tb, nb, nb, nb,
1085
nb, nb, nb, 0, 0, 0, mu, nu, ku, muladd, lat, 1,
1086
1, 1, csC, FFetch, ifetch, nfetch);
1087
PutInstLogFile1(fnam, pre, muladd, lat, nb, mu, nu, ku,
1088
FFetch, ifetch, nfetch, mf);
1092
if (pre == 'c' || pre == 'z') csA = csB = csC = 2;
1097
if ((mb*nb)/lat != lat) lat0 = GetGoodLat(muladd, kb0, mu, nu, 1, lat);
1100
for (kmax=4; kmax*kmax < k; kmax += 4);
1101
if (pre == 'd' || pre == 's') kmax *= 2;
1102
if (kmax >= N) kmax = N;
1103
else if (kmax > N/2) kmax = N/2;
1104
if (kb == 0) kuB = k = Mmin(ku,kmax);
1107
* Find best non-cleanup case
1109
mf0 = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1110
mu, nu, k, muladd, lat0, 1, csA, csB, csC,
1111
FFetch, ifetch, nfetch);
1114
* If kb is not known, try all available K unrollings; for large mu*nu*N
1115
* combinations, don't try maximal unrollings in order to avoid having
1116
* the compiler run out of space trying to optimize
1120
for (k=1; k < kmax; k += 4)
1123
if (k > N/2) k = kmax;
1126
i = GetGoodLat(muladd, kb0, mu, nu, j, lat);
1127
mf = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1128
mu, nu, k, muladd, i, 1, csA, csB, csC,
1129
FFetch, ifetch, nfetch);
1139
* If K is known, try only the most common unrollings
1143
i = GetGoodLat(muladd, kb0, mu, nu, 1, lat);
1144
mf = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1145
mu, nu, 1, muladd, i, 1, csA, csB, csC,
1146
FFetch, ifetch, nfetch);
1153
i = GetGoodLat(muladd, kb0, mu, nu, 4, lat);
1154
mf = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1155
mu, nu, 4, muladd, i, 1, csA, csB, csC,
1156
FFetch, ifetch, nfetch);
1163
mf = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1164
mu, nu, kb, muladd, lat, 1, csA, csB, csC,
1165
FFetch, ifetch, nfetch);
1174
* Try various latencies
1178
for (k=2; k < 9; k++)
1180
if (((mu*nu*i)/k)*k == mu*nu*i)
1182
mf = mmcase(NULL, pre, "JIK", ta, tb, N, N, N, mb, nb, kb, 0, 0, 0,
1183
mu, nu, kuB, muladd, k, 1, csA, csB, csC,
1184
FFetch, ifetch, nfetch);
1192
fprintf(stdout, "BEST for %c%c_%dx%dx%d: mflop=%.2f\n",
1193
ta, tb, mb, nb, kb, mf0);
1195
"pre=%c ta=%c tb=%c nb=%d mu=%d nu=%d ku=%d muladd=%d lat=%d\n",
1196
pre, ta, tb, nb, mu, nu, kuB, muladd, latB);
1197
sprintf(fnam, "res/%cbest%c%c_%dx%dx%d", pre, ta, tb, mb, nb, kb);
1198
fp = fopen(fnam, "w");
1200
PutInstLogFile(fp,muladd, latB, N, mu, nu, kuB, FFetch, ifetch, nfetch, mf0);
1204
void FindNC0(char ta, char tb, char pre, int nb, int mu, int nu, int ku,
1205
int muladd, int lat, int FFetch, int ifetch, int nfetch)
1207
FindNC_0(ta, tb, pre, nb, nb, nb, nb, mu, nu, ku, muladd, lat, FFetch,
1209
FindNC_0(ta, tb, pre, nb, 0, 0, nb, mu, nu, ku, muladd, lat, FFetch,
1211
FindNC_0(ta, tb, pre, nb, 0, 0, 0, mu, nu, ku, muladd, lat, FFetch,
1215
double NCcase(char pre, int nb, int mu, int nu, int ku, int ma, int lat,
1216
int ffetch, int ifetch, int nfetch)
1219
int ld=Mmax(1000,nb), cs=1;
1222
if (pre == 'c' || pre == 'z') cs = 2;
1225
sprintf(fnam, "res/%cNCNB%d_%d", pre, nb, ld);
1226
mf = mmcase(fnam, pre, "JIK", 'N', 'N', nb, nb, nb, nb, nb, nb,
1227
-ld, nb, nb, mu, nu, ku, ma, lat, 1, cs, cs, cs,
1228
ffetch, ifetch, nfetch);
1231
while (mf <= 0.0 && ld >= nb);
1236
int FindNoCopyNB(char pre, int nb, int mu, int nu, int ku0, int muladd, int lat,
1237
int FFetch, int ifetch, int nfetch)
1239
* See if a smaller blocking factor is needed for no-copy
1243
int i, ku, nbB=nb, csA=2, csB=2, csC=2;
1244
double mf, mfB, mf0;
1245
const double dmul = 1.02;
1247
sprintf(fnam, "res/%cNCNB", pre);
1248
if (!FileExists(fnam))
1250
mfB = NCcase(pre, nb, mu, nu, ku0, muladd, lat, FFetch, ifetch, nfetch);
1253
for (i=nb-4; i >= 16; i -= 4)
1256
mf = NCcase(pre, i, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1257
if (1.2*mf < mfB) break; /* stop search after 20% slowdown */
1258
if (nb%i == 0) mf *= dmul; /* give modest bonus to mults of nb */
1265
fp = fopen(fnam, "w");
1267
fprintf(fp, "%d\n", nbB);
1271
fp = fopen(fnam, "r");
1273
fscanf(fp, "%d\n", &nbB);
1276
fprintf(stdout, "\n%cNB = %d (%.2f), No copy %cNB = %d (%.2f)\n\n",
1277
pre, nb, mf0, pre, nbB, mfB);
1281
void FindNoCopy(char pre)
1284
int nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch, i;
1288
sprintf(ln, "res/%cMMRES", pre);
1289
GetInstLogFile(ln, pre, &muladd, &lat, &nb, &mu, &nu, &ku,
1290
&FFetch, &ifetch, &nfetch, &mf);
1291
sprintf(ln, "res/%cgMMRES", pre);
1292
GetInstLogFile(ln, pre, &muladd, &lat, &i, &mu, &nu, &ku,
1293
&FFetch, &ifetch, &nfetch, &mf);
1294
nb = FindNoCopyNB(pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1296
FindNC0('N', 'N', pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1297
FindNC0('N', 'T', pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1298
FindNC0('T', 'N', pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1299
FindNC0('T', 'T', pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1302
void FindCleanupK(char pre, int nb, int mu, int nu, int ku0, int muladd,
1303
int lat0, int FFetch, int ifetch, int nfetch)
1306
int genlat, genku, speclat, ku, kumax;
1308
double mf, genmf, specmf;
1313
for (kumax=4; kumax*kumax < i; kumax += 4);
1314
if (pre == 'd' || pre == 's') kumax *= 2;
1315
if (kumax >= nb) kumax = nb;
1316
else if (kumax > nb/2) kumax = nb/2;
1317
if (ifetch == -1 || nfetch == -1) { ifetch = mu+nu; nfetch = 1; }
1318
if (pre == 's' || pre == 'd')
1328
sprintf(fnam, "res/%cCleanK", pre);
1329
if (FileExists(fnam)) /* file already there */
1331
fp = fopen(fnam, "r");
1332
assert(fgets(fnam, 256, fp));
1333
assert(fscanf(fp, " %d", &kb) == 1);
1335
if (kb > 0 && kb != nb) TimeIt = 1;
1336
sprintf(fnam, "res/%cCleanK", pre);
1341
fp = fopen(fnam, "w");
1343
fprintf(fp, " KB MULADD LAT NB MU NU KU FFTCH IFTCH NFTCH GEN-MFLOP SPC-MFLOP\n");
1345
for (kb = nb; kb; kb--)
1348
sprintf(fnam, "res/%cKB_%d", pre, kb);
1349
speclat = GetGoodLat(muladd, kb, mu, nu, ku, lat0);
1350
specmf = mmcase(fnam, pre, "JIK", 'T', 'N', nb, nb, kb, 0, 0,
1351
kb, kb, kb, 0, mu, nu, ku, muladd, speclat, beta,
1352
1, 1, csC, FFetch, ifetch, nfetch);
1354
sprintf(fnam, "res/%cKB_0_%d", pre, ku);
1355
genlat = GetGoodLat(muladd, 8000, mu, nu, 1, lat0);
1356
genku = Mmin(kumax, ku);
1357
genmf = mmcase(fnam,pre, "JIK", 'T', 'N', nb, nb, kb, 0, 0, 0, 0, 0, 0,
1358
mu, nu, genku, muladd, genlat, beta, 1, 1, csC,
1359
FFetch, ifetch, nfetch);
1360
if (ku != 1) /* always try ku == 1 for general case */
1362
sprintf(fnam, "res/%cKB_0_1", pre);
1363
mf = mmcase(fnam,pre, "JIK", 'T', 'N', nb, nb, kb, 0, 0, 0, 0, 0, 0,
1364
mu, nu, 1, muladd, genlat, beta, 1, 1, csC,
1365
FFetch, ifetch, nfetch);
1366
if (mf > genmf) { genku = 1; genmf = mf; }
1368
if (1.01 * genmf > specmf) break;
1370
"%3d %6d %3d %3d %3d %3d %3d %5d %5d %5d %9.2lf %9.2lf\n",
1371
kb, muladd, speclat, nb, mu, nu, ku, FFetch, ifetch, nfetch,
1376
"%3d %6d %3d %3d %3d %3d %3d %5d %5d %5d %9.2lf %9.2lf\n",
1377
0, muladd, genlat, nb, mu, nu, genku, FFetch, ifetch, nfetch,
1383
FindCleanupMN(char pre, char cwh, int nb, int mu, int nu, int ku,
1384
int muladd, int lat, int FFetch, int ifetch, int nfetch)
1387
int nnb=nb, beta=1, csC=1, TimeIt=0;
1389
int mu0, nu0, ku0, ma0, lat0, ff0, if0, nf0;
1393
if (cwh == 'M') Mb = 0;
1395
if (ifetch == -1 || nfetch == -1) { ifetch = mu+nu; nfetch = 1; }
1396
if (pre == 'c' || pre == 'z')
1401
sprintf(fnam, "res/%cClean%c", pre, cwh);
1402
if (FileExists(fnam))
1404
GetInstLogFile(fnam, pre, &ma0, &lat0, &nnb, &mu0, &nu0, &ku0, &ff0,
1406
if (nnb != nb || mf <= 0.0) TimeIt = 1;
1411
mf = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, Mb, Nb, nb, nb, nb, 0,
1412
mu, nu, ku, muladd, lat, beta, 1, 1, csC,
1413
FFetch, ifetch, nfetch);
1414
fp = fopen(fnam, "w");
1416
PutInstLogFile(fp, muladd, lat, nb, mu, nu, ku,
1417
FFetch, ifetch, nfetch, mf);
1422
typedef struct CleanCase CLEANCASE;
1427
int imult, icase, fixed, nb, nb0, nb1, nb2;
1430
void PrintCleanCases(CLEANCASE *cp)
1432
for (; cp; cp = cp->next)
1435
"imult=%d, icase=%d, fixed=%d, nb=%d, %d,%d,%d, mflop=%.2f\n",
1436
cp->imult, cp->icase, cp->fixed, cp->nb, cp->nb0, cp->nb1,
1437
cp->nb2, cp->mflop);
1439
fprintf(stdout, "\n");
1441
CLEANCASE *GetUserCleanup(char pre, int nb, enum CW which)
1443
* Read in user clean file
1447
CLEANCASE *cp, *cp0;
1449
char cwh[3] = {'M', 'N', 'K'};
1452
sprintf(ln, "res/%cuClean%c", pre, cwh[which]);
1453
if (!FileExists(ln))
1455
sprintf(ln, "make RunUMMClean pre=%c nb=%d which=%c\n",
1456
pre, nb, tolower(cwh[which]));
1457
assert(system(ln) == 0);
1458
sprintf(ln, "res/%cuClean%c", pre, cwh[which]);
1460
fp = fopen(ln, "r");
1462
assert(fgets(ln, 128, fp));
1463
assert(fgets(ln, 128, fp));
1464
sscanf(ln, " %d", &n);
1465
if (n < 1) return(NULL);
1466
cp0 = cp = malloc(sizeof(CLEANCASE));
1468
for (i=0; i < n; i++)
1470
assert(fgets(ln, 128, fp));
1471
sscanf(ln, " %d %d %d %d %d %d %d %lf", &cp->imult, &cp->icase,&cp->fixed,
1472
&cp->nb, &cp->nb0, &cp->nb1, &cp->nb2, &cp->mflop);
1475
cp->next = malloc(sizeof(CLEANCASE));
1479
else cp->next = NULL;
1485
int *GetKBs(char pre, int nb)
1487
* returns nb+1 length vector, KB[i] is KB & lda of KB Cleanup; 0 means var
1494
sprintf(ln, "res/%cCleanK", pre);
1495
fp = fopen(ln, "r");
1497
assert(fgets(ln, 128, fp)); /* skip titles */
1498
KB = malloc((nb+1)*sizeof(int));
1502
if (fgets(ln, 128, fp)) { assert(sscanf(ln, " %d", KB+k)==1); }
1505
for(; k; k--) KB[k] = 0;
1509
double RebuttUserKCase(char pre, int nb, int mu, int nu, int ku, int ma,
1510
int lat, int FF, int iff, int nf, int *KBs, int *NBs)
1513
int K, csC, i, ld, iku, ilat;
1515
if (pre == 'c' || pre == 'z') csC = 2;
1518
for(i=0; i < 3 && NBs[i]; i++)
1523
ilat = GetGoodLat(ma, K, mu, nu, 1, lat);
1524
mf = mmclean(pre, CleanK, "JIK", 'T', 'N', nb, nb, K,
1525
nb, nb, ld, ld, ld, 0,
1526
mu, nu, iku, ma, ilat, 1, 1, 1, csC, FF, iff, nf);
1527
fprintf(stdout, " CleanK: %dx%dx%d : %.2f\n", nb, nb, K, mf);
1534
CLEANCASE *RebuttUserKClean(char pre, int nb, int mu, int nu, int ku,
1535
int muladd, int lat, int FF, int iff, int nf)
1539
CLEANCASE *cp0, *cp;
1541
KBs = GetKBs(pre, nb);
1543
cp0 = GetUserCleanup(pre, nb, CleanK);
1544
for (cp=cp0; cp; cp = cp->next)
1546
NB[0] = cp->nb0; NB[1] = cp->nb1; NB[2] = cp->nb2;
1547
gmf = RebuttUserKCase(pre, nb, mu, nu, ku, muladd, lat, FF, iff, nf,
1549
fprintf(stdout, " pKBmm_%d: user=%.2f generated=%.2f\n",
1550
cp->imult, cp->mflop, gmf);
1551
if (1.02*gmf > cp->mflop) cp->icase = -1;
1557
double RebuttUserCase(char pre, int nb, enum CW which, int mu, int nu, int ku,
1558
int ma, int lat, int FF, int iff, int nf, int *NBs)
1561
int NB[3], M[3], NU[3], csC, i, j, NUmax, ilat;
1562
char cwh[3] = {'M', 'N', 'K'};
1564
if (pre == 'c' || pre == 'z') csC = 2;
1566
NB[0] = NB[1] = NB[2] = M[0] = M[1] = M[2] = nb;
1568
NU[0] = mu; NU[1] = nu; NU[2] = ku;
1571
for(i=0; i < 3 && NBs[i]; i++)
1573
j = M[which] = NBs[i];
1574
NU[which] = Mmin(j, NUmax);
1575
ilat = GetGoodLat(ma, M[2], NU[0], NU[1], NU[2], lat);
1576
mf = mmclean(pre, which, "JIK", 'T', 'N', M[0], M[1], M[2],
1577
NB[0], NB[1], NB[2], nb, nb, 0,
1578
mu, nu, ku, ma, lat, 1, 1, 1, csC, FF, iff, nf);
1579
fprintf(stdout, " Clean%c: %dx%dx%d : %.2f\n", cwh[which],
1580
M[0], M[1], M[2], mf);
1587
CLEANCASE *RebuttUserCases(char pre, int nb, enum CW which,
1588
int mu, int nu, int ku, int muladd, int lat,
1589
int FF, int iff, int nf)
1593
CLEANCASE *cp0, *cp;
1594
char cwh[3] = {'M', 'N', 'K'};
1596
if (which == CleanK)
1597
return(RebuttUserKClean(pre, nb, mu, nu, ku, muladd, lat, FF, iff, nf));
1598
cp0 = GetUserCleanup(pre, nb, which);
1599
for (cp=cp0; cp; cp = cp->next)
1601
NB[0] = cp->nb0; NB[1] = cp->nb1; NB[2] = cp->nb2;
1602
gmf = RebuttUserCase(pre, nb, which, mu, nu, ku, muladd, lat,
1604
fprintf(stdout, " p%cBmm_%d: user=%.2f generated=%.2f\n",
1605
cwh[which], cp->imult, cp->mflop, gmf);
1606
if (1.02*gmf > cp->mflop) cp->icase = -1;
1611
CLEANCASE *WeedOutLosers(CLEANCASE *cp0)
1613
CLEANCASE *cp, *cp1;
1615
while(cp0 && cp0->icase == -1)
1621
if (cp0 && cp0->next)
1623
for (cp=cp0; cp->next; cp = cp->next)
1626
if (cp1->icase == -1)
1628
cp->next = cp1->next;
1630
if (cp->next == NULL) break;
1637
void KillAllCleans(CLEANCASE *cp)
1648
int NumUserCleans(CLEANCASE *cp)
1651
for (i=0; cp; cp = cp->next) if (cp->icase != -1) i++;
1655
void FindUserCleanup(char pre, int nb, enum CW which, int mu, int nu, int ku,
1656
int ma, int lat, int FF, int iff, int nf)
1658
CLEANCASE *cp, *cp0;
1661
char cwh[3] = {'M', 'N', 'K'};
1663
sprintf(ln, "res/%cuClean%cF", pre, cwh[which]);
1664
if (FileExists(ln)) return;/* already done */
1665
cp = RebuttUserCases(pre, nb, which, mu, nu, ku, ma, lat, FF, iff, nf);
1666
cp = WeedOutLosers(cp);
1667
fp = fopen (ln, "w");
1669
fprintf(fp, "MULT ICASE FIXED NB\n");
1670
fprintf(fp, "%d\n", NumUserCleans(cp));
1671
for(cp0=cp; cp; cp = cp->next)
1672
fprintf(fp, "%4d %5d %5d %3d\n",
1673
cp->imult, cp->icase, cp->fixed, cp->nb);
1678
void FindAllUserClean(char pre, int nb, int mu, int nu, int ku,
1679
int ma, int lat, int FF, int iff, int nf)
1681
FindUserCleanup(pre, nb, CleanM, mu, nu, ku, ma, lat, FF, iff, nf);
1682
FindUserCleanup(pre, nb, CleanN, mu, nu, ku, ma, lat, FF, iff, nf);
1683
FindUserCleanup(pre, nb, CleanK, mu, nu, ku, ma, lat, FF, iff, nf);
1686
void FindAllUserClean0(char pre)
1689
int nb, mu, nu, ku, muladd, lat, FF, iff, nf;
1692
sprintf(ln, "res/%cMMRES", pre);
1693
GetInstLogFile(ln, pre, &muladd, &lat, &nb, &mu, &nu, &ku, &FF,
1695
FindAllUserClean(pre, nb, mu, nu, ku, muladd, lat, FF, iff, nf);
1698
void FindCleanup(char pre, int nb, int mu, int nu, int ku, int muladd, int lat,
1699
int FFetch, int ifetch, int nfetch)
1701
FindCleanupMN(pre, 'M', nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1702
FindCleanupMN(pre, 'N', nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1703
FindCleanupK(pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1704
FindAllUserClean(pre, nb, mu, nu, ku, muladd, lat, FFetch, ifetch, nfetch);
1706
void FindAllClean(char pre)
1709
int nb, mu, nu, ku, muladd, lat, FF, iff, nf;
1712
sprintf(ln, "res/%cMMRES", pre);
1713
fprintf(stderr, "\n\nSTARTING CLEANUP SEARCH\n\n");
1714
GetInstLogFile(ln, pre, &muladd, &lat, &nb, &mu, &nu, &ku, &FF,
1716
FindCleanup(pre, nb, mu, nu, ku, muladd, lat, FF, iff, nf);
1717
fprintf(stderr, "\n\nDONE CLEANUP SEARCH\n\n");
1720
int GetNumRegs0(char pre, int muladd, int nb, int lat,
1721
int nr0, int nrN, int incN)
1723
int n, nr, i, imax, nu, mu;
1724
double *rates, mf, mmf=0.0;
1730
if (incN == -2) i <<= 1;
1734
rates = malloc(n * sizeof(double));
1736
for (i=0; i < n; i++)
1738
FindMUNU(muladd, lat, nr, &mu, &nu);
1739
mf = rates[i] = mms_case(pre, muladd, nb, mu, nu, nb, lat);
1742
if (mf > mmf) mmf = mf;
1745
if (incN == -2) nr <<= 1;
1750
for (i=imax+1; i < n && 1.10*rates[i] < mmf; i++);
1752
else if (incN == -2) i = (nr0 << imax);
1753
else i = nr0 + imax*incN;
1760
int RefineNumRegs(char pre, int muladd, int nb, int lat, int nr0, int nrN)
1762
* recursively halves gap until true number is found
1767
i = (nrN - nr0) / 2;
1768
if (i < 1) return(nr0);
1769
nr = GetNumRegs0(pre, muladd, nb, lat, nr0, nr0+i, i);
1770
if (nr != nr0) /* positive or no difference in two points, so go larger */
1772
else /* difference, point is between */
1774
return(RefineNumRegs(pre, muladd, nb, lat, nr0, nrN));
1777
int GetNumRegs00(char pre, int muladd, int nb, int lat, int maxnr)
1781
fprintf(stderr, "\n\nFINDING ROUGHLY HOW MANY REGISTERS TO USE:\n\n");
1783
nr = GetNumRegs0(pre, muladd, nb, lat, 4, maxnr, -2);
1785
* Refine number of regs
1787
if (nr != -1) i = RefineNumRegs(pre, muladd, nb, lat, nr, nr<<1);
1789
fprintf(stderr, "\n\nAPPROXIMATE NUMBER OF USABLE REGISTERS=%d\n\n", i);
1793
int GetNumRegs(char pre, int MaxL1, int maxnreg)
1796
int nreg, muladd, lat;
1798
void GetMulAdd(char pre, int *MULADD, int *lat);
1800
if (pre == 'z') pre = 'd';
1801
else if (pre == 'c') pre = 's';
1803
sprintf(nam, "res/%cnreg", pre);
1804
if (!FileExists(nam))
1806
GetMulAdd(pre, &muladd, &lat);
1807
nreg = GetNumRegs00(pre, muladd, GetSafeNB(pre, MaxL1), lat, maxnreg);
1808
fp = fopen(nam, "w");
1809
fprintf(fp, "%d\n", nreg);
1813
fp = fopen(nam, "r");
1814
fscanf(fp, " %d", &nreg);
1817
fprintf(stdout, "mmnreg = %d\n", nreg);
1821
void RunTimes(char pre)
1823
const char TR[2] = {'N', 'T'};
1824
char fnam[128], fnam2[128], ln[128];
1825
const int COMPLEX = (pre == 'c' || pre == 'z');
1826
int csC = (COMPLEX ? 2 : 1);
1827
int NB, muladd, lat, nb, mu, nu, ku, ffetch, ifetch, nfetch, ia, ib;
1828
int uma, ulat, unb=0, umu, unu, uku, uff, uif, unf;
1833
fprintf(stderr, "\n\nStart RunTimes\n");
1834
sprintf(fnam, "res/%cgMMRES", pre);
1835
fp = fopen(fnam, "r");
1838
GetInstLogLine(fp, &muladd, &lat, &nb, &mu, &nu, &ku, &ffetch,
1839
&ifetch, &nfetch, &mf);
1841
if (mf < 0.0) /* need to retime */
1843
sprintf(ln, "make RunUMMSearch pre=%c nb=-1\n", pre);
1844
assert(system(ln) == 0);
1845
mf = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
1846
nb, nb, 0, mu, nu, ku, muladd, lat, 1, 1, 1, csC,
1847
ffetch, ifetch, nfetch);
1848
PutInstLogFile1(fnam, pre, muladd, lat, nb, mu, nu, ku,
1849
ffetch, ifetch, nfetch, mf);
1850
sprintf(fnam, "res/%cMMRES", pre);
1851
if (!FileExists(fnam))
1853
sprintf(ln, "make res/%cMMRES pre=%c\n", pre, pre);
1854
assert(system(ln) == 0);
1858
GetInstLogFile(fnam, pre, &muladd, &lat, &nb, &mu, &nu, &ku, &ffetch,
1859
&ifetch, &nfetch, &mf);
1860
mf = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
1861
nb, nb, 0, mu, nu, ku, muladd, lat, 1, 1, 1, csC,
1862
ffetch, ifetch, nfetch);
1863
CreateFinalSumm(pre, muladd, lat, nb, mu, nu, ku, ffetch, ifetch,
1867
sprintf(fnam, "res/%cNCNB", pre);
1868
if (!FileExists(fnam)) return;
1869
fp = fopen(fnam, "r");
1871
assert(fscanf(fp, " %d", &NB) == 1);
1874
for (ia=0; ia < 2; ia++)
1876
for (ib=0; ib < 2; ib++)
1878
sprintf(fnam, "res/%cbest%c%c_%dx%dx%d", pre, TR[ia], TR[ib],
1880
if (FileExists(fnam))
1882
GetInstLogFile(fnam, pre, &muladd, &lat, &nb, &mu, &nu, &ku,
1883
&ffetch, &ifetch, &nfetch, &mf);
1884
if (mf < 0.0) /* need to retime */
1886
mf = mmcase(NULL, pre, "JIK", TR[ia], TR[ib], nb, nb, nb,
1887
nb, nb, nb, 0, 0, 0, mu, nu, ku, muladd, lat, 1,
1888
1, 1, csC, ffetch, ifetch, nfetch);
1889
PutInstLogFile1(fnam, pre, muladd, lat, nb, mu, nu, ku,
1890
ffetch, ifetch, nfetch, mf);
1893
sprintf(fnam, "res/%cbest%c%c_%dx%dx%d", pre, TR[ia], TR[ib],
1895
if (FileExists(fnam))
1897
GetInstLogFile(fnam, pre, &muladd, &lat, &nb, &mu, &nu, &ku,
1898
&ffetch, &ifetch, &nfetch, &mf);
1899
if (mf < 0.0) /* need to retime */
1901
mf = mmcase(NULL, pre, "JIK", TR[ia], TR[ib], nb, nb, nb,
1902
0, 0, nb, 0, 0, 0, mu, nu, ku, muladd, lat, 1,
1903
1, 1, csC, ffetch, ifetch, nfetch);
1904
PutInstLogFile1(fnam, pre, muladd, lat, nb, mu, nu, ku,
1905
ffetch, ifetch, nfetch, mf);
1908
sprintf(fnam, "res/%cbest%c%c_%dx%dx%d", pre, TR[ia], TR[ib],
1910
if (FileExists(fnam))
1912
GetInstLogFile(fnam, pre, &muladd, &lat, &nb, &mu, &nu, &ku,
1913
&ffetch, &ifetch, &nfetch, &mf);
1914
if (mf < 0.0) /* need to retime */
1916
mf = mmcase(NULL, pre, "JIK", TR[ia], TR[ib], nb, nb, nb,
1917
0, 0, 0, 0, 0, 0, mu, nu, ku, muladd, lat, 1,
1918
1, 1, csC, ffetch, ifetch, nfetch);
1919
PutInstLogFile1(fnam, pre, muladd, lat, nb, mu, nu, ku,
1920
ffetch, ifetch, nfetch, mf);
1925
fprintf(stderr, "\nDone RunTimes\n\n");
1928
void cmmsearch(char pre, int MULADD, int Fku, int nNBs, int *NBs, int nreg,
1931
* With all other parameters set by real search, find good complex NB
1934
char *typ, ln[64], upre;
1935
int i, k, mnb=0, muladd, lat, nb, mu, nu, ku, ffetch, ifetch, nfetch;
1937
double mf, mmf=0.0, umf;
1940
if (pre == 'c') upre = 's';
1943
sprintf(ln, "res/%cMMRES", pre);
1944
if (FileExists(ln)) /* already have result */
1946
GetInstLogFile(ln, pre, &muladd, &lat, &nb, &mu, &nu, &ku, &ffetch,
1947
&ifetch, &nfetch, &mf);
1950
mf = mmcase(NULL, pre, "JIK", 'T', 'N', nb, nb, nb, nb, nb, nb,
1951
nb, nb, 0, mu, nu, ku, muladd, lat, 1, 1, 1, 2,
1952
ffetch, ifetch, nfetch);
1953
gmmsearch(upre, muladd, Fku, nNBs, NBs, nreg, lat, Fnb);
1954
nb = CheckUser(pre, 1.02, mf, nb, NULL);
1955
CreateFinalSumm(pre, muladd, lat, nb, mu, nu, ku, ffetch, ifetch,
1958
sprintf(ln, "res/%cNB", pre);
1959
fp = fopen(ln, "w");
1960
fprintf(fp, "%d\n%d\n", 1, nb);
1964
gmmsearch(upre, MULADD, Fku, nNBs, NBs, nreg, LAT, Fnb);
1965
sprintf(ln, "res/%cgMMRES", upre);
1966
GetInstLogFile(ln, upre, &muladd, &lat, &nb, &mu, &nu, &ku, &ffetch,
1967
&ifetch, &nfetch, &mf);
1968
KUisNB = (nb <= ku);
1970
for (i=0; i < nNBs; i++)
1972
if (KUisNB) k = NBs[i];
1973
else k = Mmin(ku, NBs[i]);
1974
mf = mms_case(pre, muladd, NBs[i], mu, nu, k, lat);
1981
if (KUisNB) ku = mnb;
1982
else ku = Mmin(ku, mnb);
1983
sprintf(ln, "res/%cgMMRES", pre);
1984
PutInstLogFile1(ln, pre, muladd, lat, mnb, mu, nu, ku,
1985
ffetch, ifetch, nfetch, mmf);
1987
nb = CheckUser(pre, 1.02, mmf, mnb, &umf);
1990
if (ku == mnb) ku = nb;
1992
if (nb % mu || nb % nu)
1994
NO1D = GetNO1D(upre, nreg, nb, MULADD, LAT);
1995
searchmu_nu(upre, nb, nreg, Fku, MULADD, LAT, NO1D,
1996
&mmf, &nb, &mu, &nu, &ku, &lat);
1998
FindKU(upre, muladd, LAT, nb, mu, nu, &mmf, &ku, &lat);
1999
FindLAT(upre, MAXLAT, nb, muladd, mu, nu, ku, &mmf, &lat);
2000
FindFetch('T', 'N', upre, nb, nb, nb, mu, nu, ku, muladd, lat,
2001
&ffetch, &ifetch, &nfetch);
2004
* Save NB we've found
2006
sprintf(ln, "res/%cNB", pre);
2007
fp = fopen(ln, "w");
2008
fprintf(fp, "%d\n%d\n", 1, mnb);
2010
CreateFinalSumm(pre, muladd, lat, mnb, mu, nu, ku,
2011
ffetch, ifetch, nfetch, mf);
2013
void GetMulAdd(char pre, int *MULADD, int *lat)
2015
char nam[64], ln[128];
2018
sprintf(nam, "res/%cMULADD", pre);
2019
if (!FileExists(nam))
2021
sprintf(ln, "make RunMulAdd pre=%c maxlat=%d mflop=%d\n", pre, 6, 200);
2022
assert(system(ln) == 0);
2024
fp = fopen(nam, "r");
2026
fscanf(fp, "%d", MULADD);
2027
fscanf(fp, "%d", lat);
2031
int GetNumRegsMM(char pre, int MaxL1Size, int MAX_NREG)
2034
nreg = GetNumRegs(pre, MaxL1Size, MAX_NREG);
2038
"\nUNABLE TO FIND NUMBER OF REGISTERS, ASSUMMING 32.\n\n");
2043
fprintf(stderr, "FOUND NUMBER OF REGISTERS TO BE %d; THIS WOULD TAKE TOO LONG TO SEARCH, SO SETTING TO 128.\n", nreg);
2049
"FOUND # OF REGISTERS TO BE %d; TRYING 8 FOR SAFETY.\n", nreg);
2055
"FOUND # OF REGISTERS TO BE %d; TRYING 16 FOR SAFETY.\n", nreg);
2061
void GetMMRES(char pre, int nreg, int MaxL1Size, int ForceLat)
2065
int i, nNBs, muladd, lat;
2069
if (pre == 'c') upre = 's';
2070
else if (pre == 'z') upre = 'd';
2073
sprintf(ln, "res/%cNB", pre);
2074
if (!FileExists(ln)) findNBs(upre, ln, MaxL1Size);
2075
assert( (fp = fopen(ln, "r")) != NULL );
2076
fscanf(fp, "%d", &nNBs);
2077
fprintf(stdout, "\nNB's to try: ");
2078
NBs = malloc(nNBs*sizeof(int));
2079
for (i=0; i != nNBs; i++)
2081
fscanf(fp, "%d", NBs+i);
2082
fprintf(stdout, "%d ",NBs[i]);
2084
fprintf(stdout, "\n\n");
2086
GetMulAdd(upre, &muladd, &lat);
2087
if (ForceLat != -1) lat = ForceLat;
2090
if (pre == 'c' || pre == 'z')
2091
cmmsearch(pre, muladd, 0, nNBs, NBs, nreg, lat, 0);
2092
else mmsearch(pre, muladd, 0, nNBs, NBs, nreg, lat, 0);
2096
#define ATL_MAXNREG 64
2098
main(int nargs, char *args[])
2100
char prec, upre, lang;
2101
int MULADD, MaxL1Size, ForceLat, i, nreg, ROUT, FRC;
2102
int muladd, lat, nb, mu, nu, ku, ffetch, ifetch, nfetch;
2106
char ln[128], auth[65];
2108
GetSettings(nargs, args, &prec, &lang, &ku, &ForceLat, &FRC, &nreg,
2110
assert(FRC == 0 && ku == 0); /* obsolete flags */
2113
if (prec == 'z') upre = 'd';
2114
else if (prec == 'c') upre = 's';
2121
else if (ROUT == -4)
2123
FindAllUserClean0(prec);
2126
else if (ROUT == -5) /* produce ATL_mmnreg for sysinfo */
2128
GetNumRegs(prec, MaxL1Size, ATL_MAXNREG);
2131
fprintf(stderr, "Precision='%c', FORCE=%d, LAT=%d, nreg=%d, MaxL1=%d\n",
2132
prec, FRC, ForceLat, nreg, MaxL1Size);
2134
if (nreg == -1) nreg = GetNumRegsMM(upre, MaxL1Size, ATL_MAXNREG);
2135
GetMMRES(prec, nreg, MaxL1Size, ForceLat);
2139
sprintf(ln, "res/%cMMRES", prec);
2140
fp = fopen(ln, "r");
2142
assert( fgets(ln, 128, fp) != NULL );
2143
GetInstLogLine(fp, &muladd, &lat, &nb, &mu, &nu, &ku, &ffetch,
2144
&ifetch, &nfetch, &mf);
2145
assert( fgets(ln, 128, fp) != NULL );
2146
assert(fscanf(fp, " %d %d %lf \"%[^\"]\" \"%[^\"]", &icase, &unb, &umf,
2151
fprintf(stdout, "\n\nFor this run, the best parameters found were MULADD=%d, lat=%d, NB=%d, MU=%d, NU=%d, KU=%d\n",
2152
muladd, lat, nb, mu, nu, ku);
2158
"\n\nFor this run, the best case found was NB=%d user case %d\n",
2160
fprintf(stdout, "written by %s.\n", auth);
2162
fprintf(stdout, "This gave a performance = %f MFLOP.\n", mf);
2164
"The necessary files have been created. If you are happy with\n");
2166
"the above mflops for your system, type 'make %cinstall'.\n", prec);
2168
"Otherwise, try the xmmsearch with different parameters, or hand\n");
2169
fprintf(stdout, "tweak the code.\n");