2
* Automatically Tuned Linear Algebra Software v3.2
3
* (C) Copyright 1999 R. Clint Whaley
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions
8
* 1. Redistributions of source code must retain the above copyright
9
* notice, this list of conditions and the following disclaimer.
10
* 2. Redistributions in binary form must reproduce the above copyright
11
* notice, this list of conditions, and the following disclaimer in the
12
* documentation and/or other materials provided with the distribution.
13
* 3. The name of the University of Tennessee, the ATLAS group,
14
* or the names of its contributers may not be used to endorse
15
* or promote products derived from this software without specific
18
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
22
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
* POSSIBILITY OF SUCH DAMAGE.
32
#include "atlas_misc.h"
33
#include "atlas_level1.h"
34
#include "atlas_level2.h"
35
#include "atlas_lvl2.h"
38
static void gemvMlt8(const int M, const int N, const TYPE *A, const int lda,
39
const TYPE *X, const SCALAR beta, TYPE *Y);
40
static void gemvNle4(const int M, const int N, const TYPE *A, const int lda,
41
const TYPE *X, const SCALAR beta, TYPE *Y);
43
#define Yget(y_, yp_, bet_) (y_) = ATL_rzero
45
#define Yget(y_, yp_, bet_) (y_) = (yp_) * (bet_)
47
#define Yget(y_, yp_, bet_) (y_) = (yp_)
49
static void gemvN32x4(const int M, const int N, const TYPE *A, const int lda,
50
const TYPE *x, const SCALAR beta0, TYPE *y)
52
* rank-4 daxpy based NoTrans gemv
55
const int M16 = (M>>4)<<4;
56
TYPE *stY = y + M16 - 32;
57
const TYPE *A0 = A, *A1 = A+lda, *A2 = A1 + lda, *A3 = A2 + lda;
58
register TYPE z0, z1, z2, z3, z4, z5, z6, z7;
59
register TYPE y0, y1, y2, y3, y4, y5, y6, y7;
60
const register TYPE x0 = *x, x1 = x[1], x2 = x[2], x3 = x[3];
62
const register TYPE beta = beta0;
71
y0 = y1 = y2 = y3 = y4 = y5 = y6 = y7 = ATL_rzero;
73
y0 = *y; y1 = y[1]; y2 = y[2]; y3 = y[3];
74
y4 = y[4]; y5 = y[5]; y6 = y[6]; y7 = y[7];
76
y0 *= beta; y1 *= beta; y2 *= beta; y3 *= beta;
77
y4 *= beta; y5 *= beta; y6 *= beta; y7 *= beta;
80
y0 += x0 * *A0; Yget(z0, y[8], beta);
82
y2 += x2 * A2[2]; Yget(z1, y[9], beta);
84
y4 += x0 * A0[4]; Yget(z2, y[10], beta);
86
y6 += x2 * A2[6]; Yget(z3, y[11], beta);
89
y0 += x1 * *A1; Yget(z4, y[12], beta);
91
y2 += x3 * A3[2]; Yget(z5, y[13], beta);
93
y4 += x1 * A1[4]; Yget(z6, y[14], beta);
95
y6 += x3 * A3[6]; Yget(z7, y[15], beta);
116
z0 += x0 * A0[8]; *y = y0;
118
z2 += x2 * A2[10]; y[1] = y1;
120
z4 += x0 * A0[12]; y[2] = y2;
122
z6 += x2 * A2[14]; y[3] = y3;
125
z0 += x1 * A1[8]; y[4] = y4;
127
z2 += x3 * A3[10]; y[5] = y5;
129
z4 += x1 * A1[12]; y[6] = y6;
131
z6 += x3 * A3[14]; y[7] = y7;
134
z0 += x2 * A2[8]; Yget(y0, y[16], beta);
136
z2 += x0 * A0[10]; Yget(y1, y[17], beta);
138
z4 += x2 * A2[12]; Yget(y2, y[18], beta);
140
z6 += x0 * A0[14]; Yget(y3, y[19], beta);
143
z0 += x3 * A3[8]; Yget(y4, y[20], beta);
145
z2 += x1 * A1[10]; Yget(y5, y[21], beta);
147
z4 += x3 * A3[12]; Yget(y6, y[22], beta); A3 += 16;
148
z5 += x0 * A0[13]; A0 += 16;
149
z6 += x1 * A1[14]; Yget(y7, y[23], beta); A1 += 16;
150
z7 += x2 * A2[15]; A2 += 16;
155
y0 += x0 * *A0; y[8] = z0;
157
y2 += x2 * A2[2]; y[9] = z1;
159
y4 += x0 * A0[4]; y[10] = z2;
161
y6 += x2 * A2[6]; y[11] = z3;
164
y0 += x1 * *A1; y[12] = z4;
166
y2 += x3 * A3[2]; y[13] = z5;
168
y4 += x1 * A1[4]; y[14] = z6;
170
y6 += x3 * A3[6]; y[15] = z7; y += 16;
173
y0 += x2 * *A2; Yget(z0, y[8], beta);
175
y2 += x0 * A0[2]; Yget(z1, y[9], beta);
177
y4 += x2 * A2[4]; Yget(z2, y[10], beta);
179
y6 += x0 * A0[6]; Yget(z3, y[11], beta);
182
y0 += x3 * *A3; Yget(z4, y[12], beta);
184
y2 += x1 * A1[2]; Yget(z5, y[13], beta);
186
y4 += x3 * A3[4]; Yget(z6, y[14], beta);
188
y6 += x1 * A1[6]; Yget(z7, y[15], beta);
191
z0 += x0 * A0[8]; *y = y0;
193
z2 += x2 * A2[10]; y[1] = y1;
195
z4 += x0 * A0[12]; y[2] = y2;
197
z6 += x2 * A2[14]; y[3] = y3;
200
z0 += x1 * A1[8]; y[4] = y4;
202
z2 += x3 * A3[10]; y[5] = y5;
204
z4 += x1 * A1[12]; y[6] = y6;
206
z6 += x3 * A3[14]; y[7] = y7;
209
z0 += x2 * A2[8]; Yget(y0, y[16], beta);
211
z2 += x0 * A0[10]; Yget(y1, y[17], beta);
213
z4 += x2 * A2[12]; Yget(y2, y[18], beta);
215
z6 += x0 * A0[14]; Yget(y3, y[19], beta);
218
z0 += x3 * A3[8]; Yget(y4, y[20], beta);
220
z2 += x1 * A1[10]; Yget(y5, y[21], beta);
222
z4 += x3 * A3[12]; Yget(y6, y[22], beta); A3 += 16;
223
z5 += x0 * A0[13]; A0 += 16;
224
z6 += x1 * A1[14]; Yget(y7, y[23], beta); A1 += 16;
225
z7 += x2 * A2[15]; A2 += 16;
229
y0 += x0 * *A0; y[8] = z0;
231
y2 += x2 * A2[2]; y[9] = z1;
233
y4 += x0 * A0[4]; y[10] = z2;
235
y6 += x2 * A2[6]; y[11] = z3;
238
y0 += x1 * *A1; y[12] = z4;
240
y2 += x3 * A3[2]; y[13] = z5;
242
y4 += x1 * A1[4]; y[14] = z6;
244
y6 += x3 * A3[6]; y[15] = z7; y += 16;
247
y0 += x2 * *A2; Yget(z0, y[8], beta);
249
y2 += x0 * A0[2]; Yget(z1, y[9], beta);
251
y4 += x2 * A2[4]; Yget(z2, y[10], beta);
253
y6 += x0 * A0[6]; Yget(z3, y[11], beta);
256
y0 += x3 * *A3; Yget(z4, y[12], beta);
258
y2 += x1 * A1[2]; Yget(z5, y[13], beta);
260
y4 += x3 * A3[4]; Yget(z6, y[14], beta);
262
y6 += x1 * A1[6]; Yget(z7, y[15], beta);
265
z0 += x0 * A0[8]; *y = y0;
267
z2 += x2 * A2[10]; y[1] = y1;
269
z4 += x0 * A0[12]; y[2] = y2;
271
z6 += x2 * A2[14]; y[3] = y3;
274
z0 += x1 * A1[8]; y[4] = y4;
276
z2 += x3 * A3[10]; y[5] = y5;
278
z4 += x1 * A1[12]; y[6] = y6;
280
z6 += x3 * A3[14]; y[7] = y7;
308
if (M-M16) gemvMlt8(M-M16, N, A0+16, lda, x, beta, y+16);
310
else if (N) gemvMlt8(M, N, A, lda, x, beta, y);
313
static void gemv32x4(const int M, const int N, const TYPE *A, const int lda,
314
const TYPE *X, const SCALAR beta, TYPE *Y)
319
const int incA = lda<<2;
326
for (j=(N>>2); j; j--, A += incA, X += 4)
327
gemvN32x4(M, 4, A, lda, X, ATL_rone, Y);
328
if ( (j = N-((N>>2)<<2)) ) gemvNle4(M, j, A, lda, X, ATL_rone, Y);
330
gemvN32x4(M, 4, A, lda, X, beta, Y);
332
Mjoin(PATL,gemvN_a1_x1_b1_y1)
333
(M, N-4, ATL_rone, A+incA, lda, X+4, 1, ATL_rone, Y, 1);
336
else gemvMlt8(M, N, A, lda, X, beta, Y);
338
else if (M) gemvNle4(M, N, A, lda, X, beta, Y);
341
static void gemvMlt8(const int M, const int N, const TYPE *A, const int lda,
342
const TYPE *X, const SCALAR beta, TYPE *Y)
349
y0 = Mjoin(PATL,dot)(N, A, lda, X, 1);
352
y0 += Mjoin(PATL,dot)(N, A, lda, X, 1);
358
static void gemvNle4(const int M, const int N, const TYPE *A, const int lda,
359
const TYPE *X, const SCALAR beta, TYPE *Y)
362
const TYPE *A0 = A, *A1 = A+lda, *A2 = A1+lda, *A3 = A2+lda;
363
register TYPE x0, x1, x2, x3;
365
const register TYPE bet=beta;
372
Mjoin(PATL,move)(M, *X, A, 1, Y, 1);
374
Mjoin(PATL,axpby)(M, *X, A, 1, beta, Y, 1);
376
Mjoin(PATL,axpy)(M, *X, A, 1, Y, 1);
381
for (i=0; i != M; i++)
383
Y[i] = A0[i] * x0 + A1[i] * x1;
385
Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1;
387
Y[i] += A0[i] * x0 + A1[i] * x1;
391
x0 = *X; x1 = X[1]; x2 = X[2];
392
for (i=0; i != M; i++)
394
Y[i] = A0[i] * x0 + A1[i] * x1 + A2[i] * x2;
396
Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1 + A2[i] * x2;
398
Y[i] += A0[i] * x0 + A1[i] * x1 + A2[i] * x2;
402
if (M >= 32) gemv32x4(M, 4, A, lda, X, beta, Y);
405
x0 = *X; x1 = X[1]; x2 = X[2]; x3 = X[3];
406
for (i=0; i != M; i++)
408
Y[i] = A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;
410
Y[i] = Y[i]*bet + A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;
412
Y[i] += A0[i] * x0 + A1[i] * x1 + A2[i] * x2 + A3[i] * x3;
421
void Mjoin(Mjoin(Mjoin(Mjoin(Mjoin(PATL,gemvN),NM),_x1),BNM),_y1)
422
(const int M, const int N, const SCALAR alpha, const TYPE *A, const int lda,
423
const TYPE *X, const int incX, const SCALAR beta, TYPE *Y, const int incY)
425
const int incA = lda<<1, incAm = 4 - ((N>>1)<<1)*lda;
426
const int m4 = (M>>2)<<2;
428
register TYPE y0, y1, y2, y3, z0, z1, z2, z3, x0, x1, m0, m1, m2, m3;
429
register TYPE a00, a10, a20, a30, a01, a11, a21, a31;
430
const TYPE *x, *stX = X + ((N>>1)<<1)-2, *A0 = A, *A1 = A + lda;
443
z0 = z1 = z2 = z3 = y0 = y1 = y2 = y3 = ATL_rzero;
456
y0 = y1 = y2 = y3 = ATL_rzero;
629
for (nr=M-m4; nr; nr--)
632
y0 = Mjoin(PATL,dot)(N, A0, lda, X, 1);
639
y0 += Mjoin(PATL,dot)(N, A0, lda, X, 1);
645
else if (M) gemvNle4(M, N, A, lda, X, beta, Y);