~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to src/threads/blas/level3/ATL_threadMM.c

  • Committer: Package Import Robot
  • Author(s): Sébastien Villemot
  • Date: 2013-06-11 15:58:16 UTC
  • mfrom: (1.1.3 upstream)
  • mto: (2.2.21 experimental)
  • mto: This revision was merged to the branch mainline in revision 26.
  • Revision ID: package-import@ubuntu.com-20130611155816-b72z8f621tuhbzn0
Tags: upstream-3.10.1
Import upstream version 3.10.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include "atlas_misc.h"
 
2
#include "atlas_threads.h"
 
3
#include "atlas_tlvl3.h"
 
4
 
 
5
#ifdef DEBUG
 
6
#define T2c(ta_) ((ta_) == AtlasNoTrans) ? 'N' : 'T'
 
7
#endif
 
8
#ifndef ATL_TXOVER_H
 
9
int Mjoin(PATL,threadMM)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
10
                         size_t M, size_t N, size_t K)
 
11
/*
 
12
 * This dummy routine used when crossover is not tuned
 
13
 */
 
14
{
 
15
#if 0
 
16
   size_t minD, maxD;
 
17
 
 
18
   minD = Mmin(M,N);
 
19
   minD = Mmin(minD,K);
 
20
   maxD = Mmax(M,N);
 
21
   maxD = Mmax(maxD,K);
 
22
   if (M >= (NB<<(ATL_NTHRPOW2+2)))
 
23
      return(2);
 
24
   else if (minD >= 8 && maxD >= 2*NB)
 
25
      return(1);
 
26
   return(0);
 
27
#else
 
28
   int Mjoin(PATL,GemmWillThread)(ATL_CINT M, ATL_CINT N, ATL_CINT K);
 
29
   return(Mjoin(PATL,GemmWillThread)(M, N, K));
 
30
#endif
 
31
}
 
32
#else
 
33
int Mjoin(PATL,threadMM)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
34
                         size_t M, size_t N, size_t K)
 
35
/*
 
36
 * RETURNS: number of threads matmul should use to paralellize the problem
 
37
 */
 
38
{
 
39
   size_t i, j, smp2, bip2, xo, xom, D;
 
40
   const int *xop;
 
41
   int k;
 
42
   if (M < 256 && N < 256 && K < 256)   /* small matrix */
 
43
   {
 
44
/*
 
45
 *    For really small problems, table lookups too expensive, so do a quick
 
46
 *    return
 
47
 */
 
48
      j = Mmax(M,N);
 
49
      i = Mmin(M,N);
 
50
      i = Mmin(i,K);
 
51
      if (j <= NB+NB || i < NB)
 
52
         return(1);    /* quick return */
 
53
/*
 
54
 *    Make choice based on most restricted dimension
 
55
 */
 
56
      if (M < N && M < K)   /* M most restricted dim */
 
57
         goto SMALLM;
 
58
      else if (K < M && K < N)  /* K most restricted dim */
 
59
         goto SMALLK;
 
60
      else if (M == N && M == K)
 
61
         goto SQUARE;
 
62
      else  /* N is most restricted dim */
 
63
         goto SMALLN;
 
64
   }
 
65
/*
 
66
 * The following three shapes model recursive factorizations where
 
67
 * two dimensions are cut during the recursion, and a third remains large
 
68
 */
 
69
   else if (N <= 256 && K <= 256)  /* recursive shape that doesn't cut M */
 
70
   {                               /* LU uses this shape */
 
71
      i = Mmin(N, K);
 
72
      j = Mmax(N, K);
 
73
      if (i >= NB)
 
74
         i = (i+j)>>1;
 
75
      else if (i >= 8)
 
76
         i = (i+i+i+j)>>2;  /* 3/4 MIN, 1/4 MAX */
 
77
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
78
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
79
      i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
 
80
      for (j=0; j < 9; j++)
 
81
         if (i & (1<<j)) break;
 
82
      D = M;
 
83
      if (TA == AtlasNoTrans)
 
84
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SnkLm_XO : ATL_tmmNT_SnkLm_XO;
 
85
      else
 
86
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SnkLm_XO : ATL_tmmTT_SnkLm_XO;
 
87
      #ifdef DEBUG
 
88
         printf("sNKlM_%c%c, M=%d, N=%d, K=%d rD=%d, D=%d\n",
 
89
                T2c(TA), T2c(TB), M, N, K, j, D);
 
90
      #endif
 
91
   }
 
92
   else if (M <= 256 && N <= 256)  /* recursive shape that doesn't cut K */
 
93
   {                               /* QR uses, maybe in LARFT? */
 
94
      i = Mmin(M, N);
 
95
      j = Mmax(M, N);
 
96
      if (i >= NB)
 
97
         i = (i+j)>>1;
 
98
      else if (i >= 8)
 
99
         i = (i+i+i+j)>>2;  /* 3/4 MIN, 1/4 MAX */
 
100
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
101
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
102
      i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
 
103
      for (j=0; j < 9; j++)
 
104
         if (i & (1<<j)) break;
 
105
      D = K;
 
106
      if (TA == AtlasNoTrans)
 
107
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmnLk_XO : ATL_tmmNT_SmnLk_XO;
 
108
      else
 
109
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmnLk_XO : ATL_tmmTT_SmnLk_XO;
 
110
      #ifdef DEBUG
 
111
         printf("sMNlK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
 
112
                T2c(TA), T2c(TB), M, N, K, j, D);
 
113
      #endif
 
114
   }
 
115
   else if (M <= 256 && K <= 256) /* recursive shape that doesn't cut N */
 
116
   {                              /* UNCONFIRMED: QR variant uses */
 
117
      i = Mmin(M, K);
 
118
      j = Mmax(M, K);
 
119
      if (i >= NB)
 
120
         i = (i+j)>>1;
 
121
      else if (i >= 8)
 
122
         i = (i+i+i+j)>>2;  /* 3/4 MIN, 1/4 MAX */
 
123
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
124
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
125
      for (j=0; j < 9; j++)
 
126
         if (i & (1<<j)) break;
 
127
      D = N;
 
128
      if (TA == AtlasNoTrans)
 
129
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmkLn_XO : ATL_tmmNT_SmkLn_XO;
 
130
      else
 
131
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmkLn_XO : ATL_tmmTT_SmkLn_XO;
 
132
      #ifdef DEBUG
 
133
         printf("sNlMK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
 
134
                T2c(TA), T2c(TB), M, N, K, j, D);
 
135
      #endif
 
136
   }
 
137
/*
 
138
 * The three following shapes model static blocking, where two dimensions
 
139
 * are full, and the third is blocked
 
140
 */
 
141
   else if (K <= 256)           /* K dim small, as in right-looking LU/QR */
 
142
   {
 
143
SMALLK:
 
144
      D = Mmin(M,N);
 
145
      if (D >= NB+NB)
 
146
         D = (M+N)>>1;
 
147
      i = K;
 
148
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
149
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
150
      i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
 
151
      for (j=0; j < 9; j++)
 
152
         if (i & (1<<j)) break;
 
153
      if (TA == AtlasNoTrans)
 
154
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SkLmn_XO : ATL_tmmNT_SkLmn_XO;
 
155
      else
 
156
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SkLmn_XO : ATL_tmmTT_SkLmn_XO;
 
157
      #ifdef DEBUG
 
158
         printf("sKlMN_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
 
159
                 T2c(TA), T2c(TB), M, N, K, j, D);
 
160
      #endif
 
161
   }
 
162
   else if (M <= 256)          /* M dim small */
 
163
   {
 
164
SMALLM:
 
165
      D = Mmin(N,K);
 
166
      if (D >= NB+NB)
 
167
         D = (N+K)>>1;
 
168
      i = M;
 
169
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
170
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
171
      i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
 
172
      for (j=0; j < 9; j++)
 
173
         if (i & (1<<j)) break;
 
174
      if (TA == AtlasNoTrans)
 
175
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmLnk_XO : ATL_tmmNT_SmLnk_XO;
 
176
      else
 
177
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmLnk_XO : ATL_tmmTT_SmLnk_XO;
 
178
      #ifdef DEBUG
 
179
         printf("sMlNK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
 
180
                T2c(TA), T2c(TB), M, N, K, j, D);
 
181
      #endif
 
182
   }
 
183
   else if (N <= 256)          /* N dim small */
 
184
   {                           /* QR uses this */
 
185
SMALLN:
 
186
      D = Mmin(M,K);
 
187
      if (D >= NB+NB)
 
188
         D = (M+K)>>1;
 
189
      i = N;
 
190
      for (bip2=1; bip2 < i; bip2 <<= 1);
 
191
      smp2 = (bip2 == i) ? bip2 : (bip2>>1);
 
192
      i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
 
193
      for (j=0; j < 9; j++)
 
194
         if (i & (1<<j)) break;
 
195
      if (TA == AtlasNoTrans)
 
196
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SnLmk_XO : ATL_tmmNT_SnLmk_XO;
 
197
      else
 
198
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SnLmk_XO : ATL_tmmTT_SnLmk_XO;
 
199
      #ifdef DEBUG
 
200
         printf("sNlMK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
 
201
                T2c(TA), T2c(TB), M, N, K, j, D);
 
202
      #endif
 
203
   }
 
204
   else                        /* all dim > 256, call it square */
 
205
   {
 
206
SQUARE:   /* near-square shape, N <= 256 if jumped here */
 
207
      D = (M+N+K+1)/3;
 
208
      j = 0;
 
209
      if (TA == AtlasNoTrans)
 
210
         xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SQmnk_XO : ATL_tmmNT_SQmnk_XO;
 
211
      else
 
212
         xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SQmnk_XO : ATL_tmmTT_SQmnk_XO;
 
213
      #ifdef DEBUG
 
214
         printf("SQ_%c%c, M=%d, N=%d, K=%d, D=%d\n",
 
215
                T2c(TA), T2c(TB), M, N, K, D);
 
216
      #endif
 
217
   }
 
218
 
 
219
   xop += j*ATL_PDIM;
 
220
   for (k=ATL_PDIM-1; k >= 0; k--)
 
221
      if (xop[k] && D >= xop[k])
 
222
         return((k == ATL_PDIM-1) ? ATL_NTHREADS : (2<<k));
 
223
   return(1);
 
224
}
 
225
#endif