~libadjoint/libadjoint/dolfin_predictability

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import dolfin
import copy
import collections

dolfin_assemble = dolfin.assemble
def assemble(*args, **kwargs):
  form = args[0]
  output = dolfin_assemble(*args, **kwargs)
  if not isinstance(output, float):
    output.form = form
  return output

bc_apply = dolfin.DirichletBC.apply
def adjoint_bc_apply(self, *args, **kwargs):
  for arg in args:
    if not hasattr(arg, 'bcs'):
      arg.bcs = []
    arg.bcs.append(self)
  return bc_apply(self, *args, **kwargs)
dolfin.DirichletBC.apply = adjoint_bc_apply

function_vector = dolfin.Function.vector
def adjoint_function_vector(self):
  vec = function_vector(self)
  vec.function = self
  return vec
dolfin.Function.vector = adjoint_function_vector

def assemble_system(*args, **kwargs):
  lhs = args[0]
  rhs = args[1]
  bcs = args[2]

  if not isinstance(bcs, list):
    bcs = [bcs]

  (lhs_out, rhs_out) = dolfin.assemble_system(*args, **kwargs)
  lhs_out.form = lhs
  lhs_out.bcs = bcs
  rhs_out.form = rhs
  rhs_out.bcs = bcs
  return (lhs_out, rhs_out)