jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_utils.py
CHANGED
|
@@ -6,42 +6,42 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING, Callable,
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, TypeGuard
|
|
10
|
+
from types import EllipsisType
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
13
|
from jax import vmap
|
|
13
|
-
import
|
|
14
|
-
from jaxtyping import Float, Array, PyTree
|
|
14
|
+
from jaxtyping import Float, Array
|
|
15
15
|
|
|
16
16
|
from jinns.loss._boundary_conditions import (
|
|
17
17
|
_compute_boundary_loss,
|
|
18
18
|
)
|
|
19
19
|
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
|
-
from jinns.data.
|
|
20
|
+
from jinns.data._utils import make_cartesian_product
|
|
21
21
|
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
22
22
|
from jinns.nn._pinn import PINN
|
|
23
23
|
from jinns.nn._spinn import SPINN
|
|
24
24
|
from jinns.nn._hyperpinn import HyperPINN
|
|
25
|
-
from jinns.data._Batchs import
|
|
26
|
-
from jinns.parameters._params import Params
|
|
25
|
+
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
26
|
+
from jinns.parameters._params import Params
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from jinns.utils._types import
|
|
29
|
+
from jinns.utils._types import BoundaryConditionFun
|
|
30
|
+
from jinns.nn._abstract_pinn import AbstractPINN
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def dynamic_loss_apply(
|
|
33
|
-
dyn_loss:
|
|
34
|
-
u:
|
|
34
|
+
dyn_loss: Callable,
|
|
35
|
+
u: AbstractPINN,
|
|
35
36
|
batch: (
|
|
36
|
-
Float[Array, "batch_size 1"]
|
|
37
|
-
| Float[Array, "batch_size dim"]
|
|
38
|
-
| Float[Array, "batch_size 1+dim"]
|
|
37
|
+
Float[Array, " batch_size 1"]
|
|
38
|
+
| Float[Array, " batch_size dim"]
|
|
39
|
+
| Float[Array, " batch_size 1+dim"]
|
|
39
40
|
),
|
|
40
|
-
params: Params
|
|
41
|
-
vmap_axes: tuple[int | None
|
|
42
|
-
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
41
|
+
params: Params[Array],
|
|
42
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
43
43
|
u_type: PINN | HyperPINN | None = None,
|
|
44
|
-
) ->
|
|
44
|
+
) -> Float[Array, " "]:
|
|
45
45
|
"""
|
|
46
46
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
47
47
|
its type here, hence the last argument
|
|
@@ -49,16 +49,18 @@ def dynamic_loss_apply(
|
|
|
49
49
|
if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
|
|
50
50
|
v_dyn_loss = vmap(
|
|
51
51
|
lambda batch, params: dyn_loss(
|
|
52
|
-
batch,
|
|
52
|
+
batch,
|
|
53
|
+
u,
|
|
54
|
+
params, # we must place the params at the end
|
|
53
55
|
),
|
|
54
56
|
vmap_axes,
|
|
55
57
|
0,
|
|
56
58
|
)
|
|
57
59
|
residuals = v_dyn_loss(batch, params)
|
|
58
|
-
mse_dyn_loss = jnp.mean(jnp.sum(
|
|
60
|
+
mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
|
|
59
61
|
elif u_type == SPINN or isinstance(u, SPINN):
|
|
60
62
|
residuals = dyn_loss(batch, u, params)
|
|
61
|
-
mse_dyn_loss = jnp.mean(jnp.sum(
|
|
63
|
+
mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
|
|
62
64
|
else:
|
|
63
65
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
64
66
|
|
|
@@ -66,18 +68,17 @@ def dynamic_loss_apply(
|
|
|
66
68
|
|
|
67
69
|
|
|
68
70
|
def normalization_loss_apply(
|
|
69
|
-
u:
|
|
71
|
+
u: AbstractPINN,
|
|
70
72
|
batches: (
|
|
71
|
-
tuple[Float[Array, "nb_norm_samples dim"]]
|
|
73
|
+
tuple[Float[Array, " nb_norm_samples dim"]]
|
|
72
74
|
| tuple[
|
|
73
|
-
Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
|
|
75
|
+
Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
|
|
74
76
|
]
|
|
75
77
|
),
|
|
76
|
-
params: Params
|
|
77
|
-
vmap_axes_params: tuple[int | None
|
|
78
|
-
norm_weights: Float[Array, "nb_norm_samples"],
|
|
79
|
-
|
|
80
|
-
) -> float:
|
|
78
|
+
params: Params[Array],
|
|
79
|
+
vmap_axes_params: tuple[Params[int | None] | None],
|
|
80
|
+
norm_weights: Float[Array, " nb_norm_samples"],
|
|
81
|
+
) -> Float[Array, " "]:
|
|
81
82
|
"""
|
|
82
83
|
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
84
|
they represent probability distributions
|
|
@@ -92,12 +93,10 @@ def normalization_loss_apply(
|
|
|
92
93
|
res = v_u(*batches, params)
|
|
93
94
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
94
95
|
# Monte-Carlo integration using importance sampling
|
|
95
|
-
mse_norm_loss =
|
|
96
|
-
jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
97
|
-
)
|
|
96
|
+
mse_norm_loss = jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
98
97
|
else:
|
|
99
98
|
# NOTE this cartesian product is costly
|
|
100
|
-
|
|
99
|
+
batch_cart_prod = make_cartesian_product(
|
|
101
100
|
batches[0],
|
|
102
101
|
batches[1],
|
|
103
102
|
).reshape(batches[0].shape[0], batches[1].shape[0], -1)
|
|
@@ -108,11 +107,11 @@ def normalization_loss_apply(
|
|
|
108
107
|
),
|
|
109
108
|
in_axes=(0,) + vmap_axes_params,
|
|
110
109
|
)
|
|
111
|
-
res = v_u(
|
|
110
|
+
res = v_u(batch_cart_prod, params)
|
|
112
111
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
113
112
|
# For all times t, we perform an integration. Then we average the
|
|
114
113
|
# losses over times.
|
|
115
|
-
mse_norm_loss =
|
|
114
|
+
mse_norm_loss = jnp.mean(
|
|
116
115
|
jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
|
|
117
116
|
)
|
|
118
117
|
elif isinstance(u, SPINN):
|
|
@@ -120,8 +119,7 @@ def normalization_loss_apply(
|
|
|
120
119
|
res = u(*batches, params)
|
|
121
120
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
122
121
|
mse_norm_loss = (
|
|
123
|
-
|
|
124
|
-
* jnp.abs(
|
|
122
|
+
jnp.abs(
|
|
125
123
|
jnp.mean(
|
|
126
124
|
res.squeeze(),
|
|
127
125
|
)
|
|
@@ -141,11 +139,11 @@ def normalization_loss_apply(
|
|
|
141
139
|
)
|
|
142
140
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
143
141
|
# the outer mean() below is for the times stamps
|
|
144
|
-
mse_norm_loss =
|
|
142
|
+
mse_norm_loss = jnp.mean(
|
|
145
143
|
jnp.abs(
|
|
146
144
|
jnp.mean(
|
|
147
145
|
res.squeeze(),
|
|
148
|
-
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
146
|
+
axis=list(d + 1 for d in range(res.ndim - 2)),
|
|
149
147
|
)
|
|
150
148
|
* norm_weights
|
|
151
149
|
- 1
|
|
@@ -159,18 +157,33 @@ def normalization_loss_apply(
|
|
|
159
157
|
|
|
160
158
|
|
|
161
159
|
def boundary_condition_apply(
|
|
162
|
-
u:
|
|
160
|
+
u: AbstractPINN,
|
|
163
161
|
batch: PDEStatioBatch | PDENonStatioBatch,
|
|
164
|
-
params: Params
|
|
165
|
-
omega_boundary_fun:
|
|
166
|
-
omega_boundary_condition: str,
|
|
167
|
-
omega_boundary_dim:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
162
|
+
params: Params[Array],
|
|
163
|
+
omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
|
|
164
|
+
omega_boundary_condition: str | dict[str, str],
|
|
165
|
+
omega_boundary_dim: slice | dict[str, slice],
|
|
166
|
+
) -> Float[Array, " "]:
|
|
167
|
+
assert batch.border_batch is not None
|
|
171
168
|
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
172
169
|
|
|
173
|
-
|
|
170
|
+
def _check_tuple_of_dict(
|
|
171
|
+
val,
|
|
172
|
+
) -> TypeGuard[
|
|
173
|
+
tuple[
|
|
174
|
+
dict[str, BoundaryConditionFun],
|
|
175
|
+
dict[str, BoundaryConditionFun],
|
|
176
|
+
dict[str, BoundaryConditionFun],
|
|
177
|
+
]
|
|
178
|
+
]:
|
|
179
|
+
return all(isinstance(x, dict) for x in val)
|
|
180
|
+
|
|
181
|
+
omega_boundary_dicts = (
|
|
182
|
+
omega_boundary_condition,
|
|
183
|
+
omega_boundary_fun,
|
|
184
|
+
omega_boundary_dim,
|
|
185
|
+
)
|
|
186
|
+
if _check_tuple_of_dict(omega_boundary_dicts):
|
|
174
187
|
# We must create the facet tree dictionary as we do not have the
|
|
175
188
|
# enumerate from the for loop to pass the id integer
|
|
176
189
|
if batch.border_batch.shape[-1] == 2:
|
|
@@ -186,16 +199,13 @@ def boundary_condition_apply(
|
|
|
186
199
|
None
|
|
187
200
|
if c is None
|
|
188
201
|
else jnp.mean(
|
|
189
|
-
|
|
190
|
-
* _compute_boundary_loss(
|
|
191
|
-
c, f, batch, u, params, fa, d, vmap_in_axes
|
|
192
|
-
)
|
|
202
|
+
_compute_boundary_loss(c, f, batch, u, params, fa, d, vmap_in_axes)
|
|
193
203
|
)
|
|
194
204
|
),
|
|
195
|
-
omega_boundary_condition,
|
|
196
|
-
omega_boundary_fun,
|
|
205
|
+
omega_boundary_dicts[0], # omega_boundary_condition,
|
|
206
|
+
omega_boundary_dicts[1], # omega_boundary_fun,
|
|
197
207
|
facet_tree,
|
|
198
|
-
omega_boundary_dim,
|
|
208
|
+
omega_boundary_dicts[2], # omega_boundary_dim,
|
|
199
209
|
is_leaf=lambda x: x is None,
|
|
200
210
|
) # when exploring leaves with None value (no condition) the returned
|
|
201
211
|
# mse is None and we get rid of the None leaves of b_losses_by_facet
|
|
@@ -206,15 +216,14 @@ def boundary_condition_apply(
|
|
|
206
216
|
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
207
217
|
b_losses_by_facet = jax.tree_util.tree_map(
|
|
208
218
|
lambda fa: jnp.mean(
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
omega_boundary_fun,
|
|
219
|
+
_compute_boundary_loss(
|
|
220
|
+
omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
|
|
221
|
+
omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
|
|
213
222
|
batch,
|
|
214
223
|
u,
|
|
215
224
|
params,
|
|
216
225
|
fa,
|
|
217
|
-
|
|
226
|
+
omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
|
|
218
227
|
vmap_in_axes,
|
|
219
228
|
)
|
|
220
229
|
),
|
|
@@ -227,26 +236,23 @@ def boundary_condition_apply(
|
|
|
227
236
|
|
|
228
237
|
|
|
229
238
|
def observations_loss_apply(
|
|
230
|
-
u:
|
|
231
|
-
|
|
232
|
-
params: Params
|
|
233
|
-
vmap_axes: tuple[int | None
|
|
234
|
-
observed_values: Float[Array, "
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
) -> float:
|
|
238
|
-
# TODO implement for SPINN
|
|
239
|
+
u: AbstractPINN,
|
|
240
|
+
batch: Float[Array, " obs_batch_size input_dim"],
|
|
241
|
+
params: Params[Array],
|
|
242
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
243
|
+
observed_values: Float[Array, " obs_batch_size observation_dim"],
|
|
244
|
+
obs_slice: EllipsisType | slice | None,
|
|
245
|
+
) -> Float[Array, " "]:
|
|
239
246
|
if isinstance(u, (PINN, HyperPINN)):
|
|
240
247
|
v_u = vmap(
|
|
241
248
|
lambda *args: u(*args)[u.slice_solution],
|
|
242
249
|
vmap_axes,
|
|
243
250
|
0,
|
|
244
251
|
)
|
|
245
|
-
val = v_u(
|
|
252
|
+
val = v_u(batch, params)[:, obs_slice]
|
|
246
253
|
mse_observation_loss = jnp.mean(
|
|
247
254
|
jnp.sum(
|
|
248
|
-
|
|
249
|
-
* _subtract_with_check(
|
|
255
|
+
_subtract_with_check(
|
|
250
256
|
observed_values, val, cause="user defined observed_values"
|
|
251
257
|
)
|
|
252
258
|
** 2,
|
|
@@ -261,15 +267,15 @@ def observations_loss_apply(
|
|
|
261
267
|
|
|
262
268
|
|
|
263
269
|
def initial_condition_apply(
|
|
264
|
-
u:
|
|
265
|
-
omega_batch: Float[Array, "dimension"],
|
|
266
|
-
params: Params
|
|
267
|
-
vmap_axes: tuple[int | None
|
|
270
|
+
u: AbstractPINN,
|
|
271
|
+
omega_batch: Float[Array, " dimension"],
|
|
272
|
+
params: Params[Array],
|
|
273
|
+
vmap_axes: tuple[int, Params[int | None] | None],
|
|
268
274
|
initial_condition_fun: Callable,
|
|
269
|
-
|
|
270
|
-
) ->
|
|
275
|
+
t0: Float[Array, " 1"],
|
|
276
|
+
) -> Float[Array, " "]:
|
|
271
277
|
n = omega_batch.shape[0]
|
|
272
|
-
t0_omega_batch = jnp.concatenate([jnp.
|
|
278
|
+
t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
|
|
273
279
|
if isinstance(u, (PINN, HyperPINN)):
|
|
274
280
|
v_u_t0 = vmap(
|
|
275
281
|
lambda t0_x, params: _subtract_with_check(
|
|
@@ -285,7 +291,7 @@ def initial_condition_apply(
|
|
|
285
291
|
# dimension as params to be able to vmap.
|
|
286
292
|
# Recall that by convention:
|
|
287
293
|
# param_batch_dict = times_batch_size * omega_batch_size
|
|
288
|
-
mse_initial_condition = jnp.mean(jnp.sum(
|
|
294
|
+
mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
|
|
289
295
|
elif isinstance(u, SPINN):
|
|
290
296
|
values = lambda t_x: u(
|
|
291
297
|
t_x,
|
|
@@ -298,107 +304,7 @@ def initial_condition_apply(
|
|
|
298
304
|
v_ini,
|
|
299
305
|
cause="Output of initial_condition_fun",
|
|
300
306
|
)
|
|
301
|
-
mse_initial_condition = jnp.mean(jnp.sum(
|
|
307
|
+
mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
|
|
302
308
|
else:
|
|
303
309
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
304
310
|
return mse_initial_condition
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
def constraints_system_loss_apply(
|
|
308
|
-
u_constraints_dict: Dict,
|
|
309
|
-
batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
|
|
310
|
-
params_dict: ParamsDict,
|
|
311
|
-
loss_weights: Dict[str, float | Array],
|
|
312
|
-
loss_weight_struct: PyTree,
|
|
313
|
-
):
|
|
314
|
-
"""
|
|
315
|
-
Same function for systemlossODE and systemlossPDE!
|
|
316
|
-
"""
|
|
317
|
-
# Transpose so we have each u_dict as outer structure and the
|
|
318
|
-
# associated loss_weight as inner structure
|
|
319
|
-
loss_weights_T = jax.tree_util.tree_transpose(
|
|
320
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
321
|
-
jax.tree_util.tree_structure(loss_weights["initial_condition"]),
|
|
322
|
-
loss_weights,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
if isinstance(params_dict.nn_params, dict):
|
|
326
|
-
|
|
327
|
-
def apply_u_constraint(
|
|
328
|
-
u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
|
|
329
|
-
):
|
|
330
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
331
|
-
Params(
|
|
332
|
-
nn_params=nn_params,
|
|
333
|
-
eq_params=eq_params,
|
|
334
|
-
),
|
|
335
|
-
append_obs_batch(batch, obs_batch_u),
|
|
336
|
-
)[1]
|
|
337
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
338
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
339
|
-
)
|
|
340
|
-
return res_dict_ponderated
|
|
341
|
-
|
|
342
|
-
# Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
|
|
343
|
-
# with keys corresponding to the PINN and value correspondinf to an
|
|
344
|
-
# original obs_batch_dict. Hence the tree mapping also interates over
|
|
345
|
-
# batch.obs_batch_dict
|
|
346
|
-
res_dict = jax.tree_util.tree_map(
|
|
347
|
-
apply_u_constraint,
|
|
348
|
-
u_constraints_dict,
|
|
349
|
-
params_dict.nn_params,
|
|
350
|
-
(
|
|
351
|
-
params_dict.eq_params
|
|
352
|
-
if params_dict.eq_params.keys() == params_dict.nn_params.keys()
|
|
353
|
-
else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
|
|
354
|
-
), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
|
|
355
|
-
loss_weights_T,
|
|
356
|
-
batch.obs_batch_dict,
|
|
357
|
-
is_leaf=lambda x: (
|
|
358
|
-
not isinstance(x, dict) # to not traverse more than the first
|
|
359
|
-
# outer dict of the pytrees passed to the function. This will
|
|
360
|
-
# work because u_constraints_dict is a dict of Losses, and it
|
|
361
|
-
# thus stops the traversing of other dict too
|
|
362
|
-
),
|
|
363
|
-
)
|
|
364
|
-
# TODO try to get rid of this condition?
|
|
365
|
-
else:
|
|
366
|
-
|
|
367
|
-
def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
|
|
368
|
-
res_dict_for_u = u_constraint.evaluate(
|
|
369
|
-
params_dict,
|
|
370
|
-
append_obs_batch(batch, obs_batch_u),
|
|
371
|
-
)[1]
|
|
372
|
-
res_dict_ponderated = jax.tree_util.tree_map(
|
|
373
|
-
lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
|
|
374
|
-
)
|
|
375
|
-
return res_dict_ponderated
|
|
376
|
-
|
|
377
|
-
res_dict = jax.tree_util.tree_map(
|
|
378
|
-
apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
# Transpose back so we have mses as outer structures and their values
|
|
382
|
-
# for each u_dict as inner structures. The tree_leaves transforms the
|
|
383
|
-
# inner structure into a list so we can catch is as leaf it the
|
|
384
|
-
# tree_map below
|
|
385
|
-
res_dict = jax.tree_util.tree_transpose(
|
|
386
|
-
jax.tree_util.tree_structure(
|
|
387
|
-
jax.tree_util.tree_leaves(loss_weights["initial_condition"])
|
|
388
|
-
),
|
|
389
|
-
jax.tree_util.tree_structure(loss_weight_struct),
|
|
390
|
-
res_dict,
|
|
391
|
-
)
|
|
392
|
-
# For each mse, sum their values on each u_dict
|
|
393
|
-
res_dict = jax.tree_util.tree_map(
|
|
394
|
-
lambda mse: jax.tree_util.tree_reduce(
|
|
395
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
|
|
396
|
-
),
|
|
397
|
-
res_dict,
|
|
398
|
-
is_leaf=lambda x: isinstance(x, list),
|
|
399
|
-
)
|
|
400
|
-
# Total loss
|
|
401
|
-
total_loss = jax.tree_util.tree_reduce(
|
|
402
|
-
lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
|
|
403
|
-
)
|
|
404
|
-
return total_loss, res_dict
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A collection of specific weight update schemes in jinns
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
from jaxtyping import Array, Key
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from jinns.loss._loss_weights import AbstractLossWeights
|
|
14
|
+
from jinns.utils._types import AnyLossComponents
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def soft_adapt(
|
|
18
|
+
loss_weights: AbstractLossWeights,
|
|
19
|
+
iteration_nb: int,
|
|
20
|
+
loss_terms: AnyLossComponents,
|
|
21
|
+
stored_loss_terms: AnyLossComponents,
|
|
22
|
+
) -> Array:
|
|
23
|
+
r"""
|
|
24
|
+
Implement the simple strategy given in
|
|
25
|
+
https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/theory/advanced_schemes.html#softadapt
|
|
26
|
+
|
|
27
|
+
$$
|
|
28
|
+
w_j(i)= \frac{\exp(\frac{L_j(i)}{L_j(i-1)+\epsilon}-\mu(i))}
|
|
29
|
+
{\sum_{k=1}^{n_{loss}}\exp(\frac{L_k(i)}{L_k(i-1)+\epsilon}-\mu(i)}
|
|
30
|
+
$$
|
|
31
|
+
|
|
32
|
+
Note that since None is not treated as a leaf by jax tree.util functions,
|
|
33
|
+
we naturally avoid None components from loss_terms, stored_loss_terms etc.!
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def do_nothing(loss_weights, _, __):
|
|
37
|
+
return jnp.array(
|
|
38
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def soft_adapt_(_, loss_terms, stored_loss_terms):
|
|
42
|
+
ratio_pytree = jax.tree.map(
|
|
43
|
+
lambda lt, slt: lt / (slt[iteration_nb - 1] + 1e-6),
|
|
44
|
+
loss_terms,
|
|
45
|
+
stored_loss_terms,
|
|
46
|
+
)
|
|
47
|
+
mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=jnp.array(-jnp.inf))
|
|
48
|
+
ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
|
|
49
|
+
ratio_leaves = jax.tree.leaves(ratio_pytree)
|
|
50
|
+
return jax.nn.softmax(jnp.array(ratio_leaves))
|
|
51
|
+
|
|
52
|
+
return jax.lax.cond(
|
|
53
|
+
iteration_nb == 0,
|
|
54
|
+
lambda op: do_nothing(*op),
|
|
55
|
+
lambda op: soft_adapt_(*op),
|
|
56
|
+
(loss_weights, loss_terms, stored_loss_terms),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def ReLoBRaLo(
|
|
61
|
+
loss_weights: AbstractLossWeights,
|
|
62
|
+
iteration_nb: int,
|
|
63
|
+
loss_terms: AnyLossComponents,
|
|
64
|
+
stored_loss_terms: AnyLossComponents,
|
|
65
|
+
key: Key,
|
|
66
|
+
decay_factor: float = 0.9,
|
|
67
|
+
tau: float = 1, ## referred to as temperature in the article
|
|
68
|
+
p: float = 0.9,
|
|
69
|
+
):
|
|
70
|
+
r"""
|
|
71
|
+
Implementing the extension of softadapt: Relative Loss Balancing with random LookBack
|
|
72
|
+
"""
|
|
73
|
+
n_loss = len(jax.tree.leaves(loss_terms)) # number of loss terms
|
|
74
|
+
epsilon = 1e-6
|
|
75
|
+
|
|
76
|
+
def do_nothing(loss_weights, _):
|
|
77
|
+
return jnp.array(
|
|
78
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def compute_softmax_weights(current, reference):
|
|
82
|
+
ratio_pytree = jax.tree.map(
|
|
83
|
+
lambda lt, ref: lt / (ref + epsilon),
|
|
84
|
+
current,
|
|
85
|
+
reference,
|
|
86
|
+
)
|
|
87
|
+
mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=-jnp.inf)
|
|
88
|
+
ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
|
|
89
|
+
ratio_leaves = jax.tree.leaves(ratio_pytree)
|
|
90
|
+
return jax.nn.softmax(jnp.array(ratio_leaves))
|
|
91
|
+
|
|
92
|
+
def soft_adapt_prev(stored_loss_terms):
|
|
93
|
+
# ω_j(i-1)
|
|
94
|
+
prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
|
|
95
|
+
prev_prev_terms = jax.tree.map(
|
|
96
|
+
lambda slt: slt[iteration_nb - 2], stored_loss_terms
|
|
97
|
+
)
|
|
98
|
+
return compute_softmax_weights(prev_terms, prev_prev_terms)
|
|
99
|
+
|
|
100
|
+
def look_back(loss_terms, stored_loss_terms):
|
|
101
|
+
# ω̂_j^(i,0)
|
|
102
|
+
initial_terms = jax.tree.map(lambda slt: tau * slt[0], stored_loss_terms)
|
|
103
|
+
weights = compute_softmax_weights(loss_terms, initial_terms)
|
|
104
|
+
return n_loss * weights
|
|
105
|
+
|
|
106
|
+
def soft_adapt_current(loss_terms, stored_loss_terms):
|
|
107
|
+
# ω_j(i)
|
|
108
|
+
prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
|
|
109
|
+
return compute_softmax_weights(loss_terms, prev_terms)
|
|
110
|
+
|
|
111
|
+
# Bernoulli variable for random lookback
|
|
112
|
+
rho = jax.random.bernoulli(key, p).astype(float)
|
|
113
|
+
|
|
114
|
+
# Base case for first iteration
|
|
115
|
+
def first_iter_case(_):
|
|
116
|
+
return do_nothing(loss_weights, None)
|
|
117
|
+
|
|
118
|
+
# Case for iteration >= 1
|
|
119
|
+
def subsequent_iter_case(_):
|
|
120
|
+
# Compute historical weights
|
|
121
|
+
def hist_weights_case1(_):
|
|
122
|
+
return soft_adapt_current(loss_terms, stored_loss_terms)
|
|
123
|
+
|
|
124
|
+
def hist_weights_case2(_):
|
|
125
|
+
return rho * soft_adapt_prev(stored_loss_terms) + (1 - rho) * look_back(
|
|
126
|
+
loss_terms, stored_loss_terms
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
loss_weights_hist = jax.lax.cond(
|
|
130
|
+
iteration_nb < 2,
|
|
131
|
+
hist_weights_case1,
|
|
132
|
+
hist_weights_case2,
|
|
133
|
+
None,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Compute and return final weights
|
|
137
|
+
adaptive_weights = soft_adapt_current(loss_terms, stored_loss_terms)
|
|
138
|
+
return decay_factor * loss_weights_hist + (1 - decay_factor) * adaptive_weights
|
|
139
|
+
|
|
140
|
+
return jax.lax.cond(
|
|
141
|
+
iteration_nb == 0,
|
|
142
|
+
first_iter_case,
|
|
143
|
+
subsequent_iter_case,
|
|
144
|
+
None,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def lr_annealing(
|
|
149
|
+
loss_weights: AbstractLossWeights,
|
|
150
|
+
grad_terms: AnyLossComponents,
|
|
151
|
+
decay_factor: float = 0.9, # 0.9 is the recommended value from the article
|
|
152
|
+
eps: float = 1e-6,
|
|
153
|
+
) -> Array:
|
|
154
|
+
r"""
|
|
155
|
+
Implementation of the Learning rate annealing
|
|
156
|
+
Algorithm 1 in the paper UNDERSTANDING AND MITIGATING GRADIENT PATHOLOGIES IN PHYSICS-INFORMED NEURAL NETWORKS
|
|
157
|
+
|
|
158
|
+
(a) Compute $\hat{\lambda}_i$ by
|
|
159
|
+
$$
|
|
160
|
+
\hat{\lambda}_i = \frac{\max_{\theta}\{|\nabla_\theta \mathcal{L}_r (\theta_n)|\}}{mean(|\nabla_\theta \mathcal{L}_i (\theta_n)|)}, \quad i=1,\dots, M,
|
|
161
|
+
$$
|
|
162
|
+
|
|
163
|
+
(b) Update the weights $\lambda_i$ using a moving average of the form
|
|
164
|
+
$$
|
|
165
|
+
\lambda_i = (1-\alpha) \lambda_{i-1} + \alpha \hat{\lambda}_i, \quad i=1, \dots, M.
|
|
166
|
+
$$
|
|
167
|
+
|
|
168
|
+
Note that since None is not treated as a leaf by jax tree.util functions,
|
|
169
|
+
we naturally avoid None components from loss_terms, stored_loss_terms etc.!
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
assert hasattr(grad_terms, "dyn_loss")
|
|
173
|
+
dyn_loss_grads = getattr(grad_terms, "dyn_loss")
|
|
174
|
+
data_fit_grads = [
|
|
175
|
+
getattr(grad_terms, att) if hasattr(grad_terms, att) else None
|
|
176
|
+
for att in ["norm_loss", "boundary_loss", "observations", "initial_condition"]
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
dyn_loss_grads_leaves = jax.tree.leaves(
|
|
180
|
+
dyn_loss_grads,
|
|
181
|
+
is_leaf=eqx.is_inexact_array,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
max_dyn_loss_grads = jnp.max(
|
|
185
|
+
jnp.stack([jnp.max(jnp.abs(g)) for g in dyn_loss_grads_leaves])
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
mean_gradients = [
|
|
189
|
+
jnp.mean(jnp.stack([jnp.abs(jnp.mean(g)) for g in jax.tree.leaves(t)]))
|
|
190
|
+
for t in data_fit_grads
|
|
191
|
+
if t is not None and jax.tree.leaves(t)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
lambda_hat = max_dyn_loss_grads / (jnp.array(mean_gradients) + eps)
|
|
195
|
+
old_weights = jnp.array(
|
|
196
|
+
jax.tree.leaves(
|
|
197
|
+
loss_weights,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
new_weights = (1 - decay_factor) * old_weights[1:] + decay_factor * lambda_hat
|
|
202
|
+
return jnp.hstack([old_weights[0], new_weights])
|