~ubuntu-branches/ubuntu/trusty/r-cran-rcpparmadillo/trusty-proposed

« back to all changes in this revision

Viewing changes to inst/include/armadillo_bits/mul_syrk.hpp

  • Committer: Package Import Robot
  • Author(s): Dirk Eddelbuettel
  • Date: 2013-08-12 19:10:20 UTC
  • mfrom: (1.1.6)
  • Revision ID: package-import@ubuntu.com-20130812191020-4i0swxrz8v6i503v
Tags: 0.3.910.0-1
New upstream release

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// Copyright (C) 2013 Conrad Sanderson
 
2
// Copyright (C) 2013 NICTA (www.nicta.com.au)
 
3
// 
 
4
// This Source Code Form is subject to the terms of the Mozilla Public
 
5
// License, v. 2.0. If a copy of the MPL was not distributed with this
 
6
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
7
 
 
8
 
 
9
//! \addtogroup syrk
 
10
//! @{
 
11
 
 
12
 
 
13
 
 
14
class syrk_helper
 
15
  {
 
16
  public:
 
17
  
 
18
  template<typename eT>
 
19
  inline
 
20
  static
 
21
  void
 
22
  inplace_copy_upper_tri_to_lower_tri(Mat<eT>& C)
 
23
    {
 
24
    // under the assumption that C is a square matrix
 
25
    
 
26
    const uword N = C.n_rows;
 
27
    
 
28
    for(uword k=0; k < N; ++k)
 
29
      {
 
30
      eT* colmem = C.colptr(k);
 
31
      
 
32
      uword i, j;
 
33
      for(i=(k+1), j=(k+2); j < N; i+=2, j+=2)
 
34
        {
 
35
        const eT tmp_i = C.at(k,i);
 
36
        const eT tmp_j = C.at(k,j);
 
37
        
 
38
        colmem[i] = tmp_i;
 
39
        colmem[j] = tmp_j;
 
40
        }
 
41
      
 
42
      if(i < N)
 
43
        {
 
44
        colmem[i] = C.at(k,i);
 
45
        }
 
46
      }
 
47
    }
 
48
  };
 
49
 
 
50
 
 
51
 
 
52
//! partial emulation of BLAS function syrk(), specialised for A being a vector
 
53
template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
 
54
class syrk_vec
 
55
  {
 
56
  public:
 
57
  
 
58
  template<typename eT, typename TA>
 
59
  arma_hot
 
60
  inline
 
61
  static
 
62
  void
 
63
  apply
 
64
    (
 
65
          Mat<eT>& C,
 
66
    const TA&      A,
 
67
    const eT       alpha = eT(1),
 
68
    const eT       beta  = eT(0)
 
69
    )
 
70
    {
 
71
    arma_extra_debug_sigprint();
 
72
    
 
73
    const uword A_n1 = (do_trans_A == false) ? A.n_rows : A.n_cols;
 
74
    const uword A_n2 = (do_trans_A == false) ? A.n_cols : A.n_rows;
 
75
    
 
76
    const eT* A_mem = A.memptr();
 
77
    
 
78
    if(A_n1 == 1)
 
79
      {
 
80
      const eT acc1 = op_dot::direct_dot(A_n2, A_mem, A_mem);
 
81
      
 
82
           if( (use_alpha == false) && (use_beta == false) )  { C[0] =       acc1;             }
 
83
      else if( (use_alpha == true ) && (use_beta == false) )  { C[0] = alpha*acc1;             }
 
84
      else if( (use_alpha == false) && (use_beta == true ) )  { C[0] =       acc1 + beta*C[0]; }
 
85
      else if( (use_alpha == true ) && (use_beta == true ) )  { C[0] = alpha*acc1 + beta*C[0]; }
 
86
      }
 
87
    else
 
88
    for(uword k=0; k < A_n1; ++k)
 
89
      {
 
90
      const eT A_k = A_mem[k];
 
91
      
 
92
      uword i,j;
 
93
      for(i=(k), j=(k+1); j < A_n1; i+=2, j+=2)
 
94
        {
 
95
        const eT acc1 = A_k * A_mem[i];
 
96
        const eT acc2 = A_k * A_mem[j];
 
97
        
 
98
        if( (use_alpha == false) && (use_beta == false) )
 
99
          {
 
100
          C.at(k, i) = acc1;
 
101
          C.at(k, j) = acc2;
 
102
          
 
103
          C.at(i, k) = acc1;
 
104
          C.at(j, k) = acc2;
 
105
          }
 
106
        else
 
107
        if( (use_alpha == true ) && (use_beta == false) )
 
108
          {
 
109
          const eT val1 = alpha*acc1;
 
110
          const eT val2 = alpha*acc2;
 
111
          
 
112
          C.at(k, i) = val1;
 
113
          C.at(k, j) = val2;
 
114
          
 
115
          C.at(i, k) = val1;
 
116
          C.at(j, k) = val2;
 
117
          }
 
118
        else
 
119
        if( (use_alpha == false) && (use_beta == true) )
 
120
          {
 
121
          C.at(k, i) = acc1 + beta*C.at(k, i);
 
122
          C.at(k, j) = acc2 + beta*C.at(k, j);
 
123
          
 
124
          if(i != k) { C.at(i, k) = acc1 + beta*C.at(i, k); }
 
125
                       C.at(j, k) = acc2 + beta*C.at(j, k);
 
126
          }
 
127
        else
 
128
        if( (use_alpha == true ) && (use_beta == true) )
 
129
          {
 
130
          const eT val1 = alpha*acc1;
 
131
          const eT val2 = alpha*acc2;
 
132
          
 
133
          C.at(k, i) = val1 + beta*C.at(k, i);
 
134
          C.at(k, j) = val2 + beta*C.at(k, j);
 
135
          
 
136
          if(i != k)  { C.at(i, k) = val1 + beta*C.at(i, k); }
 
137
                        C.at(j, k) = val2 + beta*C.at(j, k);
 
138
          }
 
139
        }
 
140
      
 
141
      if(i < A_n1)
 
142
        {
 
143
        const eT acc1 = A_k * A_mem[i];
 
144
        
 
145
        if( (use_alpha == false) && (use_beta == false) )
 
146
          {
 
147
          C.at(k, i) = acc1;
 
148
          C.at(i, k) = acc1;
 
149
          }
 
150
        else
 
151
        if( (use_alpha == true) && (use_beta == false) )
 
152
          {
 
153
          const eT val1 = alpha*acc1;
 
154
          
 
155
          C.at(k, i) = val1;
 
156
          C.at(i, k) = val1;
 
157
          }
 
158
        else
 
159
        if( (use_alpha == false) && (use_beta == true) )
 
160
          {
 
161
                        C.at(k, i) = acc1 + beta*C.at(k, i);
 
162
          if(i != k)  { C.at(i, k) = acc1 + beta*C.at(i, k); }
 
163
          }
 
164
        else
 
165
        if( (use_alpha == true) && (use_beta == true) )
 
166
          {
 
167
          const eT val1 = alpha*acc1;
 
168
          
 
169
                        C.at(k, i) = val1 + beta*C.at(k, i);
 
170
          if(i != k)  { C.at(i, k) = val1 + beta*C.at(i, k); }
 
171
          }
 
172
        }
 
173
      }
 
174
    }
 
175
  
 
176
  };
 
177
 
 
178
 
 
179
 
 
180
//! partial emulation of BLAS function syrk()
 
181
template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
 
182
class syrk_emul
 
183
  {
 
184
  public:
 
185
  
 
186
  template<typename eT, typename TA>
 
187
  arma_hot
 
188
  inline
 
189
  static
 
190
  void
 
191
  apply
 
192
    (
 
193
          Mat<eT>& C,
 
194
    const TA&      A,
 
195
    const eT       alpha = eT(1),
 
196
    const eT       beta  = eT(0)
 
197
    )
 
198
    {
 
199
    arma_extra_debug_sigprint();
 
200
    
 
201
    // do_trans_A == false  ->   C = alpha * A   * A^T + beta*C
 
202
    // do_trans_A == true   ->   C = alpha * A^T * A   + beta*C
 
203
    
 
204
    if(do_trans_A == false)
 
205
      {
 
206
      Mat<eT> AA;
 
207
      
 
208
      op_strans::apply_noalias(AA, A);
 
209
      
 
210
      syrk_emul<true, use_alpha, use_beta>::apply(C, AA, alpha, beta);
 
211
      }
 
212
    else
 
213
    if(do_trans_A == true)
 
214
      {
 
215
      const uword A_n_rows = A.n_rows;
 
216
      const uword A_n_cols = A.n_cols;
 
217
      
 
218
      for(uword col_A=0; col_A < A_n_cols; ++col_A)
 
219
        {
 
220
        // col_A is interpreted as row_A when storing the results in matrix C
 
221
        
 
222
        const eT* A_coldata = A.colptr(col_A);
 
223
        
 
224
        for(uword k=col_A; k < A_n_cols; ++k)
 
225
          {
 
226
          const eT acc = op_dot::direct_dot_arma(A_n_rows, A_coldata, A.colptr(k));
 
227
          
 
228
          if( (use_alpha == false) && (use_beta == false) )
 
229
            {
 
230
            C.at(col_A, k) = acc;
 
231
            C.at(k, col_A) = acc;
 
232
            }
 
233
          else
 
234
          if( (use_alpha == true ) && (use_beta == false) )
 
235
            {
 
236
            const eT val = alpha*acc;
 
237
            
 
238
            C.at(col_A, k) = val;
 
239
            C.at(k, col_A) = val;
 
240
            }
 
241
          else
 
242
          if( (use_alpha == false) && (use_beta == true ) )
 
243
            {
 
244
                              C.at(col_A, k) = acc + beta*C.at(col_A, k);
 
245
            if(col_A != k)  { C.at(k, col_A) = acc + beta*C.at(k, col_A); }
 
246
            }
 
247
          else
 
248
          if( (use_alpha == true ) && (use_beta == true ) )
 
249
            {
 
250
            const eT val = alpha*acc;
 
251
            
 
252
                              C.at(col_A, k) = val + beta*C.at(col_A, k);
 
253
            if(col_A != k)  { C.at(k, col_A) = val + beta*C.at(k, col_A); }
 
254
            }
 
255
          }
 
256
        }
 
257
      }
 
258
    }
 
259
  
 
260
  };
 
261
 
 
262
 
 
263
 
 
264
template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
 
265
class syrk
 
266
  {
 
267
  public:
 
268
  
 
269
  template<typename eT, typename TA>
 
270
  inline
 
271
  static
 
272
  void
 
273
  apply_blas_type( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) )
 
274
    {
 
275
    arma_extra_debug_sigprint();
 
276
    
 
277
    if(A.is_vec())
 
278
      {
 
279
      // work around poor handling of vectors by syrk() in ATLAS 3.8.4 and standard BLAS
 
280
      
 
281
      syrk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
 
282
      
 
283
      return;
 
284
      }
 
285
    
 
286
    const uword threshold = (is_cx<eT>::yes ? 16u : 48u);
 
287
    
 
288
    if( A.n_elem <= threshold )
 
289
      {
 
290
      syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
 
291
      }
 
292
    else
 
293
      {
 
294
      #if defined(ARMA_USE_ATLAS)
 
295
        {
 
296
        if(use_beta == true)
 
297
          {
 
298
          // use a temporary matrix, as we can't assume that matrix C is already symmetric
 
299
          Mat<eT> D(C.n_rows, C.n_cols);
 
300
          
 
301
          syrk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
 
302
          
 
303
          // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
 
304
          arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
 
305
          
 
306
          return;
 
307
          }
 
308
        
 
309
        atlas::cblas_syrk<eT>
 
310
          (
 
311
          atlas::CblasColMajor,
 
312
          atlas::CblasUpper,
 
313
          (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
 
314
          C.n_cols,
 
315
          (do_trans_A) ? A.n_rows : A.n_cols,
 
316
          (use_alpha) ? alpha : eT(1),
 
317
          A.mem,
 
318
          (do_trans_A) ? A.n_rows : C.n_cols,
 
319
          (use_beta) ? beta : eT(0),
 
320
          C.memptr(),
 
321
          C.n_cols
 
322
          );
 
323
        
 
324
        syrk_helper::inplace_copy_upper_tri_to_lower_tri(C);
 
325
        }
 
326
      #elif defined(ARMA_USE_BLAS)
 
327
        {
 
328
        if(use_beta == true)
 
329
          {
 
330
          // use a temporary matrix, as we can't assume that matrix C is already symmetric
 
331
          Mat<eT> D(C.n_rows, C.n_cols);
 
332
          
 
333
          syrk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha);
 
334
          
 
335
          // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1
 
336
          arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem);
 
337
          
 
338
          return;
 
339
          }
 
340
        
 
341
        arma_extra_debug_print("blas::syrk()");
 
342
        
 
343
        const char uplo = 'U';
 
344
        
 
345
        const char trans_A = (do_trans_A) ? 'T' : 'N';
 
346
        
 
347
        const blas_int n = C.n_cols;
 
348
        const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
 
349
        
 
350
        const eT local_alpha = (use_alpha) ? alpha : eT(1);
 
351
        const eT local_beta  = (use_beta)  ? beta  : eT(0);
 
352
        
 
353
        const blas_int lda = (do_trans_A) ? k : n;
 
354
        
 
355
        arma_extra_debug_print( arma_boost::format("blas::syrk(): trans_A = %c") % trans_A );
 
356
        
 
357
        blas::syrk<eT>
 
358
          (
 
359
          &uplo,
 
360
          &trans_A,
 
361
          &n,
 
362
          &k,
 
363
          &local_alpha,
 
364
          A.mem,
 
365
          &lda,
 
366
          &local_beta,
 
367
          C.memptr(),
 
368
          &n // &ldc
 
369
          );
 
370
        
 
371
        syrk_helper::inplace_copy_upper_tri_to_lower_tri(C);
 
372
        }
 
373
      #else
 
374
        {
 
375
        syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
 
376
        }
 
377
      #endif
 
378
      }
 
379
    }
 
380
  
 
381
  
 
382
  
 
383
  template<typename eT, typename TA>
 
384
  inline
 
385
  static
 
386
  void
 
387
  apply( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0) )
 
388
    {
 
389
    if(is_cx<eT>::no)
 
390
      {
 
391
      if(A.is_vec())
 
392
        {
 
393
        syrk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
 
394
        }
 
395
      else
 
396
        {
 
397
        syrk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta);
 
398
        }
 
399
      }
 
400
    else
 
401
      {
 
402
      // handling of complex matrix by syrk_emul() is not yet implemented
 
403
      return;
 
404
      }
 
405
    }
 
406
  
 
407
  
 
408
  
 
409
  template<typename TA>
 
410
  arma_inline
 
411
  static
 
412
  void
 
413
  apply
 
414
    (
 
415
          Mat<float>& C,
 
416
    const TA&         A,
 
417
    const float alpha = float(1),
 
418
    const float beta  = float(0)
 
419
    )
 
420
    {
 
421
    syrk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
 
422
    }
 
423
  
 
424
  
 
425
  
 
426
  template<typename TA>
 
427
  arma_inline
 
428
  static
 
429
  void
 
430
  apply
 
431
    (
 
432
          Mat<double>& C,
 
433
    const TA&          A,
 
434
    const double alpha = double(1),
 
435
    const double beta  = double(0)
 
436
    )
 
437
    {
 
438
    syrk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta);
 
439
    }
 
440
  
 
441
  
 
442
  
 
443
  template<typename TA>
 
444
  arma_inline
 
445
  static
 
446
  void
 
447
  apply
 
448
    (
 
449
          Mat< std::complex<float> >& C,
 
450
    const TA&                         A,
 
451
    const std::complex<float> alpha = std::complex<float>(1),
 
452
    const std::complex<float> beta  = std::complex<float>(0)
 
453
    )
 
454
    {
 
455
    arma_ignore(C);
 
456
    arma_ignore(A);
 
457
    arma_ignore(alpha);
 
458
    arma_ignore(beta);
 
459
    
 
460
    // handling of complex matrix by syrk() is not yet implemented
 
461
    return;
 
462
    }
 
463
  
 
464
  
 
465
  
 
466
  template<typename TA>
 
467
  arma_inline
 
468
  static
 
469
  void
 
470
  apply
 
471
    (
 
472
          Mat< std::complex<double> >& C,
 
473
    const TA&                          A,
 
474
    const std::complex<double> alpha = std::complex<double>(1),
 
475
    const std::complex<double> beta  = std::complex<double>(0)
 
476
    )
 
477
    {
 
478
    arma_ignore(C);
 
479
    arma_ignore(A);
 
480
    arma_ignore(alpha);
 
481
    arma_ignore(beta);
 
482
    
 
483
    // handling of complex matrix by syrk() is not yet implemented
 
484
    return;
 
485
    }
 
486
  
 
487
  };
 
488
 
 
489
 
 
490
 
 
491
//! @}