jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +3 -3
- jinns/loss/_LossPDE.py +27 -36
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +78 -56
- jinns/loss/_operators.py +441 -184
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_utils.py
CHANGED
|
@@ -16,10 +16,9 @@ from jaxtyping import Float, Array, PyTree
|
|
|
16
16
|
from jinns.loss._boundary_conditions import (
|
|
17
17
|
_compute_boundary_loss,
|
|
18
18
|
)
|
|
19
|
-
from jinns.utils._utils import
|
|
20
|
-
from jinns.data._DataGenerators import
|
|
21
|
-
|
|
22
|
-
)
|
|
19
|
+
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
|
+
from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
|
|
21
|
+
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
23
22
|
from jinns.utils._pinn import PINN
|
|
24
23
|
from jinns.utils._spinn import SPINN
|
|
25
24
|
from jinns.utils._hyperpinn import HYPERPINN
|
|
@@ -33,7 +32,11 @@ if TYPE_CHECKING:
|
|
|
33
32
|
def dynamic_loss_apply(
|
|
34
33
|
dyn_loss: DynamicLoss,
|
|
35
34
|
u: eqx.Module,
|
|
36
|
-
|
|
35
|
+
batch: (
|
|
36
|
+
Float[Array, "batch_size 1"]
|
|
37
|
+
| Float[Array, "batch_size dim"]
|
|
38
|
+
| Float[Array, "batch_size 1+dim"]
|
|
39
|
+
),
|
|
37
40
|
params: Params | ParamsDict,
|
|
38
41
|
vmap_axes: tuple[int | None, ...],
|
|
39
42
|
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
@@ -45,16 +48,16 @@ def dynamic_loss_apply(
|
|
|
45
48
|
"""
|
|
46
49
|
if u_type == PINN or u_type == HYPERPINN or isinstance(u, (PINN, HYPERPINN)):
|
|
47
50
|
v_dyn_loss = vmap(
|
|
48
|
-
lambda
|
|
49
|
-
|
|
51
|
+
lambda batch, params: dyn_loss(
|
|
52
|
+
batch, u, params # we must place the params at the end
|
|
50
53
|
),
|
|
51
54
|
vmap_axes,
|
|
52
55
|
0,
|
|
53
56
|
)
|
|
54
|
-
residuals = v_dyn_loss(
|
|
57
|
+
residuals = v_dyn_loss(batch, params)
|
|
55
58
|
mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
|
|
56
59
|
elif u_type == SPINN or isinstance(u, SPINN):
|
|
57
|
-
residuals = dyn_loss(
|
|
60
|
+
residuals = dyn_loss(batch, u, params)
|
|
58
61
|
mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
|
|
59
62
|
else:
|
|
60
63
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
@@ -64,35 +67,49 @@ def dynamic_loss_apply(
|
|
|
64
67
|
|
|
65
68
|
def normalization_loss_apply(
|
|
66
69
|
u: eqx.Module,
|
|
67
|
-
batches:
|
|
70
|
+
batches: (
|
|
71
|
+
tuple[Float[Array, "nb_norm_samples dim"]]
|
|
72
|
+
| tuple[
|
|
73
|
+
Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
|
|
74
|
+
]
|
|
75
|
+
),
|
|
68
76
|
params: Params | ParamsDict,
|
|
69
|
-
|
|
77
|
+
vmap_axes_params: tuple[int | None, ...],
|
|
70
78
|
int_length: int,
|
|
71
79
|
loss_weight: float,
|
|
72
80
|
) -> float:
|
|
73
|
-
|
|
81
|
+
"""
|
|
82
|
+
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
|
+
they represent probability distributions
|
|
84
|
+
"""
|
|
74
85
|
if isinstance(u, (PINN, HYPERPINN)):
|
|
75
86
|
if len(batches) == 1:
|
|
76
87
|
v_u = vmap(
|
|
77
|
-
lambda
|
|
78
|
-
|
|
88
|
+
lambda b: u(b)[u.slice_solution],
|
|
89
|
+
(0,) + vmap_axes_params,
|
|
79
90
|
0,
|
|
80
91
|
)
|
|
81
|
-
|
|
82
|
-
|
|
92
|
+
res = v_u(*batches, params)
|
|
93
|
+
mse_norm_loss = loss_weight * (
|
|
94
|
+
jnp.abs(jnp.mean(res.squeeze()) * int_length - 1) ** 2
|
|
83
95
|
)
|
|
84
96
|
else:
|
|
97
|
+
# NOTE this cartesian product is costly
|
|
98
|
+
batches = make_cartesian_product(
|
|
99
|
+
batches[0],
|
|
100
|
+
batches[1],
|
|
101
|
+
).reshape(batches[0].shape[0], batches[1].shape[0], -1)
|
|
85
102
|
v_u = vmap(
|
|
86
103
|
vmap(
|
|
87
|
-
lambda
|
|
88
|
-
in_axes=(
|
|
104
|
+
lambda t_x, params_: u(t_x, params_),
|
|
105
|
+
in_axes=(0,) + vmap_axes_params,
|
|
89
106
|
),
|
|
90
|
-
in_axes=(0,
|
|
107
|
+
in_axes=(0,) + vmap_axes_params,
|
|
91
108
|
)
|
|
92
|
-
res = v_u(
|
|
93
|
-
#
|
|
109
|
+
res = v_u(batches, params)
|
|
110
|
+
# Over all the times t, we perform a integration
|
|
94
111
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
95
|
-
jnp.abs(jnp.mean(res, axis
|
|
112
|
+
jnp.abs(jnp.mean(res.squeeze(), axis=-1) * int_length - 1) ** 2
|
|
96
113
|
)
|
|
97
114
|
elif isinstance(u, SPINN):
|
|
98
115
|
if len(batches) == 1:
|
|
@@ -101,8 +118,7 @@ def normalization_loss_apply(
|
|
|
101
118
|
loss_weight
|
|
102
119
|
* jnp.abs(
|
|
103
120
|
jnp.mean(
|
|
104
|
-
|
|
105
|
-
axis=tuple(range(res.ndim - 1)),
|
|
121
|
+
res.squeeze(),
|
|
106
122
|
)
|
|
107
123
|
* int_length
|
|
108
124
|
- 1
|
|
@@ -112,12 +128,17 @@ def normalization_loss_apply(
|
|
|
112
128
|
else:
|
|
113
129
|
assert batches[1].shape[0] % batches[0].shape[0] == 0
|
|
114
130
|
rep_t = batches[1].shape[0] // batches[0].shape[0]
|
|
115
|
-
res = u(
|
|
131
|
+
res = u(
|
|
132
|
+
jnp.concatenate(
|
|
133
|
+
[jnp.repeat(batches[0], rep_t, axis=0), batches[1]], axis=-1
|
|
134
|
+
),
|
|
135
|
+
params,
|
|
136
|
+
)
|
|
116
137
|
# the outer mean() below is for the times stamps
|
|
117
138
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
118
139
|
jnp.abs(
|
|
119
140
|
jnp.mean(
|
|
120
|
-
|
|
141
|
+
res.squeeze(),
|
|
121
142
|
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
122
143
|
)
|
|
123
144
|
* int_length
|
|
@@ -140,23 +161,16 @@ def boundary_condition_apply(
|
|
|
140
161
|
omega_boundary_dim: int,
|
|
141
162
|
loss_weight: float | Float[Array, "boundary_cond_dim"],
|
|
142
163
|
) -> float:
|
|
164
|
+
|
|
165
|
+
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
166
|
+
|
|
143
167
|
if isinstance(omega_boundary_fun, dict):
|
|
144
168
|
# We must create the facet tree dictionary as we do not have the
|
|
145
169
|
# enumerate from the for loop to pass the id integer
|
|
146
|
-
if
|
|
147
|
-
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
|
|
148
|
-
) or (
|
|
149
|
-
isinstance(batch, PDENonStatioBatch)
|
|
150
|
-
and batch.times_x_border_batch.shape[-1] == 2
|
|
151
|
-
):
|
|
170
|
+
if batch.border_batch.shape[-1] == 2:
|
|
152
171
|
# 1D
|
|
153
172
|
facet_tree = {"xmin": 0, "xmax": 1}
|
|
154
|
-
elif
|
|
155
|
-
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
|
|
156
|
-
) or (
|
|
157
|
-
isinstance(batch, PDENonStatioBatch)
|
|
158
|
-
and batch.times_x_border_batch.shape[-1] == 4
|
|
159
|
-
):
|
|
173
|
+
elif batch.border_batch.shape[-1] == 4:
|
|
160
174
|
# 2D
|
|
161
175
|
facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
|
|
162
176
|
else:
|
|
@@ -166,7 +180,10 @@ def boundary_condition_apply(
|
|
|
166
180
|
None
|
|
167
181
|
if c is None
|
|
168
182
|
else jnp.mean(
|
|
169
|
-
loss_weight
|
|
183
|
+
loss_weight
|
|
184
|
+
* _compute_boundary_loss(
|
|
185
|
+
c, f, batch, u, params, fa, d, vmap_in_axes
|
|
186
|
+
)
|
|
170
187
|
)
|
|
171
188
|
),
|
|
172
189
|
omega_boundary_condition,
|
|
@@ -180,10 +197,7 @@ def boundary_condition_apply(
|
|
|
180
197
|
# Note that to keep the behaviour given in the comment above we neede
|
|
181
198
|
# to specify is_leaf according to the note in the release of 0.4.29
|
|
182
199
|
else:
|
|
183
|
-
|
|
184
|
-
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
185
|
-
else:
|
|
186
|
-
facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
|
|
200
|
+
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
187
201
|
b_losses_by_facet = jax.tree_util.tree_map(
|
|
188
202
|
lambda fa: jnp.mean(
|
|
189
203
|
loss_weight
|
|
@@ -195,6 +209,7 @@ def boundary_condition_apply(
|
|
|
195
209
|
params,
|
|
196
210
|
fa,
|
|
197
211
|
omega_boundary_dim,
|
|
212
|
+
vmap_in_axes,
|
|
198
213
|
)
|
|
199
214
|
),
|
|
200
215
|
facet_tuple,
|
|
@@ -225,8 +240,10 @@ def observations_loss_apply(
|
|
|
225
240
|
mse_observation_loss = jnp.mean(
|
|
226
241
|
jnp.sum(
|
|
227
242
|
loss_weight
|
|
228
|
-
* (
|
|
229
|
-
|
|
243
|
+
* _subtract_with_check(
|
|
244
|
+
observed_values, val, cause="user defined observed_values"
|
|
245
|
+
)
|
|
246
|
+
** 2,
|
|
230
247
|
axis=-1,
|
|
231
248
|
)
|
|
232
249
|
)
|
|
@@ -243,33 +260,38 @@ def initial_condition_apply(
|
|
|
243
260
|
params: Params | ParamsDict,
|
|
244
261
|
vmap_axes: tuple[int | None, ...],
|
|
245
262
|
initial_condition_fun: Callable,
|
|
246
|
-
n: int,
|
|
247
263
|
loss_weight: float | Float[Array, "initial_condition_dimension"],
|
|
248
264
|
) -> float:
|
|
265
|
+
n = omega_batch.shape[0]
|
|
266
|
+
t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
|
|
249
267
|
if isinstance(u, (PINN, HYPERPINN)):
|
|
250
268
|
v_u_t0 = vmap(
|
|
251
|
-
lambda
|
|
269
|
+
lambda t0_x, params: _subtract_with_check(
|
|
270
|
+
initial_condition_fun(t0_x[1:]),
|
|
271
|
+
u(t0_x, params),
|
|
272
|
+
cause="Output of initial_condition_fun",
|
|
273
|
+
),
|
|
252
274
|
vmap_axes,
|
|
253
275
|
0,
|
|
254
276
|
)
|
|
255
|
-
res = v_u_t0(
|
|
277
|
+
res = v_u_t0(t0_omega_batch, params) # NOTE take the tiled
|
|
256
278
|
# omega_batch (ie omega_batch_) to have the same batch
|
|
257
279
|
# dimension as params to be able to vmap.
|
|
258
280
|
# Recall that by convention:
|
|
259
281
|
# param_batch_dict = times_batch_size * omega_batch_size
|
|
260
282
|
mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
|
|
261
283
|
elif isinstance(u, SPINN):
|
|
262
|
-
values = lambda
|
|
263
|
-
|
|
264
|
-
x,
|
|
284
|
+
values = lambda t_x: u(
|
|
285
|
+
t_x,
|
|
265
286
|
params,
|
|
266
287
|
)[0]
|
|
267
|
-
omega_batch_grid =
|
|
268
|
-
v_ini = values(
|
|
269
|
-
|
|
270
|
-
initial_condition_fun(omega_batch_grid),
|
|
288
|
+
omega_batch_grid = get_grid(omega_batch)
|
|
289
|
+
v_ini = values(t0_omega_batch)
|
|
290
|
+
res = _subtract_with_check(
|
|
291
|
+
initial_condition_fun(omega_batch_grid),
|
|
292
|
+
v_ini,
|
|
293
|
+
cause="Output of initial_condition_fun",
|
|
271
294
|
)
|
|
272
|
-
res = ini - v_ini
|
|
273
295
|
mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
|
|
274
296
|
else:
|
|
275
297
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|