2
# This file is part of TensorToolbox.
4
# TensorToolbox is free software: you can redistribute it and/or modify
5
# it under the terms of the LGNU Lesser General Public License as published by
6
# the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
9
# TensorToolbox is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
# LGNU Lesser General Public License for more details.
14
# You should have received a copy of the LGNU Lesser General Public License
15
# along with TensorToolbox. If not, see <http://www.gnu.org/licenses/>.
18
# Copyright (C) 2014-2015 The Technical University of Denmark
19
# Scientific Computing Section
20
# Department of Applied Mathematics and Computer Science
22
# Author: Daniele Bigoni
28
from aux import bcolors, print_ok, print_fail, print_summary
30
def run(maxprocs, PLOTTING=False, loglev=logging.WARNING):
32
logging.basicConfig(level=loglev)
35
import numpy.linalg as npla
39
import TensorToolbox as DT
40
import TensorToolbox.multilinalg as mla
43
from matplotlib import pyplot as plt
48
####################################################################################
49
# Test GMRES method on simple multidim laplace equation
50
####################################################################################
52
import scipy.sparse as sp
53
import scipy.sparse.linalg as spla
55
span = np.array([0.,1.])
62
# sys.stdout.write("GMRES: Laplace N=%4d , d=%3d [START] \r" % (N,d))
65
# Construct d-D Laplace (with 2nd order finite diff)
66
D = -1./h**2. * ( np.diag(np.ones((N-1)),-1) + np.diag(np.ones((N-1)),1) + np.diag(-2.*np.ones((N)),0) )
67
D[0,0:2] = np.array([1.,0.])
68
D[-1,-2:] = np.array([0.,1.])
69
D_sp = sp.coo_matrix(D)
72
FULL_LAP = sp.coo_matrix((N**d,N**d))
74
tmp = sp.identity((1))
76
if i != j: tmp = sp.kron(tmp,I_sp)
77
else: tmp = sp.kron(tmp,D_sp)
78
FULL_LAP = FULL_LAP + tmp
80
# Construction of TT Laplace operator
85
CPi = np.empty((d,N**2))
86
for alpha in range(d):
93
CP_lap = DT.Candecomp(CPtmp)
94
TT_LAP = DT.TTmat(CP_lap,nrows=N,ncols=N,is_sparse=[True]*d)
96
TT_LAP.rounding(eps_round)
100
# Construct Right hand-side (b=1, Dirichlet BC = 0)
101
X = np.linspace(span[0],span[1],N)
105
# Construct the d-D right handside
108
tmp = np.kron(tmp,b1D)
110
# Construct the TT right handside
113
CPi = np.empty((1,N))
116
CP_b = DT.Candecomp(CPtmp)
117
W = [np.ones(N,dtype=float)/float(N) for i in range(d)]
118
TT_b = DT.WTTvec(CP_b,W)
120
TT_b.rounding(eps_round)
123
# Solve full system using npla.solve
124
(FULL_RES,FULL_info) = spla.gmres(FULL_LAP,FULL_b,tol=eps_gmres)
127
from mpl_toolkits.mplot3d import Axes3D
128
from matplotlib import cm
129
(XX,YY) = np.meshgrid(X,X)
130
fig = plt.figure(figsize=(18,7))
131
plt.suptitle("GMRES")
134
ax = fig.add_subplot(131,projection='3d')
135
ax.plot_surface(XX,YY,FULL_RES.reshape((N,N)),rstride=1, cstride=1, cmap=cm.coolwarm,
136
linewidth=0, antialiased=False)
137
plt.show(block=False)
140
x0 = DT.zerosvec(d,N)
141
(TT_RES,conv,TT_info) = mla.gmres(TT_LAP,TT_b,x0=x0,restart=10,eps=eps_gmres,ext_info=True)
142
if PLOTTING and d == 2:
144
ax = fig.add_subplot(132,projection='3d')
145
ax.plot_surface(XX,YY,TT_RES.to_tensor(),rstride=1, cstride=1, cmap=cm.coolwarm,
146
linewidth=0, antialiased=False)
147
plt.show(block=False)
150
if PLOTTING and d == 2:
152
ax = fig.add_subplot(133,projection='3d')
153
ax.plot_surface(XX,YY,np.abs(TT_RES.to_tensor()-FULL_RES.reshape((N,N))),rstride=1, cstride=1, cmap=cm.coolwarm,
154
linewidth=0, antialiased=False)
155
plt.show(block=False)
157
err2 = npla.norm(TT_RES.to_tensor().flatten()-FULL_RES,2)
159
print_ok("7.1 Weighted GMRES: Laplace N=%4d , d=%3d , 2-err=%f" % (N,d,err2))
162
print_fail("7.1 Weighted GMRES: Laplace N=%4d , d=%3d , 2-err=%f" % (N,d,err2))
165
print_summary("WTT GMRES", nsucc, nfail)
170
if __name__ == "__main__":
171
# Number of processors to be used, defined as an additional arguement
172
# $ python TestTT.py N
173
# Mind that the program in this case will run slower than the non-parallel case
174
# due to the overhead for the creation and deletion of threads.
175
if len(sys.argv) == 2:
176
maxprocs = int(sys.argv[1])
180
run(maxprocs,PLOTTING=True, loglev=logging.INFO)