2
* Automatically Tuned Linear Algebra Software v3.2
3
* (C) Copyright 1999 R. Clint Whaley
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions
8
* 1. Redistributions of source code must retain the above copyright
9
* notice, this list of conditions and the following disclaimer.
10
* 2. Redistributions in binary form must reproduce the above copyright
11
* notice, this list of conditions, and the following disclaimer in the
12
* documentation and/or other materials provided with the distribution.
13
* 3. The name of the University of Tennessee, the ATLAS group,
14
* or the names of its contributers may not be used to endorse
15
* or promote products derived from this software without specific
18
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
22
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
* POSSIBILITY OF SUCH DAMAGE.
32
#include "atlas_misc.h"
33
#include "atlas_lvl3.h"
36
#define KBmm Mjoin(PATL,pKBmm)
37
#define IBNBmm Mjoin(PATL,IBNBmm)
38
#define NBJBmm Mjoin(PATL,MBJBmm)
39
#define IBJBmm Mjoin(PATL,IBJBmm)
41
void Mjoin(PATL,mmJIK2)
42
(int K, int nMb, int nNb, int nKb, int ib, int jb, int kb,
43
const SCALAR alpha, const TYPE *pA0, const TYPE *B, int ldb,
44
TYPE *pB0, int incB, MAT2BLK B2blk, const SCALAR beta,
45
TYPE *C, int ldc, MATSCAL gescal, NBMM0 NBmm0)
47
const int incK = ATL_MulByNB(K)SHIFT, incC = ATL_MulByNB(ldc-nMb) SHIFT;
48
const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta));
51
const TYPE rbeta = ( (gescal) ? ATL_rone : *beta );
52
TYPE *pB=pB0, *stB=pB0+(ATL_MulByNBNB(nKb)SHIFT);
56
do /* Loop over full column panels of B */
60
B2blk(K, NB, B, ldb, pB, alpha);
66
do /* loop over full row panels of A */
68
if (gescal) gescal(NB, NB, beta, C, ldc);
69
if (nKb) /* loop over full blocks in panels */
71
NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc);
78
NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB, ATL_rone,
87
KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone,
89
pA += ATL_MulByNB(kb)<<1;
94
if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc);
95
KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc);
96
pA += ATL_MulByNB(kb)<<1;
105
if (gescal) gescal(ib, NB, beta, C, ldc);
106
IBNBmm(ib, K, pA, pB, rbeta, C, ldc);
121
if (B) B2blk(K, jb, B, ldb, pB, alpha);
124
if (gescal) gescal(NB, jb, beta, C, ldc);
125
NBJBmm(jb, K, pA, pB, rbeta, C, ldc);
131
if (gescal) gescal(ib, jb, beta, C, ldc);
132
IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc);
137
int Mjoin(PATL,mmJIK)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
138
const int M0, const int N, const int K,
139
const SCALAR alpha, const TYPE *A, const int lda,
140
const TYPE *B, const int ldb, const SCALAR beta,
141
TYPE *C, const int ldc)
143
* Outer three loops for matmul with outer loop over columns of B
147
int nMb, nNb, nKb, ib, jb, kb, ib2, h, i, j, k, m, n, incA, incB, incC;
149
const int incK = ATL_MulByNB(K);
150
void *vB=NULL, *vC=NULL;
152
const TYPE one[2] = {1.0,0.0}, zero[2] = {0.0,0.0};
153
MAT2BLK A2blk, B2blk;
157
nMb = ATL_DivByNB(M);
158
nNb = ATL_DivByNB(N);
159
nKb = ATL_DivByNB(K);
160
ib = M - ATL_MulByNB(nMb);
161
jb = N - ATL_MulByNB(nNb);
162
kb = K - ATL_MulByNB(nKb);
165
if (beta[1] == ATL_rzero)
168
if (*beta == ATL_rone) NBmm0 = Mjoin(PATL,CNBmm_b1);
169
else if (*beta == ATL_rzero) NBmm0 = Mjoin(PATL,CNBmm_b0);
170
else NBmm0 = Mjoin(PATL,CNBmm_bX);
174
NBmm0 = Mjoin(PATL,CNBmm_b1);
175
gescal = Mjoin(PATL,gescal_bX);
178
* Special case for when what we are really doing is
179
* C <- beta*C + alpha * A * A' or C <- beta*C + alpha * A' * A
181
if ( A == B && M == N && TA != TB && (SCALAR_IS_ONE(alpha) || M <= NB)
182
&& TA != AtlasConjTrans && TB != AtlasConjTrans )
184
AlphaIsOne = SCALAR_IS_ONE(alpha);
185
i = ATL_MulBySize(M * K);
186
if (!AlphaIsOne && pC == C && !SCALAR_IS_ZERO(beta))
187
i += ATL_MulBySize(M*N);
188
if (i <= ATL_MaxMalloc) vB = malloc(i + ATL_Cachelen);
191
pA = ATL_AlignPtr(vB);
192
if (TA == AtlasNoTrans)
193
Mjoin(PATL,row2blkT2_a1)(M, K, A, lda, pA, alpha);
194
else Mjoin(PATL,col2blk_a1)(K, M, A, lda, pA, alpha);
196
* Can't write directly to C if alpha is not one
200
if (SCALAR_IS_ZERO(beta)) h = ldc;
203
pC = pA + (M * K SHIFT);
207
Mjoin(PATL,mmJIK2)(K, nMb, nNb, nKb, ib, jb, kb, one, pA, NULL,
208
ldb, pA, 0, NULL, zero, pC, h,
209
Mjoin(PATL,gescal_b0), Mjoin(PATL,CNBmm_b0));
211
if (alpha[1] == ATL_rzero)
212
Mjoin(PATL,gescal_bXi0)(M, N, alpha, pC, h);
213
else Mjoin(PATL,gescal_bX)(M, N, alpha, pC, h);
217
if (beta[1] == ATL_rzero)
219
if (*beta == ATL_rone)
220
Mjoin(PATL,putblk_b1)(M, N, pC, C, ldc, beta);
221
else if (*beta == ATL_rnone)
222
Mjoin(PATL,putblk_bn1)(M, N, pC, C, ldc, beta);
223
else if (*beta == ATL_rzero)
224
Mjoin(PATL,putblk_b0)(M, N, pC, C, ldc, beta);
225
else Mjoin(PATL,putblk_bXi0)(M, N, pC, C, ldc, beta);
227
else Mjoin(PATL,putblk_bX)(M, N, pC, C, ldc, beta);
230
else Mjoin(PATL,mmJIK2)(K, nMb, nNb, nKb, ib, jb, kb, alpha, pA, NULL,
231
ldb, pA, 0, NULL, beta, C, ldc, gescal, NBmm0);
237
i = ATL_Cachelen + ATL_MulBySize(M*K + incK);
238
if (i <= ATL_MaxMalloc) vB = malloc(i);
241
if (TA != AtlasNoTrans && TB != AtlasNoTrans) return(1);
249
h = ATL_Cachelen + ATL_MulBySize((k+1)*incK);
250
if (h <= ATL_MaxMalloc) vB = malloc(h);
263
pB = ATL_AlignPtr(vB);
264
if (TA == AtlasNoTrans)
267
if (alpha[1] == ATL_rzero)
269
if (*alpha == ATL_rone) A2blk = Mjoin(PATL,row2blkT2_a1);
270
else A2blk = Mjoin(PATL,row2blkT2_aXi0);
272
else A2blk = Mjoin(PATL,row2blkT2_aX);
274
else if (TA == AtlasConjTrans)
277
if (alpha[1] == ATL_rzero)
279
if (*alpha == ATL_rone) A2blk = Mjoin(PATL,col2blkConj2_a1);
280
else A2blk = Mjoin(PATL,col2blkConj2_aXi0);
282
else A2blk = Mjoin(PATL,col2blkConj2_aX);
287
if (alpha[1] == ATL_rzero)
289
if (*alpha == ATL_rone) A2blk = Mjoin(PATL,col2blk2_a1);
290
else A2blk = Mjoin(PATL,col2blk2_aXi0);
292
else A2blk = Mjoin(PATL,col2blk2_aX);
294
if (TB == AtlasNoTrans)
296
incB = ATL_MulByNB(ldb) SHIFT;
297
B2blk = Mjoin(PATL,col2blk_a1);
299
else if (TB == AtlasConjTrans)
302
B2blk = Mjoin(PATL,row2blkC_a1);
307
B2blk = Mjoin(PATL,row2blkT_a1);
311
pA = pB + (incK SHIFT);
314
if (TA == AtlasNoTrans) A2blk(m, K, A, lda, pA, alpha);
315
else A2blk(K, m, A, lda, pA, alpha);
316
Mjoin(PATL,mmJIK2)(K, n, nNb, nKb, ib2, jb, kb, alpha, pA, B, ldb, pB,
317
incB, B2blk, beta, C, ldc, gescal, NBmm0);