1
#include "atlas_misc.h"
2
#include "atlas_threads.h"
3
#include "atlas_tlvl3.h"
6
#define T2c(ta_) ((ta_) == AtlasNoTrans) ? 'N' : 'T'
9
int Mjoin(PATL,threadMM)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
10
size_t M, size_t N, size_t K)
12
* This dummy routine used when crossover is not tuned
22
if (M >= (NB<<(ATL_NTHRPOW2+2)))
24
else if (minD >= 8 && maxD >= 2*NB)
28
int Mjoin(PATL,GemmWillThread)(ATL_CINT M, ATL_CINT N, ATL_CINT K);
29
return(Mjoin(PATL,GemmWillThread)(M, N, K));
33
int Mjoin(PATL,threadMM)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
34
size_t M, size_t N, size_t K)
36
* RETURNS: number of threads matmul should use to paralellize the problem
39
size_t i, j, smp2, bip2, xo, xom, D;
42
if (M < 256 && N < 256 && K < 256) /* small matrix */
45
* For really small problems, table lookups too expensive, so do a quick
51
if (j <= NB+NB || i < NB)
52
return(1); /* quick return */
54
* Make choice based on most restricted dimension
56
if (M < N && M < K) /* M most restricted dim */
58
else if (K < M && K < N) /* K most restricted dim */
60
else if (M == N && M == K)
62
else /* N is most restricted dim */
66
* The following three shapes model recursive factorizations where
67
* two dimensions are cut during the recursion, and a third remains large
69
else if (N <= 256 && K <= 256) /* recursive shape that doesn't cut M */
70
{ /* LU uses this shape */
76
i = (i+i+i+j)>>2; /* 3/4 MIN, 1/4 MAX */
77
for (bip2=1; bip2 < i; bip2 <<= 1);
78
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
79
i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
81
if (i & (1<<j)) break;
83
if (TA == AtlasNoTrans)
84
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SnkLm_XO : ATL_tmmNT_SnkLm_XO;
86
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SnkLm_XO : ATL_tmmTT_SnkLm_XO;
88
printf("sNKlM_%c%c, M=%d, N=%d, K=%d rD=%d, D=%d\n",
89
T2c(TA), T2c(TB), M, N, K, j, D);
92
else if (M <= 256 && N <= 256) /* recursive shape that doesn't cut K */
93
{ /* QR uses, maybe in LARFT? */
99
i = (i+i+i+j)>>2; /* 3/4 MIN, 1/4 MAX */
100
for (bip2=1; bip2 < i; bip2 <<= 1);
101
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
102
i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
103
for (j=0; j < 9; j++)
104
if (i & (1<<j)) break;
106
if (TA == AtlasNoTrans)
107
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmnLk_XO : ATL_tmmNT_SmnLk_XO;
109
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmnLk_XO : ATL_tmmTT_SmnLk_XO;
111
printf("sMNlK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
112
T2c(TA), T2c(TB), M, N, K, j, D);
115
else if (M <= 256 && K <= 256) /* recursive shape that doesn't cut N */
116
{ /* UNCONFIRMED: QR variant uses */
122
i = (i+i+i+j)>>2; /* 3/4 MIN, 1/4 MAX */
123
for (bip2=1; bip2 < i; bip2 <<= 1);
124
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
125
for (j=0; j < 9; j++)
126
if (i & (1<<j)) break;
128
if (TA == AtlasNoTrans)
129
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmkLn_XO : ATL_tmmNT_SmkLn_XO;
131
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmkLn_XO : ATL_tmmTT_SmkLn_XO;
133
printf("sNlMK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
134
T2c(TA), T2c(TB), M, N, K, j, D);
138
* The three following shapes model static blocking, where two dimensions
139
* are full, and the third is blocked
141
else if (K <= 256) /* K dim small, as in right-looking LU/QR */
148
for (bip2=1; bip2 < i; bip2 <<= 1);
149
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
150
i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
151
for (j=0; j < 9; j++)
152
if (i & (1<<j)) break;
153
if (TA == AtlasNoTrans)
154
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SkLmn_XO : ATL_tmmNT_SkLmn_XO;
156
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SkLmn_XO : ATL_tmmTT_SkLmn_XO;
158
printf("sKlMN_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
159
T2c(TA), T2c(TB), M, N, K, j, D);
162
else if (M <= 256) /* M dim small */
169
for (bip2=1; bip2 < i; bip2 <<= 1);
170
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
171
i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
172
for (j=0; j < 9; j++)
173
if (i & (1<<j)) break;
174
if (TA == AtlasNoTrans)
175
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SmLnk_XO : ATL_tmmNT_SmLnk_XO;
177
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SmLnk_XO : ATL_tmmTT_SmLnk_XO;
179
printf("sMlNK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
180
T2c(TA), T2c(TB), M, N, K, j, D);
183
else if (N <= 256) /* N dim small */
190
for (bip2=1; bip2 < i; bip2 <<= 1);
191
smp2 = (bip2 == i) ? bip2 : (bip2>>1);
192
i = (bip2-i < i-smp2 && i > 16) ? bip2 : smp2;
193
for (j=0; j < 9; j++)
194
if (i & (1<<j)) break;
195
if (TA == AtlasNoTrans)
196
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SnLmk_XO : ATL_tmmNT_SnLmk_XO;
198
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SnLmk_XO : ATL_tmmTT_SnLmk_XO;
200
printf("sNlMK_%c%c, M=%d, N=%d, K=%d, rD=%d, D=%d\n",
201
T2c(TA), T2c(TB), M, N, K, j, D);
204
else /* all dim > 256, call it square */
206
SQUARE: /* near-square shape, N <= 256 if jumped here */
209
if (TA == AtlasNoTrans)
210
xop = (TB == AtlasNoTrans) ? ATL_tmmNN_SQmnk_XO : ATL_tmmNT_SQmnk_XO;
212
xop = (TB == AtlasNoTrans) ? ATL_tmmTN_SQmnk_XO : ATL_tmmTT_SQmnk_XO;
214
printf("SQ_%c%c, M=%d, N=%d, K=%d, D=%d\n",
215
T2c(TA), T2c(TB), M, N, K, D);
220
for (k=ATL_PDIM-1; k >= 0; k--)
221
if (xop[k] && D >= xop[k])
222
return((k == ATL_PDIM-1) ? ATL_NTHREADS : (2<<k));