1
! Copyright (C) 2006 Imperial College London and others.
3
! Please see the AUTHORS file in the main source directory for a full list
4
! of copyright holders.
7
! Applied Modelling and Computation Group
8
! Department of Earth Science and Engineering
9
! Imperial College London
11
! amcgsoftware@imperial.ac.uk
13
! This library is free software; you can redistribute it and/or
14
! modify it under the terms of the GNU Lesser General Public
15
! License as published by the Free Software Foundation; either
16
! version 2.1 of the License, or (at your option) any later version.
18
! This library is distributed in the hope that it will be useful,
19
! but WITHOUT ANY WARRANTY; without even the implied warranty of
20
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21
! Lesser General Public License for more details.
23
! You should have received a copy of the GNU Lesser General Public
24
! License along with this library; if not, write to the Free Software
25
! Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
29
module forward_main_loop
31
#include "libadjoint/adj_fortran.h"
33
use libadjoint_data_callbacks
36
use diagnostic_variables
39
use global_parameters, only: OPTION_PATH_LEN, running_adjoint
40
use adjoint_global_variables
41
use adjoint_functional_evaluation
42
use populate_state_module
43
use boundary_conditions_from_options
45
use mangle_options_tree
46
use mangle_dirichlet_rows_module
48
use sparse_matrices_fields
49
use write_state_module
50
use diagnostic_fields_wrapper
52
use diagnostic_fields_new, only: calculate_diagnostic_variables_new => calculate_diagnostic_variables, &
53
& check_diagnostic_dependencies
57
public :: forward_main_loop_register_diagnostic
59
public :: compute_forward, calculate_functional_values, register_functional_callbacks
65
subroutine compute_forward(state)
66
type(state_type), dimension(:), intent(inout) :: state
68
type(adj_vector) :: rhs
69
type(adj_vector) :: soln
70
type(adj_matrix) :: lhs
71
type(adj_variable) :: fwd_var
78
real :: finish_time, dt
79
integer :: end_timestep, start_timestep, no_timesteps, timestep
80
real :: start_time, end_time
82
character(len=OPTION_PATH_LEN) :: simulation_base_name, functional_name
83
type(stat_type) :: new_stat
85
character(len=ADJ_NAME_LEN) :: variable_name, field_name, material_phase_name
86
type(scalar_field) :: sfield_soln, sfield_rhs
87
type(vector_field) :: vfield_soln, vfield_rhs
88
type(csr_matrix) :: csr_mat
89
type(block_csr_matrix) :: block_csr_mat
90
character(len=ADJ_DICT_LEN) :: path
91
type(adj_storage_data) :: storage
93
integer :: nfunctionals
96
ierr = adj_adjointer_check_consistency(adjointer)
97
call adj_chkierr(ierr)
99
call get_option("/timestepping/timestep", dt)
100
call get_option("/simulation_name", simulation_base_name)
101
running_adjoint = .false.
103
! Switch the html output on if you are interested what the adjointer has registered
104
if (have_option("/adjoint/debug/html_output")) then
105
ierr = adj_adjointer_to_html(adjointer, "adjointer_forward.html", ADJ_FORWARD)
106
call adj_chkierr(ierr)
107
ierr = adj_adjointer_to_html(adjointer, "adjointer_adjoint.html", ADJ_ADJOINT)
108
call adj_chkierr(ierr)
111
default_stat = new_stat
112
call initialise_walltime
113
call initialise_diagnostics(trim(simulation_base_name) // '_forward', state)
115
ierr = adj_timestep_count(adjointer, no_timesteps)
116
call adj_chkierr(ierr)
119
nfunctionals = option_count("/adjoint/functional")
121
do timestep=0,no_timesteps-1
122
ierr = adj_timestep_get_times(adjointer, timestep, start_time, end_time)
123
call adj_chkierr(ierr)
124
current_time = start_time
125
call set_option("/timestepping/current_time", current_time)
127
ierr = adj_timestep_start_equation(adjointer, timestep, start_timestep)
128
call adj_chkierr(ierr)
130
ierr = adj_timestep_end_equation(adjointer, timestep, end_timestep)
131
call adj_chkierr(ierr)
133
do equation=start_timestep,end_timestep
134
ierr = adj_get_forward_equation(adjointer, equation, lhs, rhs, fwd_var)
135
call adj_chkierr(ierr)
137
! Now solve lhs . adjoint = rhs
138
ierr = adj_variable_get_name(fwd_var, variable_name)
139
s_idx = scan(trim(variable_name), ":")
140
material_phase_name = variable_name(1:s_idx - 1)
141
field_name = variable_name(s_idx + 2:len_trim(variable_name))
142
ierr = adj_dict_find(adj_path_lookup, trim(variable_name), path)
143
if (ierr == ADJ_OK) then
149
! variable_name should be something like Fluid::Velocity
150
select case(rhs%klass)
151
case(ADJ_SCALAR_FIELD)
152
call field_from_adj_vector(rhs, sfield_rhs)
153
call allocate(sfield_soln, sfield_rhs%mesh, trim(field_name))
154
call zero(sfield_soln)
157
sfield_soln%option_path = trim(path)
158
! We need to populate the BC values:
159
call insert(state(1), sfield_soln, trim(sfield_soln%name))
160
call populate_boundary_conditions(state)
161
call set_boundary_conditions_values(state, shift_time=.false.)
162
sfield_soln = extract_scalar_field(state(1), trim(sfield_soln%name))
165
select case(lhs%klass)
166
case(ADJ_IDENTITY_MATRIX)
167
call set(sfield_soln, sfield_rhs)
169
call matrix_from_adj_matrix(lhs, csr_mat)
170
if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
171
call mult(sfield_soln, csr_mat, sfield_rhs)
173
if (.not. has_path) then
174
ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
175
call adj_chkierr(ierr)
178
sfield_rhs%bc => sfield_soln%bc
179
call set_dirichlet_consistent(sfield_rhs)
180
sfield_rhs%bc => null()
182
call petsc_solve(sfield_soln, csr_mat, sfield_rhs, option_path=path)
183
call compute_inactive_rows(sfield_soln, csr_mat, sfield_rhs)
185
case(ADJ_BLOCK_CSR_MATRIX)
186
FLAbort("Cannot map between scalar fields with a block_csr_matrix .. ")
188
FLAbort("Unknown lhs%klass")
191
call insert(state(1), sfield_soln, trim(sfield_soln%name))
193
soln = field_to_adj_vector(sfield_soln)
194
ierr = adj_storage_memory_incref(soln, storage)
195
call adj_chkierr(ierr)
197
ierr = adj_storage_set_compare(storage, .true., 1.0d-10)
198
call adj_chkierr(ierr)
199
ierr = adj_storage_set_overwrite(storage, .true.)
200
call adj_chkierr(ierr)
202
ierr = adj_record_variable(adjointer, fwd_var, storage)
203
call adj_chkierr(ierr)
204
call deallocate(sfield_soln)
205
case(ADJ_VECTOR_FIELD)
206
call field_from_adj_vector(rhs, vfield_rhs)
207
call allocate(vfield_soln, vfield_rhs%dim, vfield_rhs%mesh, trim(field_name))
208
call zero(vfield_soln)
211
vfield_soln%option_path = trim(path)
212
! We need to populate the BC values:
213
call insert(state(1), vfield_soln, trim(vfield_soln%name))
214
call populate_boundary_conditions(state)
215
call set_boundary_conditions_values(state, shift_time=.false.)
216
vfield_soln = extract_vector_field(state(1), trim(vfield_soln%name))
219
select case(lhs%klass)
220
case(ADJ_IDENTITY_MATRIX)
221
call set(vfield_soln, vfield_rhs)
223
call matrix_from_adj_matrix(lhs, csr_mat)
224
if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
225
call mult(vfield_soln, csr_mat, vfield_rhs)
227
if (.not. has_path) then
228
ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
229
call adj_chkierr(ierr)
232
vfield_rhs%bc => vfield_soln%bc
233
call set_dirichlet_consistent(vfield_rhs)
234
vfield_rhs%bc => null()
236
call petsc_solve(vfield_soln, csr_mat, vfield_rhs, option_path=path)
237
!call compute_inactive_rows(vfield_soln, csr_mat, vfield_rhs)
239
case(ADJ_BLOCK_CSR_MATRIX)
240
call matrix_from_adj_matrix(lhs, block_csr_mat)
241
if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
242
call mult(vfield_soln, block_csr_mat, vfield_rhs)
244
if (.not. has_path) then
245
ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
246
call adj_chkierr(ierr)
249
call petsc_solve(vfield_soln, block_csr_mat, vfield_rhs, option_path=path)
250
!call compute_inactive_rows(vfield_soln, block_csr_mat, vfield_rhs)
253
FLAbort("Unknown lhs%klass")
256
call insert(state(1), vfield_soln, trim(vfield_soln%name))
258
soln = field_to_adj_vector(vfield_soln)
259
ierr = adj_storage_memory_incref(soln, storage)
260
call adj_chkierr(ierr)
262
ierr = adj_storage_set_compare(storage, .true., 1.0d-10)
263
call adj_chkierr(ierr)
264
ierr = adj_storage_set_overwrite(storage, .true.)
265
call adj_chkierr(ierr)
267
ierr = adj_record_variable(adjointer, fwd_var, storage)
268
call adj_chkierr(ierr)
269
call deallocate(vfield_soln)
271
FLAbort("Unknown rhs%klass")
274
! Destroy lhs and rhs
275
call femtools_vec_destroy_proc(rhs)
276
if (lhs%klass /= ADJ_IDENTITY_MATRIX) then
277
call femtools_mat_destroy_proc(lhs)
280
if (sig_int .or. sig_hup) then
281
ewrite(-1,*) "Forward timeloop received signal, quitting"
284
end do ! end of the equation loop
286
call set_prescribed_field_values(state, exclude_interpolated=.true., exclude_nonreprescribed=.true., time=current_time)
287
call calculate_diagnostic_variables(state, exclude_nonrecalculated = .true.)
288
call calculate_diagnostic_variables_new(state, exclude_nonrecalculated = .true.)
289
! The first timestep is the initialisation of the model.
290
! We skip the evaluation of the functional at timestep zero to get the correct value.
291
if (timestep > 0) then
292
call calculate_functional_values(timestep-1)
293
! The last timestep is a the dummy timestep added at the end to act
294
! as a container for the last equation
295
if (start_time == end_time) then
296
assert(timestep == no_timesteps-1)
299
call write_diagnostics(state, current_time, dt, equation+1)
300
if (do_write_state(current_time, timestep)) then
301
call write_state(dump_no, state)
304
current_time = end_time
306
nfunctionals = option_count("/adjoint/functional")
307
do j=0,nfunctionals-1
308
call get_option("/adjoint/functional[" // int2str(j) // "]/name", functional_name)
309
if (timestep == 0) then
310
call adj_record_anything_necessary(adjointer, python_timestep=1, timestep_to_record=0, functional=trim(functional_name), states=state)
312
call adj_record_anything_necessary(adjointer, python_timestep=timestep, timestep_to_record=timestep, functional=trim(functional_name), states=state)
315
end do ! end of the timestep loop
317
call get_option("/timestepping/finish_time", finish_time)
318
assert(current_time == finish_time)
319
end subroutine compute_forward
321
subroutine register_functional_callbacks()
322
integer :: no_functionals, functional
323
character(len=ADJ_NAME_LEN) :: functional_name
326
no_functionals = option_count("/adjoint/functional")
327
do functional=0,no_functionals-1
328
if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
329
call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
330
! Register the callback to compute J
331
ierr = adj_register_functional_callback(adjointer, trim(functional_name), c_funloc(libadjoint_evaluate_functional))
332
call adj_chkierr(ierr)
335
end subroutine register_functional_callbacks
337
subroutine calculate_functional_values(timestep)
338
integer, intent(in) :: timestep
339
integer :: functional, no_functionals
340
character(len=OPTION_PATH_LEN) :: functional_name
341
real, dimension(:), pointer :: fn_value
345
no_functionals = option_count("/adjoint/functional")
346
do functional=0,no_functionals-1
347
if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
348
call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
349
ierr = adj_evaluate_functional(adjointer, timestep, functional_name, J)
350
call adj_chkierr(ierr)
351
! So we've computed the component of the functional associated with this timestep.
352
! We also want to sum them all up ...
353
call set_diagnostic(name=trim(functional_name) // "_component", statistic="value", value=(/J/))
354
fn_value => get_diagnostic(name=trim(functional_name), statistic="value")
356
call set_diagnostic(name=trim(functional_name), statistic="value", value=(/J/))
359
end subroutine calculate_functional_values
363
! Register a diagnostic variable for each functional.
364
subroutine forward_main_loop_register_diagnostic
365
integer :: functional, no_functionals
366
character(len=OPTION_PATH_LEN) :: functional_name
369
no_functionals = option_count("/adjoint/functional")
371
do functional=0,no_functionals-1
372
! Register a diagnostic for each functional
373
if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
374
call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
375
call register_diagnostic(dim=1, name=trim(functional_name) // "_component", statistic="value")
376
call register_diagnostic(dim=1, name=trim(functional_name), statistic="value")
377
! The functional value will be accumulated, so initialise it with zero.
378
call set_diagnostic(name=trim(functional_name) // "_component", statistic="value", value=(/0.0/))
379
call set_diagnostic(name=trim(functional_name), statistic="value", value=(/0.0/))
383
end subroutine forward_main_loop_register_diagnostic
385
end module forward_main_loop