jinns 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_CubicMeshPDENonStatio.py +156 -28
- jinns/data/_CubicMeshPDEStatio.py +132 -24
- jinns/loss/_DynamicLossAbstract.py +30 -2
- jinns/loss/_LossODE.py +177 -64
- jinns/loss/_LossPDE.py +146 -68
- jinns/loss/__init__.py +4 -0
- jinns/loss/_abstract_loss.py +116 -3
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +34 -24
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +72 -16
- jinns/parameters/_params.py +8 -0
- jinns/solver/_solve.py +141 -46
- jinns/utils/_containers.py +5 -2
- jinns/utils/_types.py +12 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/METADATA +5 -2
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/RECORD +22 -20
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/WHEEL +1 -1
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.4.0.dist-info → jinns-1.5.1.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_utils.py
CHANGED
|
@@ -40,7 +40,6 @@ def dynamic_loss_apply(
|
|
|
40
40
|
),
|
|
41
41
|
params: Params[Array],
|
|
42
42
|
vmap_axes: tuple[int, Params[int | None] | None],
|
|
43
|
-
loss_weight: float | Float[Array, " dyn_loss_dimension"],
|
|
44
43
|
u_type: PINN | HyperPINN | None = None,
|
|
45
44
|
) -> Float[Array, " "]:
|
|
46
45
|
"""
|
|
@@ -58,10 +57,10 @@ def dynamic_loss_apply(
|
|
|
58
57
|
0,
|
|
59
58
|
)
|
|
60
59
|
residuals = v_dyn_loss(batch, params)
|
|
61
|
-
mse_dyn_loss = jnp.mean(jnp.sum(
|
|
60
|
+
mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
|
|
62
61
|
elif u_type == SPINN or isinstance(u, SPINN):
|
|
63
62
|
residuals = dyn_loss(batch, u, params)
|
|
64
|
-
mse_dyn_loss = jnp.mean(jnp.sum(
|
|
63
|
+
mse_dyn_loss = jnp.mean(jnp.sum(residuals**2, axis=-1))
|
|
65
64
|
else:
|
|
66
65
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
67
66
|
|
|
@@ -79,7 +78,6 @@ def normalization_loss_apply(
|
|
|
79
78
|
params: Params[Array],
|
|
80
79
|
vmap_axes_params: tuple[Params[int | None] | None],
|
|
81
80
|
norm_weights: Float[Array, " nb_norm_samples"],
|
|
82
|
-
loss_weight: float,
|
|
83
81
|
) -> Float[Array, " "]:
|
|
84
82
|
"""
|
|
85
83
|
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
@@ -95,9 +93,7 @@ def normalization_loss_apply(
|
|
|
95
93
|
res = v_u(*batches, params)
|
|
96
94
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
97
95
|
# Monte-Carlo integration using importance sampling
|
|
98
|
-
mse_norm_loss =
|
|
99
|
-
jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
100
|
-
)
|
|
96
|
+
mse_norm_loss = jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
101
97
|
else:
|
|
102
98
|
# NOTE this cartesian product is costly
|
|
103
99
|
batch_cart_prod = make_cartesian_product(
|
|
@@ -115,7 +111,7 @@ def normalization_loss_apply(
|
|
|
115
111
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
116
112
|
# For all times t, we perform an integration. Then we average the
|
|
117
113
|
# losses over times.
|
|
118
|
-
mse_norm_loss =
|
|
114
|
+
mse_norm_loss = jnp.mean(
|
|
119
115
|
jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
|
|
120
116
|
)
|
|
121
117
|
elif isinstance(u, SPINN):
|
|
@@ -123,8 +119,7 @@ def normalization_loss_apply(
|
|
|
123
119
|
res = u(*batches, params)
|
|
124
120
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
125
121
|
mse_norm_loss = (
|
|
126
|
-
|
|
127
|
-
* jnp.abs(
|
|
122
|
+
jnp.abs(
|
|
128
123
|
jnp.mean(
|
|
129
124
|
res.squeeze(),
|
|
130
125
|
)
|
|
@@ -144,7 +139,7 @@ def normalization_loss_apply(
|
|
|
144
139
|
)
|
|
145
140
|
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
146
141
|
# the outer mean() below is for the times stamps
|
|
147
|
-
mse_norm_loss =
|
|
142
|
+
mse_norm_loss = jnp.mean(
|
|
148
143
|
jnp.abs(
|
|
149
144
|
jnp.mean(
|
|
150
145
|
res.squeeze(),
|
|
@@ -168,7 +163,6 @@ def boundary_condition_apply(
|
|
|
168
163
|
omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
|
|
169
164
|
omega_boundary_condition: str | dict[str, str],
|
|
170
165
|
omega_boundary_dim: slice | dict[str, slice],
|
|
171
|
-
loss_weight: float | Float[Array, " boundary_cond_dim"],
|
|
172
166
|
) -> Float[Array, " "]:
|
|
173
167
|
assert batch.border_batch is not None
|
|
174
168
|
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
@@ -205,10 +199,7 @@ def boundary_condition_apply(
|
|
|
205
199
|
None
|
|
206
200
|
if c is None
|
|
207
201
|
else jnp.mean(
|
|
208
|
-
|
|
209
|
-
* _compute_boundary_loss(
|
|
210
|
-
c, f, batch, u, params, fa, d, vmap_in_axes
|
|
211
|
-
)
|
|
202
|
+
_compute_boundary_loss(c, f, batch, u, params, fa, d, vmap_in_axes)
|
|
212
203
|
)
|
|
213
204
|
),
|
|
214
205
|
omega_boundary_dicts[0], # omega_boundary_condition,
|
|
@@ -225,8 +216,7 @@ def boundary_condition_apply(
|
|
|
225
216
|
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
226
217
|
b_losses_by_facet = jax.tree_util.tree_map(
|
|
227
218
|
lambda fa: jnp.mean(
|
|
228
|
-
|
|
229
|
-
* _compute_boundary_loss(
|
|
219
|
+
_compute_boundary_loss(
|
|
230
220
|
omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
|
|
231
221
|
omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
|
|
232
222
|
batch,
|
|
@@ -251,7 +241,6 @@ def observations_loss_apply(
|
|
|
251
241
|
params: Params[Array],
|
|
252
242
|
vmap_axes: tuple[int, Params[int | None] | None],
|
|
253
243
|
observed_values: Float[Array, " obs_batch_size observation_dim"],
|
|
254
|
-
loss_weight: float | Float[Array, " observation_dim"],
|
|
255
244
|
obs_slice: EllipsisType | slice | None,
|
|
256
245
|
) -> Float[Array, " "]:
|
|
257
246
|
if isinstance(u, (PINN, HyperPINN)):
|
|
@@ -263,8 +252,7 @@ def observations_loss_apply(
|
|
|
263
252
|
val = v_u(batch, params)[:, obs_slice]
|
|
264
253
|
mse_observation_loss = jnp.mean(
|
|
265
254
|
jnp.sum(
|
|
266
|
-
|
|
267
|
-
* _subtract_with_check(
|
|
255
|
+
_subtract_with_check(
|
|
268
256
|
observed_values, val, cause="user defined observed_values"
|
|
269
257
|
)
|
|
270
258
|
** 2,
|
|
@@ -285,7 +273,6 @@ def initial_condition_apply(
|
|
|
285
273
|
vmap_axes: tuple[int, Params[int | None] | None],
|
|
286
274
|
initial_condition_fun: Callable,
|
|
287
275
|
t0: Float[Array, " 1"],
|
|
288
|
-
loss_weight: float | Float[Array, " initial_condition_dimension"],
|
|
289
276
|
) -> Float[Array, " "]:
|
|
290
277
|
n = omega_batch.shape[0]
|
|
291
278
|
t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
|
|
@@ -304,7 +291,7 @@ def initial_condition_apply(
|
|
|
304
291
|
# dimension as params to be able to vmap.
|
|
305
292
|
# Recall that by convention:
|
|
306
293
|
# param_batch_dict = times_batch_size * omega_batch_size
|
|
307
|
-
mse_initial_condition = jnp.mean(jnp.sum(
|
|
294
|
+
mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
|
|
308
295
|
elif isinstance(u, SPINN):
|
|
309
296
|
values = lambda t_x: u(
|
|
310
297
|
t_x,
|
|
@@ -317,7 +304,30 @@ def initial_condition_apply(
|
|
|
317
304
|
v_ini,
|
|
318
305
|
cause="Output of initial_condition_fun",
|
|
319
306
|
)
|
|
320
|
-
mse_initial_condition = jnp.mean(jnp.sum(
|
|
307
|
+
mse_initial_condition = jnp.mean(jnp.sum(res**2, axis=-1))
|
|
321
308
|
else:
|
|
322
309
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
323
310
|
return mse_initial_condition
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def initial_condition_check(x, dim_size=None):
|
|
314
|
+
"""
|
|
315
|
+
Make a (dim_size,) jnp array from an int, a float or a 0D jnp array
|
|
316
|
+
|
|
317
|
+
"""
|
|
318
|
+
if isinstance(x, Array):
|
|
319
|
+
if not x.shape: # e.g. user input: jnp.array(0.)
|
|
320
|
+
x = jnp.array([x])
|
|
321
|
+
if dim_size is not None: # we check for the required dims_ize
|
|
322
|
+
if x.shape != (dim_size,):
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Wrong dim_size. It should be({dim_size},). Got shape: {x.shape}"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
elif isinstance(x, float): # e.g. user input: 0.
|
|
328
|
+
x = jnp.array([x])
|
|
329
|
+
elif isinstance(x, int): # e.g. user input: 0
|
|
330
|
+
x = jnp.array([float(x)])
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError(f"Wrong value, expected Array, float or int, got {type(x)}")
|
|
333
|
+
return x
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A collection of specific weight update schemes in jinns
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
from jaxtyping import Array, Key
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from jinns.loss._loss_weights import AbstractLossWeights
|
|
14
|
+
from jinns.utils._types import AnyLossComponents
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def soft_adapt(
|
|
18
|
+
loss_weights: AbstractLossWeights,
|
|
19
|
+
iteration_nb: int,
|
|
20
|
+
loss_terms: AnyLossComponents,
|
|
21
|
+
stored_loss_terms: AnyLossComponents,
|
|
22
|
+
) -> Array:
|
|
23
|
+
r"""
|
|
24
|
+
Implement the simple strategy given in
|
|
25
|
+
https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/theory/advanced_schemes.html#softadapt
|
|
26
|
+
|
|
27
|
+
$$
|
|
28
|
+
w_j(i)= \frac{\exp(\frac{L_j(i)}{L_j(i-1)+\epsilon}-\mu(i))}
|
|
29
|
+
{\sum_{k=1}^{n_{loss}}\exp(\frac{L_k(i)}{L_k(i-1)+\epsilon}-\mu(i)}
|
|
30
|
+
$$
|
|
31
|
+
|
|
32
|
+
Note that since None is not treated as a leaf by jax tree.util functions,
|
|
33
|
+
we naturally avoid None components from loss_terms, stored_loss_terms etc.!
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def do_nothing(loss_weights, _, __):
|
|
37
|
+
return jnp.array(
|
|
38
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def soft_adapt_(_, loss_terms, stored_loss_terms):
|
|
42
|
+
ratio_pytree = jax.tree.map(
|
|
43
|
+
lambda lt, slt: lt / (slt[iteration_nb - 1] + 1e-6),
|
|
44
|
+
loss_terms,
|
|
45
|
+
stored_loss_terms,
|
|
46
|
+
)
|
|
47
|
+
mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=jnp.array(-jnp.inf))
|
|
48
|
+
ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
|
|
49
|
+
ratio_leaves = jax.tree.leaves(ratio_pytree)
|
|
50
|
+
return jax.nn.softmax(jnp.array(ratio_leaves))
|
|
51
|
+
|
|
52
|
+
return jax.lax.cond(
|
|
53
|
+
iteration_nb == 0,
|
|
54
|
+
lambda op: do_nothing(*op),
|
|
55
|
+
lambda op: soft_adapt_(*op),
|
|
56
|
+
(loss_weights, loss_terms, stored_loss_terms),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def ReLoBRaLo(
|
|
61
|
+
loss_weights: AbstractLossWeights,
|
|
62
|
+
iteration_nb: int,
|
|
63
|
+
loss_terms: AnyLossComponents,
|
|
64
|
+
stored_loss_terms: AnyLossComponents,
|
|
65
|
+
key: Key,
|
|
66
|
+
decay_factor: float = 0.9,
|
|
67
|
+
tau: float = 1, ## referred to as temperature in the article
|
|
68
|
+
p: float = 0.9,
|
|
69
|
+
):
|
|
70
|
+
r"""
|
|
71
|
+
Implementing the extension of softadapt: Relative Loss Balancing with random LookBack
|
|
72
|
+
"""
|
|
73
|
+
n_loss = len(jax.tree.leaves(loss_terms)) # number of loss terms
|
|
74
|
+
epsilon = 1e-6
|
|
75
|
+
|
|
76
|
+
def do_nothing(loss_weights, _):
|
|
77
|
+
return jnp.array(
|
|
78
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def compute_softmax_weights(current, reference):
|
|
82
|
+
ratio_pytree = jax.tree.map(
|
|
83
|
+
lambda lt, ref: lt / (ref + epsilon),
|
|
84
|
+
current,
|
|
85
|
+
reference,
|
|
86
|
+
)
|
|
87
|
+
mu = jax.tree.reduce(jnp.maximum, ratio_pytree, initializer=-jnp.inf)
|
|
88
|
+
ratio_pytree = jax.tree.map(lambda r: r - mu, ratio_pytree)
|
|
89
|
+
ratio_leaves = jax.tree.leaves(ratio_pytree)
|
|
90
|
+
return jax.nn.softmax(jnp.array(ratio_leaves))
|
|
91
|
+
|
|
92
|
+
def soft_adapt_prev(stored_loss_terms):
|
|
93
|
+
# ω_j(i-1)
|
|
94
|
+
prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
|
|
95
|
+
prev_prev_terms = jax.tree.map(
|
|
96
|
+
lambda slt: slt[iteration_nb - 2], stored_loss_terms
|
|
97
|
+
)
|
|
98
|
+
return compute_softmax_weights(prev_terms, prev_prev_terms)
|
|
99
|
+
|
|
100
|
+
def look_back(loss_terms, stored_loss_terms):
|
|
101
|
+
# ω̂_j^(i,0)
|
|
102
|
+
initial_terms = jax.tree.map(lambda slt: tau * slt[0], stored_loss_terms)
|
|
103
|
+
weights = compute_softmax_weights(loss_terms, initial_terms)
|
|
104
|
+
return n_loss * weights
|
|
105
|
+
|
|
106
|
+
def soft_adapt_current(loss_terms, stored_loss_terms):
|
|
107
|
+
# ω_j(i)
|
|
108
|
+
prev_terms = jax.tree.map(lambda slt: slt[iteration_nb - 1], stored_loss_terms)
|
|
109
|
+
return compute_softmax_weights(loss_terms, prev_terms)
|
|
110
|
+
|
|
111
|
+
# Bernoulli variable for random lookback
|
|
112
|
+
rho = jax.random.bernoulli(key, p).astype(float)
|
|
113
|
+
|
|
114
|
+
# Base case for first iteration
|
|
115
|
+
def first_iter_case(_):
|
|
116
|
+
return do_nothing(loss_weights, None)
|
|
117
|
+
|
|
118
|
+
# Case for iteration >= 1
|
|
119
|
+
def subsequent_iter_case(_):
|
|
120
|
+
# Compute historical weights
|
|
121
|
+
def hist_weights_case1(_):
|
|
122
|
+
return soft_adapt_current(loss_terms, stored_loss_terms)
|
|
123
|
+
|
|
124
|
+
def hist_weights_case2(_):
|
|
125
|
+
return rho * soft_adapt_prev(stored_loss_terms) + (1 - rho) * look_back(
|
|
126
|
+
loss_terms, stored_loss_terms
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
loss_weights_hist = jax.lax.cond(
|
|
130
|
+
iteration_nb < 2,
|
|
131
|
+
hist_weights_case1,
|
|
132
|
+
hist_weights_case2,
|
|
133
|
+
None,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Compute and return final weights
|
|
137
|
+
adaptive_weights = soft_adapt_current(loss_terms, stored_loss_terms)
|
|
138
|
+
return decay_factor * loss_weights_hist + (1 - decay_factor) * adaptive_weights
|
|
139
|
+
|
|
140
|
+
return jax.lax.cond(
|
|
141
|
+
iteration_nb == 0,
|
|
142
|
+
first_iter_case,
|
|
143
|
+
subsequent_iter_case,
|
|
144
|
+
None,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def lr_annealing(
|
|
149
|
+
loss_weights: AbstractLossWeights,
|
|
150
|
+
grad_terms: AnyLossComponents,
|
|
151
|
+
decay_factor: float = 0.9, # 0.9 is the recommended value from the article
|
|
152
|
+
eps: float = 1e-6,
|
|
153
|
+
) -> Array:
|
|
154
|
+
r"""
|
|
155
|
+
Implementation of the Learning rate annealing
|
|
156
|
+
Algorithm 1 in the paper UNDERSTANDING AND MITIGATING GRADIENT PATHOLOGIES IN PHYSICS-INFORMED NEURAL NETWORKS
|
|
157
|
+
|
|
158
|
+
(a) Compute $\hat{\lambda}_i$ by
|
|
159
|
+
$$
|
|
160
|
+
\hat{\lambda}_i = \frac{\max_{\theta}\{|\nabla_\theta \mathcal{L}_r (\theta_n)|\}}{mean(|\nabla_\theta \mathcal{L}_i (\theta_n)|)}, \quad i=1,\dots, M,
|
|
161
|
+
$$
|
|
162
|
+
|
|
163
|
+
(b) Update the weights $\lambda_i$ using a moving average of the form
|
|
164
|
+
$$
|
|
165
|
+
\lambda_i = (1-\alpha) \lambda_{i-1} + \alpha \hat{\lambda}_i, \quad i=1, \dots, M.
|
|
166
|
+
$$
|
|
167
|
+
|
|
168
|
+
Note that since None is not treated as a leaf by jax tree.util functions,
|
|
169
|
+
we naturally avoid None components from loss_terms, stored_loss_terms etc.!
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
assert hasattr(grad_terms, "dyn_loss")
|
|
173
|
+
dyn_loss_grads = getattr(grad_terms, "dyn_loss")
|
|
174
|
+
data_fit_grads = [
|
|
175
|
+
getattr(grad_terms, att) if hasattr(grad_terms, att) else None
|
|
176
|
+
for att in ["norm_loss", "boundary_loss", "observations", "initial_condition"]
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
dyn_loss_grads_leaves = jax.tree.leaves(
|
|
180
|
+
dyn_loss_grads,
|
|
181
|
+
is_leaf=eqx.is_inexact_array,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
max_dyn_loss_grads = jnp.max(
|
|
185
|
+
jnp.stack([jnp.max(jnp.abs(g)) for g in dyn_loss_grads_leaves])
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
mean_gradients = [
|
|
189
|
+
jnp.mean(jnp.stack([jnp.abs(jnp.mean(g)) for g in jax.tree.leaves(t)]))
|
|
190
|
+
for t in data_fit_grads
|
|
191
|
+
if t is not None and jax.tree.leaves(t)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
lambda_hat = max_dyn_loss_grads / (jnp.array(mean_gradients) + eps)
|
|
195
|
+
old_weights = jnp.array(
|
|
196
|
+
jax.tree.leaves(
|
|
197
|
+
loss_weights,
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
new_weights = (1 - decay_factor) * old_weights[1:] + decay_factor * lambda_hat
|
|
202
|
+
return jnp.hstack([old_weights[0], new_weights])
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -2,26 +2,82 @@
|
|
|
2
2
|
Formalize the loss weights data structure
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from dataclasses import fields
|
|
7
|
+
|
|
8
|
+
from jaxtyping import Array
|
|
9
|
+
import jax.numpy as jnp
|
|
6
10
|
import equinox as eqx
|
|
7
11
|
|
|
8
12
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
+
def lw_converter(x):
|
|
14
|
+
if x is None:
|
|
15
|
+
return x
|
|
16
|
+
else:
|
|
17
|
+
return jnp.asarray(x)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AbstractLossWeights(eqx.Module):
|
|
21
|
+
"""
|
|
22
|
+
An abstract class, currently only useful for type hints
|
|
23
|
+
|
|
24
|
+
TODO in the future maybe loss weights could be subclasses of
|
|
25
|
+
XDEComponentsAbstract?
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def items(self):
|
|
29
|
+
"""
|
|
30
|
+
For the dataclass to be iterated like a dictionary.
|
|
31
|
+
Practical and retrocompatible with old code when loss components were
|
|
32
|
+
dictionaries
|
|
33
|
+
"""
|
|
34
|
+
return {
|
|
35
|
+
field.name: getattr(self, field.name)
|
|
36
|
+
for field in fields(self)
|
|
37
|
+
if getattr(self, field.name) is not None
|
|
38
|
+
}.items()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LossWeightsODE(AbstractLossWeights):
|
|
42
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
43
|
+
kw_only=True, default=None, converter=lw_converter
|
|
44
|
+
)
|
|
45
|
+
initial_condition: Array | float | None = eqx.field(
|
|
46
|
+
kw_only=True, default=None, converter=lw_converter
|
|
47
|
+
)
|
|
48
|
+
observations: Array | float | None = eqx.field(
|
|
49
|
+
kw_only=True, default=None, converter=lw_converter
|
|
50
|
+
)
|
|
13
51
|
|
|
14
52
|
|
|
15
|
-
class LossWeightsPDEStatio(
|
|
16
|
-
dyn_loss: Array |
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
53
|
+
class LossWeightsPDEStatio(AbstractLossWeights):
|
|
54
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
55
|
+
kw_only=True, default=None, converter=lw_converter
|
|
56
|
+
)
|
|
57
|
+
norm_loss: Array | float | None = eqx.field(
|
|
58
|
+
kw_only=True, default=None, converter=lw_converter
|
|
59
|
+
)
|
|
60
|
+
boundary_loss: Array | float | None = eqx.field(
|
|
61
|
+
kw_only=True, default=None, converter=lw_converter
|
|
62
|
+
)
|
|
63
|
+
observations: Array | float | None = eqx.field(
|
|
64
|
+
kw_only=True, default=None, converter=lw_converter
|
|
65
|
+
)
|
|
20
66
|
|
|
21
67
|
|
|
22
|
-
class LossWeightsPDENonStatio(
|
|
23
|
-
dyn_loss: Array |
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
68
|
+
class LossWeightsPDENonStatio(AbstractLossWeights):
|
|
69
|
+
dyn_loss: Array | float | None = eqx.field(
|
|
70
|
+
kw_only=True, default=None, converter=lw_converter
|
|
71
|
+
)
|
|
72
|
+
norm_loss: Array | float | None = eqx.field(
|
|
73
|
+
kw_only=True, default=None, converter=lw_converter
|
|
74
|
+
)
|
|
75
|
+
boundary_loss: Array | float | None = eqx.field(
|
|
76
|
+
kw_only=True, default=None, converter=lw_converter
|
|
77
|
+
)
|
|
78
|
+
observations: Array | float | None = eqx.field(
|
|
79
|
+
kw_only=True, default=None, converter=lw_converter
|
|
80
|
+
)
|
|
81
|
+
initial_condition: Array | float | None = eqx.field(
|
|
82
|
+
kw_only=True, default=None, converter=lw_converter
|
|
83
|
+
)
|
jinns/parameters/_params.py
CHANGED
|
@@ -10,6 +10,14 @@ from jaxtyping import Array, PyTree, Float
|
|
|
10
10
|
T = TypeVar("T") # the generic type for what is in the Params PyTree because we
|
|
11
11
|
# have possibly Params of Arrays, boolean, ...
|
|
12
12
|
|
|
13
|
+
### NOTE
|
|
14
|
+
### We are taking derivatives with respect to Params eqx.Modules.
|
|
15
|
+
### This has been shown to behave weirdly if some fields of eqx.Modules have
|
|
16
|
+
### been set as `field(init=False)`, we then should never create such fields in
|
|
17
|
+
### jinns' Params modules.
|
|
18
|
+
### We currently have silenced the warning related to this (see jinns.__init__
|
|
19
|
+
### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
|
|
20
|
+
|
|
13
21
|
|
|
14
22
|
class Params(eqx.Module, Generic[T]):
|
|
15
23
|
"""
|