~ubuntu-branches/ubuntu/vivid/atlas/vivid

« back to all changes in this revision

Viewing changes to src/threads/blas/level3/ATL_tgemm_M.c

  • Committer: Package Import Robot
  • Author(s): Sébastien Villemot
  • Date: 2013-06-11 15:58:16 UTC
  • mfrom: (1.1.3 upstream)
  • mto: (2.2.21 experimental)
  • mto: This revision was merged to the branch mainline in revision 26.
  • Revision ID: package-import@ubuntu.com-20130611155816-b72z8f621tuhbzn0
Tags: upstream-3.10.1
Import upstream version 3.10.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include "atlas_misc.h"
 
2
#include "atlas_threads.h"
 
3
#include "atlas_tlvl3.h"
 
4
/*
 
5
 * prototype the typeless tGEMM helper routines
 
6
 */
 
7
void ATL_DoWorkMM(ATL_LAUNCHSTRUCT_t *lp, void *vp);
 
8
int ATL_StructIsInitMM(void *vp);
 
9
void ATL_linearize_mmnodes(ATL_TMMNODE_t *ptmms, const int P);
 
10
int ATL_thrdecompMM_rMNK
 
11
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
12
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
13
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
14
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
15
int ATL_thrdecompMM_K
 
16
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
17
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
18
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
19
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
20
int ATL_thrdecompMM_N
 
21
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
22
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
23
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
24
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
25
int ATL_thrdecompMM_M
 
26
   (ATL_TMMNODE_t *ptmms, const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
27
    ATL_CINT Mblks, const int mr, ATL_CINT Nblks, const int nr, ATL_CINT Kblks,
 
28
    const int kr, const void *A, ATL_INT lda, const void *B, const ATL_INT ldb,
 
29
    const void *C, ATL_CINT ldc, const int P, const int indx, const int COPYC);
 
30
void Mjoin(PATL,InitTMMNodes)
 
31
   (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, const TYPE *alpha,
 
32
    const TYPE *beta, const TYPE *one, const TYPE *zero,
 
33
    ATL_thread_t *btp, ATL_TMMNODE_t *ptmms);
 
34
int Mjoin(PATL,tgemm_M)(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 
35
                       ATL_CINT M, ATL_CINT N, ATL_CINT K, const SCALAR alpha,
 
36
                       const TYPE *A, ATL_CINT lda, const TYPE *B, ATL_CINT ldb,
 
37
                       const SCALAR beta, TYPE *C, ATL_CINT ldc)
 
38
{
 
39
   #ifdef ATL_SERIAL_COMBINE
 
40
      ATL_combnode_t *combb=NULL, *combp;
 
41
   #endif
 
42
   ATL_TMMNODE_t mms[ATL_NTHREADS];
 
43
   int i, np, DividedK=0;
 
44
   #ifdef TREAL
 
45
      TYPE ONE=ATL_rone, ZERO=ATL_rzero;
 
46
   #else
 
47
      TYPE ONE[2] = {ATL_rone, ATL_rzero}, ZERO[2] = {ATL_rzero, ATL_rzero};
 
48
   #endif
 
49
 
 
50
   if (M < 1 || N < 1)
 
51
      return(0);
 
52
   if (K < 1 || SCALAR_IS_ZERO(alpha))
 
53
   {
 
54
      if (!SCALAR_IS_ONE(beta))
 
55
         Mjoin(PATL,gescal)(M, N, beta, C, ldc);
 
56
      return(0);
 
57
   }
 
58
   Mjoin(PATL,InitTMMNodes)(TA, TB, SADD alpha, SADD beta, SADD ONE,
 
59
                            SADD ZERO, NULL, mms);
 
60
   np = ATL_thrdecompMM_M(mms, TA, TB, M/MB, M%MB, N/NB, N%NB, K/KB, K%KB,
 
61
                          A, lda, B, ldb, C, ldc, ATL_NTHREADS, 0, 0);
 
62
   if (np < ATL_NTHREADS)
 
63
      ATL_linearize_mmnodes(mms, np);
 
64
#ifdef DEBUG
 
65
fprintf(stderr, "np=%d\n\n", np);
 
66
#endif
 
67
   if (np < 2)
 
68
   {
 
69
      Mjoin(PATL,gemm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
 
70
      return(1);
 
71
   }
 
72
/*
 
73
 * If we are debugging, set up serial combine queue
 
74
 */
 
75
   #ifdef ATL_SERIAL_COMBINE
 
76
      for (i=0; i < ATL_NTHREADS; i++)
 
77
      {
 
78
         if (mms[i].K)   /* if this struct being used */
 
79
         {
 
80
            if (!mms[i].ownC)   /* I need a workspace for C */
 
81
            {
 
82
               mms[i].Cw = calloc(mms[i].ldcw * mms[i].N, ATL_sizeof);
 
83
               ATL_assert(mms[i].Cw);
 
84
               combb = ATL_NewCombnode(mms[i].M, mms[i].N, mms[i].Cw,
 
85
                                       mms[i].ldcw, mms[i].C, mms[i].ldc,
 
86
                                       combb);
 
87
            }
 
88
         }
 
89
      }
 
90
   #endif
 
91
 
 
92
   ATL_goparallel(np, ATL_DoWorkMM, mms, NULL);
 
93
/*
 
94
 * If we are debugging, serially combine all workspaces back to original C
 
95
 */
 
96
   #ifdef ATL_SERIAL_COMBINE
 
97
      while(combb)
 
98
      {
 
99
         Mjoin(PATL,geadd)(combb->M, combb->N, ONE, combb->W, combb->ldw,
 
100
                           ONE, combb->D, combb->ldd);
 
101
         free(combb->W);
 
102
         combp = combb;
 
103
         combb = combb->next;
 
104
         free(combp);
 
105
      }
 
106
   #endif
 
107
   return(np);
 
108
}