1
#include "atlas_misc.h"
2
#include "atlas_threads.h"
3
#include "atlas_tlvl3.h"
5
* prototype the typeless tGEMM helper routines
7
void ATL_DoWorkMM(ATL_LAUNCHSTRUCT_t *lp, void *vp);
8
int ATL_StructIsInitMM(void *vp);
9
void ATL_linearize_mmnodes(ATL_TMMNODE_t *ptmms, const int P);
10
int ATL_thrdecompMM_rMNK
11
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
12
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
13
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
14
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
16
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
17
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
18
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
19
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
21
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
22
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
23
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
24
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
26
(ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
27
ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
28
const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
29
const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
30
void Mjoin(PATL,InitTMMNodes)
31
(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, const TYPE *alpha,
32
const TYPE *beta, const TYPE *one, const TYPE *zero,
33
ATL_thread_t *btp, ATL_TMMNODE_t *ptmms);
34
void Mjoin(PATL,HandleNewCp)(ATL_TMMNODE_t *me, ATL_TMMNODE_t *him);
35
int Mjoin(PATL,CombineCw)(ATL_TMMNODE_t *me, ATL_TMMNODE_t *him);
36
void Mjoin(PATL,CombineStructsMM)(void *vp, const int myrank, const int herank);
37
int Mjoin(PATL,tgemm_rec)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
38
ATL_CINT M, ATL_CINT N, ATL_CINT K, const SCALAR alpha,
39
const TYPE *A, ATL_CINT lda, const TYPE *B, ATL_CINT ldb,
40
const SCALAR beta, TYPE *C, ATL_CINT ldc)
42
#ifdef ATL_SERIAL_COMBINE
43
ATL_combnode_t *combb=NULL, *combp;
45
ATL_TMMNODE_t mms[ATL_NTHREADS];
46
int i, np, DividedK=0;
48
TYPE ONE=ATL_rone, ZERO=ATL_rzero;
50
TYPE ONE[2] = {ATL_rone, ATL_rzero}, ZERO[2] = {ATL_rzero, ATL_rzero};
55
if (K < 1 || SCALAR_IS_ZERO(alpha))
57
if (!SCALAR_IS_ONE(beta))
58
Mjoin(PATL,gescal)(M, N, beta, C, ldc);
61
Mjoin(PATL,InitTMMNodes)(TA, TB, SADD alpha, SADD beta, SADD ONE,
62
SADD ZERO, NULL, mms);
63
np = ATL_thrdecompMM_rMNK(mms, TA, TB, M/MB, M%MB, N/NB, N%NB, K/KB, K%KB,
64
A, lda, B, ldb, C, ldc, ATL_NTHREADS, 0, 0);
65
if (np < ATL_NTHREADS)
66
ATL_linearize_mmnodes(mms, np);
68
fprintf(stderr, "np=%d\n\n", np);
72
Mjoin(PATL,gemm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
76
* If we are debugging, set up serial combine queue
78
#ifdef ATL_SERIAL_COMBINE
79
for (i=0; i < ATL_NTHREADS; i++)
81
if (mms[i].K) /* if this struct being used */
83
if (!mms[i].ownC) /* I need a workspace for C */
85
mms[i].Cw = calloc(mms[i].ldcw * mms[i].N, ATL_sizeof);
86
ATL_assert(mms[i].Cw);
87
combb = ATL_NewCombnode(mms[i].M, mms[i].N, mms[i].Cw,
88
mms[i].ldcw, mms[i].C, mms[i].ldc,
95
ATL_goparallel(np, ATL_DoWorkMM, mms,
96
DividedK ? Mjoin(PATL,CombineStructsMM) : NULL);
98
* If we are debugging, serially combine all workspaces back to original C
100
#ifdef ATL_SERIAL_COMBINE
103
Mjoin(PATL,geadd)(combb->M, combb->N, ONE, combb->W, combb->ldw,
104
ONE, combb->D, combb->ldd);