jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_loss_utils.py CHANGED
@@ -40,7 +40,6 @@ def dynamic_loss_apply(
40
40
  ),
41
41
  params: Params[Array],
42
42
  vmap_axes: tuple[int, Params[int | None] | None],
43
- loss_weight: float | Float[Array, " dyn_loss_dimension"],
44
43
  u_type: PINN | HyperPINN | None = None,
45
44
  ) -> Float[Array, " "]:
46
45
  """
@@ -58,10 +57,10 @@ def dynamic_loss_apply(
58
57
  0,
59
58
  )
60
59
  residuals = v_dyn_loss(batch, params)
61
- mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
60
+ mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
62
61
  elif u_type == SPINN or isinstance(u, SPINN):
63
62
  residuals = dyn_loss(batch, u, params)
64
- mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
63
+ mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
65
64
  else:
66
65
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
67
66
 
@@ -79,7 +78,6 @@ def normalization_loss_apply(
79
78
  params: Params[Array],
80
79
  vmap_axes_params: tuple[Params[int | None] | None],
81
80
  norm_weights: Float[Array, " nb_norm_samples"],
82
- loss_weight: float,
83
81
  ) -> Float[Array, " "]:
84
82
  """
85
83
  Note the squeezing on each result. We expect unidimensional *PINN since
@@ -95,9 +93,7 @@ def normalization_loss_apply(
95
93
  res = v_u(*batches, params)
96
94
  assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
97
95
  # Monte-Carlo integration using importance sampling
98
- mse_norm_loss = loss_weight * (
99
- jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
100
- )
96
+ mse_norm_loss = jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
101
97
  else:
102
98
  # NOTE this cartesian product is costly
103
99
  batch_cart_prod = make_cartesian_product(
@@ -115,7 +111,7 @@ def normalization_loss_apply(
115
111
  assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
116
112
  # For all times t, we perform an integration. Then we average the
117
113
  # losses over times.
118
- mse_norm_loss = loss_weight * jnp.mean(
114
+ mse_norm_loss = jnp.mean(
119
115
  jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
120
116
  )
121
117
  elif isinstance(u, SPINN):
@@ -123,8 +119,7 @@ def normalization_loss_apply(
123
119
  res = u(*batches, params)
124
120
  assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
125
121
  mse_norm_loss = (
126
- loss_weight
127
- * jnp.abs(
122
+ jnp.abs(
128
123
  jnp.mean(
129
124
  res.squeeze(),
130
125
  )
@@ -144,7 +139,7 @@ def normalization_loss_apply(
144
139
  )
145
140
  assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
146
141
  # the outer mean() below is for the times stamps
147
- mse_norm_loss = loss_weight * jnp.mean(
142
+ mse_norm_loss = jnp.mean(
148
143
  jnp.abs(
149
144
  jnp.mean(
150
145
  res.squeeze(),
@@ -168,7 +163,6 @@ def boundary_condition_apply(
168
163
  omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
169
164
  omega_boundary_condition: str | dict[str, str],
170
165
  omega_boundary_dim: slice | dict[str, slice],
171
- loss_weight: float | Float[Array, " boundary_cond_dim"],
172
166
  ) -> Float[Array, " "]:
173
167
  assert batch.border_batch is not None
174
168
  vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
@@ -205,10 +199,7 @@ def boundary_condition_apply(
205
199
  None
206
200
  if c is None
207
201
  else jnp.mean(
208
- loss_weight
209
- * _compute_boundary_loss(
210
- c, f, batch, u, params, fa, d, vmap_in_axes
211
- )
202
+ _compute_boundary_loss(c, f, batch, u, params, fa, d, vmap_in_axes)
212
203
  )
213
204
  ),
214
205
  omega_boundary_dicts[0], # omega_boundary_condition,
@@ -225,8 +216,7 @@ def boundary_condition_apply(
225
216
  facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
226
217
  b_losses_by_facet = jax.tree_util.tree_map(
227
218
  lambda fa: jnp.mean(
228
- loss_weight
229
- * _compute_boundary_loss(
219
+ _compute_boundary_loss(
230
220
  omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
231
221
  omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
232
222
  batch,
@@ -251,7 +241,6 @@ def observations_loss_apply(
251
241
  params: Params[Array],
252
242
  vmap_axes: tuple[int, Params[int | None] | None],
253
243
  observed_values: Float[Array, " obs_batch_size observation_dim"],
254
- loss_weight: float | Float[Array, " observation_dim"],
255
244
  obs_slice: EllipsisType | slice | None,
256
245
  ) -> Float[Array, " "]:
257
246
  if isinstance(u, (PINN, HyperPINN)):
@@ -263,8 +252,7 @@ def observations_loss_apply(
263
252
  val = v_u(batch, params)[:, obs_slice]
264
253
  mse_observation_loss = jnp.mean(
265
254
  jnp.sum(
266
- loss_weight
267
- * _subtract_with_check(
255
+ _subtract_with_check(
268
256
  observed_values, val, cause="user defined observed_values"
269
257
  )
270
258
  ** 2,
@@ -285,7 +273,6 @@ def initial_condition_apply(
285
273
  vmap_axes: tuple[int, Params[int | None] | None],
286
274
  initial_condition_fun: Callable,
287
275
  t0: Float[Array, " 1"],
288
- loss_weight: float | Float[Array, " initial_condition_dimension"],
289
276
  ) -> Float[Array, " "]:
290
277
  n = omega_batch.shape[0]
291
278
  t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
@@ -304,7 +291,7 @@ def initial_condition_apply(
304
291
  # dimension as params to be able to vmap.
305
292
  # Recall that by convention:
306
293
  # param_batch_dict = times_batch_size * omega_batch_size
307
- mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
294
+ mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
308
295
  elif isinstance(u, SPINN):
309
296
  values = lambda t_x: u(
310
297
  t_x,
@@ -317,7 +304,30 @@ def initial_condition_apply(
317
304
  v_ini,
318
305
  cause="Output of initial_condition_fun",
319
306
  )
320
- mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
307
+ mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
321
308
  else:
322
309
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
323
310
  return mse_initial_condition
311
+
312
+
313
+ def initial_condition_check(x, dim_size=None):
314
+ """
315
+ Make a (dim_size,) jnp array from an int, a float or a 0D jnp array
316
+
317
+ """
318
+ if isinstance(x, Array):
319
+ if not x.shape: # e.g. user input: jnp.array(0.)
320
+ x = jnp.array([x])
321
+ if dim_size is not None: # we check for the required dims_ize
322
+ if x.shape != (dim_size,):
323
+ raise ValueError(
324
+ f"Wrong dim_size. It should be({dim_size},). Got shape: {x.shape}"
325
+ )
326
+
327
+ elif isinstance(x, float): # e.g. user input: 0.
328
+ x = jnp.array([x])
329
+ elif isinstance(x, int): # e.g. user input: 0
330
+ x = jnp.array([float(x)])
331
+ else:
332
+ raise ValueError(f"Wrong value, expected Array, float or int, got {type(x)}")
333
+ return x
@@ -0,0 +1,202 @@
1
+ """
2
+ A collection of specific weight update schemes in jinns
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import TYPE_CHECKING
7
+ from jaxtyping import Array, Key
8
+ import jax.numpy as jnp
9
+ import jax
10
+ import equinox as eqx
11
+
12
+ if TYPE_CHECKING:
13
+ from jinns.loss._loss_weights import AbstractLossWeights
14
+ from jinns.utils._types import AnyLossComponents
15
+
16
+
17
+ def soft_adapt(
18
+ loss_weights: AbstractLossWeights,
19
+ iteration_nb: int,
20
+ loss_terms: AnyLossComponents,
21
+ stored_loss_terms: AnyLossComponents,
22
+ ) -> Array:
23
+ r"""
24
+ Implement the simple strategy given in
25
+ https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/theory/advanced_schemes.html#softadapt
26
+
27
+ $$
28
+ w_j(i)= \frac{\exp(\frac{L_j(i)}{L_j(i-1)+\epsilon}-\mu(i))}
29
+ {\sum_{k=1}^{n_{loss}}\exp(\frac{L_k(i)}{L_k(i-1)+\epsilon}-\mu(i)}
30
+ $$
31
+
32
+ Note that since None is not treated as a leaf by jax tree.util functions,
33
+ we naturally avoid None components from loss_terms, stored_loss_terms etc.!
34
+ """
35
+
36
+ def do_nothing(loss_weights, _, __):
37
+ return jnp.array(
38
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
39
+ )
40
+
41
+ def soft_adapt_(_, loss_terms, stored_loss_terms):
42
+ ratio_pytree = jax.tree.map(
43
+ lambda lt, slt: lt / (slt[iteration_nb - 1] + 1e-6),
44
+ loss_terms,
45
+ stored_loss_terms,
46
+ )
47
+ mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=jnp.array(-jnp.inf))
48
+ ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
49
+ ratio_leaves = jax.tree.leaves(ratio_pytree)
50
+ return jax.nn.softmax(jnp.array(ratio_leaves))
51
+
52
+ return jax.lax.cond(
53
+ iteration_nb == 0,
54
+ lambda op: do_nothing(*op),
55
+ lambda op: soft_adapt_(*op),
56
+ (loss_weights, loss_terms, stored_loss_terms),
57
+ )
58
+
59
+
60
+ def ReLoBRaLo(
61
+ loss_weights: AbstractLossWeights,
62
+ iteration_nb: int,
63
+ loss_terms: AnyLossComponents,
64
+ stored_loss_terms: AnyLossComponents,
65
+ key: Key,
66
+ decay_factor: float = 0.9,
67
+ tau: float = 1, ## referred to as temperature in the article
68
+ p: float = 0.9,
69
+ ):
70
+ r"""
71
+ Implementing the extension of softadapt: Relative Loss Balancing with random LookBack
72
+ """
73
+ n_loss = len(jax.tree.leaves(loss_terms)) # number of loss terms
74
+ epsilon = 1e-6
75
+
76
+ def do_nothing(loss_weights, _):
77
+ return jnp.array(
78
+ jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
79
+ )
80
+
81
+ def compute_softmax_weights(current, reference):
82
+ ratio_pytree = jax.tree.map(
83
+ lambda lt, ref: lt / (ref + epsilon),
84
+ current,
85
+ reference,
86
+ )
87
+ mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=-jnp.inf)
88
+ ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
89
+ ratio_leaves = jax.tree.leaves(ratio_pytree)
90
+ return jax.nn.softmax(jnp.array(ratio_leaves))
91
+
92
+ def soft_adapt_prev(stored_loss_terms):
93
+ # ω_j(i-1)
94
+ prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
95
+ prev_prev_terms = jax.tree.map(
96
+ lambda slt: slt[iteration_nb - 2], stored_loss_terms
97
+ )
98
+ return compute_softmax_weights(prev_terms, prev_prev_terms)
99
+
100
+ def look_back(loss_terms, stored_loss_terms):
101
+ # ω̂_j^(i,0)
102
+ initial_terms = jax.tree.map(lambda slt: tau * slt[0], stored_loss_terms)
103
+ weights = compute_softmax_weights(loss_terms, initial_terms)
104
+ return n_loss * weights
105
+
106
+ def soft_adapt_current(loss_terms, stored_loss_terms):
107
+ # ω_j(i)
108
+ prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
109
+ return compute_softmax_weights(loss_terms, prev_terms)
110
+
111
+ # Bernoulli variable for random lookback
112
+ rho = jax.random.bernoulli(key, p).astype(float)
113
+
114
+ # Base case for first iteration
115
+ def first_iter_case(_):
116
+ return do_nothing(loss_weights, None)
117
+
118
+ # Case for iteration >= 1
119
+ def subsequent_iter_case(_):
120
+ # Compute historical weights
121
+ def hist_weights_case1(_):
122
+ return soft_adapt_current(loss_terms, stored_loss_terms)
123
+
124
+ def hist_weights_case2(_):
125
+ return rho * soft_adapt_prev(stored_loss_terms) + (1 - rho) * look_back(
126
+ loss_terms, stored_loss_terms
127
+ )
128
+
129
+ loss_weights_hist = jax.lax.cond(
130
+ iteration_nb < 2,
131
+ hist_weights_case1,
132
+ hist_weights_case2,
133
+ None,
134
+ )
135
+
136
+ # Compute and return final weights
137
+ adaptive_weights = soft_adapt_current(loss_terms, stored_loss_terms)
138
+ return decay_factor * loss_weights_hist + (1 - decay_factor) * adaptive_weights
139
+
140
+ return jax.lax.cond(
141
+ iteration_nb == 0,
142
+ first_iter_case,
143
+ subsequent_iter_case,
144
+ None,
145
+ )
146
+
147
+
148
+ def lr_annealing(
149
+ loss_weights: AbstractLossWeights,
150
+ grad_terms: AnyLossComponents,
151
+ decay_factor: float = 0.9, # 0.9 is the recommended value from the article
152
+ eps: float = 1e-6,
153
+ ) -> Array:
154
+ r"""
155
+ Implementation of the Learning rate annealing
156
+ Algorithm 1 in the paper UNDERSTANDING AND MITIGATING GRADIENT PATHOLOGIES IN PHYSICS-INFORMED NEURAL NETWORKS
157
+
158
+ (a) Compute $\hat{\lambda}_i$ by
159
+ $$
160
+ \hat{\lambda}_i = \frac{\max_{\theta}\{|\nabla_\theta \mathcal{L}_r (\theta_n)|\}}{mean(|\nabla_\theta \mathcal{L}_i (\theta_n)|)}, \quad i=1,\dots, M,
161
+ $$
162
+
163
+ (b) Update the weights $\lambda_i$ using a moving average of the form
164
+ $$
165
+ \lambda_i = (1-\alpha) \lambda_{i-1} + \alpha \hat{\lambda}_i, \quad i=1, \dots, M.
166
+ $$
167
+
168
+ Note that since None is not treated as a leaf by jax tree.util functions,
169
+ we naturally avoid None components from loss_terms, stored_loss_terms etc.!
170
+
171
+ """
172
+ assert hasattr(grad_terms, "dyn_loss")
173
+ dyn_loss_grads = getattr(grad_terms, "dyn_loss")
174
+ data_fit_grads = [
175
+ getattr(grad_terms, att) if hasattr(grad_terms, att) else None
176
+ for att in ["norm_loss", "boundary_loss", "observations", "initial_condition"]
177
+ ]
178
+
179
+ dyn_loss_grads_leaves = jax.tree.leaves(
180
+ dyn_loss_grads,
181
+ is_leaf=eqx.is_inexact_array,
182
+ )
183
+
184
+ max_dyn_loss_grads = jnp.max(
185
+ jnp.stack([jnp.max(jnp.abs(g)) for g in dyn_loss_grads_leaves])
186
+ )
187
+
188
+ mean_gradients = [
189
+ jnp.mean(jnp.stack([jnp.abs(jnp.mean(g)) for g in jax.tree.leaves(t)]))
190
+ for t in data_fit_grads
191
+ if t is not None and jax.tree.leaves(t)
192
+ ]
193
+
194
+ lambda_hat = max_dyn_loss_grads / (jnp.array(mean_gradients) + eps)
195
+ old_weights = jnp.array(
196
+ jax.tree.leaves(
197
+ loss_weights,
198
+ )
199
+ )
200
+
201
+ new_weights = (1 - decay_factor) * old_weights[1:] + decay_factor * lambda_hat
202
+ return jnp.hstack([old_weights[0], new_weights])
@@ -2,26 +2,82 @@
2
2
  Formalize the loss weights data structure
3
3
  """
4
4
 
5
- from jaxtyping import Array, Float
5
+ from __future__ import annotations
6
+ from dataclasses import fields
7
+
8
+ from jaxtyping import Array
9
+ import jax.numpy as jnp
6
10
  import equinox as eqx
7
11
 
8
12
 
9
- class LossWeightsODE(eqx.Module):
10
- dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
11
- initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
12
- observations: Array | Float = eqx.field(kw_only=True, default=0.0)
13
+ def lw_converter(x):
14
+ if x is None:
15
+ return x
16
+ else:
17
+ return jnp.asarray(x)
18
+
19
+
20
+ class AbstractLossWeights(eqx.Module):
21
+ """
22
+ An abstract class, currently only useful for type hints
23
+
24
+ TODO in the future maybe loss weights could be subclasses of
25
+ XDEComponentsAbstract?
26
+ """
27
+
28
+ def items(self):
29
+ """
30
+ For the dataclass to be iterated like a dictionary.
31
+ Practical and retrocompatible with old code when loss components were
32
+ dictionaries
33
+ """
34
+ return {
35
+ field.name: getattr(self, field.name)
36
+ for field in fields(self)
37
+ if getattr(self, field.name) is not None
38
+ }.items()
39
+
40
+
41
+ class LossWeightsODE(AbstractLossWeights):
42
+ dyn_loss: Array | float | None = eqx.field(
43
+ kw_only=True, default=None, converter=lw_converter
44
+ )
45
+ initial_condition: Array | float | None = eqx.field(
46
+ kw_only=True, default=None, converter=lw_converter
47
+ )
48
+ observations: Array | float | None = eqx.field(
49
+ kw_only=True, default=None, converter=lw_converter
50
+ )
13
51
 
14
52
 
15
- class LossWeightsPDEStatio(eqx.Module):
16
- dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
17
- norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
18
- boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
19
- observations: Array | Float = eqx.field(kw_only=True, default=0.0)
53
+ class LossWeightsPDEStatio(AbstractLossWeights):
54
+ dyn_loss: Array | float | None = eqx.field(
55
+ kw_only=True, default=None, converter=lw_converter
56
+ )
57
+ norm_loss: Array | float | None = eqx.field(
58
+ kw_only=True, default=None, converter=lw_converter
59
+ )
60
+ boundary_loss: Array | float | None = eqx.field(
61
+ kw_only=True, default=None, converter=lw_converter
62
+ )
63
+ observations: Array | float | None = eqx.field(
64
+ kw_only=True, default=None, converter=lw_converter
65
+ )
20
66
 
21
67
 
22
- class LossWeightsPDENonStatio(eqx.Module):
23
- dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
24
- norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
25
- boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
26
- observations: Array | Float = eqx.field(kw_only=True, default=0.0)
27
- initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
68
+ class LossWeightsPDENonStatio(AbstractLossWeights):
69
+ dyn_loss: Array | float | None = eqx.field(
70
+ kw_only=True, default=None, converter=lw_converter
71
+ )
72
+ norm_loss: Array | float | None = eqx.field(
73
+ kw_only=True, default=None, converter=lw_converter
74
+ )
75
+ boundary_loss: Array | float | None = eqx.field(
76
+ kw_only=True, default=None, converter=lw_converter
77
+ )
78
+ observations: Array | float | None = eqx.field(
79
+ kw_only=True, default=None, converter=lw_converter
80
+ )
81
+ initial_condition: Array | float | None = eqx.field(
82
+ kw_only=True, default=None, converter=lw_converter
83
+ )
@@ -10,6 +10,14 @@ from jaxtyping import Array, PyTree, Float
10
10
  T = TypeVar("T") # the generic type for what is in the Params PyTree because we
11
11
  # have possibly Params of Arrays, boolean, ...
12
12
 
13
+ ### NOTE
14
+ ### We are taking derivatives with respect to Params eqx.Modules.
15
+ ### This has been shown to behave weirdly if some fields of eqx.Modules have
16
+ ### been set as `field(init=False)`, we then should never create such fields in
17
+ ### jinns' Params modules.
18
+ ### We currently have silenced the warning related to this (see jinns.__init__
19
+ ### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
20
+
13
21
 
14
22
  class Params(eqx.Module, Generic[T]):
15
23
  """