~airpollution/fluidity/fluidity_airpollution

« back to all changes in this revision

Viewing changes to adjoint/Forward_Main_Loop.F90

  • Committer: ziyouzhj
  • Date: 2013-12-09 16:51:29 UTC
  • Revision ID: ziyouzhj@gmail.com-20131209165129-ucoetc3u0atyy05c
airpolution

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
!    Copyright (C) 2006 Imperial College London and others.
 
2
!    
 
3
!    Please see the AUTHORS file in the main source directory for a full list
 
4
!    of copyright holders.
 
5
!
 
6
!    Prof. C Pain
 
7
!    Applied Modelling and Computation Group
 
8
!    Department of Earth Science and Engineering
 
9
!    Imperial College London
 
10
!
 
11
!    amcgsoftware@imperial.ac.uk
 
12
!    
 
13
!    This library is free software; you can redistribute it and/or
 
14
!    modify it under the terms of the GNU Lesser General Public
 
15
!    License as published by the Free Software Foundation; either
 
16
!    version 2.1 of the License, or (at your option) any later version.
 
17
!
 
18
!    This library is distributed in the hope that it will be useful,
 
19
!    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
20
!    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 
21
!    Lesser General Public License for more details.
 
22
!
 
23
!    You should have received a copy of the GNU Lesser General Public
 
24
!    License along with this library; if not, write to the Free Software
 
25
!    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
 
26
!    USA
 
27
#include "fdebug.h"
 
28
 
 
29
module forward_main_loop
 
30
#ifdef HAVE_ADJOINT
 
31
#include "libadjoint/adj_fortran.h"
 
32
    use libadjoint
 
33
    use libadjoint_data_callbacks
 
34
#endif
 
35
    use state_module
 
36
    use diagnostic_variables
 
37
    use fields
 
38
    use spud
 
39
    use global_parameters, only: OPTION_PATH_LEN, running_adjoint
 
40
    use adjoint_global_variables
 
41
    use adjoint_functional_evaluation
 
42
    use populate_state_module
 
43
    use boundary_conditions_from_options
 
44
    use signal_vars
 
45
    use mangle_options_tree
 
46
    use mangle_dirichlet_rows_module
 
47
    use sparse_tools
 
48
    use sparse_matrices_fields
 
49
    use write_state_module
 
50
    use diagnostic_fields_wrapper
 
51
    use solvers
 
52
    use diagnostic_fields_new, only: calculate_diagnostic_variables_new => calculate_diagnostic_variables, & 
 
53
                                   & check_diagnostic_dependencies
 
54
    implicit none
 
55
 
 
56
    private
 
57
    public :: forward_main_loop_register_diagnostic
 
58
#ifdef HAVE_ADJOINT
 
59
    public :: compute_forward, calculate_functional_values, register_functional_callbacks
 
60
#endif
 
61
 
 
62
    contains
 
63
#ifdef HAVE_ADJOINT
 
64
 
 
65
    subroutine compute_forward(state)
 
66
      type(state_type), dimension(:), intent(inout) :: state
 
67
 
 
68
      type(adj_vector) :: rhs
 
69
      type(adj_vector) :: soln
 
70
      type(adj_matrix) :: lhs
 
71
      type(adj_variable) :: fwd_var
 
72
 
 
73
      integer :: equation
 
74
      integer :: ierr
 
75
      integer :: s_idx
 
76
      integer :: dump_no
 
77
 
 
78
      real :: finish_time, dt
 
79
      integer :: end_timestep, start_timestep, no_timesteps, timestep
 
80
      real :: start_time, end_time
 
81
 
 
82
      character(len=OPTION_PATH_LEN) :: simulation_base_name, functional_name
 
83
      type(stat_type) :: new_stat
 
84
 
 
85
      character(len=ADJ_NAME_LEN) :: variable_name, field_name, material_phase_name
 
86
      type(scalar_field) :: sfield_soln, sfield_rhs
 
87
      type(vector_field) :: vfield_soln, vfield_rhs
 
88
      type(csr_matrix) :: csr_mat
 
89
      type(block_csr_matrix) :: block_csr_mat
 
90
      character(len=ADJ_DICT_LEN) :: path
 
91
      type(adj_storage_data) :: storage
 
92
      logical :: has_path
 
93
      integer :: nfunctionals
 
94
      integer :: j
 
95
 
 
96
      ierr = adj_adjointer_check_consistency(adjointer)
 
97
      call adj_chkierr(ierr)
 
98
 
 
99
      call get_option("/timestepping/timestep", dt)
 
100
      call get_option("/simulation_name", simulation_base_name)
 
101
      running_adjoint = .false.
 
102
 
 
103
      ! Switch the html output on if you are interested what the adjointer has registered
 
104
      if (have_option("/adjoint/debug/html_output")) then
 
105
        ierr = adj_adjointer_to_html(adjointer, "adjointer_forward.html", ADJ_FORWARD)
 
106
        call adj_chkierr(ierr)
 
107
        ierr = adj_adjointer_to_html(adjointer, "adjointer_adjoint.html", ADJ_ADJOINT)
 
108
        call adj_chkierr(ierr)
 
109
      end if
 
110
 
 
111
      default_stat = new_stat
 
112
      call initialise_walltime
 
113
      call initialise_diagnostics(trim(simulation_base_name) // '_forward', state)
 
114
 
 
115
      ierr = adj_timestep_count(adjointer, no_timesteps)
 
116
      call adj_chkierr(ierr)
 
117
 
 
118
      dump_no = 0
 
119
      nfunctionals = option_count("/adjoint/functional")
 
120
 
 
121
      do timestep=0,no_timesteps-1
 
122
        ierr = adj_timestep_get_times(adjointer, timestep, start_time, end_time)
 
123
        call adj_chkierr(ierr)
 
124
        current_time = start_time
 
125
        call set_option("/timestepping/current_time", current_time)
 
126
 
 
127
        ierr = adj_timestep_start_equation(adjointer, timestep, start_timestep)
 
128
        call adj_chkierr(ierr)
 
129
 
 
130
        ierr = adj_timestep_end_equation(adjointer, timestep, end_timestep)
 
131
        call adj_chkierr(ierr)
 
132
 
 
133
        do equation=start_timestep,end_timestep
 
134
          ierr = adj_get_forward_equation(adjointer, equation, lhs, rhs, fwd_var)
 
135
          call adj_chkierr(ierr)
 
136
 
 
137
          ! Now solve lhs . adjoint = rhs
 
138
          ierr = adj_variable_get_name(fwd_var, variable_name)
 
139
          s_idx = scan(trim(variable_name), ":")
 
140
          material_phase_name = variable_name(1:s_idx - 1)
 
141
          field_name = variable_name(s_idx + 2:len_trim(variable_name))
 
142
          ierr = adj_dict_find(adj_path_lookup, trim(variable_name), path)
 
143
          if (ierr == ADJ_OK) then
 
144
            has_path = .true.
 
145
          else
 
146
            has_path = .false.
 
147
          end if
 
148
 
 
149
          ! variable_name should be something like Fluid::Velocity 
 
150
          select case(rhs%klass)
 
151
            case(ADJ_SCALAR_FIELD)
 
152
              call field_from_adj_vector(rhs, sfield_rhs)
 
153
              call allocate(sfield_soln, sfield_rhs%mesh, trim(field_name))
 
154
              call zero(sfield_soln)
 
155
 
 
156
              if (has_path) then
 
157
                sfield_soln%option_path = trim(path)
 
158
                ! We need to populate the BC values:
 
159
                call insert(state(1), sfield_soln, trim(sfield_soln%name))
 
160
                call populate_boundary_conditions(state)
 
161
                call set_boundary_conditions_values(state, shift_time=.false.)
 
162
                sfield_soln = extract_scalar_field(state(1), trim(sfield_soln%name))
 
163
              end if
 
164
 
 
165
              select case(lhs%klass)
 
166
                case(ADJ_IDENTITY_MATRIX)
 
167
                  call set(sfield_soln, sfield_rhs)
 
168
                case(ADJ_CSR_MATRIX)
 
169
                  call matrix_from_adj_matrix(lhs, csr_mat)
 
170
                  if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
 
171
                    call mult(sfield_soln, csr_mat, sfield_rhs)
 
172
                  else
 
173
                    if (.not. has_path) then
 
174
                      ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
 
175
                      call adj_chkierr(ierr)
 
176
                    end if
 
177
 
 
178
                    sfield_rhs%bc => sfield_soln%bc
 
179
                    call set_dirichlet_consistent(sfield_rhs)
 
180
                    sfield_rhs%bc => null()
 
181
 
 
182
                    call petsc_solve(sfield_soln, csr_mat, sfield_rhs, option_path=path)
 
183
                    call compute_inactive_rows(sfield_soln, csr_mat, sfield_rhs)
 
184
                  endif
 
185
                case(ADJ_BLOCK_CSR_MATRIX)
 
186
                  FLAbort("Cannot map between scalar fields with a block_csr_matrix .. ")
 
187
                case default
 
188
                  FLAbort("Unknown lhs%klass")
 
189
              end select
 
190
 
 
191
              call insert(state(1), sfield_soln, trim(sfield_soln%name))
 
192
 
 
193
              soln = field_to_adj_vector(sfield_soln)
 
194
              ierr = adj_storage_memory_incref(soln, storage)
 
195
              call adj_chkierr(ierr)
 
196
 
 
197
              ierr = adj_storage_set_compare(storage, .true., 1.0d-10)
 
198
              call adj_chkierr(ierr)
 
199
              ierr = adj_storage_set_overwrite(storage, .true.)
 
200
              call adj_chkierr(ierr)
 
201
 
 
202
              ierr = adj_record_variable(adjointer, fwd_var, storage)
 
203
              call adj_chkierr(ierr)
 
204
              call deallocate(sfield_soln)
 
205
            case(ADJ_VECTOR_FIELD)
 
206
              call field_from_adj_vector(rhs, vfield_rhs)
 
207
              call allocate(vfield_soln, vfield_rhs%dim, vfield_rhs%mesh, trim(field_name))
 
208
              call zero(vfield_soln)
 
209
 
 
210
              if (has_path) then
 
211
                vfield_soln%option_path = trim(path)
 
212
                ! We need to populate the BC values:
 
213
                call insert(state(1), vfield_soln, trim(vfield_soln%name))
 
214
                call populate_boundary_conditions(state)
 
215
                call set_boundary_conditions_values(state, shift_time=.false.)
 
216
                vfield_soln = extract_vector_field(state(1), trim(vfield_soln%name))
 
217
              end if
 
218
 
 
219
              select case(lhs%klass)
 
220
                case(ADJ_IDENTITY_MATRIX)
 
221
                  call set(vfield_soln, vfield_rhs)
 
222
                case(ADJ_CSR_MATRIX)
 
223
                  call matrix_from_adj_matrix(lhs, csr_mat)
 
224
                  if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
 
225
                    call mult(vfield_soln, csr_mat, vfield_rhs)
 
226
                  else
 
227
                    if (.not. has_path) then
 
228
                      ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
 
229
                      call adj_chkierr(ierr)
 
230
                    end if
 
231
 
 
232
                    vfield_rhs%bc => vfield_soln%bc
 
233
                    call set_dirichlet_consistent(vfield_rhs)
 
234
                    vfield_rhs%bc => null()
 
235
 
 
236
                    call petsc_solve(vfield_soln, csr_mat, vfield_rhs, option_path=path)
 
237
                    !call compute_inactive_rows(vfield_soln, csr_mat, vfield_rhs)
 
238
                  endif
 
239
                case(ADJ_BLOCK_CSR_MATRIX)
 
240
                  call matrix_from_adj_matrix(lhs, block_csr_mat)
 
241
                  if (iand(lhs%flags, ADJ_MATRIX_INVERTED) == ADJ_MATRIX_INVERTED) then
 
242
                    call mult(vfield_soln, block_csr_mat, vfield_rhs)
 
243
                  else
 
244
                    if (.not. has_path) then
 
245
                      ierr = adj_dict_find(adj_solver_path_lookup, trim(variable_name), path)
 
246
                      call adj_chkierr(ierr)
 
247
                    end if
 
248
 
 
249
                    call petsc_solve(vfield_soln, block_csr_mat, vfield_rhs, option_path=path)
 
250
                    !call compute_inactive_rows(vfield_soln, block_csr_mat, vfield_rhs)
 
251
                  endif
 
252
                case default
 
253
                  FLAbort("Unknown lhs%klass")
 
254
              end select
 
255
 
 
256
              call insert(state(1), vfield_soln, trim(vfield_soln%name))
 
257
 
 
258
              soln = field_to_adj_vector(vfield_soln)
 
259
              ierr = adj_storage_memory_incref(soln, storage)
 
260
              call adj_chkierr(ierr)
 
261
 
 
262
              ierr = adj_storage_set_compare(storage, .true., 1.0d-10)
 
263
              call adj_chkierr(ierr)
 
264
              ierr = adj_storage_set_overwrite(storage, .true.)
 
265
              call adj_chkierr(ierr)
 
266
 
 
267
              ierr = adj_record_variable(adjointer, fwd_var, storage)
 
268
              call adj_chkierr(ierr)
 
269
              call deallocate(vfield_soln)
 
270
            case default
 
271
              FLAbort("Unknown rhs%klass")
 
272
          end select
 
273
 
 
274
          ! Destroy lhs and rhs
 
275
          call femtools_vec_destroy_proc(rhs)
 
276
          if (lhs%klass /= ADJ_IDENTITY_MATRIX) then
 
277
            call femtools_mat_destroy_proc(lhs)
 
278
          endif
 
279
 
 
280
          if (sig_int .or. sig_hup) then
 
281
            ewrite(-1,*) "Forward timeloop received signal, quitting"
 
282
            return
 
283
          end if
 
284
        end do ! end of the equation loop
 
285
 
 
286
        call set_prescribed_field_values(state, exclude_interpolated=.true., exclude_nonreprescribed=.true., time=current_time)
 
287
        call calculate_diagnostic_variables(state, exclude_nonrecalculated = .true.)
 
288
        call calculate_diagnostic_variables_new(state, exclude_nonrecalculated = .true.)
 
289
        ! The first timestep is the initialisation of the model.
 
290
        ! We skip the evaluation of the functional at timestep zero to get the correct value.
 
291
        if (timestep > 0) then
 
292
          call calculate_functional_values(timestep-1)
 
293
          ! The last timestep is a the dummy timestep added at the end to act
 
294
          ! as a container for the last equation
 
295
          if (start_time == end_time) then
 
296
            assert(timestep == no_timesteps-1)
 
297
          end if
 
298
        end if
 
299
        call write_diagnostics(state, current_time, dt, equation+1)
 
300
        if (do_write_state(current_time, timestep)) then
 
301
          call write_state(dump_no, state)
 
302
        endif
 
303
 
 
304
        current_time = end_time
 
305
 
 
306
        nfunctionals = option_count("/adjoint/functional")
 
307
        do j=0,nfunctionals-1
 
308
          call get_option("/adjoint/functional[" // int2str(j) // "]/name", functional_name)
 
309
          if (timestep == 0) then
 
310
            call adj_record_anything_necessary(adjointer, python_timestep=1, timestep_to_record=0, functional=trim(functional_name), states=state)
 
311
          else
 
312
            call adj_record_anything_necessary(adjointer, python_timestep=timestep, timestep_to_record=timestep, functional=trim(functional_name), states=state)
 
313
          end if
 
314
        end do
 
315
      end do ! end of the timestep loop
 
316
 
 
317
      call get_option("/timestepping/finish_time", finish_time)
 
318
      assert(current_time == finish_time)
 
319
    end subroutine compute_forward
 
320
 
 
321
    subroutine register_functional_callbacks()                                                                                                                                                                 
 
322
      integer :: no_functionals, functional
 
323
      character(len=ADJ_NAME_LEN) :: functional_name
 
324
      integer :: ierr
 
325
 
 
326
      no_functionals = option_count("/adjoint/functional")
 
327
      do functional=0,no_functionals-1
 
328
        if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
 
329
          call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
 
330
          ! Register the callback to compute J
 
331
          ierr = adj_register_functional_callback(adjointer, trim(functional_name), c_funloc(libadjoint_evaluate_functional))
 
332
          call adj_chkierr(ierr)
 
333
        end if
 
334
      end do
 
335
    end subroutine register_functional_callbacks
 
336
 
 
337
    subroutine calculate_functional_values(timestep)
 
338
      integer, intent(in) :: timestep
 
339
      integer :: functional, no_functionals
 
340
      character(len=OPTION_PATH_LEN) :: functional_name
 
341
      real, dimension(:), pointer :: fn_value
 
342
      real :: J
 
343
      integer :: ierr
 
344
 
 
345
      no_functionals = option_count("/adjoint/functional")
 
346
      do functional=0,no_functionals-1
 
347
        if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
 
348
          call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
 
349
          ierr = adj_evaluate_functional(adjointer, timestep, functional_name, J)
 
350
          call adj_chkierr(ierr)
 
351
          ! So we've computed the component of the functional associated with this timestep.
 
352
          ! We also want to sum them all up ...
 
353
          call set_diagnostic(name=trim(functional_name) // "_component", statistic="value", value=(/J/))
 
354
          fn_value => get_diagnostic(name=trim(functional_name), statistic="value")
 
355
          J = J + fn_value(1)
 
356
          call set_diagnostic(name=trim(functional_name), statistic="value", value=(/J/))
 
357
        end if
 
358
      end do
 
359
    end subroutine calculate_functional_values
 
360
  
 
361
#endif
 
362
 
 
363
    ! Register a diagnostic variable for each functional.
 
364
    subroutine forward_main_loop_register_diagnostic
 
365
      integer :: functional, no_functionals
 
366
      character(len=OPTION_PATH_LEN) :: functional_name
 
367
 
 
368
#ifdef HAVE_ADJOINT
 
369
      no_functionals = option_count("/adjoint/functional")
 
370
      
 
371
      do functional=0,no_functionals-1
 
372
        ! Register a diagnostic for each functional 
 
373
        if (have_option("/adjoint/functional[" // int2str(functional) // "]/functional_value")) then
 
374
          call get_option("/adjoint/functional[" // int2str(functional) // "]/name", functional_name)
 
375
          call register_diagnostic(dim=1, name=trim(functional_name) // "_component", statistic="value")
 
376
          call register_diagnostic(dim=1, name=trim(functional_name), statistic="value")
 
377
          ! The functional value will be accumulated, so initialise it with zero.
 
378
          call set_diagnostic(name=trim(functional_name) // "_component", statistic="value", value=(/0.0/))
 
379
          call set_diagnostic(name=trim(functional_name), statistic="value", value=(/0.0/))
 
380
        end if
 
381
      end do
 
382
#endif
 
383
   end subroutine forward_main_loop_register_diagnostic
 
384
   
 
385
end module forward_main_loop