~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to src/blas/gemm/ATL_cmmJIK.c

  • Committer: Bazaar Package Importer
  • Author(s): Camm Maguire
  • Date: 2002-04-13 10:07:52 UTC
  • Revision ID: james.westby@ubuntu.com-20020413100752-va9zm0rd4gpurdkq
Tags: upstream-3.2.1ln
ImportĀ upstreamĀ versionĀ 3.2.1ln

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *             Automatically Tuned Linear Algebra Software v3.2
 
3
 *                    (C) Copyright 1999 R. Clint Whaley                     
 
4
 *
 
5
 * Redistribution and use in source and binary forms, with or without
 
6
 * modification, are permitted provided that the following conditions
 
7
 * are met:
 
8
 *   1. Redistributions of source code must retain the above copyright
 
9
 *      notice, this list of conditions and the following disclaimer.
 
10
 *   2. Redistributions in binary form must reproduce the above copyright
 
11
 *      notice, this list of conditions, and the following disclaimer in the
 
12
 *      documentation and/or other materials provided with the distribution.
 
13
 *   3. The name of the University of Tennessee, the ATLAS group,
 
14
 *      or the names of its contributers may not be used to endorse
 
15
 *      or promote products derived from this software without specific
 
16
 *      written permission.
 
17
 *
 
18
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 
19
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 
20
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 
21
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 
22
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 
23
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 
24
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 
25
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 
26
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 
27
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 
28
 * POSSIBILITY OF SUCH DAMAGE. 
 
29
 *
 
30
 */
 
31
 
 
32
#include "atlas_misc.h"
 
33
#include "atlas_lvl3.h"
 
34
#include <stdlib.h>
 
35
 
 
36
#define KBmm Mjoin(PATL,pKBmm)
 
37
#define IBNBmm Mjoin(PATL,IBNBmm)
 
38
#define NBJBmm Mjoin(PATL,MBJBmm)
 
39
#define IBJBmm Mjoin(PATL,IBJBmm)
 
40
 
 
41
void Mjoin(PATL,mmJIK2)
 
42
             (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb, 
 
43
              const SCALAR alpha, const TYPE *pA0, const TYPE *B, int ldb, 
 
44
              TYPE *pB0, int incB, MAT2BLK B2blk, const SCALAR beta, 
 
45
              TYPE *C, int ldc, MATSCAL gescal, NBMM0 NBmm0)
 
46
{
 
47
   const int incK = ATL_MulByNB(K)SHIFT, incC = ATL_MulByNB(ldc-nMb) SHIFT;
 
48
   const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta));
 
49
   int i, j = nNb;
 
50
   const TYPE *pA=pA0;
 
51
   const TYPE rbeta = ( (gescal) ? ATL_rone : *beta );
 
52
   TYPE *pB=pB0, *stB=pB0+(ATL_MulByNBNB(nKb)SHIFT);
 
53
 
 
54
   if (nNb)
 
55
   {
 
56
      do  /* Loop over full column panels of B */
 
57
      {
 
58
         if (B)
 
59
         {
 
60
            B2blk(K, NB, B, ldb, pB, alpha);
 
61
            B += incB;
 
62
         }
 
63
         if (nMb)
 
64
         {
 
65
            i = nMb;
 
66
            do /* loop over full row panels of A */
 
67
            {
 
68
               if (gescal) gescal(NB, NB, beta, C, ldc);
 
69
               if (nKb) /* loop over full blocks in panels */
 
70
               {
 
71
                  NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc);
 
72
                  pA += NBNB2;
 
73
                  pB += NBNB2;
 
74
                  if (nKb != 1)
 
75
                  {
 
76
                     do
 
77
                     {
 
78
                        NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB, ATL_rone, 
 
79
                                C, ldc);
 
80
                        pA += NBNB2;
 
81
                        pB += NBNB2;
 
82
                     }
 
83
                     while (pB != stB);
 
84
                  }
 
85
                  if (kb)
 
86
                  {
 
87
                     KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone, 
 
88
                          C, ldc);
 
89
                     pA += ATL_MulByNB(kb)<<1;
 
90
                  }
 
91
               }
 
92
               else if (kb)
 
93
               {
 
94
                  if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc);
 
95
                  KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc);
 
96
                  pA += ATL_MulByNB(kb)<<1;
 
97
               }
 
98
               pB = pB0;
 
99
               C += NB2;
 
100
            }
 
101
            while (--i);
 
102
         }
 
103
         if (ib) 
 
104
         {
 
105
            if (gescal) gescal(ib, NB, beta, C, ldc);
 
106
            IBNBmm(ib, K, pA, pB, rbeta, C, ldc);
 
107
         }
 
108
         if (!B)
 
109
         {
 
110
            pB0 += incK;
 
111
            pB = pB0;
 
112
            stB += incK;
 
113
         }
 
114
         C += incC;
 
115
         pA = pA0;
 
116
      }
 
117
      while (--j);
 
118
   }
 
119
   if (jb)
 
120
   {
 
121
      if (B) B2blk(K, jb, B, ldb, pB, alpha);
 
122
      for (i=nMb; i; i--)
 
123
      {
 
124
         if (gescal) gescal(NB, jb, beta, C, ldc);
 
125
         NBJBmm(jb, K, pA, pB, rbeta, C, ldc);
 
126
         pA += incK;
 
127
         C += NB2;
 
128
      }
 
129
      if (ib)
 
130
      {
 
131
         if (gescal) gescal(ib, jb, beta, C, ldc);
 
132
         IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc);
 
133
      }
 
134
   }
 
135
}
 
136
 
 
137
int Mjoin(PATL,mmJIK)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, 
 
138
                      const int M0, const int N, const int K, 
 
139
                      const SCALAR alpha, const TYPE *A, const int lda, 
 
140
                      const TYPE *B, const int ldb, const SCALAR beta, 
 
141
                      TYPE *C, const int ldc)
 
142
/*
 
143
 * Outer three loops for matmul with outer loop over columns of B
 
144
 */
 
145
{
 
146
   int M = M0;
 
147
   int nMb, nNb, nKb, ib, jb, kb, ib2, h, i, j, k, m, n, incA, incB, incC;
 
148
   int AlphaIsOne;
 
149
   const int incK = ATL_MulByNB(K);
 
150
   void *vB=NULL, *vC=NULL;
 
151
   TYPE *pA, *pB, *pC;
 
152
   const TYPE one[2] = {1.0,0.0}, zero[2] = {0.0,0.0};
 
153
   MAT2BLK A2blk, B2blk;
 
154
   MATSCAL gescal;
 
155
   NBMM0 NBmm0;
 
156
 
 
157
   nMb = ATL_DivByNB(M);
 
158
   nNb = ATL_DivByNB(N);
 
159
   nKb = ATL_DivByNB(K);
 
160
   ib = M - ATL_MulByNB(nMb);
 
161
   jb = N - ATL_MulByNB(nNb);
 
162
   kb = K - ATL_MulByNB(nKb);
 
163
 
 
164
   pC = C;
 
165
   if (beta[1] == ATL_rzero)
 
166
   {
 
167
      gescal = NULL;
 
168
      if (*beta == ATL_rone) NBmm0 = Mjoin(PATL,CNBmm_b1);
 
169
      else if (*beta == ATL_rzero) NBmm0 = Mjoin(PATL,CNBmm_b0);
 
170
      else NBmm0 = Mjoin(PATL,CNBmm_bX);
 
171
   }
 
172
   else
 
173
   {
 
174
      NBmm0 = Mjoin(PATL,CNBmm_b1);
 
175
      gescal = Mjoin(PATL,gescal_bX);
 
176
   }
 
177
/*
 
178
 * Special case for when what we are really doing is 
 
179
 *    C <- beta*C + alpha * A * A'   or   C <- beta*C + alpha * A' * A
 
180
 */
 
181
   if ( A == B && M == N && TA != TB && (SCALAR_IS_ONE(alpha) || M <= NB)
 
182
        && TA != AtlasConjTrans && TB != AtlasConjTrans )
 
183
   {
 
184
      AlphaIsOne = SCALAR_IS_ONE(alpha);
 
185
      i = ATL_MulBySize(M * K);
 
186
      if (!AlphaIsOne && pC == C && !SCALAR_IS_ZERO(beta)) 
 
187
         i += ATL_MulBySize(M*N);
 
188
      if (i <= ATL_MaxMalloc) vB = malloc(i + ATL_Cachelen);
 
189
      if (vB)
 
190
      {
 
191
         pA = ATL_AlignPtr(vB);
 
192
         if (TA == AtlasNoTrans)
 
193
            Mjoin(PATL,row2blkT2_a1)(M, K, A, lda, pA, alpha);
 
194
         else Mjoin(PATL,col2blk_a1)(K, M, A, lda, pA, alpha);
 
195
/*
 
196
 *       Can't write directly to C if alpha is not one
 
197
 */
 
198
         if (!AlphaIsOne)
 
199
         {
 
200
            if (SCALAR_IS_ZERO(beta)) h = ldc;
 
201
            else if (pC == C)
 
202
            {
 
203
               pC = pA + (M * K SHIFT);
 
204
               h = M;
 
205
            }
 
206
            else h = NB;
 
207
            Mjoin(PATL,mmJIK2)(K, nMb, nNb, nKb, ib, jb, kb, one, pA, NULL, 
 
208
                               ldb, pA, 0, NULL, zero, pC, h, 
 
209
                               Mjoin(PATL,gescal_b0), Mjoin(PATL,CNBmm_b0));
 
210
 
 
211
            if (alpha[1] == ATL_rzero) 
 
212
               Mjoin(PATL,gescal_bXi0)(M, N, alpha, pC, h);
 
213
            else Mjoin(PATL,gescal_bX)(M, N, alpha, pC, h);
 
214
 
 
215
            if (C != pC)
 
216
            {
 
217
               if (beta[1] == ATL_rzero)
 
218
               {
 
219
                  if (*beta == ATL_rone) 
 
220
                     Mjoin(PATL,putblk_b1)(M, N, pC, C, ldc, beta);
 
221
                  else if (*beta == ATL_rnone)
 
222
                     Mjoin(PATL,putblk_bn1)(M, N, pC, C, ldc, beta);
 
223
                  else if (*beta == ATL_rzero)
 
224
                     Mjoin(PATL,putblk_b0)(M, N, pC, C, ldc, beta);
 
225
                  else Mjoin(PATL,putblk_bXi0)(M, N, pC, C, ldc, beta);
 
226
               }
 
227
               else Mjoin(PATL,putblk_bX)(M, N, pC, C, ldc, beta);
 
228
            }
 
229
         }
 
230
         else Mjoin(PATL,mmJIK2)(K, nMb, nNb, nKb, ib, jb, kb, alpha, pA, NULL, 
 
231
                                 ldb, pA, 0, NULL, beta, C, ldc, gescal, NBmm0);
 
232
         free(vB);
 
233
         if (vC) free(vC);
 
234
         return(0);
 
235
      }
 
236
   }
 
237
   i = ATL_Cachelen + ATL_MulBySize(M*K + incK);
 
238
   if (i <= ATL_MaxMalloc) vB = malloc(i);
 
239
   if (!vB)
 
240
   {
 
241
      if (TA != AtlasNoTrans && TB != AtlasNoTrans) return(1);
 
242
      if (ib) n = nMb + 1;
 
243
      else n = nMb;
 
244
      for (j=2; !vB; j++)
 
245
      {
 
246
         k = n / j;
 
247
         if (k < 1) break;
 
248
         if (k*j < n) k++;
 
249
         h = ATL_Cachelen + ATL_MulBySize((k+1)*incK);
 
250
         if (h <= ATL_MaxMalloc) vB = malloc(h);
 
251
      }
 
252
      if (!vB) return(-1);
 
253
      n = k;
 
254
      m = ATL_MulByNB(n);
 
255
      ib2 = 0;
 
256
   }
 
257
   else
 
258
   {
 
259
      n = nMb;
 
260
      m = M;
 
261
      ib2 = ib;
 
262
   }
 
263
   pB = ATL_AlignPtr(vB);
 
264
   if (TA == AtlasNoTrans)
 
265
   {
 
266
      incA = m SHIFT;
 
267
      if (alpha[1] == ATL_rzero)
 
268
      {
 
269
         if (*alpha == ATL_rone) A2blk = Mjoin(PATL,row2blkT2_a1);
 
270
         else A2blk = Mjoin(PATL,row2blkT2_aXi0);
 
271
      }
 
272
      else A2blk = Mjoin(PATL,row2blkT2_aX);
 
273
   }
 
274
   else if (TA == AtlasConjTrans)
 
275
   {
 
276
      incA = lda*m SHIFT;
 
277
      if (alpha[1] == ATL_rzero)
 
278
      {
 
279
         if (*alpha == ATL_rone) A2blk = Mjoin(PATL,col2blkConj2_a1);
 
280
         else A2blk = Mjoin(PATL,col2blkConj2_aXi0);
 
281
      }
 
282
      else A2blk = Mjoin(PATL,col2blkConj2_aX);
 
283
   }
 
284
   else
 
285
   {
 
286
      incA = lda*m SHIFT;
 
287
      if (alpha[1] == ATL_rzero)
 
288
      {
 
289
         if (*alpha == ATL_rone) A2blk = Mjoin(PATL,col2blk2_a1);
 
290
         else A2blk = Mjoin(PATL,col2blk2_aXi0);
 
291
      }
 
292
      else A2blk = Mjoin(PATL,col2blk2_aX);
 
293
   }
 
294
   if (TB == AtlasNoTrans)
 
295
   {
 
296
      incB = ATL_MulByNB(ldb) SHIFT;
 
297
      B2blk = Mjoin(PATL,col2blk_a1);
 
298
   }
 
299
   else if (TB == AtlasConjTrans)
 
300
   {
 
301
      incB = NB2;
 
302
      B2blk = Mjoin(PATL,row2blkC_a1);
 
303
   }
 
304
   else
 
305
   {
 
306
      incB = NB2;
 
307
      B2blk = Mjoin(PATL,row2blkT_a1);
 
308
   }
 
309
   incC = m SHIFT;
 
310
 
 
311
   pA = pB + (incK SHIFT);
 
312
   do
 
313
   {
 
314
      if (TA == AtlasNoTrans) A2blk(m, K, A, lda, pA, alpha);
 
315
      else A2blk(K, m, A, lda, pA, alpha);
 
316
      Mjoin(PATL,mmJIK2)(K, n, nNb, nKb, ib2, jb, kb, alpha, pA, B, ldb, pB,
 
317
                         incB, B2blk, beta, C, ldc, gescal, NBmm0);
 
318
      M -= m;
 
319
      nMb -= n;
 
320
      if (M <= m)
 
321
      {
 
322
         ib2 = ib;
 
323
         m = M;
 
324
         n = nMb;
 
325
      }
 
326
      C += incC;
 
327
      A += incA;
 
328
   }
 
329
   while (M);
 
330
   free(vB);
 
331
   if (vC) free(vC);
 
332
   return(0);
 
333
}