~daniele-bigoni/tensortoolbox/tt-docs

« back to all changes in this revision

Viewing changes to TensorToolbox/unittests/TestWTT_4.py

  • Committer: Daniele Bigoni
  • Date: 2015-01-19 11:10:20 UTC
  • Revision ID: dabi@dtu.dk-20150119111020-p0uckg4ab3xqzf47
merged with research

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#
 
2
# This file is part of TensorToolbox.
 
3
#
 
4
# TensorToolbox is free software: you can redistribute it and/or modify
 
5
# it under the terms of the LGNU Lesser General Public License as published by
 
6
# the Free Software Foundation, either version 3 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# TensorToolbox is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# LGNU Lesser General Public License for more details.
 
13
#
 
14
# You should have received a copy of the LGNU Lesser General Public License
 
15
# along with TensorToolbox.  If not, see <http://www.gnu.org/licenses/>.
 
16
#
 
17
# DTU UQ Library
 
18
# Copyright (C) 2014-2015 The Technical University of Denmark
 
19
# Scientific Computing Section
 
20
# Department of Applied Mathematics and Computer Science
 
21
#
 
22
# Author: Daniele Bigoni
 
23
#
 
24
 
 
25
import logging
 
26
import sys
 
27
 
 
28
from aux import bcolors, print_ok, print_fail, print_summary
 
29
 
 
30
def run(maxprocs, PLOTTING=False, loglev=logging.WARNING):
 
31
 
 
32
    logging.basicConfig(level=loglev)
 
33
 
 
34
    import numpy as np
 
35
    import numpy.linalg as npla
 
36
    import itertools
 
37
    import time
 
38
 
 
39
    import TensorToolbox as DT
 
40
    import TensorToolbox.multilinalg as mla
 
41
 
 
42
    if PLOTTING:
 
43
        from matplotlib import pyplot as plt
 
44
 
 
45
    nsucc = 0
 
46
    nfail = 0
 
47
 
 
48
    ####################################################################################
 
49
    # Test Steepest Descent method on simple multidim laplace equation
 
50
    ####################################################################################
 
51
 
 
52
    import scipy.sparse as sp
 
53
    import scipy.sparse.linalg as spla
 
54
 
 
55
    span = np.array([0.,1.])
 
56
    d = 2
 
57
    N = 64
 
58
    h = 1/float(N-1)
 
59
    eps_cg = 1e-3
 
60
    eps_round = 1e-6
 
61
 
 
62
    # sys.stdout.write("Steepest Descent: Laplace  N=%4d   , d=%3d      [START] \n" % (N,d))
 
63
    # sys.stdout.flush()
 
64
 
 
65
    dofull = True
 
66
    try:
 
67
        # Construct d-D Laplace (with 2nd order finite diff)
 
68
        D = -1./h**2. * ( np.diag(np.ones((N-1)),-1) + np.diag(np.ones((N-1)),1) + np.diag(-2.*np.ones((N)),0) )
 
69
        D[0,0:2] = np.array([1.,0.])
 
70
        D[-1,-2:] = np.array([0.,1.])
 
71
        D_sp = sp.coo_matrix(D)
 
72
        I_sp = sp.identity(N)
 
73
        I = np.eye(N)
 
74
        FULL_LAP = sp.coo_matrix((N**d,N**d))
 
75
        for i in range(d):
 
76
            tmp = sp.identity((1))
 
77
            for j in range(d):
 
78
                if i != j: tmp = sp.kron(tmp,I_sp)
 
79
                else: tmp = sp.kron(tmp,D_sp)
 
80
            FULL_LAP = FULL_LAP + tmp
 
81
    except MemoryError:
 
82
        print "FULL CG: Memory Error"
 
83
        dofull = False
 
84
 
 
85
    # Construction of TT Laplace operator
 
86
    CPtmp = []
 
87
    D_flat = D.flatten()
 
88
    I_flat = I.flatten()
 
89
    for i in range(d):
 
90
        CPi = np.empty((d,N**2))
 
91
        for alpha in range(d):
 
92
            if i != alpha:
 
93
                CPi[alpha,:] = I_flat
 
94
            else:
 
95
                CPi[alpha,:] = D_flat
 
96
        CPtmp.append(CPi)
 
97
 
 
98
    CP_lap = DT.Candecomp(CPtmp)
 
99
    TT_LAP = DT.TTmat(CP_lap,nrows=N,ncols=N,is_sparse=[True]*d)
 
100
    TT_LAP.build(eps_round)
 
101
    TT_LAP.rounding(eps_round)
 
102
    CPtmp = None
 
103
    CP_lap = None
 
104
 
 
105
    # Construct Right hand-side (b=1, Dirichlet BC = 0)
 
106
    X = np.linspace(span[0],span[1],N)
 
107
    b1D = np.ones(N)
 
108
    b1D[0] = 0.
 
109
    b1D[-1] = 0.
 
110
 
 
111
    if dofull:
 
112
        # Construct the d-D right handside
 
113
        tmp = np.array([1.])
 
114
        for j in range(d):
 
115
            tmp = np.kron(tmp,b1D)
 
116
        FULL_b = tmp
 
117
 
 
118
    # Construct the TT right handside
 
119
    CPtmp = []
 
120
    for i in range(d):
 
121
        CPi = np.empty((1,N))
 
122
        CPi[0,:] = b1D
 
123
        CPtmp.append(CPi)
 
124
    CP_b = DT.Candecomp(CPtmp)
 
125
    W = [np.ones(N,dtype=float)/float(N) for i in range(d)]
 
126
    TT_b = DT.WTTvec(CP_b,W)
 
127
    TT_b.build()
 
128
    TT_b.rounding(eps_round)
 
129
 
 
130
    if dofull:
 
131
        # Solve full system using npla.solve
 
132
        (FULL_RES,FULL_CONV) = spla.cg(FULL_LAP,FULL_b,tol=eps_cg)
 
133
 
 
134
    if PLOTTING:
 
135
        from mpl_toolkits.mplot3d import Axes3D
 
136
        from matplotlib import cm
 
137
        (XX,YY) = np.meshgrid(X,X)
 
138
        fig = plt.figure(figsize=(18,7))
 
139
        plt.suptitle("SD")
 
140
        if d == 2:
 
141
            # Plot function
 
142
            ax = fig.add_subplot(131,projection='3d')
 
143
            ax.plot_surface(XX,YY,FULL_RES.reshape((N,N)),rstride=1, cstride=1, cmap=cm.coolwarm,
 
144
                            linewidth=0, antialiased=False)
 
145
            plt.show(block=False)
 
146
 
 
147
    # Solve TT cg
 
148
    x0 = DT.zerosvec(d,N)
 
149
    (TT_RES,TT_conv,TT_info) = mla.sd(TT_LAP,TT_b,x0=x0,maxit=10000,eps=eps_cg,ext_info=True,eps_round=eps_round)
 
150
    if PLOTTING and d == 2:
 
151
        # Plot function
 
152
        ax = fig.add_subplot(132,projection='3d')
 
153
        ax.plot_surface(XX,YY,TT_RES.to_tensor(),rstride=1, cstride=1, cmap=cm.coolwarm,
 
154
                        linewidth=0, antialiased=False)
 
155
        plt.show(block=False)
 
156
 
 
157
    # Error
 
158
    if PLOTTING and d == 2:
 
159
        # Plot function
 
160
        ax = fig.add_subplot(133,projection='3d')
 
161
        ax.plot_surface(XX,YY,np.abs(TT_RES.to_tensor()-FULL_RES.reshape((N,N))),rstride=1, cstride=1, cmap=cm.coolwarm,
 
162
                        linewidth=0, antialiased=False)
 
163
        plt.show(block=False)
 
164
 
 
165
    err2 = npla.norm(TT_RES.to_tensor().flatten()-FULL_RES,2)
 
166
    if err2 < 1e-2:
 
167
        print_ok("4.1 WSD: Laplace  N=%4d   , d=%3d  , 2-err=%f" % (N,d,err2))
 
168
        nsucc += 1
 
169
    else:
 
170
        print_fail("4.1 WSD: Laplace  N=%4d   , d=%3d  , 2-err=%f" % (N,d,err2))
 
171
        nfail += 1
 
172
 
 
173
    print_summary("WTT SD", nsucc, nfail)
 
174
    
 
175
    return (nsucc,nfail)
 
176
 
 
177
if __name__ == "__main__":
 
178
    # Number of processors to be used, defined as an additional arguement 
 
179
    # $ python TestTT.py N
 
180
    # Mind that the program in this case will run slower than the non-parallel case
 
181
    # due to the overhead for the creation and deletion of threads.
 
182
    if len(sys.argv) == 2:
 
183
        maxprocs = int(sys.argv[1])
 
184
    else:
 
185
        maxprocs = None
 
186
 
 
187
    run(maxprocs,PLOTTING=True, loglev=logging.INFO)