~nickpapior/siesta/tddft-work

« back to all changes in this revision

Viewing changes to Src/m_mat_invert.F90

  • Committer: Rafi Ullah
  • Date: 2017-09-27 12:02:36 UTC
  • mfrom: (611.1.22 trunk)
  • Revision ID: rraffiu@gmail.com-20170927120236-68zal54nt0bu1jxp
Merged with trunk-633

Show diffs side-by-side

added added

removed removed

Lines of Context:
23
23
  integer, public, parameter :: MI_IN_PLACE_RECURSIVE = 2
24
24
  integer, public, parameter :: MI_WORK = 3
25
25
 
26
 
  ! The maximum dimensionality of the problem before we turn to a direct inversion algorithm
27
 
  integer :: N_MAX = 40
28
 
 
29
26
  ! Used for BLAS calls (local variables)
30
27
  complex(dp), parameter :: z0  = dcmplx( 0._dp, 0._dp)
31
28
  complex(dp), parameter :: z1  = dcmplx( 1._dp, 0._dp)
42
39
 
43
40
    integer :: lmethod, lierr
44
41
    
45
 
 
46
42
    if ( present(method) ) then
47
43
       lmethod = method
48
44
    else
59
55
    if ( present(ierr) ) ierr = 0
60
56
 
61
57
    select case ( lmethod )
62
 
    case (  MI_IN_PLACE_LAPACK ) 
 
58
    case ( MI_IN_PLACE_LAPACK ) 
63
59
       call zgetrf(no, no, M, no, ipiv, lierr )
64
60
       if ( lierr /= 0 ) then
65
61
          if ( present(ierr) ) then
97
93
 
98
94
  end subroutine mat_invert
99
95
 
 
96
#ifdef _NOT_USED
100
97
  recursive subroutine mat_invert_recursive(M, zwork, no, ierr)
101
98
    use intrinsic_missing, only : EYE
102
99
    integer, intent(in) :: no ! Size of problem
103
100
    complex(dp), target :: M(no*no), zwork(no*no)
104
101
    integer, intent(out), optional :: ierr
105
102
 
 
103
    ! The maximum dimensionality of the problem before we turn to a direct inversion algorithm
 
104
    integer, parameter :: N_MAX = 64
 
105
 
106
106
    complex(dp), pointer :: A1(:), C2(:), B1(:), A2(:)
107
107
    complex(dp), pointer :: t1(:), t2(:)
108
108
    integer :: sA1, eA1, sC2, eC2
335
335
    end do
336
336
 
337
337
  end subroutine mat_invert_recursive
 
338
#endif
 
339
 
 
340
  recursive subroutine mat_invert_recursive(M, zwork, no, ierr)
 
341
    use intrinsic_missing, only : EYE
 
342
    integer, intent(in) :: no ! Size of problem
 
343
    complex(dp), target :: M(no*no), zwork(no*no)
 
344
    integer, intent(out), optional :: ierr
 
345
 
 
346
    ! The maximum dimensionality of the problem before we turn to a direct inversion algorithm
 
347
    integer, parameter :: N_MAX = 64
 
348
    integer, parameter :: sA1 = 1
 
349
 
 
350
    complex(dp), pointer :: A1(:), C2(:), B1(:), A2(:)
 
351
    integer :: sC2, sB1, sA2
 
352
    integer :: n1, n2, i, j
 
353
 
 
354
    ! M is the matrix we wish to invert using the tri-diagonal
 
355
    ! routine.
 
356
    ! This will enable us to half no in order to "faster" 
 
357
    ! invert the matrix
 
358
 
 
359
    if ( no <= N_MAX ) then
 
360
       call mat_invert(M,zwork,no, method = MI_IN_PLACE_LAPACK , ierr=ierr)
 
361
       return
 
362
    end if
 
363
 
 
364
    if ( present(ierr) ) ierr = 0
 
365
 
 
366
    ! Calculate the partition sizes of the matrix problem
 
367
    n1 = no / 2
 
368
    n2 = no - n1
 
369
 
 
370
    if ( Npiv < max(n1,n2) ) then
 
371
       call die('Error in initialization of the pivoting &
 
372
            &array.')
 
373
    end if
 
374
 
 
375
    ! Point to the partitions
 
376
    sB1 = sA1 + n1 ** 2
 
377
    sC2 = sB1 + n1 * n2
 
378
    sA2 = sC2 + n1 * n2
 
379
    A1 => zwork(sA1:)
 
380
    B1 => zwork(sB1:)
 
381
    C2 => zwork(sC2:)
 
382
    A2 => zwork(sA2:)
 
383
 
 
384
    ! Copy over everything to preserve values for inversion
 
385
!$OMP parallel do default(shared), private(i)
 
386
    do i = 1, no ** 2
 
387
       zwork(i) = M(i)
 
388
    end do
 
389
!$OMP end parallel do
 
390
    
 
391
    ! Calculate Y2/B1
 
392
    call zgesv(n1,n2,A1(1),no,ipiv,C2(1),no,i)
 
393
    if ( i /= 0 ) then
 
394
       if ( present(ierr) ) then
 
395
          ierr = i
 
396
          return
 
397
       else
 
398
          call die('Error on inverting Y2/B1')
 
399
       end if
 
400
    end if
 
401
 
 
402
    ! Calculate X1/C2
 
403
    call zgesv(n2,n1,A2(1),no,ipiv,B1(1),no,i)
 
404
    if ( i /= 0 ) then
 
405
       if ( present(ierr) ) then
 
406
          ierr = i
 
407
          return
 
408
       else
 
409
          call die('Error on inverting X1/C2')
 
410
       end if
 
411
    end if
 
412
 
 
413
    ! Calculate the diagonal inverted matrix
 
414
    ! Calculate: A1 - X1
 
415
#ifdef USE_GEMM3M
 
416
    call zgemm3m( &
 
417
#else
 
418
    call zgemm( &
 
419
#endif
 
420
         'N','N',n1,n1,n2, &
 
421
         zm1,M(sC2),no,B1(1),no,z1,M(sA1),no)
 
422
 
 
423
    ! Calculate: A2 - Y2
 
424
#ifdef USE_GEMM3M
 
425
    call zgemm3m( &
 
426
#else
 
427
    call zgemm( &
 
428
#endif
 
429
         'N','N',n2,n2,n1, &
 
430
         zm1,M(sB1),no,C2(1),no,z1,M(sA2),no)
 
431
    
 
432
    call zgetrf(n1, n1, M(sA1), no, ipiv, i )
 
433
    if ( i /= 0 ) then
 
434
       if ( present(ierr) ) then
 
435
          ierr = i
 
436
          return
 
437
       else
 
438
          call die('Error on LU-decomposition A1')
 
439
       end if
 
440
    end if
 
441
 
 
442
    call zgetrf(n2, n2, M(sA2), no, ipiv, i )
 
443
    if ( i /= 0 ) then
 
444
       if ( present(ierr) ) then
 
445
          ierr = i
 
446
          return
 
447
       else
 
448
          call die('Error on LU-decomposition A2')
 
449
       end if
 
450
    end if
 
451
 
 
452
    ! Now before we use A1 as work array we will copy
 
453
    !  B1 and C2
 
454
    ! because they are used to calculate the off-diagonal
 
455
    ! Note it has to be done in this order
 
456
    !   A1, B1, C2, A2 is the order of matrices in the array
 
457
    call copy(n1,n2,zwork(1),zwork(sB1),no)
 
458
    call copy(n1,n2,zwork(n1*n2+1),zwork(sC2),no)
 
459
    j = n1*n2*2 + 1
 
460
 
 
461
    call zgetri(n1, M(1), no, ipiv, zwork(j), no**2-j, i)
 
462
    if ( i /= 0 ) then
 
463
       if ( present(ierr) ) then
 
464
          ierr = i
 
465
          return
 
466
       else
 
467
          call die('Error on inverting A1')
 
468
       end if
 
469
    end if
 
470
 
 
471
    call zgetri(n2, M(sA2), no, ipiv, zwork(j), no**2-j, i)
 
472
    if ( i /= 0 ) then
 
473
       if ( present(ierr) ) then
 
474
          ierr = i
 
475
          return
 
476
       else
 
477
          call die('Error on inverting A2')
 
478
       end if
 
479
    end if
 
480
 
 
481
    ! Calculate the off-diagonal arrays
 
482
    ! Do matrix-multiplication
 
483
    ! Calculate: X1/C2 * M11
 
484
#ifdef USE_GEMM3M
 
485
    call zgemm3m( &
 
486
#else
 
487
    call zgemm( &
 
488
#endif
 
489
         'N','N',n2,n1,n1, &
 
490
         zm1, zwork(1),n2,M(sA1),no,z0, M(sB1),no)
 
491
 
 
492
    ! Calculate the off-diagonal arrays
 
493
    ! Do matrix-multiplication
 
494
    ! Calculate: Y2/B1 * M22
 
495
#ifdef USE_GEMM3M
 
496
    call zgemm3m( &
 
497
#else
 
498
    call zgemm( &
 
499
#endif
 
500
         'N','N',n1,n2,n2, &
 
501
         zm1, zwork(n1*n2+1),n1,M(sA2),no,z0, M(sC2),no)
 
502
 
 
503
  contains
 
504
 
 
505
    subroutine copy(n1,n2,A,B,LDB)
 
506
      integer, intent(in) :: n1, n2, LDB
 
507
      complex(dp), intent(inout) :: A(n1,n2)
 
508
      complex(dp), intent(in) :: B(LDB,n2)
 
509
 
 
510
      integer :: i, j
 
511
 
 
512
      do j = 1 , n2
 
513
         do i = 1 , n1
 
514
            A(i,j) = B(i,j)
 
515
         end do
 
516
      end do
 
517
      
 
518
    end subroutine copy
 
519
 
 
520
  end subroutine mat_invert_recursive
338
521
 
339
522
end module m_mat_invert