2
# This file is part of TensorToolbox.
4
# TensorToolbox is free software: you can redistribute it and/or modify
5
# it under the terms of the LGNU Lesser General Public License as published by
6
# the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
9
# TensorToolbox is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
# LGNU Lesser General Public License for more details.
14
# You should have received a copy of the LGNU Lesser General Public License
15
# along with TensorToolbox. If not, see <http://www.gnu.org/licenses/>.
18
# Copyright (C) 2014-2015 The Technical University of Denmark
19
# Scientific Computing Section
20
# Department of Applied Mathematics and Computer Science
22
# Author: Daniele Bigoni
28
from aux import bcolors, print_ok, print_fail, print_summary
30
def run(maxprocs, PLOTTING=False, loglev=logging.WARNING):
32
logging.basicConfig(level=loglev)
35
import numpy.linalg as npla
39
import TensorToolbox as DT
40
import TensorToolbox.multilinalg as mla
43
from matplotlib import pyplot as plt
48
####################################################################################
49
# Test Steepest Descent method on simple multidim laplace equation
50
####################################################################################
52
import scipy.sparse as sp
53
import scipy.sparse.linalg as spla
55
span = np.array([0.,1.])
62
# sys.stdout.write("Steepest Descent: Laplace N=%4d , d=%3d [START] \n" % (N,d))
67
# Construct d-D Laplace (with 2nd order finite diff)
68
D = -1./h**2. * ( np.diag(np.ones((N-1)),-1) + np.diag(np.ones((N-1)),1) + np.diag(-2.*np.ones((N)),0) )
69
D[0,0:2] = np.array([1.,0.])
70
D[-1,-2:] = np.array([0.,1.])
71
D_sp = sp.coo_matrix(D)
74
FULL_LAP = sp.coo_matrix((N**d,N**d))
76
tmp = sp.identity((1))
78
if i != j: tmp = sp.kron(tmp,I_sp)
79
else: tmp = sp.kron(tmp,D_sp)
80
FULL_LAP = FULL_LAP + tmp
82
print "FULL CG: Memory Error"
85
# Construction of TT Laplace operator
90
CPi = np.empty((d,N**2))
91
for alpha in range(d):
98
CP_lap = DT.Candecomp(CPtmp)
99
TT_LAP = DT.TTmat(CP_lap,nrows=N,ncols=N,is_sparse=[True]*d)
100
TT_LAP.build(eps_round)
101
TT_LAP.rounding(eps_round)
105
# Construct Right hand-side (b=1, Dirichlet BC = 0)
106
X = np.linspace(span[0],span[1],N)
112
# Construct the d-D right handside
115
tmp = np.kron(tmp,b1D)
118
# Construct the TT right handside
121
CPi = np.empty((1,N))
124
CP_b = DT.Candecomp(CPtmp)
125
W = [np.ones(N,dtype=float)/float(N) for i in range(d)]
126
TT_b = DT.WTTvec(CP_b,W)
128
TT_b.rounding(eps_round)
131
# Solve full system using npla.solve
132
(FULL_RES,FULL_CONV) = spla.cg(FULL_LAP,FULL_b,tol=eps_cg)
135
from mpl_toolkits.mplot3d import Axes3D
136
from matplotlib import cm
137
(XX,YY) = np.meshgrid(X,X)
138
fig = plt.figure(figsize=(18,7))
142
ax = fig.add_subplot(131,projection='3d')
143
ax.plot_surface(XX,YY,FULL_RES.reshape((N,N)),rstride=1, cstride=1, cmap=cm.coolwarm,
144
linewidth=0, antialiased=False)
145
plt.show(block=False)
148
x0 = DT.zerosvec(d,N)
149
(TT_RES,TT_conv,TT_info) = mla.sd(TT_LAP,TT_b,x0=x0,maxit=10000,eps=eps_cg,ext_info=True,eps_round=eps_round)
150
if PLOTTING and d == 2:
152
ax = fig.add_subplot(132,projection='3d')
153
ax.plot_surface(XX,YY,TT_RES.to_tensor(),rstride=1, cstride=1, cmap=cm.coolwarm,
154
linewidth=0, antialiased=False)
155
plt.show(block=False)
158
if PLOTTING and d == 2:
160
ax = fig.add_subplot(133,projection='3d')
161
ax.plot_surface(XX,YY,np.abs(TT_RES.to_tensor()-FULL_RES.reshape((N,N))),rstride=1, cstride=1, cmap=cm.coolwarm,
162
linewidth=0, antialiased=False)
163
plt.show(block=False)
165
err2 = npla.norm(TT_RES.to_tensor().flatten()-FULL_RES,2)
167
print_ok("4.1 WSD: Laplace N=%4d , d=%3d , 2-err=%f" % (N,d,err2))
170
print_fail("4.1 WSD: Laplace N=%4d , d=%3d , 2-err=%f" % (N,d,err2))
173
print_summary("WTT SD", nsucc, nfail)
177
if __name__ == "__main__":
178
# Number of processors to be used, defined as an additional arguement
179
# $ python TestTT.py N
180
# Mind that the program in this case will run slower than the non-parallel case
181
# due to the overhead for the creation and deletion of threads.
182
if len(sys.argv) == 2:
183
maxprocs = int(sys.argv[1])
187
run(maxprocs,PLOTTING=True, loglev=logging.INFO)