2
* Automatically Tuned Linear Algebra Software v3.10.1
3
* Copyright (C) 2009 Chad Zalkin
5
* Code contributers : Chad Zalkin, R. Clint Whaley
7
* Redistribution and use in source and binary forms, with or without
8
* modification, are permitted provided that the following conditions
10
* 1. Redistributions of source code must retain the above copyright
11
* notice, this list of conditions and the following disclaimer.
12
* 2. Redistributions in binary form must reproduce the above copyright
13
* notice, this list of conditions, and the following disclaimer in the
14
* documentation and/or other materials provided with the distribution.
15
* 3. The name of the ATLAS group or the names of its contributers may
16
* not be used to endorse or promote products derived from this
17
* software without specific written permission.
19
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
21
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
22
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
23
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
* POSSIBILITY OF SUCH DAMAGE.
51
#define COMPLEX_SINGLE 2
52
#define COMPLEX_DOUBLE 3
63
#define ALIGN_NALIGN 1
64
#define NALIGN_ALIGN 2
67
#define CACHE_LINE_SIZE 64
71
BOOL intrinSrcDest = TRUE;
72
BOOL useVoidPointersForC = FALSE;
73
BOOL useVoidPointersForA = FALSE;
74
BOOL useVoidPointersForB = FALSE;
77
/* Prefetch Options */
102
ATL_INT vector_stride;
103
ATL_INT vector_length_bytes;
109
/* Load vectors from A and B */
112
/* Load/Store a single element */
116
/* Store an aligned/unaligned vector */
129
BOOL verifyUnrollings;
130
BOOL treatLoadsAsFloat;
131
BOOL treatStoresAsFloat;
134
BOOL constantFolding;
135
FILE* outputLocation;
138
/* Load options from the command line */
139
void loadOptions( int argc, char **argv );
140
static void loadInt( char* tag, ATL_INT *value, int argc, char** argv );
141
static void loadFloat( char* tag, float *value, int argc, char** argv );
142
static void loadBool( char* tag, BOOL *value, int argc, char** argv );
143
static char* loadString( char* tag, int argc, char** argv );
144
static int requestHelp( int argc, char** argv );
145
static void convertElementType( char specifier );
146
static void setOutputLocation( char* file );
147
static void printHelp();
148
int numArgsProcessed = 0;
151
void getNBString( char* out );
152
void getMBString( char* out );
153
void getKBString( char* out );
155
/* These functions print the sections of the kernel */
156
void printMainLoops( int alignmentOfC, char *name );
157
void printILoop( int alignmentOfC, BOOL prefetchA, BOOL prefetchB );
159
void printPreamble();
161
void printBody( BOOL simple );
163
/* These functions print the unrollings of the k loop */
164
void printAllKUnrollings( BOOL prefetchA, BOOL prefetchB );
166
void k_partialUnrolling( ATL_INT offset );
167
void k_unrollingFullStep( ATL_INT delta );
168
void printKRolled( BOOL prefetchA, BOOL prefetchB );
169
void printPartiallyUnrolledK( BOOL prefetchA, BOOL prefetchB );
170
void printFullyUnrolledK( BOOL prefetchA, BOOL prefetchB );
172
/* This compresses the vectors back to scalars */
173
void printScalarCompression();
174
void printScalarCompressionSingle();
175
void storeResults( int alignmentOfC );
178
/* These functions print the code to the console *
179
* and manage the indention */
180
void emit( const char *fmt, ...);
181
void emitCat( const char *fmt, ...);
182
void indent( int delta );
186
* Keeps track of the number of indents in the emitted code.
187
* This allows the system to properly nest braces in the output.
192
* Create some shorthand for the load/store instructions
193
* These are aliased to strings because one option adds a
194
* typecast to the load/store instructions. It wouldn't do
195
* to have to test that option all the time....
203
/* Options Variables */
212
int argc, /* Number of command line args */
213
char** argv /* Array of command line args */
216
* Prints an implementation of GEMM given the parameters on the
220
loadOptions( argc, argv ); /* Read options from the command line */
222
if( unroll.nb == USE_KB )
224
assert( unroll.kb % unroll.a == 0 );
225
assert( unroll.kb % unroll.b == 0 );
227
else if( unroll.nb != PARAMETER )
229
assert( unroll.nb % unroll.a == 0 );
230
assert( unroll.nb % unroll.b == 0 );
233
printPreamble(); /* Emit includes, defines, and variables */
237
* Emit three cases: Beta = 0, Beta = 1, Beta != 0,1
238
* This allows conditional compilation of complex calculaions
239
* which need all three cases.
241
* Each case includes a call to printIntro() which emits the
242
* function header, and a call to printBody() which emits the
243
* alignment cases as needed.
247
printBody( options.cAlignment );
249
assert( tabwidth == 0 );
258
* Print setup information such as CPP includes, data type defininitions,
259
* and defines used to name constants.
262
emit( "#define ATL_INT int\n" );
264
/* Store some strings to represent the NB and MB constants */
270
emit("#include <stdio.h>\n" );
271
emit("#include <stdint.h>\n" );
272
emit("#include <pmmintrin.h>\n" );
276
/* Emit some defines, so that the code is readable */
277
emit( "#define I_UNROLL %d\n", unroll.a );
278
emit( "#define J_UNROLL %d\n", unroll.b );
281
/* Setup the prefetch options */
282
emit("/* Is prefetched data written or just read? */\n");
283
emit( "#define PF_READONLY 0\n" );
284
emit( "#define PF_READWRITE 1\n" );
285
emit( "#define PF_NO_REUSE 0\n" );
287
emit("\n/* Default temporality of cache prefetch (1-3) */\n");
288
emit( "#define PF_DEF 1\n" );
289
emit( "#define CACHE_LINE_SIZE %d\n", CACHE_LINE_SIZE );
291
if( options.treatLoadsAsFloat )
293
emit( "#define MMCAST( a ) (float*)(a)\n" );
297
emit( "#define MMCAST( a ) (a)\n" );
300
if( options.treatStoresAsFloat )
302
emit( "#define MMCASTStore( a ) (float*)(a)\n" );
303
emit( "#define MMCASTStoreintrin( a ) (__m128)(a)\n" );
307
emit( "#define MMCASTStore( a ) (a)\n" );
308
emit( "#define MMCASTStoreintrin( a ) (a)\n" );
315
* Emit code that deduces constants used in this configuration,
316
* data types, and the function prototype
324
emit( "#define TYPE %s\n", element.type_name );
325
emit("void ATL_USERMM( const ATL_INT M, const ATL_INT N, const ATL_INT K,\n");
326
emit(" const TYPE alpha, const TYPE *A, const ATL_INT lda,\n");
327
emit(" const TYPE *B, const ATL_INT ldb,\n");
328
emit(" const TYPE beta, TYPE *C, const ATL_INT ldc )\n");
333
emit("register ATL_INT i, j, k;\n");
337
/* Create variables for each of the vector registers */
338
emit( "/* Vector registers to hold the elements of C */\n" );
339
for( b=0; b<unroll.b; ++b )
341
emit( "%s c%d_0", element.intrinsic, b );
342
for( a=1; a<unroll.a; ++a )
344
emitCat( ", c%d_%d", b, a );
350
* If beta must be applied, create some registers for the
353
if( options.beta != BETA0 )
355
emit( "/* Vector register to hold C*beta */\n" );
356
for( b=0; b<unroll.b; ++b )
358
ATL_INT remaining = unroll.a;
360
emit( "%s ", element.intrinsic );
362
/* Create registers that include the entire stride */
363
for( ; remaining > element.vector_stride;
364
remaining -= element.vector_stride )
366
emitCat( "bc%d_%d", b, a );
367
a += element.vector_stride;
368
if( remaining != element.vector_stride )
373
/* Create registers for elements that do not fit in the stride */
374
for( ; remaining > 0; --remaining )
376
emitCat( "bc%d_%d", b, a );
388
emit( "/* Temporary vector registers for use in inner loop */\n" );
389
emit("%s temp; \n", element.intrinsic );
391
if( element.type == COMPLEX_DOUBLE || element.type == COMPLEX_SINGLE )
393
for( x=0; x<element.vector_stride; ++x )
395
emit("%s temp%d; \n", element.intrinsic, x );
400
if( options.verifyUnrollings == TRUE )
402
emit("assert(M%%%d==0);\n", unroll.a );
403
emit("assert(N%%%d==0);\n", unroll.b );
406
/* Start prefetching from B */
409
emit("__builtin_prefetch( B, PF_READONLY, PF_DEF );\n");
412
/* Load the beta factor so it will be ready to apply later */
413
if( options.beta == BETAX || options.beta == BETAN1 )
415
if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
417
emit("const %s betaV = _mm_set1_ps( beta ); \n", element.intrinsic );
419
emit("const %s betaV = _mm_set1_pd( beta ); \n", element.intrinsic );
423
emit("/* Pointer adjustments */ \n");
425
if( options.ldb == PARAMETER )
427
emit("register const ATL_INT ldb_bytes = ldb << %d;\n", element.shift );
431
emit("register const ATL_INT ldb_bytes3 = ldb_bytes*3;\n" );
437
if( options.lda == PARAMETER )
439
emit("register const ATL_INT lda_bytes = lda << %d;\n",
441
emit("register const ATL_INT lda_bytes3 = lda_bytes * 3;\n");
445
/* Since complex matricies are strided by 2, ldc is off by a factor of 2 */
446
if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
448
switch( options.ldc )
451
if( useVoidPointersForC )
452
emit("register const ATL_INT ldc_bytes = ldc << %d;\n",
455
emit("register const ATL_INT ldc_bytes = 2*ldc;\n");
459
if( useVoidPointersForC )
460
emit( "register const ATL_INT ldc_bytes = 2*KB%d;\n",
463
emit( "register const ATL_INT ldc_bytes = 2*KB;\n" );
467
if( useVoidPointersForC )
468
emit( "register const ATL_INT ldc_bytes = 2*%d*%d;\n",
469
element.size, options.ldc );
471
emit( "register const ATL_INT ldc_bytes = 2*%d;\n",
478
switch( options.ldc )
481
if( useVoidPointersForC )
482
emit("register const ATL_INT ldc_bytes = ldc << %d;\n",
485
emit("register const ATL_INT ldc_bytes = ldc;\n" );
489
if( useVoidPointersForC )
490
emit( "register const ATL_INT ldc_bytes = KB*%d;\n",
493
emit( "register const ATL_INT ldc_bytes = KB;\n" );
497
if( useVoidPointersForC )
498
emit( "register const ATL_INT ldc_bytes = %d*%d;\n",
499
element.size, options.ldc );
501
emit( "register const ATL_INT ldc_bytes = %d;\n",
508
if( useVoidPointersForB )
509
emit("register void const *B0_off = (void*)B;\n");
511
emit("register TYPE const *B0_off = B;\n");
515
if( prefetch.ABlock )
517
emit("register void const * prefetchABlock = " );
518
if( options.lda == PARAMETER )
519
emitCat(" (void*)(A + %s*lda); \n", nb );
520
else if( options.lda == USE_KB )
521
emitCat(" (void*)(A + %s*KB); \n", nb );
523
emitCat(" (void*)(A + %s*%d);\n", nb, options.lda*element.size );
528
emit("register void const * prefetchACols = " );
529
if( options.lda == PARAMETER )
530
emitCat(" (void*)(A + %s*lda); \n", nb );
531
else if( options.lda == USE_KB )
532
emitCat(" (void*)(A + %s*KB); \n", nb );
534
emitCat(" (void*)(A + %s*%d);\n", nb, options.lda*element.size );
539
emit("register void const *prefetchB = " );
540
if( unroll.mb == PARAMETER )
541
emitCat(" (void*)(B + MB*ldb);\n" );
542
else if( unroll.mb == USE_KB )
543
emitCat(" (void*)(B + %d*ldb);\n", unroll.kb );
545
emitCat(" (void*)(B + %d*ldb);\n", unroll.mb );
550
emit("__builtin_prefetch( prefetchACols, PF_READONLY, PF_DEF );\n");
555
emit("__builtin_prefetch( prefetchB, PF_READONLY, PF_DEF );\n");
560
emit( "/* Unroll A */\n");
561
emit( "%s A0, a0", element.intrinsic );
562
for( a=1; a < unroll.a; ++a )
563
emitCat( ", A%d, a%d", a, a );
566
emit( "/* Unroll B */\n" );
567
emit( "%s B0", element.intrinsic );
568
for( b=1; b < unroll.b; ++b )
569
emitCat( ", B%d", b );
571
emitCat( ", B1", b );
577
if( options.lda == PARAMETER )
579
emit("register const ATL_INT unroll_a = I_UNROLL*lda_bytes;\n");
581
else if( options.lda == USE_KB )
583
if( useVoidPointersForA )
584
emit("register const ATL_INT unroll_a = I_UNROLL*KB%d;\n",
587
emit("register const ATL_INT unroll_a = I_UNROLL*KB;\n" );
591
if( useVoidPointersForA )
593
emit( "register const ATL_INT unroll_a = I_UNROLL*%d*%d;\n",
594
options.lda, element.size );
596
emit( "register const ATL_INT unroll_a = I_UNROLL*%d;\n",
602
if( useVoidPointersForC )
603
emit("register void* cPtr = (void*)C;\n" );
605
emit("register TYPE* cPtr = C;\n" );
615
BOOL simple /* If 1, unaligned case is assumed, else, generate aligned cases */
618
* Generate the I,J,K loops, accounting for the possibility that 4 loops
619
* are needed to account for the aligned, unaligned, and two alternating
627
printMainLoops( NALIGNED, "Non aligned" );
631
emit("const intptr_t ci = (intptr_t)cPtr;\n");
632
emit("if( (ci + 15) >> %d << %d == ci )\n", element.shift, element.shift );
635
emit("if( ldc %% 2 == 0 )\n" );
638
printMainLoops( ALIGNED, "C Aligned" );
642
printMainLoops( ALIGN_NALIGN, "C Aligned/Nonaligned columns" );
648
emit("if( ldc %% 2 == 0 )\n" );
651
printMainLoops( NALIGNED, "C Nonaligned" );
655
printMainLoops( NALIGN_ALIGN, "C Nonaligned/Aligned columns" );
669
ATL_INT times, /* How far to offset from the base value? */
673
* Returns a compilable code string that will evaluate to
674
* a byte offset in terms of lda.
675
* RETURNS: char* describing the offset.
678
if( options.lda == USE_KB )
680
if( options.constantFolding )
682
offset = times*unroll.kb + offset;
686
char *out = malloc( 255 );
692
if( useVoidPointersForA )
693
sprintf( out, "A0_off + %d*KB%d + %d",
694
times, element.size, offset );
696
sprintf( out, "A0_off + %d*KB + %d",
699
if( useVoidPointersForA )
700
sprintf( out, "A0_off + %d*KB%d", times, element.size );
702
sprintf( out, "A0_off + %d*KB", times );
707
sprintf( out, "A0_off + %d", offset );
709
sprintf( out, "A0_off" );
716
else if( options.lda == PARAMETER )
718
char *out = malloc( 255 );
723
case 0: sprintf( out, "A0_off + %d", offset ); break;
724
case 1: sprintf( out, "A0_off + lda_bytes + %d", offset ); break;
725
case 2: sprintf( out, "A0_off + 2*lda_bytes + %d", offset ); break;
726
case 3: sprintf( out, "A3_off + %d", offset ); break;
727
case 4: sprintf( out, "A0_off + 4*lda_bytes + %d", offset ); break;
728
case 5: sprintf( out, "A3_off + 2*lda_bytes + %d", offset ); break;
729
case 6: sprintf( out, "A0_off + 2*lda_bytes3 + %d", offset ); break;
730
case 7: sprintf( out, "A3_off + 4*lda_bytes + %d", offset ); break;
731
case 8: sprintf( out, "A0_off + 8*lda_bytes + %d", offset ); break;
732
case 9: sprintf( out, "A3_off + 2*lda_bytes3 + %d", offset ); break;
733
default: sprintf( out, "A0_off + %d*lda_bytes + %d",
739
case 0: sprintf( out, "A0_off" ); break;
740
case 1: sprintf( out, "A0_off + lda_bytes" ); break;
741
case 2: sprintf( out, "A0_off + 2*lda_bytes" ); break;
742
case 3: sprintf( out, "A3_off" ); break;
743
case 4: sprintf( out, "A0_off + 4*lda_bytes" ); break;
744
case 5: sprintf( out, "A3_off + 2*lda_bytes" ); break;
745
case 6: sprintf( out, "A0_off + 2*lda_bytes3" ); break;
746
case 7: sprintf( out, "A3_off + 4*lda_bytes" ); break;
747
case 8: sprintf( out, "A0_off + 8*lda_bytes" ); break;
748
case 9: sprintf( out, "A3_off + 2*lda_bytes3" ); break;
749
default: sprintf( out, "A0_off + %d*lda_bytes", times );
756
char *out = malloc( 255 );
760
if( useVoidPointersForA )
761
delta = options.lda*element.size;
765
if( options.constantFolding )
773
sprintf( out, "A0_off + %d*%d + %d",
774
times, delta, offset );
776
sprintf( out, "A0_off + %d*%d",
782
sprintf( out, "A0_off + %d", offset );
784
sprintf( out, "A0_off" );
793
ATL_INT times, /* number of multiples of ldb to offset */
794
ATL_INT offset /* Extra offset */
797
* Returns a compilable code string that will evaluate to
798
* a byte offset in terms of ldb.
799
* RETURNS: char* describing the offset.
802
if( options.ldb == USE_KB )
804
if( options.constantFolding )
806
offset = (times*unroll.kb + offset);
810
char *out = malloc( 255 );
815
if( useVoidPointersForB )
816
sprintf( out, "B0_off + %d*KB%d + %d",
817
times, element.size, offset );
819
sprintf( out, "B0_off + %d*KB + %d",
822
if( useVoidPointersForB )
823
sprintf( out, "B0_off + %d*KB%d", times, element.size );
825
sprintf( out, "B0_off + %d*KB%d", times, element.size );
830
sprintf( out, "B0_off + %d", offset );
832
sprintf( out, "B0_off" );
836
} else if( options.ldb == PARAMETER ) {
839
char *out = malloc( 255 );
842
case 0: sprintf( out, "B0_off + %d", offset ); break;
843
case 1: sprintf( out, "B0_off + ldb_bytes + %d", offset ); break;
844
case 2: sprintf( out, "B0_off + 2*ldb_bytes + %d", offset ); break;
845
case 3: sprintf( out, "B0_off + ldb_bytes3 + %d", offset ); break;
846
case 4: sprintf( out, "B0_off + 4*ldb_bytes + %d", offset ); break;
847
case 6: sprintf( out, "B0_off + 2*ldb_bytes3 + %d", offset ); break;
850
if( useVoidPointersForB )
851
sprintf( out, "B0_off + %d*%d + %d",
852
times, options.ldb*element.size, offset );
854
sprintf( out, "B0_off + %d*%d + %d",
855
times, options.ldb, offset );
862
case 0: return "B0_off";
863
case 1: return "B0_off + ldb_bytes";
864
case 2: return "B0_off + 2*ldb_bytes";
865
case 3: return "B0_off + ldb_bytes3";
866
case 4: return "B0_off + 4*ldb_bytes";
867
case 6: return "B0_off + 2*ldb_bytes3";
870
char *out = malloc( 255 );
871
if( useVoidPointersForB )
872
sprintf( out, "B0_off + %d*%d",
873
times, options.ldb*element.size );
875
sprintf( out, "B0_off + %d*%d",
876
times, options.ldb*element.size );
882
char *out = malloc( 255 );
887
case 0: sprintf( out, "B0_off + %d", offset ); break;
889
if( useVoidPointersForB )
890
sprintf( out, "B0_off + %d*sizeof(TYPE) + %d",
891
options.ldb, offset );
893
sprintf( out, "B0_off + %d + %d",
894
options.ldb, offset );
896
if( useVoidPointersForB )
897
sprintf( out, "B0_off + %d*%d*sizeof(TYPE) + %d",
898
times, options.ldb, offset );
900
sprintf( out, "B0_off + %d*%d + %d",
901
times, options.ldb, offset );
907
case 0: sprintf( out, "B0_off" ); break;
909
if( useVoidPointersForB )
910
sprintf( out, "B0_off + %d*sizeof(TYPE)", options.ldb );
912
sprintf( out, "B0_off + %d", options.ldb );
914
if( useVoidPointersForB )
915
sprintf( out, "B0_off + %d*%d*sizeof(TYPE)",
916
times, options.ldb );
918
sprintf( out, "B0_off + %d*%d", times, options.ldb );
929
* Emit code for the initial iteration of the K loop.
930
* This iteration is special because it does not need to
931
* accumulate, it only needs to initialize the scalar
932
* expansion registers.
936
emit( "/* K_Unrolling0 */\n" );
939
for( a=0; a<unroll.a; ++a )
941
char* deltaLDA = ldaOffset( a, 0 );
942
emit( "A%d = %s( MMCAST(%s) );\n", a, element.load_ab, deltaLDA );
945
for( b=0; b<unroll.b; ++b )
947
emit( "B%d = %s( MMCAST(%s) );\n",
948
b, element.load_ab, ldbOffset(b, 0) );
951
for( a=0; a<unroll.a; ++a )
953
emit( "c%d_%d = B%d;\n", b, a, b );
956
emit( "c%d_%d = _mm_mul_p%c( A%d, c%d_%d );\n",
957
b, a, element.cType, a, b, a );
959
emit( "c%d_%d = _mm_mul_p%c( c%d_%d, A%d );\n",
960
b, a, element.cType, b, a, a );
969
void k_unrollingFullStep
971
ATL_INT delta /* Number of bytes to offset during this unrolling */
973
/* Emit code that performs an unrolling step of the K loop.
974
* This function will be called repeatedly, increasing the delta value
975
* to account for the unrolling
980
char* tempId[] = { "B0", "B1" };
983
emit( "/* K_Unrolling: %d */\n", delta );
984
assert( delta < unroll.kb );
987
for( a=0; a<unroll.a; ++a )
989
if( useVoidPointersForA )
990
emit("A%d = %s( MMCAST(%s) );\n",
991
a, element.load_ab, ldaOffset(a, delta*element.size ) );
993
emit("A%d = %s( MMCAST(%s) );\n",
994
a, element.load_ab, ldaOffset(a, delta) );
999
for( b=0; b<unroll.b-1; ++b )
1002
if( useVoidPointersForB )
1003
emit( "B%d = %s( MMCAST(%s) );\n",
1004
b, element.load_ab, ldbOffset( b, delta*element.size ) );
1006
emit( "B%d = %s( MMCAST(%s) );\n",
1007
b, element.load_ab, ldbOffset( b, delta ) );
1010
for( a=0; a<unroll.a; ++a )
1012
emit( "a%d = A%d;\n", a, a );
1015
emit( "a%d = _mm_mul_p%c( B%d, a%d );\n",
1016
a, element.cType, b, a );
1017
emit( "c%d_%d = _mm_add_p%c( a%d, c%d_%d );\n",
1018
b, a, element.cType, a, b, a );
1020
emit( "a%d = _mm_mul_p%c( a%d, B%d );\n",
1021
a, element.cType, a, b );
1022
emit( "c%d_%d = _mm_add_p%c( c%d_%d, a%d );\n",
1023
b, a, element.cType, b, a, a );
1029
if( useVoidPointersForB )
1030
emit( "B%d = %s( MMCAST(%s) );\n",
1031
unroll.b-1, element.load_ab,
1032
ldbOffset( unroll.b-1, delta*element.size )
1035
emit( "B%d = %s( MMCAST(%s) );\n",
1036
unroll.b-1, element.load_ab, ldbOffset( unroll.b-1, delta ) );
1039
for( a=0; a<unroll.a; ++a )
1043
emit( "A%d = _mm_mul_p%c( B%d, A%d );\n",
1044
a, element.cType, unroll.b-1, a );
1045
emit( "c%d_%d = _mm_add_p%c( A%d, c%d_%d );\n",
1046
unroll.b-1, a, element.cType, a, unroll.b-1, a );
1048
emit( "A%d = _mm_mul_p%c( A%d, B%d );\n",
1049
a, element.cType, a, unroll.b-1 );
1050
emit( "c%d_%d = _mm_add_p%c( c%d_%d, A%d );\n",
1051
unroll.b-1, a, element.cType, unroll.b-1, a );
1058
void k_partialUnrolling
1060
ATL_INT offset /* Number of bytes to offset during this unrolling */
1062
/* Emit code that performs an unrolling step of the K loop.
1063
* This function will be called repeatedly, increasing the delta value
1064
* to account for the unrolling
1069
assert( offset < unroll.kb );
1070
assert( offset < unroll.k );
1072
emit( "/* k_partialUnrolling: %d */\n", offset );
1074
for( b=0; b<unroll.b; ++b )
1078
if( useVoidPointersForB )
1079
emit("B%d = %s( MMCAST( %s + (k+%d)*sizeof(TYPE) ) );\n",
1080
b, element.load_ab, ldbOffset( b, 0 ), offset );
1082
emit("B%d = %s( MMCAST( %s + (k+%d) ) );\n",
1083
b, element.load_ab, ldbOffset( b, 0 ), offset );
1085
if( useVoidPointersForB )
1086
emit("B%d = %s( MMCAST( %s + k*sizeof(TYPE) ) );\n",
1087
b, element.load_ab, ldbOffset( b, 0 ) );
1089
emit("B%d = %s( MMCAST( %s + k ) );\n",
1090
b, element.load_ab, ldbOffset( b, 0 ) );
1094
for( b=0; b<unroll.b; ++b )
1096
for( a=0; a<unroll.a; ++a )
1102
if( useVoidPointersForA )
1103
emit("A%d = %s( MMCAST(%s + (k+%d)*sizeof(TYPE)) );\n",
1104
a, element.load_ab, ldaOffset(a, 0), offset );
1106
emit("A%d = %s( MMCAST( %s + (k+%d) ) );\n",
1107
a, element.load_ab, ldaOffset(a, 0), offset );
1109
if( useVoidPointersForA )
1110
emit("A%d = %s( MMCAST( %s + k*sizeof(TYPE)));\n",
1111
a, element.load_ab, ldaOffset(a, 0) );
1113
emit("A%d = %s( MMCAST(%s + k));\n",
1114
a, element.load_ab, ldaOffset(a, 0) );
1117
emit("temp = _mm_mul_p%c( B%d, A%d );\n", element.cType, b, a );
1121
emit("c%d_%d = _mm_add_p%c( temp, c%d_%d );\n",
1122
b, a, element.cType, b, a );
1124
emit("c%d_%d = _mm_add_p%c(c%d_%d, temp );\n",
1125
b, a, element.cType, b, a );
1134
void printAllKUnrollings
1136
BOOL prefetchA, /* Prefetch A during this call? */
1137
BOOL prefetchB /* Prefetch B during this call? */
1143
printKRolled( prefetchA, prefetchB );
1145
else if( unroll.k < unroll.kb )
1147
printf( "unrollings: %d, %d\n", unroll.kb, unroll.k );
1148
assert( unroll.kb % unroll.k == 0 );
1149
printPartiallyUnrolledK( prefetchA, prefetchB );
1153
printFullyUnrolledK( prefetchA, prefetchB );
1157
if( prefetchA && prefetch.ABlock )
1160
const ATL_INT numABlockPrefetches =
1161
unroll.a * unroll.b * element.size / CACHE_LINE_SIZE + 1;
1162
emit("prefetchABlock += %d*pfBlockDistance;\n", numABlockPrefetches );
1166
if( options.ldb == PARAMETER )
1168
if( prefetchB && prefetch.BCols )
1170
emit( "prefetchB += J_UNROLL*ldb_bytes;\n" );
1172
} else if( options.ldb == USE_KB ) {
1173
if( prefetchB && prefetch.BCols )
1175
emit( "prefetchB += J_UNROLL*KB*%d;\n", element.size );
1178
if( prefetchB && prefetch.BCols )
1180
emit( "prefetchB += J_UNROLL*%d;\n", element.size*options.ldb );
1186
void printFullyUnrolledK(
1187
BOOL prefetchA, /* Allow prefetches of A? */
1188
BOOL prefetchB /* Allow prefetches of B? */
1191
* Print the K loop fully unrolled.
1194
ATL_INT prefetchB_counter = 0;
1195
ATL_INT blockACounter = 0;
1197
ATL_INT pfColumn = 0;
1198
ATL_INT prefetchACols = 0;
1201
const ATL_INT numABlockPrefetches =
1202
unroll.a * unroll.b * element.size / CACHE_LINE_SIZE + 1;
1204
const ATL_INT numAColPrefetches =
1205
unroll.a * element.size / CACHE_LINE_SIZE + 1;
1208
* (Number of b unrolls: so we can fetch one line each time)
1209
* (KB * element.size): number of bytes to fetch
1211
const ATL_INT numBPrefetches =
1212
unroll.b * unroll.kb * element.size / CACHE_LINE_SIZE + 1;
1214
const ATL_INT prefetchBDelta =
1215
unroll.kb * element.size / numBPrefetches;
1217
k_unrolling0(); /* Specialize first unrolling */
1220
* Emit unrollings by vector size
1222
for( k=element.vector_stride; k<unroll.kb; k+=element.vector_stride )
1225
* Prefetch an element from the next block of A
1227
if( prefetch.ABlock && prefetchA && blockACounter < numABlockPrefetches )
1229
emit("/* Prefetch one element from the next block of A */\n");
1230
emit("__builtin_prefetch( prefetchABlock + %d*pfBlockDistance,"
1231
"PF_READONLY, PF_DEF );\n", blockACounter );
1235
k_unrollingFullStep( k );
1237
if( prefetchB && prefetchB_counter < numBPrefetches )
1239
ATL_INT row = prefetchB_counter % unroll.b;
1241
if( options.ldb == PARAMETER )
1243
emit( "__builtin_prefetch( prefetchB + %d*ldb_bytes + %d,"
1244
"PF_READONLY, PF_DEF);\n",
1245
row, prefetchB_counter / unroll.b * prefetchBDelta );
1247
emit( "__builtin_prefetch( prefetchB + %d + %d,"
1248
"PF_READONLY, PF_DEF);\n",
1249
row*options.ldb*element.size,
1250
prefetchB_counter / unroll.b * prefetchBDelta );
1253
prefetchB_counter++;
1258
* Prefetch some columns of A further along
1260
if( prefetch.ACols )
1262
switch( options.lda )
1265
emit( "__builtin_prefetch( prefetchACols+%d+KB*%d,"
1266
"PF_READONLY, PF_DEF );\n",
1267
prefetchACols, pfColumn*element.size );
1271
emit( "__builtin_prefetch( prefetchACols+%d+lda_bytes*%d,"
1272
"PF_READONLY, PF_DEF );\n",
1273
prefetchACols, pfColumn );
1277
emit( "__builtin_prefetch( prefetchACols+lda_bytes*%d+%d,"
1278
"PF_READONLY, PF_DEF );\n",
1279
pfColumn, prefetchACols );
1283
if( pfColumn == unroll.a )
1285
prefetchACols += CACHE_LINE_SIZE;
1292
void printPartiallyUnrolledK
1294
BOOL prefetchA, /* Allow prefetches of A? */
1295
BOOL prefetchB /* Allor prefetches of B? */
1298
* Print a partially unrolled K loop, only used when
1299
* K != KB, otherwise the fully unrolled k loop function
1302
* The prefetch parameters allow prefetching on this iteration,
1303
* but do not require it. This allows separate passes to enable
1304
* prefetching on each input (For peeling, etc.).
1308
ATL_INT prefetchB_counter = 0;
1311
for( offset=element.vector_stride; offset<unroll.k; offset+=element.vector_stride )
1313
k_unrollingFullStep( offset );
1316
emit( "/* k unroll factor: %d */\n", unroll.k );
1318
emit( "for( k=%d; k<%d; k+=%d)\n", unroll.k, unroll.kb, unroll.k );
1322
for( offset=0; offset<unroll.k; offset+=element.vector_stride )
1324
k_partialUnrolling( offset );
1334
BOOL prefetchA, /* Allow prefetch of A */
1335
BOOL prefetchB /* Allow prefetch of B */
1337
/* Print a rolled K Loop */
1344
for( k=0; k<unroll.b; ++k )
1346
emit( "__builtin_prefetch( prefetchB + %d*KB, PF_READONLY, PF_DEF );"
1348
emit( "prefetchB += CACHE_LINE_SIZE;\n" );
1355
* If prefetching is allowed on this iteration, and it is
1356
* globally enabled, prefetch an element from the next block of A
1358
if( prefetchA && prefetch.ABlock )
1360
emit("/* Prefetch one element from the next block of A */\n");
1361
emit("__builtin_prefetch( prefetchABlock,PF_READONLY,PF_DEF );\n");
1362
emit("prefetchABlock += pfBlockDistance;\n");
1366
k_unrolling0(); /* Print the initial unrolling */
1369
* Print the rolled K loop, adjusting for the peeled iteration
1374
emit( "for( k=%d; k<K; k+=%d )\n", element.vector_stride, element.vector_stride );
1378
emit( "for( k=%d; k<KB; k+=%d )\n",
1379
element.vector_stride, element.vector_stride );
1383
emit( "for( k=%d; k<%d; k+=%d )\n",
1384
element.vector_stride, unroll.kb, element.vector_stride );
1389
k_partialUnrolling( 0 );
1396
void printScalarCompressionSingle()
1398
* Print the scalar compression routine for single precision.
1399
* This will take the four floats in the vector unit and combine
1400
* them into one. If possible, four vector units will be compressed
1401
* into one vector that can be written directly to memeory.
1405
emit("/* Single Scalar Compression */\n" );
1406
for( b=0; b<unroll.b; ++b )
1409
remaining = unroll.a;
1411
for( ; remaining>=4; remaining-=4 )
1414
* -> [c0a+c0b, c0c+c0d, c1a+c1b, c1c+c1d]
1415
* -> [c2a+c2b, c2c+c2d, c3a+c3v, c3c+c3d]
1416
* -> [c0a+c0b+c0c+c0d, c1a+c1b+c1c+c1d,
1417
* c2a+c2b+c2c+c2d, c3a+c3b+c3c+c3d ]
1419
emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b,a+1 );
1420
emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a+2, b, a+2, b, a+3 );
1421
emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a,b, a+2);
1425
for( ; remaining > 0; remaining-- )
1427
emit( "/* additional remaining step */\n" );
1428
emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b, a );
1429
emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b, a );
1436
void printScalarCompression()
1438
* Emit the scalar compression algorithm for double precision
1443
emit("/* Combine scalar expansion back to scalar */\n");
1446
* If I is unrolled an even number of times, the horizontal adds
1447
* will have no cleanup case (for doubles)
1449
if( unroll.a % element.vector_stride == 0 )
1451
for( b=0; b<unroll.b; ++b )
1453
for( a=0; a<unroll.a; a+=2 )
1455
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1456
b, a, element.cType, b, a, b, a+1 );
1463
* There are not an even number of vectors in this store
1467
emit( "/* handling uneven case */\n" );
1468
for( b=0; b<unroll.b; ++b )
1470
for( a=0; a<unroll.a-1; a+=2 )
1472
if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
1474
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1475
b, a, element.cType, b, a, b, a+1 );
1476
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1477
b, a, element.cType, b, a, b, a+1 );
1479
emit( "/* double */\n" );
1480
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1481
b, a, element.cType, b, a, b, a+1 );
1485
if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
1488
* Cleanup case, run when there are not an even number of vectors.
1490
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1491
b, a, element.cType, b, a, b, a );
1492
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1493
b, a, element.cType, b, a, b, a );
1495
emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
1496
b, a, element.cType, b, a, b, a );
1504
int alignmentOfC /* Is C aligned, unaligned, or alternating? */
1507
* Store the results of the I loop iteration to memory.
1508
* This will update the C matrix
1512
emit("/* Store results back to memory */\n");
1513
for( b=0; b<unroll.b; ++b )
1516
remaining = unroll.a;
1522
if( alignmentOfC == ALIGNED )
1524
store = element.aStore;
1528
store = element.uStore;
1532
if( element.type == SINGLE || element.type == DOUBLE )
1534
for( ; remaining>=element.vector_stride;
1535
remaining-=element.vector_stride )
1539
if( useVoidPointersForC )
1541
emit("%s( MMCAST( cPtrI%d+%d ), MMCASTStoreintrin( c%d_%d ) );\n",
1542
store, b, a*element.size,b,a );
1545
emit("%s( MMCAST( cPtrI%d+%d ), MMCASTStoreintrin( c%d_%d ) );\n",
1546
store, b, a, b, a );
1548
emit("%s( MMCAST( cPtrI%d ), MMCASTStoreintrin( c%d_%d ) );\n",
1551
a += element.vector_stride;
1554
for( ; remaining > 0; remaining-- )
1558
if( useVoidPointersForC )
1559
emit("%s( cPtrI%d+%d, c%d_%d );\n",
1560
element.sStore, b, a*element.size,b,a );
1562
emit("%s( cPtrI%d+%d, c%d_%d );\n",
1563
element.sStore, b, a, b, a );
1565
emit("%s( cPtrI%d, c%d_%d );\n",
1566
element.sStore, b, b, a );
1577
for( ; remaining>=element.vector_stride;
1578
remaining -= element.vector_stride )
1580
if( element.type == COMPLEX_SINGLE )
1582
if( useVoidPointersForC )
1584
/* @TODO: Fix this copy */
1585
emit( "temp = c%d_0;\n", b );
1586
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x) );
1588
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1589
"_MM_SHUFFLE(1,1,1,1));\n", b, a, b, a );
1590
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+8) );
1592
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1593
"_MM_SHUFFLE(2,2,2,2));\n", b, a, b, a );
1594
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+16) );
1596
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1597
"_MM_SHUFFLE(3,3,3,3));\n", b, a, b, a );
1598
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+24) );
1600
emit( "temp = c%d_%d;\n", b, a );
1601
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x) );
1603
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1604
"_MM_SHUFFLE(1,1,1,1));\n", b, a, b, a );
1605
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+2) );
1607
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1608
"_MM_SHUFFLE(2,2,2,2));\n", b, a, b, a );
1609
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+4) );
1611
emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
1612
"_MM_SHUFFLE(3,3,3,3));\n", b, a, b, a );
1613
emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+6) );
1618
} else if( element.type == COMPLEX_DOUBLE ) {
1620
if( useVoidPointersForC )
1622
emit( "_mm_store_sd( cPtrI%d+%d, c%d_%d );\n",
1623
b, IOffset, b, COffset );
1624
emit( "temp = _mm_shuffle_pd( c%d_%d, c%d_%d,"
1625
"_MM_SHUFFLE2(1,1));\n", b, COffset, b, COffset );
1628
emit( "_mm_store_sd( cPtrI%d+%d, temp );\n\n",
1633
emit( "_mm_store_sd( cPtrI%d+%d, c%d_%d );\n",
1634
b, IOffset, b, COffset );
1635
emit( "temp = _mm_shuffle_pd( c%d_%d, c%d_%d,"
1636
"_MM_SHUFFLE2(1,1));\n", b, COffset, b, COffset );
1639
emit( "_mm_store_sd( cPtrI%d+%d, temp );\n\n",
1647
for( ; remaining >0; --remaining )
1649
if( useVoidPointersForC )
1650
emit( "_mm_store_s%c( cPtrI%d+%d, c%d_%d );\n",
1651
element.cType, b, 2*(a+x)*element.size, b, a+x );
1653
emit( " _mm_store_s%c( cPtrI%d+%d, c%d_%d );\n",
1656
2*(unroll.a-remaining),
1658
(unroll.a-remaining) );
1669
* Apply the beta factor, if one is needed.
1674
emit( "/* Applying Beta */\n" );
1676
if( options.beta == BETA0 )
1678
emit("/* No beta will be appied */\n" );
1683
emit( "/* Apply Beta Factor */\n" );
1684
for( b=0; b<unroll.b; ++b )
1686
ATL_INT remaining = unroll.a;
1688
if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
1691
emit( "/* Load C from memory */\n" );
1693
remaining>=element.vector_stride;
1694
remaining-=element.vector_stride )
1697
for( ; x<element.vector_stride; ++x )
1699
if( useVoidPointersForC )
1701
emit("temp%d = %s( cPtrI%d + %d );\n", x,
1702
element.sLoad, b, 2*(element.size*(x+a)) );
1704
emit("temp%d = %s( cPtrI%d + %d );\n", x,
1705
element.sLoad, b, 2*(x+a) );
1710
if( element.type == COMPLEX_SINGLE )
1712
/* temp0 = temp0[0], temp1[0], temp0[1], temp1[1] */
1713
emit( "temp0 = _mm_unpacklo_ps( temp0, temp1 );\n" );
1715
/* temp2 = temp2[0], temp3[0], temp2[1], temp3[1] */
1716
emit( "temp2 = _mm_unpacklo_ps( temp2, temp3 );\n" );
1719
emit( "bc%d_%d = _mm_movelh_ps( temp0, temp2 );\n",
1722
/* b?_? = temp0[0], temp1[0], temp2[0], temp2[0] */
1725
if( element.type == COMPLEX_DOUBLE )
1727
emit( "bc%d_%d = _mm_shuffle_pd( temp0, temp1,"
1728
"_MM_SHUFFLE2(0, 0 ) );\n", b, a );
1732
if( options.beta != BETA1 )
1736
emit("bc%d_%d = _mm_mul_p%c( betaV, bc%d_%d );\n",
1737
b,a, element.cType, b,a );
1739
emit("bc%d_%d = _mm_mul_pd%c bc%d_%d, betaV );\n",
1740
b,a, element.cType, b,a );
1743
a+= element.vector_stride;
1745
for( ; remaining > 0; --remaining )
1747
emit( "/* %d remaining */\n", remaining );
1748
if( useVoidPointersForC )
1750
emit("bc%d_%d = %s( cPtrI%d+%d );\n",
1751
b, a, element.sLoad, b, 2*a*element.size );
1753
emit("bc%d_%d = %s( cPtrI%d+%d );\n",
1754
b, a, element.sLoad, b, 2*a );
1758
if( options.beta != BETA1 )
1762
emit("bc%d_%d = _mm_mul_s%c( betaV, bc%d_%d );\n",
1763
b,a, element.cType, b,a );
1765
emit("bc%d_%d = _mm_mul_s%c( bc%d_%d, betaV );\n",
1766
b,a, element.cType, b,a );
1773
emit( "/* Load C from memory */\n" );
1776
remaining>=element.vector_stride;
1777
remaining-=element.vector_stride )
1779
if( useVoidPointersForC )
1780
emit("bc%d_%d = _mm_loadu_p%c( cPtrI%d+%d );\n", b, a,
1781
element.cType, b, a*element.size );
1783
emit("bc%d_%d = _mm_loadu_p%c( cPtrI%d+%d );\n", b, a,
1784
element.cType, b, a );
1786
if( options.beta != BETA1 )
1790
emit("bc%d_%d = _mm_mul_p%c( betaV, bc%d_%d );\n",
1791
b,a, element.cType, b,a );
1793
emit("bc%d_%d = _mm_mul_p%c( bc%d_%d, betaV );\n",
1794
b,a, element.cType, b,a );
1797
a+= element.vector_stride;
1799
for( ; remaining > 0; --remaining )
1801
emit( "/* %d remaining */\n", remaining );
1802
if( useVoidPointersForC )
1804
emit("bc%d_%d = %s( cPtrI%d+%d );\n",
1805
b, a, element.sLoad, b, a*element.size );
1807
emit("bc%d_%d = %s( cPtrI%d+%d );\n",
1808
b, a, element.sLoad, b, a );
1811
if( options.beta != BETA1 )
1815
emit("bc%d_%d = _mm_mul_s%c( betaV, bc%d_%d );\n",
1816
b,a, element.cType, b,a );
1818
emit("bc%d_%d = _mm_mul_s%c( bc%d_%d, betaV );\n",
1819
b,a, element.cType, b,a );
1827
emit( "/* C = (beta*C) + (matrix multiply) */\n" );
1828
for( b=0; b<unroll.b; ++b )
1830
ATL_INT remaining = unroll.a;
1832
for( ; remaining >= element.vector_stride;
1833
remaining -= element.vector_stride )
1837
emit("c%d_%d = _mm_add_p%c( bc%d_%d, c%d_%d );\n",
1838
b, a, element.cType, b, a, b, a );
1840
emit("c%d_%d = _mm_add_p%c( c%d_%d, bc%d_%d );\n",
1841
b, a, element.cType, b, a, b, a );
1843
a += element.vector_stride;
1846
for( ; remaining > 0; --remaining )
1850
emit( "c%d_%d = _mm_add_s%c( bc%d_%d, c%d_%d );\n",
1851
b, a, element.cType, b, a, b, a );
1853
emit( "c%d_%d = _mm_add_s%c( c%d_%d, bc%d_%d );\n",
1854
b, a, element.cType, b, a, b, a );
1867
int alignmentOfC, /* How is C aligned? */
1868
char* name /* What is the name of this alignment */
1871
* Print the I,J,K loops, accounting for a specific form of alignment,
1872
* aligned, unaligned, alternating
1879
* We are fetching (MB*KB*eltsize) bytes (one block of A).
1880
* If we prefetch outside the KB loop, we will have (MB/mu)*(NB/nu)
1881
* prefetch opportunities.
1883
* Therefore, we must fetch (MB*KB*eltsize)/ [(MB/mu)*(NB/nu)]
1884
* which is equal to (mu*nu*KB*eltsize)/NB.
1886
* We still need to take care of the m-loop peeled case,
1887
* when there is one less iteration of the m loop.
1890
if( prefetch.ABlock )
1892
emit("const ATL_INT pfBlockDistance = (%d * %d * KB * %d) / %s;\n",
1893
unroll.a, unroll.b, element.size, nb );
1897
emit("/* =======================================\n" );
1898
emit(" * Begin generated inner loops for case %s\n", name );
1899
emit(" * ======================================= */\n" );
1904
emit("for( j=-NB; j!=0; j+=J_UNROLL) \n" );
1907
emit("for( j=-KB; j!=0; j+=J_UNROLL) \n" );
1910
emit("for( j=-%d; j!=0; j+=J_UNROLL) \n", unroll.nb );
1917
if( useVoidPointersForA )
1918
emit("register void const *A0_off = (void*)A; \n");
1920
emit("register TYPE const *A0_off = A; \n");
1923
if( options.lda == PARAMETER && unroll.a > 2 )
1925
emit("register void const *A3_off = A0_off + lda_bytes3;\n");
1928
emit( "register void const *A5_off = A3_off + lda_bytes*2;\n" );
1933
if( useVoidPointersForC )
1934
emit( "register void *cPtrI0 = (void*)cPtr;\n" );
1936
emit( "register TYPE *cPtrI0 = cPtr;\n" );
1939
for( b=1; b<unroll.b; ++b )
1941
emit("register TYPE *cPtrI%d = cPtrI%d + ldc_bytes;\n", b, b-1);
1947
if( prefetch.fetchC == TRUE )
1949
for( b=0; b<unroll.b; ++b )
1951
emit("__builtin_prefetch( cPtrI%d, PF_READONLY, PF_DEF );\n", b );
1956
char* deltaStr = "";
1958
* Peel the last iteration of the inner loop if prefetch should run on B
1960
if( prefetch.BCols == TRUE )
1962
deltaStr = "+I_UNROLL";
1967
emit("for( i=-%s%s; i != 0; i+= I_UNROLL )\n", mb, deltaStr );
1971
printILoop( alignmentOfC, TRUE, FALSE );
1973
emit("} /* End i/MB loop */\n\n");
1975
if( prefetch.BCols == TRUE )
1977
printILoop( alignmentOfC, FALSE, TRUE );
1980
switch( options.ldb )
1983
if( useVoidPointersForB )
1984
emit( "B0_off += J_UNROLL*ldb_bytes;\n");
1986
emit( "B0_off += J_UNROLL*ldb_bytes;\n");
1990
if( useVoidPointersForB )
1991
emit( "B0_off += J_UNROLL*KB%d;\n", element.size );
1993
emit( "B0_off += J_UNROLL*KB;\n" );
1997
if( useVoidPointersForB )
1998
emit( "B0_off += J_UNROLL*%d*sizeof(TYPE);\n", options.ldb );
2000
emit( "B0_off += J_UNROLL*%d;\n", options.ldb );
2002
emit( "cPtr += J_UNROLL*ldc_bytes;\n" );
2005
emit("} /* End j/NB loop */\n");
2006
emit("/* End of generated inner loops */\n");
2011
* Print one iteration of the middle loop, including beta adjustments,
2012
* summing of the inner loop, and iteration along the matricies.
2015
int alignmentOfC, /* How is the data aligned? */
2016
BOOL prefetchA, /* Allow prefetch of A during this iteration? */
2017
BOOL prefetchB /* Allow prefetch of B during this iteration? */
2026
if( prefetch.prefetchC )
2028
if( element.type == SINGLE || element.type == DOUBLE )
2030
for( b=0; b<unroll.b; ++b )
2032
for( offset=0; offset<unroll.a; offset+=2 )
2036
emit("__builtin_prefetch( cPtrI%d+%d, PF_READONLY, PF_DEF );\n",
2037
b, offset*element.size );
2039
emit("__builtin_prefetch( cPtrI%d, PF_READONLY, PF_DEF );\n",
2049
printAllKUnrollings( prefetchA, prefetchB );
2054
* Scalar compression of singles and doubles behaves differently
2056
if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
2058
printScalarCompressionSingle();
2062
printScalarCompression();
2068
* Apply the beta scaling factor
2075
* Move to the next iteration
2077
emit("/* Move pointers to next iteration */ \n");
2078
emit("A0_off += unroll_a;\n");
2080
if( options.lda == PARAMETER && unroll.a > 2 )
2082
emit("A3_off += unroll_a;\n");
2084
emit( "A5_off += unroll_a;\n" );
2090
* Store the results of the computation back to memory
2092
storeResults( alignmentOfC );
2097
* Increment Pointers
2099
for( b=0; b<unroll.b; ++b )
2101
if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
2103
if( useVoidPointersForC )
2104
emit("cPtrI%d += %d*2*I_UNROLL;\n", b, element.size );
2106
emit("cPtrI%d += 2*I_UNROLL;\n", b );
2108
if( useVoidPointersForC )
2109
emit("cPtrI%d += %d*I_UNROLL;\n", b, element.size );
2111
emit("cPtrI%d += I_UNROLL;\n", b );
2119
if( prefetch.ACols )
2121
if( options.lda == PARAMETER )
2122
emit( "prefetchACols += %d*lda;\n", unroll.a );
2123
else if( options.lda == USE_KB )
2124
emit( "prefetchACols += %d*KB;\n", unroll.a );
2126
emit( "prefetchACols += %d;\n", unroll.a * options.lda * element.size );
2132
void emit( const char *fmt, ...)
2134
* Write a string to the output file. It will be indented
2135
* by the current indent value, vall indent(int) to adjust
2136
* the indent factor. indent(1) will do a standard scoping
2137
* indent, and indent(-1) will remove that indent.
2140
assert( tabwidth >= 0 );
2145
for( t=0; t<tabwidth; ++t )
2147
fprintf( options.outputLocation, " ");
2149
vfprintf( options.outputLocation, fmt, arg);
2154
void emitCat( const char *fmt, ...)
2156
* Write a string to the output file. Do not do indenting.
2157
* This is used when a line of code is emitted in parts.
2162
vfprintf( options.outputLocation, fmt, arg);
2168
void indent( int delta )
2170
* Adjust the current indent for an emit call. See emit().
2174
assert( tabwidth >= 0 );
2178
static void loadDefaults()
2180
* Load all default settings
2183
prefetch.ACols = FALSE;
2184
prefetch.ABlock = FALSE;
2185
prefetch.BCols = FALSE;
2186
prefetch.fetchC = FALSE;
2187
prefetch.prefetchC = FALSE;
2195
unroll.k = unroll.kb;
2197
element.vector_length_bytes = 16;
2198
element.type = DOUBLE;
2201
options.cAlignment = 1;
2202
options.ABAligned = TRUE;
2203
options.lda = PARAMETER;
2204
options.ldb = PARAMETER;
2205
options.ldc = PARAMETER;
2206
options.verifyUnrollings = FALSE;
2207
options.treatLoadsAsFloat = FALSE;
2208
options.treatStoresAsFloat = FALSE;
2209
options.constantFolding = TRUE;
2210
options.beta = BETA0;
2211
options.outputLocation = stdout;
2217
int argc, /* Number of command line elements */
2218
char **argv /* Command line elements */
2221
* Load the options from the command line
2227
/* Determine if the user requested the help message */
2228
requestHelp( argc, argv );
2230
/* Load all default settings */
2233
/* Load unrolling and blocking factors */
2234
loadInt( "-m", &unroll.a, argc, argv );
2235
loadInt( "-n", &unroll.b, argc, argv );
2236
loadInt( "-k", &unroll.k, argc, argv );
2238
loadInt( "-M", &unroll.mb, argc, argv );
2239
loadInt( "-N", &unroll.nb, argc, argv );
2240
loadInt( "-K", &unroll.kb, argc, argv );
2244
unroll.k = unroll.kb;
2246
/* Load the beta factor */
2247
loadInt( "-beta", &options.beta, argc, argv );
2248
if( options.beta > 1 || options.beta < -1 )
2249
options.beta = BETAX;
2252
/* Determine if alignment checks are requested */
2253
loadInt( "-CAlignment", &options.cAlignment, argc, argv );
2254
loadBool( "-ABAligned", &options.ABAligned, argc, argv );
2256
/* Load the element type: float or double, complex or real */
2257
s = loadString( "-p", argc, argv );
2258
convertElementType( s[0] );
2260
/* Where should the file be written? */
2261
s = loadString( "-f", argc, argv );
2262
setOutputLocation( s );
2265
loadBool( "-verifyUnrollings", &options.verifyUnrollings, argc, argv );
2266
loadInt( "-lda", &options.lda, argc, argv );
2267
loadInt( "-ldb", &options.ldb, argc, argv );
2268
loadInt( "-ldc", &options.ldc, argc, argv );
2269
loadBool( "-treatLoadsAsFloat", &options.treatLoadsAsFloat, argc, argv );
2270
loadBool( "-treatStoresAsFloat", &options.treatStoresAsFloat, argc, argv );
2272
if( options.lda == unroll.kb )
2273
options.lda = USE_KB;
2275
if( options.ldb == unroll.kb )
2276
options.ldb = USE_KB;
2279
loadBool( "-prefetchACols", &prefetch.ACols, argc, argv );
2280
loadBool( "-prefetchABlock", &prefetch.ABlock, argc, argv );
2281
loadBool( "-prefetchBCols", &prefetch.BCols, argc, argv );
2282
loadBool( "-FF", &prefetch.fetchC, argc, argv );
2284
loadBool( "-prefetchCelts", &prefetch.prefetchC, argc, argv );
2285
loadBool( "-constantFolding", &options.constantFolding, argc, argv );
2289
switch( element.type )
2291
case COMPLEX_DOUBLE:
2293
element.size = sizeof( double );
2295
element.vector_stride = 2;
2297
strcpy( element.intrinsic, "__m128d" );
2300
/* Use the shorter instruction? */
2301
if( options.treatLoadsAsFloat )
2303
if( options.ABAligned )
2305
strcpy( element.load_ab, "(__m128d)_mm_load_ps" );
2307
strcpy( element.load_ab, "(__m128d)_mm_loadu_ps" );
2309
strcpy( element.sLoad, "(__m128d)_mm_load_ss" );
2311
if( options.ABAligned )
2313
strcpy( element.load_ab, "_mm_load_pd" );
2315
strcpy( element.load_ab, "_mm_loadu_pd" );
2317
strcpy( element.sLoad, "_mm_load_sd" );
2320
/* Use the shorter instruction? */
2321
if( options.treatStoresAsFloat )
2323
strcpy( element.aStore, "_mm_store_ps" );
2324
strcpy( element.uStore, "_mm_storeu_ps" );
2325
strcpy( element.sStore, "_mm_store_ss" );
2327
strcpy( element.aStore, "_mm_store_pd" );
2328
strcpy( element.uStore, "_mm_storeu_pd" );
2329
strcpy( element.sStore, "_mm_store_sd" );
2332
strcpy( element.type_name, "double" );
2335
case COMPLEX_SINGLE:
2337
element.size = sizeof( float );
2339
element.vector_stride = 4;
2342
strcpy( element.intrinsic, "__m128" );
2343
strcpy( element.aStore, "_mm_store_ps" );
2344
strcpy( element.uStore, "_mm_storeu_ps" );
2345
strcpy( element.sStore, "_mm_store_ss" );
2346
strcpy( element.sLoad, "_mm_load_ss" );
2349
if( options.ABAligned )
2351
strcpy( element.load_ab, "_mm_load_ps" );
2353
strcpy( element.load_ab, "_mm_loadu_ps" );
2356
strcpy( element.type_name, "float" );
2366
if( numArgsProcessed != (argc-1)/2 )
2369
fprintf( stderr, "Commandline contained unknown arguments\n" );
2370
printf( "There were %d args\n", argc );
2371
for( i=0; i<argc; ++i )
2373
printf( " %s\n", argv[i] );
2375
for( i=0; i<argc; ++i )
2377
printf( " %s\n", argv[i] );
2385
assert( unroll.kb == 0 || unroll.kb % element.vector_stride == 0 );
2392
char* tag, /* The name of the parameter to load */
2393
int argc, /* Number of command line arguments */
2394
char** argv /* Command line arguments */
2397
* Load a string from the command line and return it.
2398
* RETURNS: The string loaded from the command line,
2399
* 0 if the parameter was not present.
2404
for( i=0; i<argc; ++i )
2406
if( strcmp( tag, argv[i] ) == 0 )
2408
assert( i+1 < argc );
2419
int argc, /* Number of command line parameters */
2420
char** argv /* The command line parameters */
2423
* Determine if the user requested the help message
2424
* by passing a special switch to the program.
2429
const char* tags[] = { "-?", "-h", "--help" };
2431
for( flag=0; flag<3; ++flag )
2433
for( i=0; i<argc; ++i )
2435
if( strcmp( tags[ flag ], argv[i] ) == 0 )
2446
* Print the help message to standard output.
2449
fprintf( stdout, "Prints a listing of a GEMM kernel\n" );
2450
fprintf( stdout, "Optional Arguments:\n" );
2452
fprintf( stdout, " -p [s,c,z,d] \n" );
2453
fprintf( stdout, " -f <filename> => File to generate\n" );
2455
/* Print Unrolling Options */
2456
fprintf( stdout, " -m <int> => Number of columns of A to unroll\n" );
2457
fprintf( stdout, " -n <int> => Number of rows of B to unroll\n" );
2458
fprintf( stdout, " -k <int> => Numver of times to unroll along K\n" );
2460
/* Print blocking options */
2461
fprintf( stdout, " -M <int> => block size, integer or 0 for runtime\n" );
2462
fprintf( stdout, " -N <int> => block size, integer or 0 for runtime\n" );
2463
fprintf( stdout, " -K <int> => block size, integer or 0 for runtime\n" );
2464
fprintf( stdout, " -lda <int> => lda, or 0 for runtime, -1 to use KB\n" );
2465
fprintf( stdout, " -ldb <int> => ldb, or 0 for runtime, -1 to use KB\n" );
2466
fprintf( stdout, " -ldc <int> => ldc, or 0 for runtime, -1 to use KB\n" );
2468
/* print algorithm optimizations */
2469
fprintf( stdout, " -beta <int> => Assume beta value 0, 1, or" );
2470
fprintf( stdout, "other.\n" );
2471
fprintf( stdout, " -CAlignment => 0 to test all alignments, " );
2472
fprintf( stdout, "1 to assume misaligned.\n" );
2473
fprintf( stdout, " -ABAligned <bool> ==> assume that A and B are aligned\n" );
2474
fprintf( stdout, " -verifyUnrollings <bool> ==> assert that unrolling "
2475
"params are correct\n" );
2476
fprintf( stdout, " -treatLoadsAsFloat <bool> ==> use loadps instead of "
2478
fprintf( stdout, " -treatStoresAsFloat <bool> ==> use storeps instead of "
2481
/* Print prefetch options */
2482
fprintf( stdout, " -FF <0/1> => Fetch C at top of loop\n" );
2483
fprintf( stdout, " -prefetchACols <bool> ==> Prefetch the next "
2484
"unrolling of A\n" );
2485
fprintf( stdout, " -prefetchABlock <bool> ==> Prefetch the next "
2487
fprintf( stdout, " -prefetchBCols <bool> ==> Prefetch the next rows "
2489
fprintf( stdout, " -prefetchC <bool> ==> Prefetch C at the top of "
2492
fprintf( stdout, " -constantFolding <bool> ==> Perform constant folding "
2493
"when possible.\n" );
2502
char* tag, /* The parameter to find */
2503
ATL_INT *value, /* (output) Populated with the value following the tag */
2504
int argc, /* Number of command line parameters */
2505
char** argv /* The command line parameters */
2508
* Reads an integer value from the command line, as specified with a given
2509
* tag. Value is unchanged if the tag is not found on the command line.
2514
for( i=0; i<argc; ++i )
2516
if( strcmp( tag, argv[i] ) == 0 )
2518
assert( i+1 < argc );
2519
*value = atoi( argv[i+1] );
2529
char* tag, /* The parameter to find */
2530
float *value, /* (output) Populated with the value found */
2531
int argc, /* The number of command line parameters */
2532
char** argv /* The command line parameters */
2535
* Reads a float value from the command line, as specified with a given
2536
* tag. Value is unchanged if the tag is not found on the command line.
2541
for( i=0; i<argc; ++i )
2543
if( strcmp( tag, argv[i] ) == 0 )
2545
assert( i+1 < argc );
2546
*value = atof( argv[i+1] );
2555
char* tag, /* The parameter to find */
2556
int *value, /* The value of the parameter */
2557
int argc, /* Number of command line parameters */
2558
char** argv /* Command line parameters */
2561
* Reads a boolean value from the command line, as specified with a given
2562
* tag. Value is unchanged if the tag is not found on the command line.
2567
for( i=0; i<argc; ++i )
2569
if( strcmp( tag, argv[i] ) == 0 )
2571
assert( i+1 < argc );
2572
if( strcmp( argv[i+1], "1" ) == 0 )
2575
} else if( strcmp( argv[i+1], "0" ) == 0 ) {
2578
fprintf( stderr, "ERROR: option tag \"%s\" requires"
2579
" 1 or 0 value.\n", tag );
2593
char* out /* The string to write to */
2595
/* Generate a c compatible expression that defines
2596
* the value of NB. This may be a constant or a
2597
* variable from the generated program.
2603
sprintf( out, "%s", "N" );
2606
sprintf( out, "%s", "K" );
2609
sprintf( out, "%d", unroll.nb );
2617
char* out /* The string to write to */
2619
/* Generate a c compatible expression that defines
2620
* the value of MB. This may be a constant or a
2621
* variable from the generated program.
2627
sprintf( out, "%s", "M" );
2630
sprintf( out, "%s", "K" );
2633
sprintf( out, "%d", unroll.mb );
2640
char* out /* The string to write to */
2643
* Generate a c compatible expression that defines
2644
* the value of KB. This may be a constant or a
2645
* variable from the generated program.
2651
sprintf( out, "%s", "K" );
2654
sprintf( out, "%s", "KB" );
2657
sprintf( out, "%d", unroll.kb );
2663
void convertElementType
2665
char specifier /* The type specifier */
2668
* Interpret a type specifier to determine what
2669
* datatype is generated.
2677
element.type = SINGLE;
2678
element.cType = 's';
2682
element.type = DOUBLE;
2683
element.cType = 'd';
2687
element.type = COMPLEX_DOUBLE;
2688
element.cType = 'd';
2692
element.type = COMPLEX_SINGLE;
2693
element.cType = 's';
2697
fprintf( stderr, "Element type \"%c\" is not valid\n", specifier );
2703
void setOutputLocation
2705
char* file /* The name of the file to generate */
2708
* Set the output file location.
2709
* Use NULL for stdout, "" will not change the current setting.
2714
options.outputLocation = stdout;
2718
if( strcmp( file, "" ) != 0 )
2720
options.outputLocation = fopen( file, "w" );