1
#include "atlas_misc.h"
2
#include "atlas_threads.h"
3
#include "atlas_tlvl3.h"
5
* prototype the typeless tGEMM helper routines
7
void ATL_DoWorkMM(ATL_LAUNCHSTRUCT_t *lp, void *vp);
8
int ATL_StructIsInitMM(void *vp);
9
void ATL_linearize_mmnodes(ATL_TMMNODE_t *ptmms, const int P);
10
int ATL_thrdecompMM_rMNK
11
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
12
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
13
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
14
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
16
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
17
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
18
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
19
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
21
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
22
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
23
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
24
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
26
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
27
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
28
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
29
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
30
void Mjoin(PATL,InitTMMNodes)
31
(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, const TYPE *alpha,
32
const TYPE *beta, const TYPE *one, const TYPE *zero,
33
ATL_thread_t *btp, ATL_TMMNODE_t *ptmms);
34
int Mjoin(PATL,tgemm_M)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
35
ATL_CINT M, ATL_CINT N, ATL_CINT K, const SCALAR alpha,
36
const TYPE *A, ATL_CINT lda, const TYPE *B, ATL_CINT ldb,
37
const SCALAR beta, TYPE *C, ATL_CINT ldc)
39
#ifdef ATL_SERIAL_COMBINE
40
ATL_combnode_t *combb=NULL, *combp;
42
ATL_TMMNODE_t mms[ATL_NTHREADS];
43
int i, np, DividedK=0;
45
TYPE ONE=ATL_rone, ZERO=ATL_rzero;
47
TYPE ONE[2] = {ATL_rone, ATL_rzero}, ZERO[2] = {ATL_rzero, ATL_rzero};
52
if (K < 1 || SCALAR_IS_ZERO(alpha))
54
if (!SCALAR_IS_ONE(beta))
55
Mjoin(PATL,gescal)(M, N, beta, C, ldc);
58
Mjoin(PATL,InitTMMNodes)(TA, TB, SADD alpha, SADD beta, SADD ONE,
59
SADD ZERO, NULL, mms);
60
np = ATL_thrdecompMM_M(mms, TA, TB, M/MB, M%MB, N/NB, N%NB, K/KB, K%KB,
61
A, lda, B, ldb, C, ldc, ATL_NTHREADS, 0, 0);
62
if (np < ATL_NTHREADS)
63
ATL_linearize_mmnodes(mms, np);
65
fprintf(stderr, "np=%d\n\n", np);
69
Mjoin(PATL,gemm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
73
* If we are debugging, set up serial combine queue
75
#ifdef ATL_SERIAL_COMBINE
76
for (i=0; i < ATL_NTHREADS; i++)
78
if (mms[i].K) /* if this struct being used */
80
if (!mms[i].ownC) /* I need a workspace for C */
82
mms[i].Cw = calloc(mms[i].ldcw * mms[i].N, ATL_sizeof);
83
ATL_assert(mms[i].Cw);
84
combb = ATL_NewCombnode(mms[i].M, mms[i].N, mms[i].Cw,
85
mms[i].ldcw, mms[i].C, mms[i].ldc,
92
ATL_goparallel(np, ATL_DoWorkMM, mms, NULL);
94
* If we are debugging, serially combine all workspaces back to original C
96
#ifdef ATL_SERIAL_COMBINE
99
Mjoin(PATL,geadd)(combb->M, combb->N, ONE, combb->W, combb->ldw,
100
ONE, combb->D, combb->ldd);