dolfinx-adjoint 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,37 @@
1
+ """Top-level package for dxa."""
2
+
3
+ from importlib.metadata import metadata
4
+
5
+ # Start annotation at import
6
+ import pyadjoint as _pyad
7
+
8
+ from .assembly import assemble_scalar, error_norm
9
+ from .solvers import LinearProblem, NonlinearProblem
10
+ from .types import Constant, Function
11
+ from .types.function import assign
12
+
13
+ meta = metadata("dolfinx_adjoint")
14
+ __version__ = meta.get("Version")
15
+ __license__ = meta.get("License")
16
+ __author__ = meta.get("Author")
17
+ __email__ = meta.get("Author-email")
18
+ __program_name__ = meta.get("Name")
19
+
20
+
21
+ _pyad.set_working_tape(_pyad.Tape())
22
+ _pyad.continue_annotation()
23
+
24
+ __all__ = [
25
+ "Constant",
26
+ "Function",
27
+ "LinearProblem",
28
+ "NonlinearProblem",
29
+ "assemble_scalar",
30
+ "assign",
31
+ "error_norm",
32
+ "__version__",
33
+ "__author__",
34
+ "__license__",
35
+ "__email__",
36
+ "__program_name__",
37
+ ]
@@ -0,0 +1,93 @@
1
+ import typing
2
+
3
+ from mpi4py import MPI
4
+
5
+ import dolfinx
6
+ import numpy
7
+ import numpy.typing as npt
8
+ import ufl
9
+ from pyadjoint.overloaded_type import create_overloaded_object
10
+ from pyadjoint.tape import annotate_tape, get_working_tape, stop_annotating
11
+
12
+ from .blocks.assembly import AssembleBlock
13
+
14
+
15
+ def assemble_scalar(form: ufl.Form, **kwargs):
16
+ """Assemble as scalar value from a form.
17
+
18
+ Args:
19
+ form: Symbolic form (UFL) to assemble.
20
+ kwargs: Keyword arguments to pass to the assembly routine.
21
+ Includes ``"ad_block_tag"`` to tag the block in the adjoint tape,
22
+ ``"annotate"`` to control whether the assembly is annotated in the adjoint tape,
23
+ ``"jit_options"`` for JIT compilation options,
24
+ and ``"form_compiler_options"`` for form compiler options and ``"entity_map"`` for assembling with Arguments
25
+ and coefficients form meshes that has some relation.
26
+ """
27
+ ad_block_tag = kwargs.pop("ad_block_tag", None)
28
+
29
+ annotate = annotate_tape(kwargs)
30
+ with stop_annotating():
31
+ compiled_form = dolfinx.fem.form(
32
+ form,
33
+ jit_options=kwargs.pop("jit_options", None),
34
+ form_compiler_options=kwargs.pop("form_compiler_options", None),
35
+ entity_maps=kwargs.pop("entity_maps", None),
36
+ )
37
+
38
+ local_output = dolfinx.fem.assemble_scalar(compiled_form)
39
+ comm = compiled_form.mesh.comm
40
+ output = comm.allreduce(local_output, op=MPI.SUM)
41
+ assert isinstance(output, float)
42
+
43
+ output = create_overloaded_object(output)
44
+
45
+ if annotate:
46
+ block = AssembleBlock(form, ad_block_tag=ad_block_tag)
47
+
48
+ tape = get_working_tape()
49
+ tape.add_block(block)
50
+
51
+ block.add_output(output.block_variable)
52
+
53
+ return output
54
+
55
+
56
+ def error_norm(
57
+ u_ex: ufl.core.expr.Expr,
58
+ u: ufl.core.expr.Expr,
59
+ norm_type=typing.Literal["L2", "H1"],
60
+ jit_options: typing.Optional[dict] = None,
61
+ form_compiler_options: typing.Optional[dict] = None,
62
+ entity_map: typing.Optional[dict[dolfinx.mesh.Mesh, npt.NDArray[numpy.int32]]] = None,
63
+ ad_block_tag: typing.Optional[str] = None,
64
+ annotate: bool = True,
65
+ ) -> float:
66
+ """Compute the error norm between the exact solution and the computed solution.
67
+
68
+ Args:
69
+ u_ex: The exact solution as a UFL expression.
70
+ u: The computed solution as a UFL expression.
71
+ norm_type: The type of norm to compute, either "L2" or "H1".
72
+ jit_options: Optional JIT compilation options.
73
+ form_compiler_options: Optional form compiler options.
74
+ entity_map: Optional mapping from mesh entities to submesh entities.
75
+ ad_block_tag: Optional tag for the block in the adjoint tape.
76
+ annotate: Whether to annotate the assignment in the adjoint tape.
77
+ Returns:
78
+ The computed error norm as a float.
79
+ """
80
+ diff = u_ex - u
81
+ norm = ufl.inner(diff, diff) * ufl.dx
82
+ if norm_type == "H1":
83
+ norm += ufl.inner(ufl.grad(diff), ufl.grad(diff)) * ufl.dx
84
+ return numpy.sqrt(
85
+ assemble_scalar(
86
+ norm,
87
+ jit_options=jit_options,
88
+ form_compiler_options=form_compiler_options,
89
+ entity_maps=entity_map,
90
+ ad_block_tag=ad_block_tag,
91
+ annotate=annotate,
92
+ )
93
+ )
@@ -0,0 +1,7 @@
1
+ from .assembly import AssembleBlock
2
+ from .function_assigner import FunctionAssignBlock
3
+
4
+ __all__ = [
5
+ "AssembleBlock",
6
+ "FunctionAssignBlock",
7
+ ]
@@ -0,0 +1,77 @@
1
+ import dolfinx
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+
5
+
6
+ class _SpecialVector(dolfinx.la.Vector):
7
+ """Workaround adding __iadd__ to `dolfinx.la.Vector`."""
8
+
9
+ def __init__(self, x, function_space: dolfinx.fem.FunctionSpace):
10
+ super().__init__(x)
11
+ self._function_space = function_space
12
+
13
+ def __iadd__(self, other):
14
+ self.array[:] += other.array[:]
15
+ return self
16
+
17
+ @property
18
+ def function_space(self) -> dolfinx.fem.FunctionSpace:
19
+ return self._function_space
20
+
21
+ @property
22
+ def x(self):
23
+ return self
24
+
25
+ @property
26
+ def name(self):
27
+ return "SpecialVector"
28
+
29
+
30
+ def _vector(
31
+ map, bs: int, function_space: dolfinx.fem.FunctionSpace, dtype: npt.DTypeLike = np.float64
32
+ ) -> _SpecialVector:
33
+ """Create a distributed vector.
34
+
35
+ Args:
36
+ map: Index map the describes the size and distribution of the
37
+ vector.
38
+ bs: Block size.
39
+ dtype: The scalar type.
40
+
41
+ Returns:
42
+ A distributed vector.
43
+ """
44
+ if np.issubdtype(dtype, np.float32):
45
+ vtype = dolfinx.cpp.la.Vector_float32
46
+ elif np.issubdtype(dtype, np.float64):
47
+ vtype = dolfinx.cpp.la.Vector_float64
48
+ elif np.issubdtype(dtype, np.complex64):
49
+ vtype = dolfinx.cpp.la.Vector_complex64
50
+ elif np.issubdtype(dtype, np.complex128):
51
+ vtype = dolfinx.cpp.la.Vector_complex128
52
+ elif np.issubdtype(dtype, np.int8):
53
+ vtype = dolfinx.cpp.la.Vector_int8
54
+ elif np.issubdtype(dtype, np.int32):
55
+ vtype = dolfinx.cpp.la.Vector_int32
56
+ elif np.issubdtype(dtype, np.int64):
57
+ vtype = dolfinx.cpp.la.Vector_int64
58
+ else:
59
+ raise NotImplementedError(f"Type {dtype} not supported.")
60
+
61
+ return _SpecialVector(vtype(map, bs), function_space)
62
+
63
+
64
+ def _create_vector(L: dolfinx.fem.Form, space: dolfinx.fem.FunctionSpace) -> _SpecialVector:
65
+ """Create a Vector that is compatible with a given linear form.
66
+
67
+ Args:
68
+ L: A linear form.
69
+
70
+ Returns:
71
+ A vector that the form can be assembled into.
72
+ """
73
+ # Can just take the first dofmap here, since all dof maps have the same
74
+ # index map in mixed-topology meshes
75
+ dofmap = L.function_spaces[0].dofmaps(0) # type: ignore
76
+ assert space._cpp_object == L.function_spaces[0], "Function space mismatch when creating vector."
77
+ return _vector(dofmap.index_map, dofmap.index_map_bs, dtype=L.dtype, function_space=space)
@@ -0,0 +1,322 @@
1
+ import typing
2
+
3
+ from mpi4py import MPI
4
+
5
+ import dolfinx
6
+ import ufl
7
+ from pyadjoint import Block, create_overloaded_object
8
+ from ufl.formatting.ufl2unicode import ufl2unicode
9
+
10
+ from ._vector import _create_vector, _SpecialVector, _vector # noqa: F401
11
+
12
+
13
+ def assemble_compiled_form(
14
+ form: dolfinx.fem.Form, tensor: typing.Optional[typing.Union[dolfinx.la.Vector, _SpecialVector | float]] = None
15
+ ) -> typing.Union[dolfinx.la.Vector, _SpecialVector, float]:
16
+ """Assemble a compiled form and optionally apply Dirichlet boundary condition.
17
+
18
+ Args:
19
+ form: Compiled form to assemble.
20
+ tensor: Optional vector to which the assembled form will be added.
21
+ Returns:
22
+ tensor: The assembled vector, which is either the input tensor or a new vector
23
+ created from the form's function space(s).
24
+ Raises:
25
+ NotImplementedError: If the form's rank is not 0 or 1.
26
+ """
27
+
28
+ if form.rank == 1:
29
+ assert isinstance(tensor, dolfinx.la.Vector)
30
+ dolfinx.fem.assemble._assemble_vector_array(tensor.array, form)
31
+ tensor.scatter_reverse(dolfinx.la.InsertMode.add)
32
+ tensor.scatter_forward()
33
+ elif form.rank == 0:
34
+ local_val = dolfinx.fem.assemble_scalar(form)
35
+ comm = form.mesh.comm
36
+ tensor = comm.allreduce(local_val, op=MPI.SUM)
37
+
38
+ else:
39
+ raise NotImplementedError("Only 1-form assembly is currently supported.")
40
+ return tensor
41
+
42
+
43
+ class AssembleBlock(Block):
44
+ """Block for assembling a symbolic UFL form into a tensor.
45
+
46
+ Args:
47
+ form: The UFL form to assemble.
48
+ ad_block_tag: Tag for the block in the adjoint tape.
49
+ jit_options: Dictionary of options for JIT compilation.
50
+ form_compiler_options: Dictionary of options for the form compiler.
51
+ entity_maps: Dictionary mapping meshes to entity maps for assembly.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ form: ufl.Form,
57
+ ad_block_tag: typing.Optional[str] = None,
58
+ jit_options: typing.Optional[dict] = None,
59
+ form_compiler_options: typing.Optional[dict] = None,
60
+ entity_maps: typing.Optional[typing.Sequence[dolfinx.mesh.EntityMap]] = None,
61
+ ):
62
+ super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag)
63
+
64
+ # Store the options for code generation
65
+ self._jit_options = jit_options
66
+ self._form_compiler_options = form_compiler_options
67
+ self._entity_maps = entity_maps
68
+
69
+ # Store compiled and original form
70
+ self.form = form
71
+ self.compiled_form = dolfinx.fem.form(
72
+ form, jit_options=jit_options, form_compiler_options=form_compiler_options, entity_maps=entity_maps
73
+ )
74
+
75
+ # NOTE: Add when we want to do shape optimization
76
+ # mesh = self.form.ufl_domain().ufl_cargo()
77
+ # self.add_dependency(mesh)
78
+ for coefficient in self.form.coefficients():
79
+ self.add_dependency(coefficient, no_duplicates=True)
80
+ # Set up cache for vectors that can be reused in adjoint action
81
+ # self._cached_vectors: dict[int, _SpecialVector] = {}
82
+
83
+ def __str__(self):
84
+ return f"assemble({ufl2unicode(self.form)})"
85
+
86
+ def compute_action_adjoint(
87
+ self,
88
+ adj_input: typing.Union[float, dolfinx.la.Vector],
89
+ arity_form: int,
90
+ form: typing.Optional[ufl.Form] = None,
91
+ c_rep: typing.Optional[typing.Union[ufl.Coefficient, ufl.Constant]] = None,
92
+ space: typing.Optional[dolfinx.fem.FunctionSpace] = None,
93
+ dform: typing.Optional[dolfinx.fem.Form] = None,
94
+ ):
95
+ """This computes the action of the adjoint of the derivative of `form` wrt `c_rep` on `adj_input`.
96
+
97
+ In other words, it returns:
98
+
99
+ .. math::
100
+
101
+ \\left\\langle\\left(\\frac{\\partial form}{\\partial c_{rep}}\\right)^*, adj_{input} \\right\\rangle
102
+
103
+ - If `form` has arity 0, then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 1-form
104
+ and `adj_input` a float, we can simply use the `*` operator.
105
+
106
+ - If `form` has arity 1 then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 2-form
107
+ and we can symbolically take its adjoint and then apply the action on `adj_input`, to finally
108
+ assemble the result.
109
+
110
+ Args:
111
+ adj_input: The input to the adjoint operation, typically a scalar or vector.
112
+ arity_form: The arity of the form, i.e., 0 for scalar, 1 for vector, 2 for matrix etc.
113
+ form: The UFL form to differentiate if `dform` is not provided.
114
+ c_rep: The coefficient or constant with respect to which the derivative is taken.
115
+ space: The function space associated with the `c_rep` to form an `ufl.Argument` in.
116
+ dform: Pre-computed derivative form, :math:`\\frac{\\partial form}{\\partial c_{rep}}`.
117
+ """
118
+ if arity_form == 0:
119
+ assert arity_form == self.compiled_form.rank, "Inconsistent arity of input form and block form."
120
+ if dform is None:
121
+ assert space is not None
122
+ dc = ufl.TestFunction(space)
123
+ dform = ufl.derivative(form, c_rep, dc)
124
+
125
+ assert isinstance(dform, ufl.Form), "dform must be a UFL form."
126
+ compiled_adjoint = dolfinx.fem.form(
127
+ dform,
128
+ jit_options=self._jit_options,
129
+ form_compiler_options=self._form_compiler_options,
130
+ entity_maps=self._entity_maps,
131
+ )
132
+ if space is None:
133
+ # If space is not supplied infer it from the form
134
+ assert len(dform.arguments()) == 1
135
+ space = dform.arguments()[0].ufl_function_space()
136
+ # self._cached_vectors[id(space)] = _create_vector(compiled_adjoint)
137
+ vector = _create_vector(compiled_adjoint, space)
138
+ vector.array[:] = 0.0
139
+ # elif self._cached_vectors.get(id(space)) is None:
140
+ # Create a new vector for this space
141
+ # self._cached_vectors[id(space)] = _create_vector(compiled_adjoint)
142
+ # self._cached_vectors[id(space)].array[:] = 0.0
143
+ # assemble_compiled_form(compiled_adjoint, self._cached_vectors[id(space)])
144
+ assemble_compiled_form(compiled_adjoint, vector)
145
+ # return a vector scaled by the scalar `adj_input`
146
+ vector.array[:] *= vector.x.array.dtype.type(adj_input)
147
+ vector.scatter_forward()
148
+
149
+ return vector, dform
150
+ # Return a Vector scaled by the scalar `adj_input`
151
+ # self._cached_vectors[id(space)].array[:] *= adj_input
152
+ # self._cached_vectors[id(space)].scatter_forward()
153
+ # return self._cached_vectors[id(space)], dform
154
+ # elif arity_form == 1:
155
+ # if dform is None:
156
+ # dc = dolfin.TrialFunction(space)
157
+ # dform = dolfin.derivative(form, c_rep, dc)
158
+ # # Get the Function
159
+ # adj_input = adj_input.function
160
+ # # Symbolic operators such as action/adjoint require derivatives to have been expanded beforehand.
161
+ # # However, UFL doesn't support expanding coordinate derivatives of Coefficients in physical space,
162
+ # # implying that we can't symbolically take the action/adjoint of the Jacobian for SpatialCoordinates.
163
+ # # -> Workaround: Apply action/adjoint numerically (using PETSc).
164
+ # if not isinstance(c_rep, dolfin.SpatialCoordinate):
165
+ # # Symbolically compute: (dform/dc_rep)^* * adj_input
166
+ # adj_output = dolfin.action(dolfin.adjoint(dform), adj_input)
167
+ # adj_output = assemble_adjoint_value(adj_output)
168
+ # else:
169
+ # # Get PETSc matrix
170
+ # dform_mat = assemble_adjoint_value(dform).petscmat
171
+ # # Action of the adjoint (Hermitian transpose)
172
+ # adj_output = dolfin.Function(space)
173
+ # with adj_input.dat.vec_ro as v_vec:
174
+ # with adj_output.dat.vec as res_vec:
175
+ # dform_mat.multHermitian(v_vec, res_vec)
176
+ # return adj_output, dform
177
+ else:
178
+ raise ValueError("Forms with arity > 1 are not handled yet!")
179
+
180
+ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
181
+ replaced_coeffs = {}
182
+ for block_variable in self.get_dependencies():
183
+ coeff = block_variable.output
184
+ c_rep = block_variable.saved_output
185
+ if coeff in self.form.coefficients():
186
+ replaced_coeffs[coeff] = c_rep
187
+
188
+ form = ufl.replace(self.form, replaced_coeffs)
189
+ return form
190
+
191
+ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
192
+ form = prepared
193
+ adj_input = adj_inputs[0]
194
+ c = block_variable.output
195
+ c_rep = block_variable.saved_output
196
+
197
+ from ufl.algorithms.analysis import extract_arguments
198
+
199
+ arity_form = len(extract_arguments(form))
200
+
201
+ # if isinstance(c, dolfin.Constant):
202
+ # mesh = extract_mesh_from_form(self.form)
203
+ # space = c._ad_function_space(mesh)
204
+ if isinstance(c, dolfinx.fem.Function):
205
+ space = c.function_space
206
+ # elif isinstance(c, dolfin.Mesh):
207
+ # c_rep = dolfin.SpatialCoordinate(c_rep)
208
+ # space = c._ad_function_space()
209
+
210
+ return self.compute_action_adjoint(adj_input, arity_form, form, c_rep, space)[0]
211
+
212
+ def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs):
213
+ return self.prepare_evaluate_adj(inputs, tlm_inputs, self.get_dependencies())
214
+
215
+ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
216
+ form = prepared
217
+ dform = 0.0
218
+
219
+ from ufl.algorithms.analysis import extract_arguments
220
+
221
+ arity_form = len(extract_arguments(form))
222
+ for bv in self.get_dependencies():
223
+ c_rep = bv.saved_output
224
+ tlm_value = bv.tlm_value
225
+ if tlm_value is None:
226
+ continue
227
+ if isinstance(c_rep, dolfinx.mesh.Mesh):
228
+ X = ufl.SpatialCoordinate(c_rep)
229
+ dform += ufl.derivative(form, X, tlm_value)
230
+ else:
231
+ dform += ufl.derivative(form, c_rep, tlm_value)
232
+ if not isinstance(dform, float):
233
+ dform = ufl.algorithms.expand_derivatives(dform)
234
+ compiled_form = dolfinx.fem.form(
235
+ dform,
236
+ jit_options=self._jit_options,
237
+ form_compiler_options=self._form_compiler_options,
238
+ entity_maps=self._entity_maps,
239
+ )
240
+ dform = assemble_compiled_form(compiled_form)
241
+ if arity_form == 1 and dform != 0:
242
+ # Then dform is a Vector
243
+ dform = dform.function
244
+ return dform
245
+
246
+ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies):
247
+ return self.prepare_evaluate_adj(inputs, adj_inputs, relevant_dependencies)
248
+
249
+ def evaluate_hessian_component(
250
+ self,
251
+ inputs,
252
+ hessian_inputs,
253
+ adj_inputs,
254
+ block_variable,
255
+ idx,
256
+ relevant_dependencies,
257
+ prepared=None,
258
+ ):
259
+ form = prepared
260
+ hessian_input = hessian_inputs[0]
261
+ adj_input = adj_inputs[0]
262
+
263
+ from ufl.algorithms.analysis import extract_arguments
264
+
265
+ arity_form = len(extract_arguments(form))
266
+
267
+ c1 = block_variable.output
268
+ c1_rep = block_variable.saved_output
269
+
270
+ if isinstance(c1, dolfinx.fem.Constant):
271
+ mesh = form.ufl_domain()
272
+ space = c1._ad_function_space(mesh)
273
+ elif isinstance(c1, dolfinx.fem.Function):
274
+ space = c1.function_space
275
+ elif isinstance(c1, dolfinx.mesh.Mesh):
276
+ c1_rep = ufl.SpatialCoordinate(c1)
277
+ space = c1._ad_function_space()
278
+ else:
279
+ return None
280
+ hessian_outputs, dform = self.compute_action_adjoint(hessian_input, arity_form, form, c1_rep, space)
281
+ ddform = 0
282
+ for other_idx, bv in relevant_dependencies:
283
+ c2_rep = bv.saved_output
284
+ tlm_input = bv.tlm_value
285
+
286
+ if tlm_input is None:
287
+ continue
288
+
289
+ if isinstance(c2_rep, dolfinx.mesh.Mesh):
290
+ X = ufl.SpatialCoordinate(c2_rep)
291
+ ddform += ufl.derivative(dform, X, tlm_input)
292
+ else:
293
+ ddform += ufl.derivative(dform, c2_rep, tlm_input)
294
+ if not isinstance(ddform, float):
295
+ ddform = ufl.algorithms.expand_derivatives(ddform)
296
+
297
+ if not ddform.empty():
298
+ # FIXME: COmpare ddform with legacy dolfin_adjoitn here, as this is DG-0, while hessian is in DG-0
299
+ adj_action = self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0]
300
+ try:
301
+ hessian_outputs += adj_action
302
+ except TypeError:
303
+ hessian_outputs.array[:] += adj_action.array[:]
304
+ return hessian_outputs
305
+
306
+ def prepare_recompute_component(self, inputs, relevant_outputs):
307
+ return self.prepare_evaluate_adj(inputs, None, None)
308
+
309
+ def recompute_component(self, inputs, block_variable, idx, prepared):
310
+ form = prepared
311
+
312
+ compiled_form = dolfinx.fem.form(
313
+ form,
314
+ jit_options=self._jit_options,
315
+ form_compiler_options=self._form_compiler_options,
316
+ entity_maps=self._entity_maps,
317
+ )
318
+ local_output = dolfinx.fem.assemble_scalar(compiled_form)
319
+ comm = compiled_form.mesh.comm
320
+ output = comm.allreduce(local_output, op=MPI.SUM)
321
+ output = create_overloaded_object(output)
322
+ return output