~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to src/threads/blas/level3/ATL_tgemm_rec.c

  • Committer: Package Import Robot
  • Author(s): Sébastien Villemot
  • Date: 2013-06-11 15:58:16 UTC
  • mfrom: (1.1.3 upstream)
  • mto: (2.2.21 experimental)
  • mto: This revision was merged to the branch mainline in revision 26.
  • Revision ID: package-import@ubuntu.com-20130611155816-b72z8f621tuhbzn0
Tags: upstream-3.10.1
Import upstream version 3.10.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include "atlas_misc.h"
 
2
#include "atlas_threads.h"
 
3
#include "atlas_tlvl3.h"
 
4
/*
 
5
 * prototype the typeless tGEMM helper routines
 
6
 */
 
7
void ATL_DoWorkMM(ATL_LAUNCHSTRUCT_t *lp, void *vp);
 
8
int ATL_StructIsInitMM(void *vp);
 
9
void ATL_linearize_mmnodes(ATL_TMMNODE_t *ptmms, const int P);
 
10
int ATL_thrdecompMM_rMNK
 
11
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
12
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
13
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
14
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
15
int ATL_thrdecompMM_K
 
16
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
17
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
18
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
19
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
20
int ATL_thrdecompMM_N
 
21
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
22
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
23
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
24
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
25
int ATL_thrdecompMM_M
 
26
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
27
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
28
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
29
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
30
void Mjoin(PATL,InitTMMNodes)
 
31
   (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, const TYPE *alpha,
 
32
    const TYPE *beta, const TYPE *one, const TYPE *zero,
 
33
    ATL_thread_t *btp, ATL_TMMNODE_t *ptmms);
 
34
void Mjoin(PATL,HandleNewCp)(ATL_TMMNODE_t *me, ATL_TMMNODE_t *him);
 
35
int Mjoin(PATL,CombineCw)(ATL_TMMNODE_t *me, ATL_TMMNODE_t *him);
 
36
void Mjoin(PATL,CombineStructsMM)(void *vp, const int myrank, const int herank);
 
37
int Mjoin(PATL,tgemm_rec)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
38
                       ATL_CINT M, ATL_CINT N, ATL_CINT K, const SCALAR alpha,
 
39
                       const TYPE *A, ATL_CINT lda, const TYPE *B, ATL_CINT ldb,
 
40
                       const SCALAR beta, TYPE *C, ATL_CINT ldc)
 
41
{
 
42
   #ifdef ATL_SERIAL_COMBINE
 
43
      ATL_combnode_t *combb=NULL, *combp;
 
44
   #endif
 
45
   ATL_TMMNODE_t mms[ATL_NTHREADS];
 
46
   int i, np, DividedK=0;
 
47
   #ifdef TREAL
 
48
      TYPE ONE=ATL_rone, ZERO=ATL_rzero;
 
49
   #else
 
50
      TYPE ONE[2] = {ATL_rone, ATL_rzero}, ZERO[2] = {ATL_rzero, ATL_rzero};
 
51
   #endif
 
52
 
 
53
   if (M < 1 || N < 1)
 
54
      return(0);
 
55
   if (K < 1 || SCALAR_IS_ZERO(alpha))
 
56
   {
 
57
      if (!SCALAR_IS_ONE(beta))
 
58
         Mjoin(PATL,gescal)(M, N, beta, C, ldc);
 
59
      return(0);
 
60
   }
 
61
   Mjoin(PATL,InitTMMNodes)(TA, TB, SADD alpha, SADD beta, SADD ONE,
 
62
                            SADD ZERO, NULL, mms);
 
63
   np = ATL_thrdecompMM_rMNK(mms, TA, TB, M/MB, M%MB, N/NB, N%NB, K/KB, K%KB,
 
64
                             A, lda, B, ldb, C, ldc, ATL_NTHREADS, 0, 0);
 
65
   if (np < ATL_NTHREADS)
 
66
      ATL_linearize_mmnodes(mms, np);
 
67
#ifdef DEBUG
 
68
fprintf(stderr, "np=%d\n\n", np);
 
69
#endif
 
70
   if (np < 2)
 
71
   {
 
72
      Mjoin(PATL,gemm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
 
73
      return(1);
 
74
   }
 
75
/*
 
76
 * If we are debugging, set up serial combine queue
 
77
 */
 
78
   #ifdef ATL_SERIAL_COMBINE
 
79
      for (i=0; i < ATL_NTHREADS; i++)
 
80
      {
 
81
         if (mms[i].K)   /* if this struct being used */
 
82
         {
 
83
            if (!mms[i].ownC)   /* I need a workspace for C */
 
84
            {
 
85
               mms[i].Cw = calloc(mms[i].ldcw * mms[i].N, ATL_sizeof);
 
86
               ATL_assert(mms[i].Cw);
 
87
               combb = ATL_NewCombnode(mms[i].M, mms[i].N, mms[i].Cw,
 
88
                                       mms[i].ldcw, mms[i].C, mms[i].ldc,
 
89
                                       combb);
 
90
            }
 
91
         }
 
92
      }
 
93
   #endif
 
94
 
 
95
   ATL_goparallel(np, ATL_DoWorkMM, mms,
 
96
                  DividedK ? Mjoin(PATL,CombineStructsMM) : NULL);
 
97
/*
 
98
 * If we are debugging, serially combine all workspaces back to original C
 
99
 */
 
100
   #ifdef ATL_SERIAL_COMBINE
 
101
      while(combb)
 
102
      {
 
103
         Mjoin(PATL,geadd)(combb->M, combb->N, ONE, combb->W, combb->ldw,
 
104
                           ONE, combb->D, combb->ldd);
 
105
         free(combb->W);
 
106
         combp = combb;
 
107
         combb = combb->next;
 
108
         free(combp);
 
109
      }
 
110
   #endif
 
111
   return(np);
 
112
}