~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to tune/blas/gemm/mmgen_sse.c

  • Committer: Package Import Robot
  • Author(s): Sébastien Villemot
  • Date: 2013-06-11 15:58:16 UTC
  • mfrom: (1.1.3 upstream)
  • mto: (2.2.21 experimental)
  • mto: This revision was merged to the branch mainline in revision 26.
  • Revision ID: package-import@ubuntu.com-20130611155816-b72z8f621tuhbzn0
Tags: upstream-3.10.1
Import upstream version 3.10.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *             Automatically Tuned Linear Algebra Software v3.10.1
 
3
 * Copyright (C) 2009 Chad Zalkin
 
4
 *
 
5
 * Code contributers : Chad Zalkin, R. Clint Whaley
 
6
 *
 
7
 * Redistribution and use in source and binary forms, with or without
 
8
 * modification, are permitted provided that the following conditions
 
9
 * are met:
 
10
 *   1. Redistributions of source code must retain the above copyright
 
11
 *      notice, this list of conditions and the following disclaimer.
 
12
 *   2. Redistributions in binary form must reproduce the above copyright
 
13
 *      notice, this list of conditions, and the following disclaimer in the
 
14
 *      documentation and/or other materials provided with the distribution.
 
15
 *   3. The name of the ATLAS group or the names of its contributers may
 
16
 *      not be used to endorse or promote products derived from this
 
17
 *      software without specific written permission.
 
18
 *
 
19
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 
20
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 
21
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 
22
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
 
23
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 
24
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 
25
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 
26
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 
27
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 
28
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 
29
 * POSSIBILITY OF SUCH DAMAGE.
 
30
 *
 
31
 */
 
32
/* #define DEBUG */
 
33
#ifdef DEBUG
 
34
   #define VERIFY 1
 
35
#endif
 
36
#include <stdio.h>
 
37
#include <stdarg.h>
 
38
#include <assert.h>
 
39
#include <math.h>
 
40
#include <stdlib.h>
 
41
#include <string.h>
 
42
 
 
43
#define ATL_INT int
 
44
 
 
45
typedef int BOOL;
 
46
#define TRUE 1
 
47
#define FALSE 0
 
48
 
 
49
#define SINGLE 0
 
50
#define DOUBLE 1
 
51
#define COMPLEX_SINGLE 2
 
52
#define COMPLEX_DOUBLE 3
 
53
 
 
54
#define BETAN1 -1
 
55
#define BETA0 0
 
56
#define BETA1 1
 
57
#define BETAX 2
 
58
 
 
59
#define PARAMETER 0
 
60
#define USE_KB -1
 
61
 
 
62
#define ALIGNED 0
 
63
#define ALIGN_NALIGN 1
 
64
#define NALIGN_ALIGN 2
 
65
#define NALIGNED 3
 
66
 
 
67
#define CACHE_LINE_SIZE 64
 
68
 
 
69
 
 
70
 
 
71
BOOL intrinSrcDest = TRUE;
 
72
BOOL useVoidPointersForC = FALSE;
 
73
BOOL useVoidPointersForA = FALSE;
 
74
BOOL useVoidPointersForB = FALSE;
 
75
 
 
76
 
 
77
/* Prefetch Options */
 
78
typedef struct
 
79
{
 
80
   BOOL ACols;
 
81
   BOOL ABlock;
 
82
   BOOL BCols;
 
83
   BOOL fetchC;
 
84
   BOOL prefetchC;
 
85
} Prefetch;
 
86
 
 
87
 
 
88
typedef struct
 
89
{
 
90
        ATL_INT a;
 
91
        ATL_INT b;
 
92
        ATL_INT k;
 
93
   ATL_INT mb;
 
94
   ATL_INT nb;
 
95
   ATL_INT kb;
 
96
} Unrolling;
 
97
 
 
98
 
 
99
typedef struct
 
100
{
 
101
        ATL_INT size;
 
102
        ATL_INT vector_stride;
 
103
        ATL_INT vector_length_bytes;
 
104
        ATL_INT shift;
 
105
        ATL_INT type;
 
106
   char cType;
 
107
   char type_name[7];
 
108
 
 
109
   /* Load vectors from A and B */
 
110
   char load_ab[25];
 
111
 
 
112
   /* Load/Store a single element */
 
113
        char sLoad[25];
 
114
        char sStore[25];
 
115
 
 
116
   /* Store an aligned/unaligned vector */
 
117
   char aStore[25];
 
118
   char uStore[25];
 
119
   char intrinsic[8];
 
120
} Element;
 
121
 
 
122
typedef struct
 
123
{
 
124
        ATL_INT cAlignment;
 
125
   BOOL ABAligned;
 
126
        ATL_INT lda;
 
127
        ATL_INT ldb;
 
128
   ATL_INT ldc;
 
129
        BOOL verifyUnrollings;
 
130
        BOOL treatLoadsAsFloat;
 
131
        BOOL treatStoresAsFloat;
 
132
   ATL_INT beta;
 
133
 
 
134
   BOOL constantFolding;
 
135
   FILE* outputLocation;
 
136
} GlobalOpts;
 
137
 
 
138
/* Load options from the command line */
 
139
void loadOptions( int argc, char **argv );
 
140
static void loadInt( char* tag, ATL_INT *value, int argc, char** argv );
 
141
static void loadFloat( char* tag, float *value, int argc, char** argv );
 
142
static void loadBool( char* tag, BOOL *value, int argc, char** argv );
 
143
static char* loadString( char* tag, int argc, char** argv );
 
144
static int requestHelp( int argc, char** argv );
 
145
static void convertElementType( char specifier );
 
146
static void setOutputLocation( char* file );
 
147
static void printHelp();
 
148
int numArgsProcessed = 0;
 
149
 
 
150
 
 
151
void getNBString( char* out );
 
152
void getMBString( char* out );
 
153
void getKBString( char* out );
 
154
 
 
155
/* These functions print the sections of the kernel */
 
156
void printMainLoops( int alignmentOfC, char *name );
 
157
void printILoop( int alignmentOfC, BOOL prefetchA, BOOL prefetchB );
 
158
 
 
159
void printPreamble();
 
160
void printIntro();
 
161
void printBody( BOOL simple );
 
162
 
 
163
/* These functions print the unrollings of the k loop */
 
164
void printAllKUnrollings( BOOL prefetchA, BOOL prefetchB );
 
165
void k_unrolling0();
 
166
void k_partialUnrolling( ATL_INT offset );
 
167
void k_unrollingFullStep( ATL_INT delta );
 
168
void printKRolled( BOOL prefetchA, BOOL prefetchB );
 
169
void printPartiallyUnrolledK( BOOL prefetchA, BOOL prefetchB );
 
170
void printFullyUnrolledK( BOOL prefetchA, BOOL prefetchB );
 
171
 
 
172
/* This compresses the vectors back to scalars */
 
173
void printScalarCompression();
 
174
void printScalarCompressionSingle();
 
175
void storeResults( int alignmentOfC );
 
176
void applyBeta();
 
177
 
 
178
/* These functions print the code to the console *
 
179
 * and manage the indention */
 
180
void emit( const char *fmt, ...);
 
181
void emitCat( const char *fmt, ...);
 
182
void indent( int delta );
 
183
 
 
184
 
 
185
/*
 
186
 * Keeps track of the number of indents in the emitted code.
 
187
 * This allows the system to properly nest braces in the output.
 
188
 */
 
189
int tabwidth=0;
 
190
 
 
191
/*
 
192
 * Create some shorthand for the load/store instructions
 
193
 * These are aliased to strings because one option adds a
 
194
 * typecast to the load/store instructions.  It wouldn't do
 
195
 * to have to test that option all the time....
 
196
 */
 
197
char nb[20];
 
198
char mb[20];
 
199
 
 
200
 
 
201
 
 
202
 
 
203
/* Options Variables */
 
204
Prefetch prefetch;
 
205
Unrolling unroll;
 
206
Element element;
 
207
GlobalOpts options;
 
208
 
 
209
 
 
210
 
 
211
int main(
 
212
      int argc,   /* Number of command line args */
 
213
      char** argv /* Array of command line args */
 
214
)
 
215
/*
 
216
 * Prints an implementation of GEMM given the parameters on the
 
217
 * command line.
 
218
 */
 
219
{
 
220
   loadOptions( argc, argv );  /* Read options from the command line */
 
221
 
 
222
        if( unroll.nb == USE_KB )
 
223
   {
 
224
                assert( unroll.kb % unroll.a == 0 );
 
225
      assert( unroll.kb % unroll.b == 0 );
 
226
   }
 
227
   else if( unroll.nb != PARAMETER )
 
228
   {
 
229
                assert( unroll.nb % unroll.a == 0 );
 
230
      assert( unroll.nb % unroll.b == 0 );
 
231
   }
 
232
 
 
233
   printPreamble();  /* Emit includes, defines, and variables */
 
234
 
 
235
 
 
236
/*
 
237
 * Emit three cases: Beta = 0, Beta = 1, Beta != 0,1
 
238
 * This allows conditional compilation of complex calculaions
 
239
 * which need all three cases.
 
240
 *
 
241
 * Each case includes a call to printIntro() which emits the
 
242
 * function header, and a call to printBody() which emits the
 
243
 * alignment cases as needed.
 
244
 */
 
245
 
 
246
   printIntro();
 
247
   printBody( options.cAlignment );
 
248
 
 
249
   assert( tabwidth == 0 );
 
250
 
 
251
   return 0;
 
252
}
 
253
 
 
254
 
 
255
 
 
256
void printPreamble()
 
257
/*
 
258
 * Print setup information such as CPP includes, data type defininitions,
 
259
 * and defines used to name constants.
 
260
 */
 
261
{
 
262
   emit( "#define ATL_INT int\n" );
 
263
 
 
264
/* Store some strings to represent the NB and MB constants */
 
265
   getNBString( nb );
 
266
   getMBString( mb );
 
267
 
 
268
 
 
269
/* Print includes */
 
270
   emit("#include <stdio.h>\n" );
 
271
   emit("#include <stdint.h>\n" );
 
272
   emit("#include <pmmintrin.h>\n" );
 
273
   emit("\n");
 
274
 
 
275
 
 
276
/* Emit some defines, so that the code is readable */
 
277
   emit( "#define I_UNROLL %d\n", unroll.a );
 
278
   emit( "#define J_UNROLL %d\n", unroll.b );
 
279
 
 
280
 
 
281
/* Setup the prefetch options */
 
282
   emit("/* Is prefetched data written or just read? */\n");
 
283
   emit( "#define PF_READONLY 0\n" );
 
284
   emit( "#define PF_READWRITE 1\n" );
 
285
   emit( "#define PF_NO_REUSE 0\n" );
 
286
 
 
287
   emit("\n/* Default temporality of cache prefetch (1-3) */\n");
 
288
   emit( "#define PF_DEF 1\n" );
 
289
   emit( "#define CACHE_LINE_SIZE %d\n", CACHE_LINE_SIZE );
 
290
 
 
291
   if( options.treatLoadsAsFloat )
 
292
   {
 
293
      emit( "#define MMCAST( a ) (float*)(a)\n" );
 
294
   }
 
295
   else
 
296
   {
 
297
      emit( "#define MMCAST( a ) (a)\n" );
 
298
   }
 
299
 
 
300
   if( options.treatStoresAsFloat )
 
301
   {
 
302
      emit( "#define MMCASTStore( a ) (float*)(a)\n" );
 
303
      emit( "#define MMCASTStoreintrin( a ) (__m128)(a)\n" );
 
304
   }
 
305
   else
 
306
   {
 
307
      emit( "#define MMCASTStore( a ) (a)\n" );
 
308
      emit( "#define MMCASTStoreintrin( a ) (a)\n" );
 
309
   }
 
310
}
 
311
 
 
312
 
 
313
void printIntro()
 
314
/*
 
315
 * Emit code that deduces constants used in this configuration,
 
316
 * data types, and the function prototype
 
317
 */
 
318
{
 
319
   ATL_INT a;
 
320
   ATL_INT b;
 
321
   ATL_INT x;
 
322
 
 
323
 
 
324
   emit( "#define TYPE %s\n", element.type_name );
 
325
   emit("void ATL_USERMM( const ATL_INT M, const ATL_INT N, const ATL_INT K,\n");
 
326
   emit("                 const TYPE alpha, const TYPE *A, const ATL_INT lda,\n");
 
327
   emit("                 const TYPE *B, const ATL_INT ldb,\n");
 
328
   emit("                 const TYPE beta, TYPE *C, const ATL_INT ldc )\n");
 
329
   emit("{\n");
 
330
 
 
331
   indent( 1 );
 
332
 
 
333
   emit("register ATL_INT i, j, k;\n");
 
334
   emit("\n");
 
335
 
 
336
 
 
337
/* Create variables for each of the vector registers */
 
338
   emit( "/* Vector registers to hold the elements of C */\n" );
 
339
   for( b=0; b<unroll.b; ++b )
 
340
   {
 
341
      emit( "%s c%d_0", element.intrinsic, b );
 
342
      for( a=1; a<unroll.a; ++a )
 
343
      {
 
344
         emitCat( ", c%d_%d", b, a );
 
345
      }
 
346
      emitCat( ";\n" );
 
347
   }
 
348
 
 
349
/*
 
350
 * If beta must be applied, create some registers for the
 
351
 * result
 
352
 */
 
353
   if( options.beta != BETA0 )
 
354
   {
 
355
      emit( "/* Vector register to hold C*beta */\n" );
 
356
      for( b=0; b<unroll.b; ++b )
 
357
      {
 
358
         ATL_INT remaining = unroll.a;
 
359
         ATL_INT a = 0;
 
360
         emit( "%s ", element.intrinsic );
 
361
 
 
362
/*       Create registers that include the entire stride */
 
363
         for( ; remaining > element.vector_stride;
 
364
                remaining -= element.vector_stride )
 
365
         {
 
366
            emitCat( "bc%d_%d", b, a );
 
367
            a += element.vector_stride;
 
368
            if( remaining != element.vector_stride )
 
369
            {
 
370
               emitCat( ", " );
 
371
            }
 
372
         }
 
373
/*       Create registers for elements that do not fit in the stride */
 
374
         for( ; remaining > 0; --remaining )
 
375
         {
 
376
            emitCat( "bc%d_%d", b, a );
 
377
            if( remaining != 1 )
 
378
            {
 
379
               emitCat( ", " );
 
380
            }
 
381
            ++a;
 
382
         }
 
383
 
 
384
         emitCat( ";\n" );
 
385
      }
 
386
   }
 
387
 
 
388
   emit( "/* Temporary vector registers for use in inner loop */\n" );
 
389
   emit("%s temp; \n", element.intrinsic );
 
390
 
 
391
   if( element.type == COMPLEX_DOUBLE || element.type == COMPLEX_SINGLE )
 
392
   {
 
393
      for( x=0; x<element.vector_stride; ++x )
 
394
      {
 
395
         emit("%s temp%d;  \n", element.intrinsic, x );
 
396
      }
 
397
   }
 
398
 
 
399
 
 
400
   if( options.verifyUnrollings == TRUE )
 
401
   {
 
402
      emit("assert(M%%%d==0);\n", unroll.a );
 
403
      emit("assert(N%%%d==0);\n", unroll.b );
 
404
   }
 
405
 
 
406
/* Start prefetching from B */
 
407
   if( prefetch.BCols )
 
408
   {
 
409
      emit("__builtin_prefetch( B, PF_READONLY, PF_DEF );\n");
 
410
   }
 
411
 
 
412
/* Load the beta factor so it will be ready to apply later */
 
413
   if( options.beta == BETAX || options.beta == BETAN1 )
 
414
   {
 
415
      if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
 
416
      {
 
417
         emit("const %s betaV = _mm_set1_ps( beta ); \n", element.intrinsic );
 
418
      } else {
 
419
         emit("const %s betaV = _mm_set1_pd( beta ); \n", element.intrinsic );
 
420
      }
 
421
   }
 
422
 
 
423
   emit("/* Pointer adjustments */  \n");
 
424
 
 
425
   if( options.ldb == PARAMETER )
 
426
   {
 
427
      emit("register const ATL_INT ldb_bytes = ldb << %d;\n", element.shift );
 
428
 
 
429
      if( unroll.b > 2 )
 
430
      {
 
431
         emit("register const ATL_INT ldb_bytes3 = ldb_bytes*3;\n" );
 
432
      }
 
433
   }
 
434
 
 
435
 
 
436
 
 
437
   if( options.lda == PARAMETER )
 
438
   {
 
439
      emit("register const ATL_INT lda_bytes = lda << %d;\n",
 
440
            element.shift );
 
441
      emit("register const ATL_INT lda_bytes3 = lda_bytes * 3;\n");
 
442
   }
 
443
 
 
444
 
 
445
/* Since complex matricies are strided by 2, ldc is off by a factor of 2 */
 
446
  if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
 
447
  {
 
448
      switch( options.ldc )
 
449
      {
 
450
      case PARAMETER:
 
451
         if( useVoidPointersForC )
 
452
            emit("register const ATL_INT ldc_bytes = ldc << %d;\n",
 
453
                  element.shift+1);
 
454
         else
 
455
            emit("register const ATL_INT ldc_bytes = 2*ldc;\n");
 
456
         break;
 
457
 
 
458
      case USE_KB:
 
459
         if( useVoidPointersForC )
 
460
            emit( "register const ATL_INT ldc_bytes = 2*KB%d;\n",
 
461
                  element.size );
 
462
         else
 
463
            emit( "register const ATL_INT ldc_bytes = 2*KB;\n" );
 
464
         break;
 
465
 
 
466
      default:
 
467
         if( useVoidPointersForC )
 
468
            emit( "register const ATL_INT ldc_bytes = 2*%d*%d;\n",
 
469
                  element.size, options.ldc );
 
470
         else
 
471
            emit( "register const ATL_INT ldc_bytes = 2*%d;\n",
 
472
                  options.ldc );
 
473
      }
 
474
      emit("\n");
 
475
   }
 
476
   else
 
477
   {
 
478
      switch( options.ldc )
 
479
      {
 
480
      case PARAMETER:
 
481
         if( useVoidPointersForC )
 
482
            emit("register const ATL_INT ldc_bytes = ldc << %d;\n",
 
483
                  element.shift);
 
484
         else
 
485
            emit("register const ATL_INT ldc_bytes = ldc;\n" );
 
486
         break;
 
487
 
 
488
      case USE_KB:
 
489
         if( useVoidPointersForC )
 
490
            emit( "register const ATL_INT ldc_bytes = KB*%d;\n",
 
491
                  element.size );
 
492
         else
 
493
            emit( "register const ATL_INT ldc_bytes = KB;\n" );
 
494
         break;
 
495
 
 
496
      default:
 
497
         if( useVoidPointersForC )
 
498
            emit( "register const ATL_INT ldc_bytes = %d*%d;\n",
 
499
                   element.size, options.ldc );
 
500
         else
 
501
            emit( "register const ATL_INT ldc_bytes = %d;\n",
 
502
                   options.ldc );
 
503
      }
 
504
      emit("\n");
 
505
   }
 
506
 
 
507
 
 
508
   if( useVoidPointersForB )
 
509
      emit("register void const *B0_off = (void*)B;\n");
 
510
   else
 
511
      emit("register TYPE const *B0_off = B;\n");
 
512
 
 
513
   emit("   \n");
 
514
 
 
515
   if( prefetch.ABlock )
 
516
   {
 
517
      emit("register void const * prefetchABlock = " );
 
518
      if( options.lda == PARAMETER )
 
519
         emitCat(" (void*)(A + %s*lda); \n", nb );
 
520
      else if( options.lda == USE_KB )
 
521
         emitCat(" (void*)(A + %s*KB); \n", nb );
 
522
      else
 
523
         emitCat(" (void*)(A + %s*%d);\n", nb, options.lda*element.size );
 
524
   }
 
525
 
 
526
   if( prefetch.ACols )
 
527
   {
 
528
      emit("register void const * prefetchACols = " );
 
529
      if( options.lda == PARAMETER )
 
530
         emitCat(" (void*)(A + %s*lda); \n", nb );
 
531
      else if( options.lda == USE_KB )
 
532
         emitCat(" (void*)(A + %s*KB); \n", nb );
 
533
      else
 
534
         emitCat(" (void*)(A + %s*%d);\n", nb, options.lda*element.size );
 
535
   }
 
536
 
 
537
   if( prefetch.BCols )
 
538
   {
 
539
      emit("register void const *prefetchB = " );
 
540
      if( unroll.mb == PARAMETER )
 
541
         emitCat(" (void*)(B + MB*ldb);\n" );
 
542
      else if( unroll.mb == USE_KB )
 
543
         emitCat(" (void*)(B + %d*ldb);\n", unroll.kb );
 
544
      else
 
545
         emitCat(" (void*)(B + %d*ldb);\n", unroll.mb );
 
546
   }
 
547
 
 
548
   if( prefetch.ACols )
 
549
   {
 
550
      emit("__builtin_prefetch( prefetchACols, PF_READONLY, PF_DEF );\n");
 
551
   }
 
552
 
 
553
   if( prefetch.BCols )
 
554
   {
 
555
      emit("__builtin_prefetch( prefetchB, PF_READONLY, PF_DEF );\n");
 
556
   }
 
557
 
 
558
   emit("\n");
 
559
 
 
560
   emit( "/* Unroll A */\n");
 
561
   emit( "%s A0, a0", element.intrinsic );
 
562
   for( a=1; a < unroll.a; ++a )
 
563
      emitCat( ", A%d, a%d", a, a );
 
564
   emitCat( ";\n" );
 
565
 
 
566
   emit( "/* Unroll B */\n" );
 
567
   emit( "%s B0", element.intrinsic );
 
568
   for( b=1; b < unroll.b; ++b )
 
569
      emitCat( ", B%d", b );
 
570
   if( unroll.b == 1 )
 
571
      emitCat( ", B1", b );
 
572
   emitCat( ";\n" );
 
573
 
 
574
 
 
575
   emit("\n\n");
 
576
 
 
577
   if( options.lda == PARAMETER )
 
578
   {
 
579
      emit("register const ATL_INT unroll_a = I_UNROLL*lda_bytes;\n");
 
580
   }
 
581
   else if( options.lda == USE_KB )
 
582
   {
 
583
      if( useVoidPointersForA )
 
584
         emit("register const ATL_INT unroll_a = I_UNROLL*KB%d;\n",
 
585
               element.size );
 
586
      else
 
587
         emit("register const ATL_INT unroll_a = I_UNROLL*KB;\n" );
 
588
   }
 
589
   else
 
590
   {
 
591
      if( useVoidPointersForA )
 
592
      {
 
593
         emit( "register const ATL_INT unroll_a = I_UNROLL*%d*%d;\n",
 
594
               options.lda, element.size );
 
595
      } else {
 
596
         emit( "register const ATL_INT unroll_a = I_UNROLL*%d;\n",
 
597
               options.lda );
 
598
      }
 
599
   }
 
600
 
 
601
 
 
602
   if( useVoidPointersForC )
 
603
      emit("register void* cPtr = (void*)C;\n" );
 
604
   else
 
605
      emit("register TYPE* cPtr = C;\n" );
 
606
 
 
607
 
 
608
   emit("\n\n");
 
609
 
 
610
}
 
611
 
 
612
 
 
613
void printBody
 
614
(
 
615
  BOOL simple /* If 1, unaligned case is assumed, else, generate aligned cases */
 
616
)
 
617
/*
 
618
 * Generate the I,J,K loops, accounting for the possibility that 4 loops
 
619
 * are needed to account for the aligned, unaligned, and two alternating
 
620
 * alignment cases.
 
621
 */
 
622
{
 
623
/* MAIN LOOPS */
 
624
 
 
625
   if( simple == TRUE )
 
626
   {
 
627
      printMainLoops( NALIGNED, "Non aligned" );
 
628
   }
 
629
   else
 
630
   {
 
631
      emit("const intptr_t ci = (intptr_t)cPtr;\n");
 
632
      emit("if( (ci + 15) >> %d << %d == ci )\n", element.shift, element.shift );
 
633
      emit("{\n" );
 
634
         indent(1);
 
635
         emit("if( ldc %% 2 == 0 )\n" );
 
636
         emit("{\n" );
 
637
            indent( 1 );
 
638
            printMainLoops( ALIGNED, "C Aligned" );
 
639
            indent(-1 );
 
640
         emit("} else {\n" );
 
641
            indent( 1 );
 
642
            printMainLoops( ALIGN_NALIGN, "C Aligned/Nonaligned columns" );
 
643
            indent(-1 );
 
644
         emit("} \n" );
 
645
         indent(-1);
 
646
      emit("} else { \n");
 
647
         indent( 1);
 
648
         emit("if( ldc %% 2 == 0 )\n" );
 
649
         emit("{\n");
 
650
            indent( 1 );
 
651
            printMainLoops( NALIGNED, "C Nonaligned" );
 
652
            indent(-1 );
 
653
         emit("} else {\n");
 
654
            indent( 1 );
 
655
            printMainLoops( NALIGN_ALIGN, "C Nonaligned/Aligned columns" );
 
656
            indent(-1 );
 
657
         emit("}\n");
 
658
         indent(-1);
 
659
      emit("}\n" );
 
660
   }
 
661
   indent(-1);
 
662
   emit("}\n" );
 
663
   return;
 
664
}
 
665
 
 
666
 
 
667
char* ldaOffset
 
668
(
 
669
 ATL_INT times,  /* How far to offset from the base value? */
 
670
 ATL_INT offset
 
671
)
 
672
/*
 
673
 * Returns a compilable code string that will evaluate to
 
674
 * a byte offset in terms of lda.
 
675
 * RETURNS: char* describing the offset.
 
676
 */
 
677
{
 
678
   if( options.lda == USE_KB )
 
679
   {
 
680
      if( options.constantFolding )
 
681
      {
 
682
         offset = times*unroll.kb + offset;
 
683
         times = 0;
 
684
      }
 
685
 
 
686
      char *out = malloc( 255 );
 
687
 
 
688
      if( times > 0 )
 
689
      {
 
690
         if( offset > 0 )
 
691
         {
 
692
            if( useVoidPointersForA )
 
693
               sprintf( out, "A0_off + %d*KB%d + %d",
 
694
                        times, element.size, offset );
 
695
            else
 
696
               sprintf( out, "A0_off + %d*KB + %d",
 
697
                        times, offset );
 
698
         } else {
 
699
            if( useVoidPointersForA )
 
700
               sprintf( out, "A0_off + %d*KB%d", times, element.size );
 
701
            else
 
702
               sprintf( out, "A0_off + %d*KB", times );
 
703
         }
 
704
      } else {
 
705
         if( offset > 0 )
 
706
         {
 
707
            sprintf( out, "A0_off + %d", offset );
 
708
         } else {
 
709
            sprintf( out, "A0_off" );
 
710
         }
 
711
      }
 
712
 
 
713
 
 
714
      return out;
 
715
   }
 
716
   else if( options.lda == PARAMETER )
 
717
   {
 
718
      char *out = malloc( 255 );
 
719
      if( offset > 0 )
 
720
      {
 
721
         switch( times )
 
722
         {
 
723
            case 0: sprintf( out, "A0_off + %d", offset ); break;
 
724
            case 1: sprintf( out, "A0_off + lda_bytes + %d", offset ); break;
 
725
            case 2: sprintf( out, "A0_off + 2*lda_bytes + %d", offset ); break;
 
726
            case 3: sprintf( out, "A3_off + %d", offset ); break;
 
727
            case 4: sprintf( out, "A0_off + 4*lda_bytes + %d", offset ); break;
 
728
            case 5: sprintf( out, "A3_off + 2*lda_bytes + %d", offset ); break;
 
729
            case 6: sprintf( out, "A0_off + 2*lda_bytes3 + %d", offset ); break;
 
730
            case 7: sprintf( out, "A3_off + 4*lda_bytes + %d", offset ); break;
 
731
            case 8: sprintf( out, "A0_off + 8*lda_bytes + %d", offset ); break;
 
732
            case 9: sprintf( out, "A3_off + 2*lda_bytes3 + %d", offset ); break;
 
733
            default: sprintf( out, "A0_off + %d*lda_bytes + %d",
 
734
                              times, offset );
 
735
         }
 
736
      } else {
 
737
         switch( times )
 
738
         {
 
739
            case 0: sprintf( out, "A0_off" ); break;
 
740
            case 1: sprintf( out, "A0_off + lda_bytes" ); break;
 
741
            case 2: sprintf( out, "A0_off + 2*lda_bytes" ); break;
 
742
            case 3: sprintf( out, "A3_off" ); break;
 
743
            case 4: sprintf( out, "A0_off + 4*lda_bytes" ); break;
 
744
            case 5: sprintf( out, "A3_off + 2*lda_bytes" ); break;
 
745
            case 6: sprintf( out, "A0_off + 2*lda_bytes3" ); break;
 
746
            case 7: sprintf( out, "A3_off + 4*lda_bytes" ); break;
 
747
            case 8: sprintf( out, "A0_off + 8*lda_bytes" ); break;
 
748
            case 9: sprintf( out, "A3_off + 2*lda_bytes3" ); break;
 
749
            default: sprintf( out, "A0_off + %d*lda_bytes", times );
 
750
         }
 
751
      }
 
752
      return out;
 
753
   }
 
754
   else
 
755
   {
 
756
                char *out = malloc( 255 );
 
757
                if( times > 0 )
 
758
                {
 
759
      int delta;
 
760
         if( useVoidPointersForA )
 
761
            delta = options.lda*element.size;
 
762
         else
 
763
            delta = options.lda;
 
764
 
 
765
         if( options.constantFolding )
 
766
         {
 
767
            delta = times*delta;
 
768
            offset = 0;
 
769
         }
 
770
 
 
771
         if( offset > 0 )
 
772
         {
 
773
                           sprintf( out, "A0_off + %d*%d + %d",
 
774
                     times, delta, offset );
 
775
         } else {
 
776
                           sprintf( out, "A0_off + %d*%d",
 
777
                     times, delta );
 
778
         }
 
779
                } else {
 
780
         if( offset > 0 )
 
781
         {
 
782
           sprintf( out, "A0_off + %d", offset );
 
783
         } else {
 
784
           sprintf( out, "A0_off" );
 
785
         }
 
786
                }
 
787
        return out;
 
788
   }
 
789
}
 
790
 
 
791
char* ldbOffset
 
792
(
 
793
 ATL_INT times,  /* number of multiples of ldb to offset */
 
794
 ATL_INT offset  /* Extra offset */
 
795
)
 
796
/*
 
797
 * Returns a compilable code string that will evaluate to
 
798
 * a byte offset in terms of ldb.
 
799
 * RETURNS: char* describing the offset.
 
800
 */
 
801
{
 
802
   if( options.ldb == USE_KB )
 
803
   {
 
804
      if( options.constantFolding )
 
805
      {
 
806
         offset = (times*unroll.kb + offset);
 
807
         times = 0;
 
808
      }
 
809
 
 
810
      char *out = malloc( 255 );
 
811
      if( times > 0 )
 
812
      {
 
813
         if( offset > 0 )
 
814
         {
 
815
            if( useVoidPointersForB )
 
816
               sprintf( out, "B0_off + %d*KB%d + %d",
 
817
                     times, element.size, offset );
 
818
            else
 
819
               sprintf( out, "B0_off + %d*KB + %d",
 
820
                     times, offset );
 
821
         } else {
 
822
            if( useVoidPointersForB )
 
823
               sprintf( out, "B0_off + %d*KB%d", times, element.size );
 
824
            else
 
825
               sprintf( out, "B0_off + %d*KB%d", times, element.size );
 
826
         }
 
827
      } else {
 
828
         if( offset > 0 )
 
829
         {
 
830
            sprintf( out, "B0_off + %d", offset );
 
831
         } else {
 
832
            sprintf( out, "B0_off" );
 
833
         }
 
834
      }
 
835
      return out;
 
836
   } else if( options.ldb == PARAMETER ) {
 
837
      if( offset > 0 )
 
838
      {
 
839
         char *out = malloc( 255 );
 
840
         switch( times )
 
841
         {
 
842
         case 0: sprintf( out, "B0_off + %d", offset ); break;
 
843
         case 1: sprintf( out, "B0_off + ldb_bytes + %d", offset ); break;
 
844
         case 2: sprintf( out, "B0_off + 2*ldb_bytes + %d", offset ); break;
 
845
         case 3: sprintf( out, "B0_off + ldb_bytes3 + %d", offset ); break;
 
846
         case 4: sprintf( out, "B0_off + 4*ldb_bytes + %d", offset ); break;
 
847
         case 6: sprintf( out, "B0_off + 2*ldb_bytes3 + %d", offset ); break;
 
848
         default:
 
849
            {
 
850
               if( useVoidPointersForB )
 
851
                  sprintf( out, "B0_off + %d*%d + %d",
 
852
                        times, options.ldb*element.size, offset );
 
853
               else
 
854
                  sprintf( out, "B0_off + %d*%d + %d",
 
855
                        times, options.ldb, offset );
 
856
            }
 
857
         }
 
858
         return out;
 
859
      } else {
 
860
         switch( times )
 
861
         {
 
862
         case 0: return "B0_off";
 
863
         case 1: return "B0_off + ldb_bytes";
 
864
         case 2: return "B0_off + 2*ldb_bytes";
 
865
         case 3: return "B0_off + ldb_bytes3";
 
866
         case 4: return "B0_off + 4*ldb_bytes";
 
867
         case 6: return "B0_off + 2*ldb_bytes3";
 
868
         default:
 
869
            {
 
870
               char *out = malloc( 255 );
 
871
               if( useVoidPointersForB )
 
872
                  sprintf( out, "B0_off + %d*%d",
 
873
                        times, options.ldb*element.size );
 
874
               else
 
875
                  sprintf( out, "B0_off + %d*%d",
 
876
                        times, options.ldb*element.size );
 
877
               return out;
 
878
            }
 
879
         }
 
880
      }
 
881
   } else {
 
882
      char *out = malloc( 255 );
 
883
      if( offset > 0 )
 
884
      {
 
885
         switch( times )
 
886
         {
 
887
            case 0: sprintf( out, "B0_off + %d", offset ); break;
 
888
            case 1:
 
889
                    if( useVoidPointersForB )
 
890
                       sprintf( out, "B0_off + %d*sizeof(TYPE) + %d",
 
891
                             options.ldb, offset );
 
892
                    else
 
893
                       sprintf( out, "B0_off + %d + %d",
 
894
                             options.ldb, offset );
 
895
            default:
 
896
               if( useVoidPointersForB )
 
897
                  sprintf( out, "B0_off + %d*%d*sizeof(TYPE) + %d",
 
898
                           times, options.ldb, offset );
 
899
               else
 
900
                  sprintf( out, "B0_off + %d*%d + %d",
 
901
                        times, options.ldb, offset );
 
902
         }
 
903
         return out;
 
904
      } else {
 
905
         switch( times )
 
906
         {
 
907
            case 0: sprintf( out, "B0_off" ); break;
 
908
            case 1:
 
909
                  if( useVoidPointersForB )
 
910
                    sprintf( out, "B0_off + %d*sizeof(TYPE)", options.ldb );
 
911
                  else
 
912
                    sprintf( out, "B0_off + %d", options.ldb );
 
913
            default:
 
914
               if( useVoidPointersForB )
 
915
                  sprintf( out, "B0_off + %d*%d*sizeof(TYPE)",
 
916
                           times, options.ldb );
 
917
               else
 
918
                  sprintf( out, "B0_off + %d*%d", times, options.ldb );
 
919
         }
 
920
         return out;
 
921
      }
 
922
   }
 
923
}
 
924
 
 
925
 
 
926
 
 
927
void k_unrolling0()
 
928
/*
 
929
 * Emit code for the initial iteration of the K loop.
 
930
 * This iteration is special because it does not need to
 
931
 * accumulate, it only needs to initialize the scalar
 
932
 * expansion registers.
 
933
 */
 
934
{
 
935
   ATL_INT a, b;
 
936
   emit( "/* K_Unrolling0 */\n" );
 
937
 
 
938
 
 
939
   for( a=0; a<unroll.a; ++a )
 
940
   {
 
941
      char* deltaLDA = ldaOffset( a, 0 );
 
942
      emit( "A%d = %s( MMCAST(%s) );\n", a, element.load_ab, deltaLDA );
 
943
   }
 
944
 
 
945
   for( b=0; b<unroll.b; ++b )
 
946
   {
 
947
      emit( "B%d = %s( MMCAST(%s) );\n",
 
948
            b, element.load_ab, ldbOffset(b, 0) );
 
949
 
 
950
 
 
951
      for( a=0; a<unroll.a; ++a )
 
952
      {
 
953
         emit( "c%d_%d = B%d;\n", b, a, b );
 
954
         if( intrinSrcDest )
 
955
         {
 
956
            emit( "c%d_%d = _mm_mul_p%c( A%d, c%d_%d );\n",
 
957
                  b, a, element.cType, a, b, a );
 
958
         } else {
 
959
            emit( "c%d_%d = _mm_mul_p%c( c%d_%d, A%d );\n",
 
960
                  b, a, element.cType, b, a, a );
 
961
         }
 
962
      }
 
963
 
 
964
      emit( "\n" );
 
965
   }
 
966
}
 
967
 
 
968
 
 
969
void k_unrollingFullStep
 
970
(
 
971
  ATL_INT delta    /* Number of bytes to offset during this unrolling */
 
972
)
 
973
/* Emit code that performs an unrolling step of the K loop.
 
974
 * This function will be called repeatedly, increasing the delta value
 
975
 * to account for the unrolling
 
976
 */
 
977
{
 
978
   ATL_INT a, b;
 
979
 
 
980
   char* tempId[] = { "B0", "B1" };
 
981
 
 
982
 
 
983
   emit( "/* K_Unrolling: %d */\n", delta );
 
984
   assert( delta < unroll.kb );
 
985
 
 
986
 
 
987
   for( a=0; a<unroll.a; ++a )
 
988
   {
 
989
      if( useVoidPointersForA )
 
990
         emit("A%d = %s( MMCAST(%s) );\n",
 
991
               a, element.load_ab, ldaOffset(a, delta*element.size ) );
 
992
      else
 
993
         emit("A%d = %s( MMCAST(%s) );\n",
 
994
               a, element.load_ab, ldaOffset(a, delta) );
 
995
   }
 
996
 
 
997
 
 
998
   /* b = unroll.b-1 */
 
999
   for( b=0; b<unroll.b-1; ++b )
 
1000
   {
 
1001
      emit( "\n" );
 
1002
      if( useVoidPointersForB )
 
1003
         emit( "B%d = %s( MMCAST(%s) );\n",
 
1004
               b, element.load_ab, ldbOffset( b, delta*element.size ) );
 
1005
      else
 
1006
         emit( "B%d = %s( MMCAST(%s) );\n",
 
1007
               b, element.load_ab, ldbOffset( b, delta ) );
 
1008
 
 
1009
 
 
1010
      for( a=0; a<unroll.a; ++a )
 
1011
      {
 
1012
         emit( "a%d = A%d;\n", a, a );
 
1013
         if( intrinSrcDest )
 
1014
         {
 
1015
            emit( "a%d = _mm_mul_p%c( B%d, a%d );\n",
 
1016
                  a, element.cType, b, a );
 
1017
            emit( "c%d_%d = _mm_add_p%c( a%d, c%d_%d );\n",
 
1018
                  b, a, element.cType,  a, b, a );
 
1019
         } else {
 
1020
            emit( "a%d = _mm_mul_p%c( a%d, B%d );\n",
 
1021
                  a, element.cType, a, b );
 
1022
             emit( "c%d_%d = _mm_add_p%c( c%d_%d, a%d );\n",
 
1023
                  b, a, element.cType,  b, a, a );
 
1024
         }
 
1025
      }
 
1026
   }
 
1027
 
 
1028
   emit( "\n" );
 
1029
   if( useVoidPointersForB )
 
1030
      emit( "B%d = %s( MMCAST(%s) );\n",
 
1031
            unroll.b-1, element.load_ab,
 
1032
            ldbOffset( unroll.b-1, delta*element.size )
 
1033
          );
 
1034
   else
 
1035
      emit( "B%d = %s( MMCAST(%s) );\n",
 
1036
            unroll.b-1, element.load_ab, ldbOffset( unroll.b-1, delta ) );
 
1037
 
 
1038
 
 
1039
   for( a=0; a<unroll.a; ++a )
 
1040
   {
 
1041
      if( intrinSrcDest )
 
1042
      {
 
1043
         emit( "A%d = _mm_mul_p%c( B%d, A%d );\n",
 
1044
               a, element.cType, unroll.b-1, a );
 
1045
         emit( "c%d_%d = _mm_add_p%c( A%d, c%d_%d );\n",
 
1046
               unroll.b-1, a, element.cType, a, unroll.b-1, a );
 
1047
      } else {
 
1048
         emit( "A%d = _mm_mul_p%c( A%d, B%d );\n",
 
1049
               a, element.cType, a, unroll.b-1 );
 
1050
         emit( "c%d_%d = _mm_add_p%c( c%d_%d, A%d );\n",
 
1051
               unroll.b-1, a, element.cType, unroll.b-1, a );
 
1052
      }
 
1053
   }
 
1054
 
 
1055
}
 
1056
 
 
1057
 
 
1058
void k_partialUnrolling
 
1059
(
 
1060
  ATL_INT offset    /* Number of bytes to offset during this unrolling */
 
1061
)
 
1062
/* Emit code that performs an unrolling step of the K loop.
 
1063
 * This function will be called repeatedly, increasing the delta value
 
1064
 * to account for the unrolling
 
1065
 */
 
1066
{
 
1067
   ATL_INT a, b;
 
1068
 
 
1069
   assert( offset < unroll.kb );
 
1070
        assert( offset < unroll.k );
 
1071
 
 
1072
   emit( "/* k_partialUnrolling: %d */\n", offset );
 
1073
 
 
1074
   for( b=0; b<unroll.b; ++b )
 
1075
   {
 
1076
                if( offset > 0 )
 
1077
      {
 
1078
         if( useVoidPointersForB )
 
1079
            emit("B%d = %s( MMCAST( %s + (k+%d)*sizeof(TYPE) ) );\n",
 
1080
               b, element.load_ab, ldbOffset( b, 0 ), offset );
 
1081
         else
 
1082
            emit("B%d = %s( MMCAST( %s + (k+%d) ) );\n",
 
1083
               b, element.load_ab, ldbOffset( b, 0 ), offset );
 
1084
      } else {
 
1085
         if( useVoidPointersForB )
 
1086
            emit("B%d = %s( MMCAST( %s + k*sizeof(TYPE) ) );\n",
 
1087
                 b, element.load_ab, ldbOffset( b, 0 ) );
 
1088
         else
 
1089
            emit("B%d = %s( MMCAST( %s + k ) );\n",
 
1090
                 b, element.load_ab, ldbOffset( b, 0 ) );
 
1091
      }
 
1092
        }
 
1093
 
 
1094
   for( b=0; b<unroll.b; ++b )
 
1095
   {
 
1096
      for( a=0; a<unroll.a; ++a )
 
1097
      {
 
1098
         if( b == 0 )
 
1099
         {
 
1100
            if( offset > 0 )
 
1101
            {
 
1102
               if( useVoidPointersForA )
 
1103
                  emit("A%d = %s( MMCAST(%s + (k+%d)*sizeof(TYPE)) );\n",
 
1104
                     a, element.load_ab, ldaOffset(a, 0), offset );
 
1105
               else
 
1106
                  emit("A%d = %s( MMCAST( %s + (k+%d) ) );\n",
 
1107
                     a, element.load_ab, ldaOffset(a, 0), offset );
 
1108
            } else {
 
1109
               if( useVoidPointersForA )
 
1110
                  emit("A%d = %s( MMCAST( %s + k*sizeof(TYPE)));\n",
 
1111
                     a, element.load_ab, ldaOffset(a, 0) );
 
1112
               else
 
1113
                  emit("A%d = %s( MMCAST(%s + k));\n",
 
1114
                     a, element.load_ab, ldaOffset(a, 0) );
 
1115
            }
 
1116
         }
 
1117
         emit("temp = _mm_mul_p%c( B%d, A%d );\n", element.cType, b, a );
 
1118
 
 
1119
         if( intrinSrcDest )
 
1120
         {
 
1121
         emit("c%d_%d = _mm_add_p%c( temp, c%d_%d );\n",
 
1122
               b, a, element.cType, b, a );
 
1123
         } else {
 
1124
         emit("c%d_%d = _mm_add_p%c(c%d_%d, temp );\n",
 
1125
               b, a, element.cType, b, a );
 
1126
         }
 
1127
      }
 
1128
   }
 
1129
}
 
1130
 
 
1131
 
 
1132
 
 
1133
 
 
1134
void printAllKUnrollings
 
1135
(
 
1136
  BOOL prefetchA,  /* Prefetch A during this call? */
 
1137
  BOOL prefetchB   /* Prefetch B during this call? */
 
1138
)
 
1139
{
 
1140
 
 
1141
        if( unroll.k == 1 )
 
1142
   {
 
1143
      printKRolled( prefetchA, prefetchB );
 
1144
   }
 
1145
   else if( unroll.k < unroll.kb )
 
1146
   {
 
1147
      printf( "unrollings: %d, %d\n", unroll.kb, unroll.k );
 
1148
                assert( unroll.kb % unroll.k == 0 );
 
1149
      printPartiallyUnrolledK( prefetchA, prefetchB );
 
1150
   }
 
1151
   else
 
1152
   {
 
1153
      printFullyUnrolledK( prefetchA, prefetchB );
 
1154
   }
 
1155
 
 
1156
 
 
1157
   if( prefetchA && prefetch.ABlock )
 
1158
   {
 
1159
 
 
1160
      const ATL_INT numABlockPrefetches =
 
1161
         unroll.a * unroll.b * element.size / CACHE_LINE_SIZE + 1;
 
1162
      emit("prefetchABlock += %d*pfBlockDistance;\n", numABlockPrefetches );
 
1163
   }
 
1164
 
 
1165
 
 
1166
   if( options.ldb == PARAMETER )
 
1167
   {
 
1168
      if( prefetchB && prefetch.BCols )
 
1169
      {
 
1170
         emit( "prefetchB += J_UNROLL*ldb_bytes;\n" );
 
1171
      }
 
1172
   } else if( options.ldb == USE_KB ) {
 
1173
      if( prefetchB && prefetch.BCols )
 
1174
      {
 
1175
         emit( "prefetchB += J_UNROLL*KB*%d;\n", element.size );
 
1176
      }
 
1177
   } else {
 
1178
      if( prefetchB && prefetch.BCols )
 
1179
      {
 
1180
         emit( "prefetchB += J_UNROLL*%d;\n", element.size*options.ldb );
 
1181
      }
 
1182
   }
 
1183
 
 
1184
}
 
1185
 
 
1186
void printFullyUnrolledK(
 
1187
      BOOL prefetchA,  /* Allow prefetches of A? */
 
1188
      BOOL prefetchB   /* Allow prefetches of B? */
 
1189
)
 
1190
/*
 
1191
 * Print the K loop fully unrolled.
 
1192
 */
 
1193
{
 
1194
   ATL_INT prefetchB_counter = 0;
 
1195
   ATL_INT blockACounter = 0;
 
1196
 
 
1197
   ATL_INT pfColumn = 0;
 
1198
   ATL_INT prefetchACols = 0;
 
1199
   ATL_INT k;
 
1200
 
 
1201
   const ATL_INT numABlockPrefetches =
 
1202
      unroll.a * unroll.b * element.size / CACHE_LINE_SIZE + 1;
 
1203
 
 
1204
   const ATL_INT numAColPrefetches =
 
1205
      unroll.a * element.size / CACHE_LINE_SIZE + 1;
 
1206
 
 
1207
/*
 
1208
 * (Number of b unrolls: so we can fetch one line each time)
 
1209
 * (KB * element.size): number of bytes to fetch
 
1210
 */
 
1211
   const ATL_INT numBPrefetches =
 
1212
      unroll.b * unroll.kb * element.size / CACHE_LINE_SIZE + 1;
 
1213
 
 
1214
   const ATL_INT prefetchBDelta =
 
1215
      unroll.kb * element.size / numBPrefetches;
 
1216
 
 
1217
   k_unrolling0(); /* Specialize first unrolling */
 
1218
 
 
1219
/*
 
1220
 * Emit unrollings by vector size
 
1221
 */
 
1222
   for( k=element.vector_stride; k<unroll.kb; k+=element.vector_stride )
 
1223
   {
 
1224
/*
 
1225
 *    Prefetch an element from the next block of A
 
1226
 */
 
1227
      if( prefetch.ABlock && prefetchA && blockACounter < numABlockPrefetches )
 
1228
      {
 
1229
         emit("/* Prefetch one element from the next block of A */\n");
 
1230
         emit("__builtin_prefetch( prefetchABlock + %d*pfBlockDistance,"
 
1231
              "PF_READONLY, PF_DEF );\n", blockACounter );
 
1232
         blockACounter++;
 
1233
      }
 
1234
 
 
1235
      k_unrollingFullStep( k );
 
1236
 
 
1237
      if( prefetchB && prefetchB_counter < numBPrefetches )
 
1238
      {
 
1239
         ATL_INT row = prefetchB_counter % unroll.b;
 
1240
 
 
1241
         if( options.ldb == PARAMETER )
 
1242
         {
 
1243
            emit( "__builtin_prefetch( prefetchB + %d*ldb_bytes + %d,"
 
1244
                  "PF_READONLY, PF_DEF);\n",
 
1245
                  row, prefetchB_counter / unroll.b * prefetchBDelta );
 
1246
         } else {
 
1247
            emit( "__builtin_prefetch( prefetchB + %d + %d,"
 
1248
                  "PF_READONLY, PF_DEF);\n",
 
1249
                  row*options.ldb*element.size,
 
1250
                  prefetchB_counter / unroll.b * prefetchBDelta );
 
1251
         }
 
1252
 
 
1253
         prefetchB_counter++;
 
1254
      }
 
1255
 
 
1256
 
 
1257
/*
 
1258
 *    Prefetch some columns of A further along
 
1259
 */
 
1260
      if( prefetch.ACols )
 
1261
      {
 
1262
         switch( options.lda )
 
1263
         {
 
1264
            case USE_KB:
 
1265
            emit( "__builtin_prefetch( prefetchACols+%d+KB*%d,"
 
1266
                  "PF_READONLY, PF_DEF );\n",
 
1267
                  prefetchACols, pfColumn*element.size );
 
1268
            break;
 
1269
 
 
1270
            case PARAMETER:
 
1271
            emit( "__builtin_prefetch( prefetchACols+%d+lda_bytes*%d,"
 
1272
                  "PF_READONLY, PF_DEF );\n",
 
1273
                  prefetchACols, pfColumn );
 
1274
            break;
 
1275
 
 
1276
            default:
 
1277
            emit( "__builtin_prefetch( prefetchACols+lda_bytes*%d+%d,"
 
1278
                  "PF_READONLY, PF_DEF );\n",
 
1279
                  pfColumn, prefetchACols );
 
1280
         }
 
1281
 
 
1282
      pfColumn++;
 
1283
         if( pfColumn == unroll.a )
 
1284
         {
 
1285
            prefetchACols += CACHE_LINE_SIZE;
 
1286
            pfColumn = 0;
 
1287
         }
 
1288
      }
 
1289
   }
 
1290
}
 
1291
 
 
1292
void printPartiallyUnrolledK
 
1293
(
 
1294
   BOOL prefetchA,  /* Allow prefetches of A? */
 
1295
   BOOL prefetchB   /* Allor prefetches of B? */
 
1296
)
 
1297
/*
 
1298
 * Print a partially unrolled K loop, only used when
 
1299
 * K != KB, otherwise the fully unrolled k loop function
 
1300
 * is called.
 
1301
 *
 
1302
 * The prefetch parameters allow prefetching on this iteration,
 
1303
 * but do not require it.  This allows separate passes to enable
 
1304
 * prefetching on each input (For peeling, etc.).
 
1305
 */
 
1306
{
 
1307
   ATL_INT offset;
 
1308
   ATL_INT prefetchB_counter = 0;
 
1309
 
 
1310
   k_unrolling0();
 
1311
   for( offset=element.vector_stride; offset<unroll.k; offset+=element.vector_stride )
 
1312
        {
 
1313
      k_unrollingFullStep( offset );
 
1314
        }
 
1315
 
 
1316
   emit( "/* k unroll factor: %d */\n", unroll.k );
 
1317
 
 
1318
   emit( "for( k=%d; k<%d; k+=%d)\n", unroll.k, unroll.kb, unroll.k );
 
1319
   emit( "{\n" );
 
1320
   indent( 1 );
 
1321
 
 
1322
   for( offset=0; offset<unroll.k; offset+=element.vector_stride )
 
1323
   {
 
1324
      k_partialUnrolling( offset );
 
1325
   }
 
1326
 
 
1327
        indent(-1);
 
1328
   emit( "}\n" );
 
1329
}
 
1330
 
 
1331
 
 
1332
void printKRolled
 
1333
(
 
1334
   BOOL prefetchA,  /* Allow prefetch of A */
 
1335
   BOOL prefetchB   /* Allow prefetch of B */
 
1336
)
 
1337
/* Print a rolled K Loop */
 
1338
{
 
1339
   char base[100];
 
1340
 
 
1341
   if( prefetchB )
 
1342
   {
 
1343
      ATL_INT k;
 
1344
      for( k=0; k<unroll.b; ++k )
 
1345
      {
 
1346
         emit( "__builtin_prefetch( prefetchB + %d*KB, PF_READONLY, PF_DEF );"
 
1347
               "\n", k );
 
1348
         emit( "prefetchB += CACHE_LINE_SIZE;\n" );
 
1349
      }
 
1350
   }
 
1351
 
 
1352
 
 
1353
 
 
1354
/*
 
1355
 *    If prefetching is allowed on this iteration, and it is
 
1356
 *    globally enabled, prefetch an element from the next block of A
 
1357
 */
 
1358
   if( prefetchA && prefetch.ABlock )
 
1359
   {
 
1360
      emit("/* Prefetch one element from the next block of A */\n");
 
1361
      emit("__builtin_prefetch( prefetchABlock,PF_READONLY,PF_DEF );\n");
 
1362
      emit("prefetchABlock += pfBlockDistance;\n");
 
1363
      emit("\n");
 
1364
   }
 
1365
 
 
1366
   k_unrolling0();  /* Print the initial unrolling */
 
1367
 
 
1368
/*
 
1369
 * Print the rolled K loop, adjusting for the peeled iteration
 
1370
 */
 
1371
        switch( unroll.kb )
 
1372
        {
 
1373
                case PARAMETER:
 
1374
                emit( "for( k=%d; k<K; k+=%d )\n", element.vector_stride, element.vector_stride );
 
1375
                break;
 
1376
 
 
1377
                case USE_KB:
 
1378
           emit( "for( k=%d; k<KB; k+=%d )\n",
 
1379
          element.vector_stride, element.vector_stride );
 
1380
                break;
 
1381
 
 
1382
                default:
 
1383
                emit( "for( k=%d; k<%d; k+=%d )\n",
 
1384
                        element.vector_stride, unroll.kb, element.vector_stride );
 
1385
        }
 
1386
 
 
1387
   emit( "{\n" );
 
1388
   indent(1);
 
1389
   k_partialUnrolling( 0 );
 
1390
   indent(-1);
 
1391
   emit( "}\n" );
 
1392
}
 
1393
 
 
1394
 
 
1395
 
 
1396
void printScalarCompressionSingle()
 
1397
/*
 
1398
 * Print the scalar compression routine for single precision.
 
1399
 * This will take the four floats in the vector unit and combine
 
1400
 * them into one.  If possible, four vector units will be compressed
 
1401
 * into one vector that can be written directly to memeory.
 
1402
 */
 
1403
{
 
1404
   ATL_INT a, b;
 
1405
   emit("/* Single Scalar Compression */\n" );
 
1406
   for( b=0; b<unroll.b; ++b )
 
1407
   {
 
1408
      ATL_INT remaining;
 
1409
      remaining = unroll.a;
 
1410
      a = 0;
 
1411
      for( ; remaining>=4; remaining-=4 )
 
1412
      {
 
1413
/*
 
1414
 *       -> [c0a+c0b, c0c+c0d, c1a+c1b, c1c+c1d]
 
1415
 *       -> [c2a+c2b, c2c+c2d, c3a+c3v, c3c+c3d]
 
1416
 *       -> [c0a+c0b+c0c+c0d, c1a+c1b+c1c+c1d,
 
1417
 *           c2a+c2b+c2c+c2d, c3a+c3b+c3c+c3d ]
 
1418
 */
 
1419
         emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b,a+1 );
 
1420
         emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a+2, b, a+2, b, a+3 );
 
1421
         emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a,b, a+2);
 
1422
         emit( "\n" );
 
1423
         a += 4;
 
1424
      }
 
1425
      for( ; remaining > 0; remaining-- )
 
1426
      {
 
1427
         emit( "/* additional remaining step */\n" );
 
1428
         emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b, a );
 
1429
         emit( "c%d_%d = _mm_hadd_ps( c%d_%d, c%d_%d );\n", b, a, b, a, b, a );
 
1430
         a += 1;
 
1431
      }
 
1432
   }
 
1433
}
 
1434
 
 
1435
 
 
1436
void printScalarCompression()
 
1437
/*
 
1438
 * Emit the scalar compression algorithm for double precision
 
1439
 */
 
1440
{
 
1441
   ATL_INT a, b;
 
1442
 
 
1443
   emit("/* Combine scalar expansion back to scalar */\n");
 
1444
 
 
1445
/*
 
1446
 * If I is unrolled an even number of times, the horizontal adds
 
1447
 * will have no cleanup case (for doubles)
 
1448
 */
 
1449
   if( unroll.a % element.vector_stride == 0 )
 
1450
   {
 
1451
      for( b=0; b<unroll.b; ++b )
 
1452
      {
 
1453
         for( a=0; a<unroll.a; a+=2 )
 
1454
         {
 
1455
            emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1456
                  b, a, element.cType, b, a, b, a+1 );
 
1457
         }
 
1458
      }
 
1459
 
 
1460
   }
 
1461
 
 
1462
/*
 
1463
 * There are not an even number of vectors in this store
 
1464
 */
 
1465
   else
 
1466
   {
 
1467
      emit( "/* handling uneven case */\n" );
 
1468
      for( b=0; b<unroll.b; ++b )
 
1469
      {
 
1470
         for( a=0; a<unroll.a-1; a+=2 )
 
1471
         {
 
1472
            if( element.type == SINGLE || element.type == COMPLEX_SINGLE  )
 
1473
            {
 
1474
            emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1475
                  b, a, element.cType, b, a, b, a+1 );
 
1476
            emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1477
                  b, a, element.cType, b, a, b, a+1 );
 
1478
            } else {
 
1479
               emit( "/* double */\n" );
 
1480
               emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1481
                     b, a, element.cType, b, a, b, a+1 );
 
1482
            }
 
1483
         }
 
1484
 
 
1485
         if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
 
1486
         {
 
1487
/*
 
1488
 *       Cleanup case, run when there are not an even number of vectors.
 
1489
 */
 
1490
         emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1491
               b, a, element.cType, b, a, b, a );
 
1492
         emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1493
               b, a, element.cType, b, a, b, a );
 
1494
         } else {
 
1495
         emit("c%d_%d = _mm_hadd_p%c( c%d_%d, c%d_%d );\n",
 
1496
               b, a, element.cType, b, a, b, a );
 
1497
         }
 
1498
      }
 
1499
   }
 
1500
}
 
1501
 
 
1502
 
 
1503
void storeResults(
 
1504
   int alignmentOfC  /* Is C aligned, unaligned, or alternating? */
 
1505
)
 
1506
/*
 
1507
 * Store the results of the I loop iteration to memory.
 
1508
 * This will update the C matrix
 
1509
 */
 
1510
{
 
1511
   ATL_INT a, b;
 
1512
   emit("/* Store results back to memory  */\n");
 
1513
        for( b=0; b<unroll.b; ++b )
 
1514
   {
 
1515
      ATL_INT remaining;
 
1516
      remaining = unroll.a;
 
1517
      a = 0;
 
1518
 
 
1519
 
 
1520
 
 
1521
                char* store;
 
1522
                if( alignmentOfC == ALIGNED )
 
1523
                {
 
1524
                        store = element.aStore;
 
1525
                }
 
1526
                else
 
1527
                {
 
1528
                        store = element.uStore;
 
1529
                }
 
1530
 
 
1531
 
 
1532
                if( element.type == SINGLE || element.type == DOUBLE )
 
1533
      {
 
1534
         for( ; remaining>=element.vector_stride;
 
1535
                remaining-=element.vector_stride )
 
1536
         {
 
1537
            if( a > 0 )
 
1538
            {
 
1539
               if( useVoidPointersForC )
 
1540
               {
 
1541
                  emit("%s( MMCAST( cPtrI%d+%d ),  MMCASTStoreintrin( c%d_%d ) );\n",
 
1542
                        store, b, a*element.size,b,a );
 
1543
               }
 
1544
               else
 
1545
                  emit("%s( MMCAST( cPtrI%d+%d ),  MMCASTStoreintrin( c%d_%d ) );\n",
 
1546
                        store, b, a, b, a );
 
1547
            } else {
 
1548
               emit("%s( MMCAST( cPtrI%d ),  MMCASTStoreintrin( c%d_%d ) );\n",
 
1549
                     store, b, b, a );
 
1550
            }
 
1551
            a += element.vector_stride;
 
1552
          }
 
1553
 
 
1554
         for( ; remaining > 0; remaining-- )
 
1555
         {
 
1556
            if( a > 0 )
 
1557
            {
 
1558
               if( useVoidPointersForC )
 
1559
                  emit("%s( cPtrI%d+%d,  c%d_%d );\n",
 
1560
                     element.sStore, b, a*element.size,b,a );
 
1561
               else
 
1562
                  emit("%s( cPtrI%d+%d,  c%d_%d );\n",
 
1563
                     element.sStore, b, a, b, a );
 
1564
            } else {
 
1565
               emit("%s( cPtrI%d,  c%d_%d );\n",
 
1566
                    element.sStore, b, b, a );
 
1567
            }
 
1568
            a += 1;
 
1569
         }
 
1570
      }
 
1571
      else
 
1572
      {
 
1573
                        ATL_INT x = 0;
 
1574
         ATL_INT IOffset=0;
 
1575
         ATL_INT COffset=0;
 
1576
 
 
1577
         for( ; remaining>=element.vector_stride;
 
1578
                remaining -= element.vector_stride )
 
1579
         {
 
1580
            if( element.type == COMPLEX_SINGLE )
 
1581
            {
 
1582
               if( useVoidPointersForC )
 
1583
               {
 
1584
                  /* @TODO: Fix this copy */
 
1585
                  emit( "temp = c%d_0;\n", b );
 
1586
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x) );
 
1587
 
 
1588
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1589
                        "_MM_SHUFFLE(1,1,1,1));\n", b, a, b, a );
 
1590
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+8) );
 
1591
 
 
1592
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1593
                        "_MM_SHUFFLE(2,2,2,2));\n", b, a, b, a );
 
1594
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+16) );
 
1595
 
 
1596
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1597
                        "_MM_SHUFFLE(3,3,3,3));\n", b, a, b, a );
 
1598
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+24) );
 
1599
               } else {
 
1600
                  emit( "temp = c%d_%d;\n", b, a );
 
1601
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x) );
 
1602
 
 
1603
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1604
                        "_MM_SHUFFLE(1,1,1,1));\n", b, a, b, a );
 
1605
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+2) );
 
1606
 
 
1607
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1608
                        "_MM_SHUFFLE(2,2,2,2));\n", b, a, b, a );
 
1609
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+4) );
 
1610
 
 
1611
                  emit( "temp = _mm_shuffle_ps( c%d_%d, c%d_%d,"
 
1612
                        "_MM_SHUFFLE(3,3,3,3));\n", b, a, b, a );
 
1613
                  emit( "_mm_store_ss( cPtrI%d+%d, temp );\n", b, (x+6) );
 
1614
 
 
1615
                  a += 4;
 
1616
                  x += 8;
 
1617
               }
 
1618
            } else if( element.type == COMPLEX_DOUBLE ) {
 
1619
 
 
1620
               if( useVoidPointersForC )
 
1621
               {
 
1622
                  emit( "_mm_store_sd( cPtrI%d+%d, c%d_%d );\n",
 
1623
                        b, IOffset, b, COffset );
 
1624
                  emit( "temp = _mm_shuffle_pd( c%d_%d, c%d_%d,"
 
1625
                        "_MM_SHUFFLE2(1,1));\n", b, COffset, b, COffset );
 
1626
 
 
1627
                  IOffset+= 8;
 
1628
                  emit( "_mm_store_sd( cPtrI%d+%d, temp );\n\n",
 
1629
                        b, IOffset );
 
1630
 
 
1631
                  IOffset+= 8;
 
1632
               } else {
 
1633
                  emit( "_mm_store_sd( cPtrI%d+%d, c%d_%d );\n",
 
1634
                        b, IOffset, b, COffset );
 
1635
                  emit( "temp = _mm_shuffle_pd( c%d_%d, c%d_%d,"
 
1636
                        "_MM_SHUFFLE2(1,1));\n", b, COffset, b, COffset );
 
1637
 
 
1638
                  IOffset+= 2;
 
1639
                  emit( "_mm_store_sd( cPtrI%d+%d, temp );\n\n",
 
1640
                        b, IOffset );
 
1641
                  IOffset += 2;
 
1642
                  COffset += 2;
 
1643
               }
 
1644
            }
 
1645
         }
 
1646
 
 
1647
         for( ; remaining >0; --remaining )
 
1648
         {
 
1649
            if( useVoidPointersForC )
 
1650
               emit( "_mm_store_s%c( cPtrI%d+%d, c%d_%d );\n",
 
1651
                     element.cType, b, 2*(a+x)*element.size, b,  a+x );
 
1652
            else
 
1653
               emit( " _mm_store_s%c( cPtrI%d+%d, c%d_%d );\n",
 
1654
                     element.cType,
 
1655
                     b,
 
1656
                     2*(unroll.a-remaining),
 
1657
                     b,
 
1658
                     (unroll.a-remaining) );
 
1659
            a+=1;
 
1660
         }
 
1661
 
 
1662
      }
 
1663
   }
 
1664
}
 
1665
 
 
1666
 
 
1667
void applyBeta()
 
1668
/*
 
1669
 * Apply the beta factor, if one is needed.
 
1670
 */
 
1671
{
 
1672
   ATL_INT a, b;
 
1673
 
 
1674
   emit( "/* Applying Beta */\n" );
 
1675
 
 
1676
   if( options.beta == BETA0 )
 
1677
   {
 
1678
      emit("/* No beta will be appied */\n" );
 
1679
   }
 
1680
   else
 
1681
   {
 
1682
      indent( 1 );
 
1683
      emit( "/* Apply Beta Factor */\n" );
 
1684
      for( b=0; b<unroll.b; ++b )
 
1685
      {
 
1686
         ATL_INT remaining = unroll.a;
 
1687
 
 
1688
         if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
 
1689
         {
 
1690
 
 
1691
            emit( "/* Load C from memory */\n" );
 
1692
            for( a=0;
 
1693
               remaining>=element.vector_stride;
 
1694
               remaining-=element.vector_stride )
 
1695
            {
 
1696
               ATL_INT x=0;
 
1697
               for( ; x<element.vector_stride; ++x )
 
1698
               {
 
1699
                  if( useVoidPointersForC )
 
1700
                  {
 
1701
                     emit("temp%d = %s( cPtrI%d + %d );\n", x,
 
1702
                          element.sLoad, b, 2*(element.size*(x+a)) );
 
1703
                  } else {
 
1704
                     emit("temp%d = %s( cPtrI%d + %d );\n", x,
 
1705
                          element.sLoad, b, 2*(x+a) );
 
1706
                  }
 
1707
 
 
1708
               }
 
1709
 
 
1710
               if( element.type == COMPLEX_SINGLE )
 
1711
               {
 
1712
/*                temp0 = temp0[0], temp1[0], temp0[1], temp1[1] */
 
1713
                  emit( "temp0 = _mm_unpacklo_ps( temp0, temp1 );\n" );
 
1714
 
 
1715
/*                temp2 = temp2[0], temp3[0], temp2[1], temp3[1] */
 
1716
                  emit( "temp2 = _mm_unpacklo_ps( temp2, temp3 );\n" );
 
1717
 
 
1718
 
 
1719
                  emit( "bc%d_%d = _mm_movelh_ps( temp0, temp2 );\n",
 
1720
                         b, a );
 
1721
 
 
1722
/*                b?_? = temp0[0], temp1[0], temp2[0], temp2[0] */
 
1723
               }
 
1724
 
 
1725
               if( element.type == COMPLEX_DOUBLE )
 
1726
               {
 
1727
                  emit( "bc%d_%d = _mm_shuffle_pd( temp0, temp1,"
 
1728
                        "_MM_SHUFFLE2(0, 0 )  );\n", b, a );
 
1729
               }
 
1730
 
 
1731
 
 
1732
               if( options.beta != BETA1 )
 
1733
               {
 
1734
                  if( intrinSrcDest )
 
1735
                  {
 
1736
                  emit("bc%d_%d = _mm_mul_p%c( betaV, bc%d_%d );\n",
 
1737
                     b,a, element.cType, b,a );
 
1738
                  } else {
 
1739
                  emit("bc%d_%d = _mm_mul_pd%c bc%d_%d, betaV );\n",
 
1740
                     b,a, element.cType, b,a );
 
1741
                  }
 
1742
               }
 
1743
               a+= element.vector_stride;
 
1744
            }
 
1745
            for( ; remaining > 0; --remaining )
 
1746
            {
 
1747
               emit( "/* %d remaining */\n", remaining );
 
1748
               if( useVoidPointersForC )
 
1749
               {
 
1750
                  emit("bc%d_%d = %s( cPtrI%d+%d );\n",
 
1751
                     b, a, element.sLoad, b, 2*a*element.size );
 
1752
               } else {
 
1753
                  emit("bc%d_%d = %s( cPtrI%d+%d );\n",
 
1754
                     b, a, element.sLoad, b, 2*a );
 
1755
               }
 
1756
 
 
1757
 
 
1758
               if( options.beta != BETA1 )
 
1759
               {
 
1760
                  if( intrinSrcDest )
 
1761
                  {
 
1762
                  emit("bc%d_%d = _mm_mul_s%c( betaV, bc%d_%d );\n",
 
1763
                        b,a, element.cType, b,a );
 
1764
                  } else {
 
1765
                  emit("bc%d_%d = _mm_mul_s%c( bc%d_%d, betaV );\n",
 
1766
                        b,a, element.cType, b,a );
 
1767
                  }
 
1768
               }
 
1769
               a++;
 
1770
            }
 
1771
 
 
1772
         } else {
 
1773
            emit( "/* Load C from memory */\n" );
 
1774
 
 
1775
            for( a=0;
 
1776
               remaining>=element.vector_stride;
 
1777
               remaining-=element.vector_stride )
 
1778
            {
 
1779
               if( useVoidPointersForC )
 
1780
                  emit("bc%d_%d = _mm_loadu_p%c( cPtrI%d+%d );\n", b, a,
 
1781
                     element.cType, b, a*element.size );
 
1782
               else
 
1783
                  emit("bc%d_%d = _mm_loadu_p%c( cPtrI%d+%d );\n", b, a,
 
1784
                     element.cType, b, a );
 
1785
 
 
1786
                  if( options.beta != BETA1 )
 
1787
                  {
 
1788
                     if( intrinSrcDest )
 
1789
                     {
 
1790
                     emit("bc%d_%d = _mm_mul_p%c( betaV, bc%d_%d );\n",
 
1791
                        b,a, element.cType, b,a );
 
1792
                     } else {
 
1793
                     emit("bc%d_%d = _mm_mul_p%c( bc%d_%d, betaV );\n",
 
1794
                        b,a, element.cType, b,a );
 
1795
                     }
 
1796
                  }
 
1797
               a+= element.vector_stride;
 
1798
            }
 
1799
            for( ; remaining > 0; --remaining )
 
1800
            {
 
1801
               emit( "/* %d remaining */\n", remaining );
 
1802
               if( useVoidPointersForC )
 
1803
               {
 
1804
                  emit("bc%d_%d = %s( cPtrI%d+%d );\n",
 
1805
                        b, a, element.sLoad, b, a*element.size );
 
1806
               } else {
 
1807
                  emit("bc%d_%d = %s( cPtrI%d+%d );\n",
 
1808
                        b, a, element.sLoad, b, a );
 
1809
               }
 
1810
 
 
1811
               if( options.beta != BETA1 )
 
1812
               {
 
1813
                  if( intrinSrcDest )
 
1814
                  {
 
1815
                  emit("bc%d_%d = _mm_mul_s%c( betaV, bc%d_%d );\n",
 
1816
                        b,a, element.cType, b,a );
 
1817
                  } else {
 
1818
                  emit("bc%d_%d = _mm_mul_s%c( bc%d_%d, betaV );\n",
 
1819
                        b,a, element.cType, b,a );
 
1820
                  }
 
1821
               }
 
1822
               a++;
 
1823
            }
 
1824
         }
 
1825
 
 
1826
      }
 
1827
      emit( "/* C = (beta*C) + (matrix multiply) */\n" );
 
1828
      for( b=0; b<unroll.b; ++b )
 
1829
      {
 
1830
         ATL_INT remaining = unroll.a;
 
1831
         a = 0;
 
1832
         for( ; remaining >= element.vector_stride;
 
1833
                remaining -= element.vector_stride )
 
1834
         {
 
1835
            if( intrinSrcDest )
 
1836
            {
 
1837
            emit("c%d_%d = _mm_add_p%c( bc%d_%d, c%d_%d );\n",
 
1838
                  b, a, element.cType, b, a, b, a );
 
1839
            } else {
 
1840
            emit("c%d_%d = _mm_add_p%c( c%d_%d, bc%d_%d );\n",
 
1841
                  b, a, element.cType, b, a, b, a );
 
1842
            }
 
1843
            a += element.vector_stride;
 
1844
         }
 
1845
 
 
1846
         for( ; remaining > 0; --remaining )
 
1847
         {
 
1848
            if( intrinSrcDest )
 
1849
            {
 
1850
            emit( "c%d_%d = _mm_add_s%c( bc%d_%d, c%d_%d );\n",
 
1851
                  b, a, element.cType, b, a, b, a );
 
1852
            } else {
 
1853
            emit( "c%d_%d = _mm_add_s%c( c%d_%d, bc%d_%d );\n",
 
1854
                  b, a, element.cType, b, a, b, a );
 
1855
            }
 
1856
            ++a;
 
1857
         }
 
1858
      }
 
1859
 
 
1860
      indent(-1 );
 
1861
   }
 
1862
}
 
1863
 
 
1864
 
 
1865
void printMainLoops
 
1866
(
 
1867
   int alignmentOfC,  /* How is C aligned? */
 
1868
   char* name      /* What is the name of this alignment */
 
1869
)
 
1870
/*
 
1871
 * Print the I,J,K loops, accounting for a specific form of alignment,
 
1872
 * aligned, unaligned, alternating
 
1873
 */
 
1874
{
 
1875
   ATL_INT a, b, k;
 
1876
   ATL_INT pf;
 
1877
 
 
1878
/*
 
1879
 * We are fetching (MB*KB*eltsize) bytes (one block of A).
 
1880
 * If we prefetch outside the KB loop, we will have (MB/mu)*(NB/nu)
 
1881
 * prefetch opportunities.
 
1882
 *
 
1883
 * Therefore, we must fetch (MB*KB*eltsize)/ [(MB/mu)*(NB/nu)]
 
1884
 * which is equal to (mu*nu*KB*eltsize)/NB.
 
1885
 *
 
1886
 * We still need to take care of the m-loop peeled case,
 
1887
 * when there is one less iteration of the m loop.
 
1888
 */
 
1889
 
 
1890
   if( prefetch.ABlock )
 
1891
   {
 
1892
      emit("const ATL_INT pfBlockDistance = (%d * %d * KB * %d) / %s;\n",
 
1893
            unroll.a, unroll.b, element.size, nb );
 
1894
   }
 
1895
 
 
1896
 
 
1897
        emit("/* =======================================\n" );
 
1898
   emit(" * Begin generated inner loops for case %s\n", name );
 
1899
   emit(" * ======================================= */\n" );
 
1900
 
 
1901
   switch( unroll.nb )
 
1902
   {
 
1903
      case PARAMETER:
 
1904
      emit("for( j=-NB; j!=0; j+=J_UNROLL) \n" );
 
1905
      break;
 
1906
      case USE_KB:
 
1907
      emit("for( j=-KB; j!=0; j+=J_UNROLL) \n" );
 
1908
      break;
 
1909
      default:
 
1910
      emit("for( j=-%d; j!=0; j+=J_UNROLL) \n", unroll.nb );
 
1911
      break;
 
1912
   }
 
1913
 
 
1914
   emit("{\n");
 
1915
   indent( 1 );
 
1916
 
 
1917
   if( useVoidPointersForA )
 
1918
      emit("register void const *A0_off = (void*)A; \n");
 
1919
   else
 
1920
      emit("register TYPE const *A0_off = A; \n");
 
1921
 
 
1922
 
 
1923
   if( options.lda == PARAMETER && unroll.a > 2 )
 
1924
   {
 
1925
      emit("register void const *A3_off = A0_off + lda_bytes3;\n");
 
1926
      if( unroll.a > 4 )
 
1927
      {
 
1928
         emit( "register void const *A5_off = A3_off + lda_bytes*2;\n" );
 
1929
      }
 
1930
   }
 
1931
   emit("\n");
 
1932
 
 
1933
   if( useVoidPointersForC )
 
1934
      emit( "register void *cPtrI0 = (void*)cPtr;\n" );
 
1935
   else
 
1936
      emit( "register TYPE *cPtrI0 = cPtr;\n" );
 
1937
 
 
1938
 
 
1939
   for( b=1; b<unroll.b; ++b )
 
1940
   {
 
1941
      emit("register TYPE *cPtrI%d = cPtrI%d + ldc_bytes;\n", b, b-1);
 
1942
   }
 
1943
 
 
1944
   emit("\n\n");
 
1945
 
 
1946
 
 
1947
   if( prefetch.fetchC == TRUE )
 
1948
   {
 
1949
      for( b=0; b<unroll.b; ++b )
 
1950
      {
 
1951
         emit("__builtin_prefetch( cPtrI%d, PF_READONLY, PF_DEF );\n", b );
 
1952
      }
 
1953
   }
 
1954
 
 
1955
 
 
1956
   char* deltaStr = "";
 
1957
/*
 
1958
 * Peel the last iteration of the inner loop if prefetch should run on B
 
1959
 */
 
1960
   if( prefetch.BCols == TRUE )
 
1961
   {
 
1962
      deltaStr = "+I_UNROLL";
 
1963
   }
 
1964
 
 
1965
 
 
1966
 
 
1967
   emit("for( i=-%s%s; i != 0; i+= I_UNROLL )\n", mb, deltaStr );
 
1968
 
 
1969
   emit("{\n");
 
1970
   indent( 1 );
 
1971
                printILoop( alignmentOfC, TRUE, FALSE );
 
1972
   indent( -1 );
 
1973
   emit("} /* End i/MB loop */\n\n");
 
1974
 
 
1975
   if( prefetch.BCols == TRUE )
 
1976
   {
 
1977
      printILoop( alignmentOfC, FALSE, TRUE );
 
1978
   }
 
1979
 
 
1980
   switch( options.ldb )
 
1981
   {
 
1982
      case PARAMETER:
 
1983
         if( useVoidPointersForB )
 
1984
            emit( "B0_off += J_UNROLL*ldb_bytes;\n");
 
1985
         else
 
1986
            emit( "B0_off += J_UNROLL*ldb_bytes;\n");
 
1987
      break;
 
1988
 
 
1989
      case USE_KB:
 
1990
         if( useVoidPointersForB )
 
1991
            emit( "B0_off += J_UNROLL*KB%d;\n", element.size );
 
1992
         else
 
1993
            emit( "B0_off += J_UNROLL*KB;\n" );
 
1994
      break;
 
1995
 
 
1996
      default:
 
1997
         if( useVoidPointersForB )
 
1998
            emit( "B0_off += J_UNROLL*%d*sizeof(TYPE);\n", options.ldb );
 
1999
         else
 
2000
            emit( "B0_off += J_UNROLL*%d;\n", options.ldb );
 
2001
   }
 
2002
   emit( "cPtr += J_UNROLL*ldc_bytes;\n" );
 
2003
 
 
2004
   indent( -1 );
 
2005
   emit("} /* End j/NB loop */\n");
 
2006
   emit("/* End of generated inner loops */\n");
 
2007
}
 
2008
 
 
2009
 
 
2010
/*
 
2011
 * Print one iteration of the middle loop, including beta adjustments,
 
2012
 * summing of the inner loop, and iteration along the matricies.
 
2013
 */
 
2014
void printILoop(
 
2015
  int alignmentOfC,   /* How is the data aligned? */
 
2016
  BOOL prefetchA,  /* Allow prefetch of A during this iteration? */
 
2017
  BOOL prefetchB   /* Allow prefetch of B during this iteration? */
 
2018
)
 
2019
{
 
2020
   ATL_INT pf;
 
2021
   ATL_INT offset;
 
2022
   ATL_INT a=0;
 
2023
   ATL_INT b;
 
2024
 
 
2025
 
 
2026
   if( prefetch.prefetchC )
 
2027
   {
 
2028
                if( element.type == SINGLE || element.type == DOUBLE )
 
2029
      {
 
2030
         for( b=0; b<unroll.b; ++b )
 
2031
         {
 
2032
            for( offset=0; offset<unroll.a; offset+=2 )
 
2033
            {
 
2034
               if( offset > 0 )
 
2035
               {
 
2036
                  emit("__builtin_prefetch( cPtrI%d+%d, PF_READONLY, PF_DEF );\n",
 
2037
                        b, offset*element.size );
 
2038
               } else {
 
2039
                  emit("__builtin_prefetch( cPtrI%d, PF_READONLY, PF_DEF );\n",
 
2040
                        b );
 
2041
               }
 
2042
            }
 
2043
         }
 
2044
      }
 
2045
   }
 
2046
 
 
2047
 
 
2048
 
 
2049
        printAllKUnrollings( prefetchA, prefetchB );
 
2050
 
 
2051
 
 
2052
 
 
2053
/*
 
2054
 * Scalar compression of singles and doubles behaves differently
 
2055
 */
 
2056
   if( element.type == SINGLE || element.type == COMPLEX_SINGLE )
 
2057
   {
 
2058
      printScalarCompressionSingle();
 
2059
   }
 
2060
   else
 
2061
   {
 
2062
      printScalarCompression();
 
2063
   }
 
2064
 
 
2065
 
 
2066
 
 
2067
/*
 
2068
 * Apply the beta scaling factor
 
2069
 */
 
2070
   applyBeta();
 
2071
 
 
2072
 
 
2073
 
 
2074
/*
 
2075
 * Move to the next iteration
 
2076
 */
 
2077
   emit("/* Move pointers to next iteration */  \n");
 
2078
   emit("A0_off += unroll_a;\n");
 
2079
 
 
2080
   if( options.lda == PARAMETER && unroll.a > 2 )
 
2081
   {
 
2082
      emit("A3_off += unroll_a;\n");
 
2083
      if( unroll.a > 5 )
 
2084
         emit( "A5_off += unroll_a;\n" );
 
2085
   }
 
2086
   emit("\n");
 
2087
 
 
2088
 
 
2089
/*
 
2090
 * Store the results of the computation back to memory
 
2091
 */
 
2092
   storeResults( alignmentOfC );
 
2093
 
 
2094
 
 
2095
 
 
2096
/*
 
2097
 * Increment Pointers
 
2098
 */
 
2099
   for( b=0; b<unroll.b; ++b )
 
2100
   {
 
2101
      if( element.type == COMPLEX_SINGLE || element.type == COMPLEX_DOUBLE )
 
2102
      {
 
2103
         if( useVoidPointersForC )
 
2104
            emit("cPtrI%d += %d*2*I_UNROLL;\n", b, element.size );
 
2105
         else
 
2106
            emit("cPtrI%d += 2*I_UNROLL;\n", b );
 
2107
      } else {
 
2108
         if( useVoidPointersForC )
 
2109
            emit("cPtrI%d += %d*I_UNROLL;\n", b, element.size );
 
2110
         else
 
2111
            emit("cPtrI%d += I_UNROLL;\n", b );
 
2112
      }
 
2113
   }
 
2114
 
 
2115
   emit("\n\n");
 
2116
 
 
2117
   if( prefetchA )
 
2118
   {
 
2119
      if( prefetch.ACols )
 
2120
      {
 
2121
         if( options.lda == PARAMETER )
 
2122
            emit( "prefetchACols += %d*lda;\n", unroll.a );
 
2123
         else if( options.lda == USE_KB )
 
2124
            emit( "prefetchACols += %d*KB;\n", unroll.a );
 
2125
         else
 
2126
            emit( "prefetchACols += %d;\n", unroll.a * options.lda * element.size );
 
2127
      }
 
2128
   }
 
2129
}
 
2130
 
 
2131
 
 
2132
void emit( const char *fmt, ...)
 
2133
/*
 
2134
 * Write a string to the output file.  It will be indented
 
2135
 * by the current indent value, vall indent(int) to adjust
 
2136
 * the indent factor.   indent(1) will do a standard scoping
 
2137
 * indent, and indent(-1) will remove that indent.
 
2138
 */
 
2139
{
 
2140
   assert( tabwidth >= 0 );
 
2141
 
 
2142
   va_list arg;
 
2143
   va_start(arg, fmt);
 
2144
   int t;
 
2145
   for( t=0; t<tabwidth; ++t )
 
2146
   {
 
2147
      fprintf( options.outputLocation, "   ");
 
2148
   }
 
2149
   vfprintf( options.outputLocation, fmt, arg);
 
2150
   va_end(arg);
 
2151
}
 
2152
 
 
2153
 
 
2154
void emitCat( const char *fmt, ...)
 
2155
/*
 
2156
 * Write a string to the output file.  Do not do indenting.
 
2157
 * This is used when a line of code is emitted in parts.
 
2158
 */
 
2159
{
 
2160
   va_list arg;
 
2161
   va_start(arg, fmt);
 
2162
   vfprintf( options.outputLocation, fmt, arg);
 
2163
   va_end(arg);
 
2164
}
 
2165
 
 
2166
 
 
2167
 
 
2168
void indent( int delta )
 
2169
/*
 
2170
 * Adjust the current indent for an emit call. See emit().
 
2171
 */
 
2172
{
 
2173
   tabwidth += delta;
 
2174
   assert( tabwidth >= 0 );
 
2175
}
 
2176
 
 
2177
 
 
2178
static void loadDefaults()
 
2179
/*
 
2180
 * Load all default settings
 
2181
 */
 
2182
{
 
2183
   prefetch.ACols = FALSE;
 
2184
   prefetch.ABlock = FALSE;
 
2185
   prefetch.BCols = FALSE;
 
2186
   prefetch.fetchC = FALSE;
 
2187
   prefetch.prefetchC = FALSE;
 
2188
 
 
2189
   unroll.nb = 60;
 
2190
   unroll.mb = 60;
 
2191
   unroll.kb = 60;
 
2192
 
 
2193
        unroll.a = 1;
 
2194
        unroll.b = 1;
 
2195
        unroll.k = unroll.kb;
 
2196
 
 
2197
        element.vector_length_bytes = 16;
 
2198
        element.type = DOUBLE;
 
2199
 
 
2200
 
 
2201
        options.cAlignment = 1;
 
2202
   options.ABAligned = TRUE;
 
2203
        options.lda = PARAMETER;
 
2204
        options.ldb = PARAMETER;
 
2205
   options.ldc = PARAMETER;
 
2206
        options.verifyUnrollings = FALSE;
 
2207
        options.treatLoadsAsFloat = FALSE;
 
2208
        options.treatStoresAsFloat = FALSE;
 
2209
   options.constantFolding = TRUE;
 
2210
   options.beta = BETA0;
 
2211
   options.outputLocation = stdout;
 
2212
}
 
2213
 
 
2214
 
 
2215
void loadOptions
 
2216
(
 
2217
   int argc, /* Number of command line elements */
 
2218
   char **argv  /* Command line elements */
 
2219
)
 
2220
/*
 
2221
 * Load the options from the command line
 
2222
 */
 
2223
{
 
2224
        char *s;
 
2225
   int tmp;
 
2226
 
 
2227
/* Determine if the user requested the help message */
 
2228
        requestHelp( argc, argv );
 
2229
 
 
2230
/* Load all default settings */
 
2231
        loadDefaults();
 
2232
 
 
2233
/* Load unrolling and blocking factors */
 
2234
        loadInt( "-m", &unroll.a, argc, argv  );
 
2235
        loadInt( "-n", &unroll.b, argc, argv  );
 
2236
        loadInt( "-k", &unroll.k, argc, argv  );
 
2237
 
 
2238
   loadInt( "-M", &unroll.mb, argc, argv );
 
2239
   loadInt( "-N", &unroll.nb, argc, argv );
 
2240
   loadInt( "-K", &unroll.kb, argc, argv );
 
2241
 
 
2242
 
 
2243
        if( unroll.k <= 0 )
 
2244
                unroll.k = unroll.kb;
 
2245
 
 
2246
/* Load the beta factor */
 
2247
   loadInt( "-beta", &options.beta, argc, argv );
 
2248
   if( options.beta > 1 || options.beta < -1 )
 
2249
      options.beta = BETAX;
 
2250
 
 
2251
 
 
2252
/* Determine if alignment checks are requested */
 
2253
        loadInt( "-CAlignment", &options.cAlignment, argc, argv );
 
2254
   loadBool( "-ABAligned", &options.ABAligned, argc, argv );
 
2255
 
 
2256
/* Load the element type: float or double, complex or real */
 
2257
   s = loadString( "-p", argc, argv );
 
2258
        convertElementType( s[0] );
 
2259
 
 
2260
/* Where should the file be written? */
 
2261
   s = loadString( "-f", argc, argv );
 
2262
   setOutputLocation( s );
 
2263
 
 
2264
 
 
2265
        loadBool( "-verifyUnrollings", &options.verifyUnrollings, argc, argv );
 
2266
        loadInt( "-lda", &options.lda, argc, argv );
 
2267
        loadInt( "-ldb", &options.ldb, argc, argv );
 
2268
        loadInt( "-ldc", &options.ldc, argc, argv );
 
2269
        loadBool( "-treatLoadsAsFloat", &options.treatLoadsAsFloat, argc, argv );
 
2270
        loadBool( "-treatStoresAsFloat", &options.treatStoresAsFloat, argc, argv );
 
2271
 
 
2272
   if( options.lda == unroll.kb )
 
2273
      options.lda = USE_KB;
 
2274
 
 
2275
   if( options.ldb == unroll.kb )
 
2276
      options.ldb = USE_KB;
 
2277
 
 
2278
 
 
2279
   loadBool( "-prefetchACols", &prefetch.ACols, argc, argv );
 
2280
   loadBool( "-prefetchABlock", &prefetch.ABlock, argc, argv );
 
2281
   loadBool( "-prefetchBCols", &prefetch.BCols, argc, argv );
 
2282
   loadBool( "-FF", &prefetch.fetchC, argc, argv );
 
2283
 
 
2284
   loadBool( "-prefetchCelts", &prefetch.prefetchC, argc, argv );
 
2285
   loadBool( "-constantFolding", &options.constantFolding, argc, argv );
 
2286
 
 
2287
 
 
2288
 
 
2289
        switch( element.type )
 
2290
        {
 
2291
      case COMPLEX_DOUBLE:
 
2292
                case DOUBLE:
 
2293
                element.size = sizeof( double );
 
2294
                element.shift = 3;
 
2295
      element.vector_stride = 2;
 
2296
 
 
2297
      strcpy( element.intrinsic,  "__m128d" );
 
2298
 
 
2299
 
 
2300
/*    Use the shorter instruction? */
 
2301
      if( options.treatLoadsAsFloat )
 
2302
      {
 
2303
         if( options.ABAligned )
 
2304
         {
 
2305
            strcpy( element.load_ab, "(__m128d)_mm_load_ps" );
 
2306
         } else {
 
2307
            strcpy( element.load_ab, "(__m128d)_mm_loadu_ps" );
 
2308
         }
 
2309
                        strcpy( element.sLoad, "(__m128d)_mm_load_ss" );
 
2310
      } else {
 
2311
         if( options.ABAligned )
 
2312
         {
 
2313
            strcpy( element.load_ab, "_mm_load_pd" );
 
2314
         } else {
 
2315
            strcpy( element.load_ab, "_mm_loadu_pd" );
 
2316
         }
 
2317
                        strcpy( element.sLoad, "_mm_load_sd" );
 
2318
      }
 
2319
 
 
2320
/*    Use the shorter instruction? */
 
2321
      if( options.treatStoresAsFloat )
 
2322
      {
 
2323
         strcpy( element.aStore, "_mm_store_ps" );
 
2324
         strcpy( element.uStore, "_mm_storeu_ps" );
 
2325
                        strcpy( element.sStore, "_mm_store_ss" );
 
2326
      } else {
 
2327
        strcpy( element.aStore, "_mm_store_pd" );
 
2328
        strcpy( element.uStore, "_mm_storeu_pd" );
 
2329
                  strcpy( element.sStore, "_mm_store_sd" );
 
2330
      }
 
2331
 
 
2332
      strcpy( element.type_name, "double" );
 
2333
                break;
 
2334
 
 
2335
      case COMPLEX_SINGLE:
 
2336
      case SINGLE:
 
2337
                element.size = sizeof( float );
 
2338
                element.shift = 2;
 
2339
      element.vector_stride = 4;
 
2340
 
 
2341
 
 
2342
      strcpy( element.intrinsic, "__m128" );
 
2343
      strcpy( element.aStore, "_mm_store_ps" );
 
2344
      strcpy( element.uStore, "_mm_storeu_ps" );
 
2345
                strcpy( element.sStore, "_mm_store_ss" );
 
2346
                strcpy( element.sLoad, "_mm_load_ss" );
 
2347
 
 
2348
 
 
2349
      if( options.ABAligned )
 
2350
      {
 
2351
         strcpy( element.load_ab, "_mm_load_ps" );
 
2352
      } else {
 
2353
         strcpy( element.load_ab, "_mm_loadu_ps" );
 
2354
      }
 
2355
 
 
2356
      strcpy( element.type_name, "float" );
 
2357
                break;
 
2358
 
 
2359
                default:
 
2360
                assert( 0 );
 
2361
        }
 
2362
 
 
2363
 
 
2364
 
 
2365
 
 
2366
   if( numArgsProcessed != (argc-1)/2 )
 
2367
   {
 
2368
      int i;
 
2369
      fprintf( stderr, "Commandline contained unknown arguments\n" );
 
2370
      printf( "There were %d args\n", argc );
 
2371
      for( i=0; i<argc; ++i )
 
2372
      {
 
2373
         printf( "  %s\n", argv[i] );
 
2374
      }
 
2375
      for( i=0; i<argc; ++i )
 
2376
      {
 
2377
         printf( "  %s\n", argv[i] );
 
2378
      }
 
2379
 
 
2380
        printHelp();
 
2381
        }
 
2382
 
 
2383
 
 
2384
 
 
2385
        assert( unroll.kb == 0 || unroll.kb % element.vector_stride == 0 );
 
2386
}
 
2387
 
 
2388
 
 
2389
 
 
2390
char* loadString
 
2391
(
 
2392
   char* tag,  /* The name of the parameter to load */
 
2393
   int argc,   /* Number of command line arguments */
 
2394
   char** argv /* Command line arguments */
 
2395
)
 
2396
/*
 
2397
 * Load a string from the command line and return it.
 
2398
 * RETURNS: The string loaded from the command line,
 
2399
 *          0 if the parameter was not present.
 
2400
 */
 
2401
{
 
2402
        int i;
 
2403
 
 
2404
        for( i=0; i<argc; ++i )
 
2405
        {
 
2406
                if( strcmp( tag, argv[i] ) == 0 )
 
2407
                {
 
2408
                        assert( i+1 < argc );
 
2409
         numArgsProcessed++;
 
2410
                        return argv[i+1];
 
2411
                }
 
2412
        }
 
2413
        return 0;
 
2414
}
 
2415
 
 
2416
 
 
2417
int requestHelp
 
2418
(
 
2419
   int argc,   /* Number of command line parameters */
 
2420
   char** argv /* The command line parameters */
 
2421
)
 
2422
/*
 
2423
 * Determine if the user requested the help message
 
2424
 * by passing a special switch to the program.
 
2425
 */
 
2426
{
 
2427
        int i;
 
2428
        int flag;
 
2429
        const char* tags[] = { "-?", "-h", "--help" };
 
2430
 
 
2431
        for( flag=0; flag<3; ++flag )
 
2432
        {
 
2433
                for( i=0; i<argc; ++i )
 
2434
                {
 
2435
                        if( strcmp( tags[ flag ], argv[i] ) == 0 )
 
2436
                        {
 
2437
            printHelp();
 
2438
         }
 
2439
      }
 
2440
   }
 
2441
return 1;
 
2442
}
 
2443
 
 
2444
void printHelp()
 
2445
/*
 
2446
 * Print the help message to standard output.
 
2447
 */
 
2448
{
 
2449
   fprintf( stdout, "Prints a listing of a GEMM kernel\n" );
 
2450
   fprintf( stdout, "Optional Arguments:\n" );
 
2451
 
 
2452
   fprintf( stdout, "  -p [s,c,z,d] \n" );
 
2453
   fprintf( stdout, "  -f <filename> => File to generate\n" );
 
2454
 
 
2455
/* Print Unrolling Options */
 
2456
   fprintf( stdout, "  -m <int>  => Number of columns of A to unroll\n" );
 
2457
   fprintf( stdout, "  -n <int>  => Number of rows of B to unroll\n" );
 
2458
   fprintf( stdout, "  -k <int>  => Numver of times to unroll along K\n" );
 
2459
 
 
2460
/* Print blocking options */
 
2461
   fprintf( stdout, "  -M <int> => block size, integer or 0 for runtime\n" );
 
2462
   fprintf( stdout, "  -N <int> => block size, integer or 0 for runtime\n" );
 
2463
   fprintf( stdout, "  -K <int> => block size, integer or 0 for runtime\n" );
 
2464
   fprintf( stdout, "  -lda <int> => lda, or 0 for runtime, -1 to use KB\n" );
 
2465
   fprintf( stdout, "  -ldb <int> => ldb, or 0 for runtime, -1 to use KB\n" );
 
2466
   fprintf( stdout, "  -ldc <int> => ldc, or 0 for runtime, -1 to use KB\n" );
 
2467
 
 
2468
/* print algorithm optimizations */
 
2469
   fprintf( stdout, "  -beta <int> => Assume beta value 0, 1, or" );
 
2470
   fprintf( stdout, "other.\n" );
 
2471
   fprintf( stdout, "  -CAlignment  => 0 to test all alignments, " );
 
2472
   fprintf( stdout, "1 to assume misaligned.\n" );
 
2473
   fprintf( stdout, "  -ABAligned <bool> ==> assume that A and B are aligned\n" );
 
2474
   fprintf( stdout, "  -verifyUnrollings <bool> ==> assert that unrolling "
 
2475
                    "params are correct\n" );
 
2476
   fprintf( stdout, "  -treatLoadsAsFloat <bool> ==> use loadps instead of "
 
2477
                    "loadpd\n" );
 
2478
   fprintf( stdout, "  -treatStoresAsFloat <bool> ==> use storeps instead of "
 
2479
                    "storepd\n" );
 
2480
 
 
2481
/* Print prefetch options */
 
2482
   fprintf( stdout, "  -FF <0/1> => Fetch C at top of loop\n" );
 
2483
   fprintf( stdout, "  -prefetchACols <bool> ==> Prefetch the next "
 
2484
                    "unrolling of A\n" );
 
2485
   fprintf( stdout, "  -prefetchABlock <bool> ==> Prefetch the next "
 
2486
                    "block of A\n");
 
2487
   fprintf( stdout, "  -prefetchBCols <bool> ==> Prefetch the next rows "
 
2488
                    " of B\n" );
 
2489
   fprintf( stdout, "  -prefetchC <bool> ==> Prefetch C at the top of "
 
2490
                    "loop\n" );
 
2491
 
 
2492
   fprintf( stdout, "  -constantFolding <bool> ==> Perform constant folding "
 
2493
                    "when possible.\n" );
 
2494
 
 
2495
   exit(1);
 
2496
}
 
2497
 
 
2498
 
 
2499
 
 
2500
void loadInt
 
2501
(
 
2502
   char* tag,   /* The parameter to find */
 
2503
   ATL_INT *value,  /* (output) Populated with the value following the tag */
 
2504
   int argc,    /* Number of command line parameters */
 
2505
   char** argv  /* The command line parameters */
 
2506
)
 
2507
/*
 
2508
 * Reads an integer value from the command line, as specified with a given
 
2509
 * tag.  Value is unchanged if the tag is not found on the command line.
 
2510
 */
 
2511
{
 
2512
        int i;
 
2513
 
 
2514
        for( i=0; i<argc; ++i )
 
2515
        {
 
2516
                if( strcmp( tag, argv[i] ) == 0 )
 
2517
                {
 
2518
                        assert( i+1 < argc );
 
2519
                        *value = atoi( argv[i+1] );
 
2520
         numArgsProcessed++;
 
2521
                        return;
 
2522
                }
 
2523
        }
 
2524
}
 
2525
 
 
2526
 
 
2527
void loadFloat
 
2528
(
 
2529
   char* tag,    /* The parameter to find */
 
2530
   float *value, /* (output) Populated with the value found */
 
2531
   int argc,     /* The number of command line parameters */
 
2532
   char** argv   /* The command line parameters */
 
2533
)
 
2534
/*
 
2535
 * Reads a float value from the command line, as specified with a given
 
2536
 * tag.  Value is unchanged if the tag is not found on the command line.
 
2537
 */
 
2538
{
 
2539
        int i;
 
2540
 
 
2541
        for( i=0; i<argc; ++i )
 
2542
        {
 
2543
                if( strcmp( tag, argv[i] ) == 0 )
 
2544
                {
 
2545
                        assert( i+1 < argc );
 
2546
                        *value = atof( argv[i+1] );
 
2547
         numArgsProcessed++;
 
2548
                        return;
 
2549
                }
 
2550
        }
 
2551
}
 
2552
 
 
2553
void loadBool
 
2554
(
 
2555
    char* tag,  /* The parameter to find */
 
2556
    int *value, /* The value of the parameter */
 
2557
    int argc,   /* Number of command line parameters */
 
2558
    char** argv /* Command line parameters */
 
2559
)
 
2560
/*
 
2561
 * Reads a boolean value from the command line, as specified with a given
 
2562
 * tag.  Value is unchanged if the tag is not found on the command line.
 
2563
 */
 
2564
{
 
2565
        int i;
 
2566
 
 
2567
        for( i=0; i<argc; ++i )
 
2568
        {
 
2569
                if( strcmp( tag, argv[i] ) == 0 )
 
2570
                {
 
2571
                        assert( i+1 < argc );
 
2572
                        if( strcmp( argv[i+1], "1" ) == 0 )
 
2573
                        {
 
2574
                                *value = TRUE;
 
2575
                        } else if( strcmp( argv[i+1], "0" ) == 0 ) {
 
2576
                                *value = FALSE;
 
2577
                        } else {
 
2578
                                fprintf( stderr, "ERROR: option tag \"%s\" requires"
 
2579
                                        " 1 or 0 value.\n", tag );
 
2580
                                exit( 1 );
 
2581
                        }
 
2582
 
 
2583
         numArgsProcessed++;
 
2584
                        return;
 
2585
                }
 
2586
        }
 
2587
}
 
2588
 
 
2589
 
 
2590
 
 
2591
void getNBString
 
2592
(
 
2593
   char* out   /* The string to write to */
 
2594
)
 
2595
/* Generate a c compatible expression that defines
 
2596
 * the value of NB.  This may be a constant or a
 
2597
 * variable from the generated program.
 
2598
 */
 
2599
{
 
2600
   switch( unroll.nb )
 
2601
   {
 
2602
   case PARAMETER:
 
2603
      sprintf( out, "%s", "N" );
 
2604
      break;
 
2605
   case USE_KB:
 
2606
      sprintf( out, "%s", "K" );
 
2607
      break;
 
2608
   default:
 
2609
      sprintf( out, "%d", unroll.nb );
 
2610
   }
 
2611
}
 
2612
 
 
2613
 
 
2614
 
 
2615
void getMBString
 
2616
(
 
2617
   char* out  /* The string to write to */
 
2618
)
 
2619
/* Generate a c compatible expression that defines
 
2620
 * the value of MB.  This may be a constant or a
 
2621
 * variable from the generated program.
 
2622
 */
 
2623
{
 
2624
   switch( unroll.mb )
 
2625
   {
 
2626
   case PARAMETER:
 
2627
      sprintf( out, "%s", "M" );
 
2628
      break;
 
2629
   case USE_KB:
 
2630
      sprintf( out, "%s", "K" );
 
2631
      break;
 
2632
   default:
 
2633
      sprintf( out, "%d", unroll.mb );
 
2634
   }
 
2635
}
 
2636
 
 
2637
 
 
2638
void getKBString
 
2639
(
 
2640
   char* out  /* The string to write to */
 
2641
)
 
2642
/*
 
2643
 * Generate a c compatible expression that defines
 
2644
 * the value of KB.  This may be a constant or a
 
2645
 * variable from the generated program.
 
2646
 */
 
2647
{
 
2648
   switch( unroll.kb )
 
2649
   {
 
2650
   case PARAMETER:
 
2651
      sprintf( out, "%s", "K" );
 
2652
      break;
 
2653
   case USE_KB:
 
2654
      sprintf( out, "%s", "KB" );
 
2655
      break;
 
2656
   default:
 
2657
      sprintf( out, "%d", unroll.kb );
 
2658
   }
 
2659
}
 
2660
 
 
2661
 
 
2662
 
 
2663
void convertElementType
 
2664
(
 
2665
   char specifier   /* The type specifier */
 
2666
)
 
2667
/*
 
2668
 * Interpret a type specifier to determine what
 
2669
 * datatype is generated.
 
2670
 */
 
2671
{
 
2672
 
 
2673
 
 
2674
   switch( specifier )
 
2675
   {
 
2676
      case 's':
 
2677
      element.type = SINGLE;
 
2678
      element.cType = 's';
 
2679
      break;
 
2680
 
 
2681
      case 'd':
 
2682
      element.type = DOUBLE;
 
2683
      element.cType = 'd';
 
2684
      break;
 
2685
 
 
2686
      case 'z':
 
2687
      element.type = COMPLEX_DOUBLE;
 
2688
      element.cType = 'd';
 
2689
      break;
 
2690
 
 
2691
      case 'c':
 
2692
      element.type = COMPLEX_SINGLE;
 
2693
      element.cType = 's';
 
2694
      break;
 
2695
 
 
2696
      default:
 
2697
      fprintf( stderr, "Element type \"%c\" is not valid\n", specifier );
 
2698
      printHelp();
 
2699
   }
 
2700
}
 
2701
 
 
2702
 
 
2703
void setOutputLocation
 
2704
(
 
2705
   char* file  /* The name of the file to generate */
 
2706
)
 
2707
/*
 
2708
 * Set the output file location.
 
2709
 * Use NULL for stdout, "" will not change the current setting.
 
2710
 */
 
2711
{
 
2712
   if( file == NULL )
 
2713
   {
 
2714
      options.outputLocation = stdout;
 
2715
      return;
 
2716
   }
 
2717
 
 
2718
   if( strcmp( file, "" ) != 0 )
 
2719
   {
 
2720
      options.outputLocation = fopen( file, "w" );
 
2721
   }
 
2722
}
 
2723