~daniele-bigoni/tensortoolbox/tt-docs

« back to all changes in this revision

Viewing changes to TensorToolbox/unittests/TestWTT_7.py

  • Committer: Daniele Bigoni
  • Date: 2015-01-19 11:10:20 UTC
  • Revision ID: dabi@dtu.dk-20150119111020-p0uckg4ab3xqzf47
merged with research

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#
 
2
# This file is part of TensorToolbox.
 
3
#
 
4
# TensorToolbox is free software: you can redistribute it and/or modify
 
5
# it under the terms of the LGNU Lesser General Public License as published by
 
6
# the Free Software Foundation, either version 3 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# TensorToolbox is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# LGNU Lesser General Public License for more details.
 
13
#
 
14
# You should have received a copy of the LGNU Lesser General Public License
 
15
# along with TensorToolbox.  If not, see <http://www.gnu.org/licenses/>.
 
16
#
 
17
# DTU UQ Library
 
18
# Copyright (C) 2014-2015 The Technical University of Denmark
 
19
# Scientific Computing Section
 
20
# Department of Applied Mathematics and Computer Science
 
21
#
 
22
# Author: Daniele Bigoni
 
23
#
 
24
 
 
25
import logging
 
26
import sys
 
27
 
 
28
from aux import bcolors, print_ok, print_fail, print_summary
 
29
 
 
30
def run(maxprocs, PLOTTING=False, loglev=logging.WARNING):
 
31
 
 
32
    logging.basicConfig(level=loglev)
 
33
 
 
34
    import numpy as np
 
35
    import numpy.linalg as npla
 
36
    import itertools
 
37
    import time
 
38
 
 
39
    import TensorToolbox as DT
 
40
    import TensorToolbox.multilinalg as mla
 
41
 
 
42
    if PLOTTING:
 
43
        from matplotlib import pyplot as plt
 
44
 
 
45
    nsucc = 0
 
46
    nfail = 0
 
47
 
 
48
    ####################################################################################
 
49
    # Test GMRES method on simple multidim laplace equation
 
50
    ####################################################################################
 
51
 
 
52
    import scipy.sparse as sp
 
53
    import scipy.sparse.linalg as spla
 
54
 
 
55
    span = np.array([0.,1.])
 
56
    d = 2
 
57
    N = 64
 
58
    h = 1/float(N-1)
 
59
    eps_gmres = 1e-3
 
60
    eps_round = 1e-6
 
61
 
 
62
    # sys.stdout.write("GMRES: Laplace  N=%4d   , d=%3d      [START] \r" % (N,d))
 
63
    # sys.stdout.flush()
 
64
 
 
65
    # Construct d-D Laplace (with 2nd order finite diff)
 
66
    D = -1./h**2. * ( np.diag(np.ones((N-1)),-1) + np.diag(np.ones((N-1)),1) + np.diag(-2.*np.ones((N)),0) )
 
67
    D[0,0:2] = np.array([1.,0.])
 
68
    D[-1,-2:] = np.array([0.,1.])
 
69
    D_sp = sp.coo_matrix(D)
 
70
    I_sp = sp.identity(N)
 
71
    I = np.eye(N)
 
72
    FULL_LAP = sp.coo_matrix((N**d,N**d))
 
73
    for i in range(d):
 
74
        tmp = sp.identity((1))
 
75
        for j in range(d):
 
76
            if i != j: tmp = sp.kron(tmp,I_sp)
 
77
            else: tmp = sp.kron(tmp,D_sp)
 
78
        FULL_LAP = FULL_LAP + tmp
 
79
 
 
80
    # Construction of TT Laplace operator
 
81
    CPtmp = []
 
82
    D_flat = D.flatten()
 
83
    I_flat = I.flatten()
 
84
    for i in range(d):
 
85
        CPi = np.empty((d,N**2))
 
86
        for alpha in range(d):
 
87
            if i != alpha:
 
88
                CPi[alpha,:] = I_flat
 
89
            else:
 
90
                CPi[alpha,:] = D_flat
 
91
        CPtmp.append(CPi)
 
92
 
 
93
    CP_lap = DT.Candecomp(CPtmp)
 
94
    TT_LAP = DT.TTmat(CP_lap,nrows=N,ncols=N,is_sparse=[True]*d)
 
95
    TT_LAP.build()
 
96
    TT_LAP.rounding(eps_round)
 
97
    CPtmp = None
 
98
    CP_lap = None
 
99
 
 
100
    # Construct Right hand-side (b=1, Dirichlet BC = 0)
 
101
    X = np.linspace(span[0],span[1],N)
 
102
    b1D = np.ones(N)
 
103
    b1D[0] = 0.
 
104
    b1D[-1] = 0.
 
105
    # Construct the d-D right handside
 
106
    tmp = np.array([1.])
 
107
    for j in range(d):
 
108
        tmp = np.kron(tmp,b1D)
 
109
    FULL_b = tmp
 
110
    # Construct the TT right handside
 
111
    CPtmp = []
 
112
    for i in range(d):
 
113
        CPi = np.empty((1,N))
 
114
        CPi[0,:] = b1D
 
115
        CPtmp.append(CPi)
 
116
    CP_b = DT.Candecomp(CPtmp)
 
117
    W = [np.ones(N,dtype=float)/float(N) for i in range(d)]
 
118
    TT_b = DT.WTTvec(CP_b,W)
 
119
    TT_b.build()
 
120
    TT_b.rounding(eps_round)
 
121
 
 
122
 
 
123
    # Solve full system using npla.solve
 
124
    (FULL_RES,FULL_info) = spla.gmres(FULL_LAP,FULL_b,tol=eps_gmres)
 
125
 
 
126
    if PLOTTING:
 
127
        from mpl_toolkits.mplot3d import Axes3D
 
128
        from matplotlib import cm
 
129
        (XX,YY) = np.meshgrid(X,X)
 
130
        fig = plt.figure(figsize=(18,7))
 
131
        plt.suptitle("GMRES")
 
132
        if d == 2:
 
133
            # Plot function
 
134
            ax = fig.add_subplot(131,projection='3d')
 
135
            ax.plot_surface(XX,YY,FULL_RES.reshape((N,N)),rstride=1, cstride=1, cmap=cm.coolwarm,
 
136
                            linewidth=0, antialiased=False)
 
137
            plt.show(block=False)
 
138
 
 
139
    # Solve TT cg
 
140
    x0 = DT.zerosvec(d,N)
 
141
    (TT_RES,conv,TT_info) = mla.gmres(TT_LAP,TT_b,x0=x0,restart=10,eps=eps_gmres,ext_info=True)
 
142
    if PLOTTING and d == 2:
 
143
        # Plot function
 
144
        ax = fig.add_subplot(132,projection='3d')
 
145
        ax.plot_surface(XX,YY,TT_RES.to_tensor(),rstride=1, cstride=1, cmap=cm.coolwarm,
 
146
                        linewidth=0, antialiased=False)
 
147
        plt.show(block=False)
 
148
 
 
149
    # Error
 
150
    if PLOTTING and d == 2:
 
151
        # Plot function
 
152
        ax = fig.add_subplot(133,projection='3d')
 
153
        ax.plot_surface(XX,YY,np.abs(TT_RES.to_tensor()-FULL_RES.reshape((N,N))),rstride=1, cstride=1, cmap=cm.coolwarm,
 
154
                        linewidth=0, antialiased=False)
 
155
        plt.show(block=False)
 
156
 
 
157
    err2 = npla.norm(TT_RES.to_tensor().flatten()-FULL_RES,2)
 
158
    if err2 < 1e-2:
 
159
        print_ok("7.1 Weighted GMRES: Laplace  N=%4d   , d=%3d  , 2-err=%f" % (N,d,err2))
 
160
        nsucc += 1
 
161
    else:
 
162
        print_fail("7.1 Weighted GMRES: Laplace  N=%4d   , d=%3d  , 2-err=%f" % (N,d,err2))
 
163
        nfail += 1
 
164
 
 
165
    print_summary("WTT GMRES", nsucc, nfail)
 
166
    
 
167
    return (nsucc,nfail)
 
168
 
 
169
 
 
170
if __name__ == "__main__":
 
171
    # Number of processors to be used, defined as an additional arguement 
 
172
    # $ python TestTT.py N
 
173
    # Mind that the program in this case will run slower than the non-parallel case
 
174
    # due to the overhead for the creation and deletion of threads.
 
175
    if len(sys.argv) == 2:
 
176
        maxprocs = int(sys.argv[1])
 
177
    else:
 
178
        maxprocs = None
 
179
 
 
180
    run(maxprocs,PLOTTING=True, loglev=logging.INFO)