jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_utils.py CHANGED
@@ -6,42 +6,42 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Callable, Dict
9
+ from typing import TYPE_CHECKING, Callable, TypeGuard
10
+ from types import EllipsisType
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap
13
- import equinox as eqx
14
- from jaxtyping import Float, Array, PyTree
14
+ from jaxtyping import Float, Array
15
15
 
16
16
  from jinns.loss._boundary_conditions import (
17
17
  _compute_boundary_loss,
18
18
  )
19
19
  from jinns.utils._utils import _subtract_with_check, get_grid
20
- from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
20
+ from jinns.data._utils import make_cartesian_product
21
21
  from jinns.parameters._params import _get_vmap_in_axes_params
22
22
  from jinns.nn._pinn import PINN
23
23
  from jinns.nn._spinn import SPINN
24
24
  from jinns.nn._hyperpinn import HyperPINN
25
- from jinns.data._Batchs import *
26
- from jinns.parameters._params import Params, ParamsDict
25
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
26
+ from jinns.parameters._params import Params
27
27
 
28
28
  if TYPE_CHECKING:
29
- from jinns.utils._types import *
29
+ from jinns.utils._types import BoundaryConditionFun
30
+ from jinns.nn._abstract_pinn import AbstractPINN
30
31
 
31
32
 
32
33
  def dynamic_loss_apply(
33
- dyn_loss: DynamicLoss,
34
- u: eqx.Module,
34
+ dyn_loss: Callable,
35
+ u: AbstractPINN,
35
36
  batch: (
36
- Float[Array, "batch_size 1"]
37
- | Float[Array, "batch_size dim"]
38
- | Float[Array, "batch_size 1+dim"]
37
+ Float[Array, " batch_size 1"]
38
+ | Float[Array, " batch_size dim"]
39
+ | Float[Array, " batch_size 1+dim"]
39
40
  ),
40
- params: Params | ParamsDict,
41
- vmap_axes: tuple[int | None, ...],
42
- loss_weight: float | Float[Array, "dyn_loss_dimension"],
41
+ params: Params[Array],
42
+ vmap_axes: tuple[int, Params[int | None] | None],
43
43
  u_type: PINN | HyperPINN | None = None,
44
- ) -> float:
44
+ ) -> Float[Array, " "]:
45
45
  """
46
46
  Sometimes when u is a lambda function a or dict we do not have access to
47
47
  its type here, hence the last argument
@@ -49,16 +49,18 @@ def dynamic_loss_apply(
49
49
  if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
50
50
  v_dyn_loss = vmap(
51
51
  lambda batch, params: dyn_loss(
52
- batch, u, params # we must place the params at the end
52
+ batch,
53
+ u,
54
+ params, # we must place the params at the end
53
55
  ),
54
56
  vmap_axes,
55
57
  0,
56
58
  )
57
59
  residuals = v_dyn_loss(batch, params)
58
- mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
60
+ mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
59
61
  elif u_type == SPINN or isinstance(u, SPINN):
60
62
  residuals = dyn_loss(batch, u, params)
61
- mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
63
+ mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
62
64
  else:
63
65
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
64
66
 
@@ -66,18 +68,17 @@ def dynamic_loss_apply(
66
68
 
67
69
 
68
70
  def normalization_loss_apply(
69
- u: eqx.Module,
71
+ u: AbstractPINN,
70
72
  batches: (
71
- tuple[Float[Array, "nb_norm_samples dim"]]
73
+ tuple[Float[Array, " nb_norm_samples dim"]]
72
74
  | tuple[
73
- Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
75
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
74
76
  ]
75
77
  ),
76
- params: Params | ParamsDict,
77
- vmap_axes_params: tuple[int | None, ...],
78
- norm_weights: Float[Array, "nb_norm_samples"],
79
- loss_weight: float,
80
- ) -> float:
78
+ params: Params[Array],
79
+ vmap_axes_params: tuple[Params[int | None] | None],
80
+ norm_weights: Float[Array, " nb_norm_samples"],
81
+ ) -> Float[Array, " "]:
81
82
  """
82
83
  Note the squeezing on each result. We expect unidimensional *PINN since
83
84
  they represent probability distributions
@@ -92,12 +93,10 @@ def normalization_loss_apply(
92
93
  res = v_u(*batches, params)
93
94
  assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
94
95
  # Monte-Carlo integration using importance sampling
95
- mse_norm_loss = loss_weight * (
96
- jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
97
- )
96
+ mse_norm_loss = jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
98
97
  else:
99
98
  # NOTE this cartesian product is costly
100
- batches = make_cartesian_product(
99
+ batch_cart_prod = make_cartesian_product(
101
100
  batches[0],
102
101
  batches[1],
103
102
  ).reshape(batches[0].shape[0], batches[1].shape[0], -1)
@@ -108,11 +107,11 @@ def normalization_loss_apply(
108
107
  ),
109
108
  in_axes=(0,) + vmap_axes_params,
110
109
  )
111
- res = v_u(batches, params)
110
+ res = v_u(batch_cart_prod, params)
112
111
  assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
113
112
  # For all times t, we perform an integration. Then we average the
114
113
  # losses over times.
115
- mse_norm_loss = loss_weight * jnp.mean(
114
+ mse_norm_loss = jnp.mean(
116
115
  jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
117
116
  )
118
117
  elif isinstance(u, SPINN):
@@ -120,8 +119,7 @@ def normalization_loss_apply(
120
119
  res = u(*batches, params)
121
120
  assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
122
121
  mse_norm_loss = (
123
- loss_weight
124
- * jnp.abs(
122
+ jnp.abs(
125
123
  jnp.mean(
126
124
  res.squeeze(),
127
125
  )
@@ -141,11 +139,11 @@ def normalization_loss_apply(
141
139
  )
142
140
  assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
143
141
  # the outer mean() below is for the times stamps
144
- mse_norm_loss = loss_weight * jnp.mean(
142
+ mse_norm_loss = jnp.mean(
145
143
  jnp.abs(
146
144
  jnp.mean(
147
145
  res.squeeze(),
148
- axis=(d + 1 for d in range(res.ndim - 2)),
146
+ axis=list(d + 1 for d in range(res.ndim - 2)),
149
147
  )
150
148
  * norm_weights
151
149
  - 1
@@ -159,18 +157,33 @@ def normalization_loss_apply(
159
157
 
160
158
 
161
159
  def boundary_condition_apply(
162
- u: eqx.Module,
160
+ u: AbstractPINN,
163
161
  batch: PDEStatioBatch | PDENonStatioBatch,
164
- params: Params | ParamsDict,
165
- omega_boundary_fun: Callable,
166
- omega_boundary_condition: str,
167
- omega_boundary_dim: int,
168
- loss_weight: float | Float[Array, "boundary_cond_dim"],
169
- ) -> float:
170
-
162
+ params: Params[Array],
163
+ omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
164
+ omega_boundary_condition: str | dict[str, str],
165
+ omega_boundary_dim: slice | dict[str, slice],
166
+ ) -> Float[Array, " "]:
167
+ assert batch.border_batch is not None
171
168
  vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
172
169
 
173
- if isinstance(omega_boundary_fun, dict):
170
+ def _check_tuple_of_dict(
171
+ val,
172
+ ) -> TypeGuard[
173
+ tuple[
174
+ dict[str, BoundaryConditionFun],
175
+ dict[str, BoundaryConditionFun],
176
+ dict[str, BoundaryConditionFun],
177
+ ]
178
+ ]:
179
+ return all(isinstance(x, dict) for x in val)
180
+
181
+ omega_boundary_dicts = (
182
+ omega_boundary_condition,
183
+ omega_boundary_fun,
184
+ omega_boundary_dim,
185
+ )
186
+ if _check_tuple_of_dict(omega_boundary_dicts):
174
187
  # We must create the facet tree dictionary as we do not have the
175
188
  # enumerate from the for loop to pass the id integer
176
189
  if batch.border_batch.shape[-1] == 2:
@@ -186,16 +199,13 @@ def boundary_condition_apply(
186
199
  None
187
200
  if c is None
188
201
  else jnp.mean(
189
- loss_weight
190
- * _compute_boundary_loss(
191
- c, f, batch, u, params, fa, d, vmap_in_axes
192
- )
202
+ _compute_boundary_loss(c, f, batch, u, params, fa, d, vmap_in_axes)
193
203
  )
194
204
  ),
195
- omega_boundary_condition,
196
- omega_boundary_fun,
205
+ omega_boundary_dicts[0], # omega_boundary_condition,
206
+ omega_boundary_dicts[1], # omega_boundary_fun,
197
207
  facet_tree,
198
- omega_boundary_dim,
208
+ omega_boundary_dicts[2], # omega_boundary_dim,
199
209
  is_leaf=lambda x: x is None,
200
210
  ) # when exploring leaves with None value (no condition) the returned
201
211
  # mse is None and we get rid of the None leaves of b_losses_by_facet
@@ -206,15 +216,14 @@ def boundary_condition_apply(
206
216
  facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
207
217
  b_losses_by_facet = jax.tree_util.tree_map(
208
218
  lambda fa: jnp.mean(
209
- loss_weight
210
- * _compute_boundary_loss(
211
- omega_boundary_condition,
212
- omega_boundary_fun,
219
+ _compute_boundary_loss(
220
+ omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
221
+ omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
213
222
  batch,
214
223
  u,
215
224
  params,
216
225
  fa,
217
- omega_boundary_dim,
226
+ omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
218
227
  vmap_in_axes,
219
228
  )
220
229
  ),
@@ -227,26 +236,23 @@ def boundary_condition_apply(
227
236
 
228
237
 
229
238
  def observations_loss_apply(
230
- u: eqx.Module,
231
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
232
- params: Params | ParamsDict,
233
- vmap_axes: tuple[int | None, ...],
234
- observed_values: Float[Array, "batch_size observation_dim"],
235
- loss_weight: float | Float[Array, "observation_dim"],
236
- obs_slice: slice,
237
- ) -> float:
238
- # TODO implement for SPINN
239
+ u: AbstractPINN,
240
+ batch: Float[Array, " obs_batch_size input_dim"],
241
+ params: Params[Array],
242
+ vmap_axes: tuple[int, Params[int | None] | None],
243
+ observed_values: Float[Array, " obs_batch_size observation_dim"],
244
+ obs_slice: EllipsisType | slice | None,
245
+ ) -> Float[Array, " "]:
239
246
  if isinstance(u, (PINN, HyperPINN)):
240
247
  v_u = vmap(
241
248
  lambda *args: u(*args)[u.slice_solution],
242
249
  vmap_axes,
243
250
  0,
244
251
  )
245
- val = v_u(*batches, params)[:, obs_slice]
252
+ val = v_u(batch, params)[:, obs_slice]
246
253
  mse_observation_loss = jnp.mean(
247
254
  jnp.sum(
248
- loss_weight
249
- * _subtract_with_check(
255
+ _subtract_with_check(
250
256
  observed_values, val, cause="user defined observed_values"
251
257
  )
252
258
  ** 2,
@@ -261,15 +267,15 @@ def observations_loss_apply(
261
267
 
262
268
 
263
269
  def initial_condition_apply(
264
- u: eqx.Module,
265
- omega_batch: Float[Array, "dimension"],
266
- params: Params | ParamsDict,
267
- vmap_axes: tuple[int | None, ...],
270
+ u: AbstractPINN,
271
+ omega_batch: Float[Array, " dimension"],
272
+ params: Params[Array],
273
+ vmap_axes: tuple[int, Params[int | None] | None],
268
274
  initial_condition_fun: Callable,
269
- loss_weight: float | Float[Array, "initial_condition_dimension"],
270
- ) -> float:
275
+ t0: Float[Array, " 1"],
276
+ ) -> Float[Array, " "]:
271
277
  n = omega_batch.shape[0]
272
- t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
278
+ t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
273
279
  if isinstance(u, (PINN, HyperPINN)):
274
280
  v_u_t0 = vmap(
275
281
  lambda t0_x, params: _subtract_with_check(
@@ -285,7 +291,7 @@ def initial_condition_apply(
285
291
  # dimension as params to be able to vmap.
286
292
  # Recall that by convention:
287
293
  # param_batch_dict = times_batch_size * omega_batch_size
288
- mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
294
+ mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
289
295
  elif isinstance(u, SPINN):
290
296
  values = lambda t_x: u(
291
297
  t_x,
@@ -298,107 +304,7 @@ def initial_condition_apply(
298
304
  v_ini,
299
305
  cause="Output of initial_condition_fun",
300
306
  )
301
- mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
307
+ mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
302
308
  else:
303
309
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
304
310
  return mse_initial_condition
305
-
306
-
307
- def constraints_system_loss_apply(
308
- u_constraints_dict: Dict,
309
- batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
310
- params_dict: ParamsDict,
311
- loss_weights: Dict[str, float | Array],
312
- loss_weight_struct: PyTree,
313
- ):
314
- """
315
- Same function for systemlossODE and systemlossPDE!
316
- """
317
- # Transpose so we have each u_dict as outer structure and the
318
- # associated loss_weight as inner structure
319
- loss_weights_T = jax.tree_util.tree_transpose(
320
- jax.tree_util.tree_structure(loss_weight_struct),
321
- jax.tree_util.tree_structure(loss_weights["initial_condition"]),
322
- loss_weights,
323
- )
324
-
325
- if isinstance(params_dict.nn_params, dict):
326
-
327
- def apply_u_constraint(
328
- u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
329
- ):
330
- res_dict_for_u = u_constraint.evaluate(
331
- Params(
332
- nn_params=nn_params,
333
- eq_params=eq_params,
334
- ),
335
- append_obs_batch(batch, obs_batch_u),
336
- )[1]
337
- res_dict_ponderated = jax.tree_util.tree_map(
338
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
339
- )
340
- return res_dict_ponderated
341
-
342
- # Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
343
- # with keys corresponding to the PINN and value correspondinf to an
344
- # original obs_batch_dict. Hence the tree mapping also interates over
345
- # batch.obs_batch_dict
346
- res_dict = jax.tree_util.tree_map(
347
- apply_u_constraint,
348
- u_constraints_dict,
349
- params_dict.nn_params,
350
- (
351
- params_dict.eq_params
352
- if params_dict.eq_params.keys() == params_dict.nn_params.keys()
353
- else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
354
- ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
355
- loss_weights_T,
356
- batch.obs_batch_dict,
357
- is_leaf=lambda x: (
358
- not isinstance(x, dict) # to not traverse more than the first
359
- # outer dict of the pytrees passed to the function. This will
360
- # work because u_constraints_dict is a dict of Losses, and it
361
- # thus stops the traversing of other dict too
362
- ),
363
- )
364
- # TODO try to get rid of this condition?
365
- else:
366
-
367
- def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
368
- res_dict_for_u = u_constraint.evaluate(
369
- params_dict,
370
- append_obs_batch(batch, obs_batch_u),
371
- )[1]
372
- res_dict_ponderated = jax.tree_util.tree_map(
373
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
374
- )
375
- return res_dict_ponderated
376
-
377
- res_dict = jax.tree_util.tree_map(
378
- apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
379
- )
380
-
381
- # Transpose back so we have mses as outer structures and their values
382
- # for each u_dict as inner structures. The tree_leaves transforms the
383
- # inner structure into a list so we can catch is as leaf it the
384
- # tree_map below
385
- res_dict = jax.tree_util.tree_transpose(
386
- jax.tree_util.tree_structure(
387
- jax.tree_util.tree_leaves(loss_weights["initial_condition"])
388
- ),
389
- jax.tree_util.tree_structure(loss_weight_struct),
390
- res_dict,
391
- )
392
- # For each mse, sum their values on each u_dict
393
- res_dict = jax.tree_util.tree_map(
394
- lambda mse: jax.tree_util.tree_reduce(
395
- lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
396
- ),
397
- res_dict,
398
- is_leaf=lambda x: isinstance(x, list),
399
- )
400
- # Total loss
401
- total_loss = jax.tree_util.tree_reduce(
402
- lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
403
- )
404
- return total_loss, res_dict
@@ -0,0 +1,202 @@
1
+ """
2
+ A collection of specific weight update schemes in jinns
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import TYPE_CHECKING
7
+ from jaxtyping import Array, Key
8
+ import jax.numpy as jnp
9
+ import jax
10
+ import equinox as eqx
11
+
12
+ if TYPE_CHECKING:
13
+ from jinns.loss._loss_weights import AbstractLossWeights
14
+ from jinns.utils._types import AnyLossComponents
15
+
16
+
17
+ def soft_adapt(
18
+ loss_weights: AbstractLossWeights,
19
+ iteration_nb: int,
20
+ loss_terms: AnyLossComponents,
21
+ stored_loss_terms: AnyLossComponents,
22
+ ) -> Array:
23
+ r"""
24
+ Implement the simple strategy given in
25
+ https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/theory/advanced_schemes.html#softadapt
26
+
27
+ $$
28
+ w_j(i)= \frac{\exp(\frac{L_j(i)}{L_j(i-1)+\epsilon}-\mu(i))}
29
+ {\sum_{k=1}^{n_{loss}}\exp(\frac{L_k(i)}{L_k(i-1)+\epsilon}-\mu(i)}
30
+ $$
31
+
32
+ Note that since None is not treated as a leaf by jax tree.util functions,
33
+ we naturally avoid None components from loss_terms, stored_loss_terms etc.!
34
+ """
35
+
36
+ def do_nothing(loss_weights, _, __):
37
+ return jnp.array(
38
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
39
+ )
40
+
41
+ def soft_adapt_(_, loss_terms, stored_loss_terms):
42
+ ratio_pytree = jax.tree.map(
43
+ lambda lt, slt: lt / (slt[iteration_nb - 1] + 1e-6),
44
+ loss_terms,
45
+ stored_loss_terms,
46
+ )
47
+ mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=jnp.array(-jnp.inf))
48
+ ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
49
+ ratio_leaves = jax.tree.leaves(ratio_pytree)
50
+ return jax.nn.softmax(jnp.array(ratio_leaves))
51
+
52
+ return jax.lax.cond(
53
+ iteration_nb == 0,
54
+ lambda op: do_nothing(*op),
55
+ lambda op: soft_adapt_(*op),
56
+ (loss_weights, loss_terms, stored_loss_terms),
57
+ )
58
+
59
+
60
+ def ReLoBRaLo(
61
+ loss_weights: AbstractLossWeights,
62
+ iteration_nb: int,
63
+ loss_terms: AnyLossComponents,
64
+ stored_loss_terms: AnyLossComponents,
65
+ key: Key,
66
+ decay_factor: float = 0.9,
67
+ tau: float = 1, ## referred to as temperature in the article
68
+ p: float = 0.9,
69
+ ):
70
+ r"""
71
+ Implementing the extension of softadapt: Relative Loss Balancing with random LookBack
72
+ """
73
+ n_loss = len(jax.tree.leaves(loss_terms)) # number of loss terms
74
+ epsilon = 1e-6
75
+
76
+ def do_nothing(loss_weights, _):
77
+ return jnp.array(
78
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
79
+ )
80
+
81
+ def compute_softmax_weights(current, reference):
82
+ ratio_pytree = jax.tree.map(
83
+ lambda lt, ref: lt / (ref + epsilon),
84
+ current,
85
+ reference,
86
+ )
87
+ mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=-jnp.inf)
88
+ ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
89
+ ratio_leaves = jax.tree.leaves(ratio_pytree)
90
+ return jax.nn.softmax(jnp.array(ratio_leaves))
91
+
92
+ def soft_adapt_prev(stored_loss_terms):
93
+ # ω_j(i-1)
94
+ prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
95
+ prev_prev_terms = jax.tree.map(
96
+ lambda slt: slt[iteration_nb - 2], stored_loss_terms
97
+ )
98
+ return compute_softmax_weights(prev_terms, prev_prev_terms)
99
+
100
+ def look_back(loss_terms, stored_loss_terms):
101
+ # ω̂_j^(i,0)
102
+ initial_terms = jax.tree.map(lambda slt: tau * slt[0], stored_loss_terms)
103
+ weights = compute_softmax_weights(loss_terms, initial_terms)
104
+ return n_loss * weights
105
+
106
+ def soft_adapt_current(loss_terms, stored_loss_terms):
107
+ # ω_j(i)
108
+ prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
109
+ return compute_softmax_weights(loss_terms, prev_terms)
110
+
111
+ # Bernoulli variable for random lookback
112
+ rho = jax.random.bernoulli(key, p).astype(float)
113
+
114
+ # Base case for first iteration
115
+ def first_iter_case(_):
116
+ return do_nothing(loss_weights, None)
117
+
118
+ # Case for iteration >= 1
119
+ def subsequent_iter_case(_):
120
+ # Compute historical weights
121
+ def hist_weights_case1(_):
122
+ return soft_adapt_current(loss_terms, stored_loss_terms)
123
+
124
+ def hist_weights_case2(_):
125
+ return rho * soft_adapt_prev(stored_loss_terms) + (1 - rho) * look_back(
126
+ loss_terms, stored_loss_terms
127
+ )
128
+
129
+ loss_weights_hist = jax.lax.cond(
130
+ iteration_nb < 2,
131
+ hist_weights_case1,
132
+ hist_weights_case2,
133
+ None,
134
+ )
135
+
136
+ # Compute and return final weights
137
+ adaptive_weights = soft_adapt_current(loss_terms, stored_loss_terms)
138
+ return decay_factor * loss_weights_hist + (1 - decay_factor) * adaptive_weights
139
+
140
+ return jax.lax.cond(
141
+ iteration_nb == 0,
142
+ first_iter_case,
143
+ subsequent_iter_case,
144
+ None,
145
+ )
146
+
147
+
148
+ def lr_annealing(
149
+ loss_weights: AbstractLossWeights,
150
+ grad_terms: AnyLossComponents,
151
+ decay_factor: float = 0.9, # 0.9 is the recommended value from the article
152
+ eps: float = 1e-6,
153
+ ) -> Array:
154
+ r"""
155
+ Implementation of the Learning rate annealing
156
+ Algorithm 1 in the paper UNDERSTANDING AND MITIGATING GRADIENT PATHOLOGIES IN PHYSICS-INFORMED NEURAL NETWORKS
157
+
158
+ (a) Compute $\hat{\lambda}_i$ by
159
+ $$
160
+ \hat{\lambda}_i = \frac{\max_{\theta}\{|\nabla_\theta \mathcal{L}_r (\theta_n)|\}}{mean(|\nabla_\theta \mathcal{L}_i (\theta_n)|)}, \quad i=1,\dots, M,
161
+ $$
162
+
163
+ (b) Update the weights $\lambda_i$ using a moving average of the form
164
+ $$
165
+ \lambda_i = (1-\alpha) \lambda_{i-1} + \alpha \hat{\lambda}_i, \quad i=1, \dots, M.
166
+ $$
167
+
168
+ Note that since None is not treated as a leaf by jax tree.util functions,
169
+ we naturally avoid None components from loss_terms, stored_loss_terms etc.!
170
+
171
+ """
172
+ assert hasattr(grad_terms, "dyn_loss")
173
+ dyn_loss_grads = getattr(grad_terms, "dyn_loss")
174
+ data_fit_grads = [
175
+ getattr(grad_terms, att) if hasattr(grad_terms, att) else None
176
+ for att in ["norm_loss", "boundary_loss", "observations", "initial_condition"]
177
+ ]
178
+
179
+ dyn_loss_grads_leaves = jax.tree.leaves(
180
+ dyn_loss_grads,
181
+ is_leaf=eqx.is_inexact_array,
182
+ )
183
+
184
+ max_dyn_loss_grads = jnp.max(
185
+ jnp.stack([jnp.max(jnp.abs(g)) for g in dyn_loss_grads_leaves])
186
+ )
187
+
188
+ mean_gradients = [
189
+ jnp.mean(jnp.stack([jnp.abs(jnp.mean(g)) for g in jax.tree.leaves(t)]))
190
+ for t in data_fit_grads
191
+ if t is not None and jax.tree.leaves(t)
192
+ ]
193
+
194
+ lambda_hat = max_dyn_loss_grads / (jnp.array(mean_gradients) + eps)
195
+ old_weights = jnp.array(
196
+ jax.tree.leaves(
197
+ loss_weights,
198
+ )
199
+ )
200
+
201
+ new_weights = (1 - decay_factor) * old_weights[1:] + decay_factor * lambda_hat
202
+ return jnp.hstack([old_weights[0], new_weights])