jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +534 -343
- jinns/loss/_DynamicLoss.py +152 -175
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +4 -4
- jinns/loss/_LossPDE.py +102 -74
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +150 -281
- jinns/loss/_loss_utils.py +95 -67
- jinns/loss/_operators.py +441 -186
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +47 -31
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +113 -100
- jinns/solver/_rar.py +104 -409
- jinns/solver/_solve.py +87 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +1 -4
- jinns/utils/_containers.py +3 -1
- jinns/utils/_types.py +5 -4
- jinns/utils/_utils.py +40 -12
- jinns-1.3.0.dist-info/METADATA +127 -0
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -410
- jinns/utils/_pinn.py +0 -334
- jinns/utils/_spinn.py +0 -268
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/loss/_loss_utils.py
CHANGED
|
@@ -16,13 +16,12 @@ from jaxtyping import Float, Array, PyTree
|
|
|
16
16
|
from jinns.loss._boundary_conditions import (
|
|
17
17
|
_compute_boundary_loss,
|
|
18
18
|
)
|
|
19
|
-
from jinns.utils._utils import
|
|
20
|
-
from jinns.data._DataGenerators import
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from jinns.
|
|
24
|
-
from jinns.
|
|
25
|
-
from jinns.utils._hyperpinn import HYPERPINN
|
|
19
|
+
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
|
+
from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
|
|
21
|
+
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
22
|
+
from jinns.nn._pinn import PINN
|
|
23
|
+
from jinns.nn._spinn import SPINN
|
|
24
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
26
25
|
from jinns.data._Batchs import *
|
|
27
26
|
from jinns.parameters._params import Params, ParamsDict
|
|
28
27
|
|
|
@@ -33,28 +32,32 @@ if TYPE_CHECKING:
|
|
|
33
32
|
def dynamic_loss_apply(
|
|
34
33
|
dyn_loss: DynamicLoss,
|
|
35
34
|
u: eqx.Module,
|
|
36
|
-
|
|
35
|
+
batch: (
|
|
36
|
+
Float[Array, "batch_size 1"]
|
|
37
|
+
| Float[Array, "batch_size dim"]
|
|
38
|
+
| Float[Array, "batch_size 1+dim"]
|
|
39
|
+
),
|
|
37
40
|
params: Params | ParamsDict,
|
|
38
41
|
vmap_axes: tuple[int | None, ...],
|
|
39
42
|
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
40
|
-
u_type: PINN |
|
|
43
|
+
u_type: PINN | HyperPINN | None = None,
|
|
41
44
|
) -> float:
|
|
42
45
|
"""
|
|
43
46
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
44
47
|
its type here, hence the last argument
|
|
45
48
|
"""
|
|
46
|
-
if u_type == PINN or u_type ==
|
|
49
|
+
if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
|
|
47
50
|
v_dyn_loss = vmap(
|
|
48
|
-
lambda
|
|
49
|
-
|
|
51
|
+
lambda batch, params: dyn_loss(
|
|
52
|
+
batch, u, params # we must place the params at the end
|
|
50
53
|
),
|
|
51
54
|
vmap_axes,
|
|
52
55
|
0,
|
|
53
56
|
)
|
|
54
|
-
residuals = v_dyn_loss(
|
|
57
|
+
residuals = v_dyn_loss(batch, params)
|
|
55
58
|
mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
|
|
56
59
|
elif u_type == SPINN or isinstance(u, SPINN):
|
|
57
|
-
residuals = dyn_loss(
|
|
60
|
+
residuals = dyn_loss(batch, u, params)
|
|
58
61
|
mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
|
|
59
62
|
else:
|
|
60
63
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
@@ -64,47 +67,65 @@ def dynamic_loss_apply(
|
|
|
64
67
|
|
|
65
68
|
def normalization_loss_apply(
|
|
66
69
|
u: eqx.Module,
|
|
67
|
-
batches:
|
|
70
|
+
batches: (
|
|
71
|
+
tuple[Float[Array, "nb_norm_samples dim"]]
|
|
72
|
+
| tuple[
|
|
73
|
+
Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
|
|
74
|
+
]
|
|
75
|
+
),
|
|
68
76
|
params: Params | ParamsDict,
|
|
69
|
-
|
|
70
|
-
|
|
77
|
+
vmap_axes_params: tuple[int | None, ...],
|
|
78
|
+
norm_weights: Float[Array, "nb_norm_samples"],
|
|
71
79
|
loss_weight: float,
|
|
72
80
|
) -> float:
|
|
73
|
-
|
|
74
|
-
|
|
81
|
+
"""
|
|
82
|
+
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
|
+
they represent probability distributions
|
|
84
|
+
"""
|
|
85
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
75
86
|
if len(batches) == 1:
|
|
76
87
|
v_u = vmap(
|
|
77
|
-
lambda *
|
|
78
|
-
|
|
88
|
+
lambda *b: u(*b)[u.slice_solution],
|
|
89
|
+
(0,) + vmap_axes_params,
|
|
79
90
|
0,
|
|
80
91
|
)
|
|
81
|
-
|
|
82
|
-
|
|
92
|
+
res = v_u(*batches, params)
|
|
93
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
94
|
+
# Monte-Carlo integration using importance sampling
|
|
95
|
+
mse_norm_loss = loss_weight * (
|
|
96
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
83
97
|
)
|
|
84
98
|
else:
|
|
99
|
+
# NOTE this cartesian product is costly
|
|
100
|
+
batches = make_cartesian_product(
|
|
101
|
+
batches[0],
|
|
102
|
+
batches[1],
|
|
103
|
+
).reshape(batches[0].shape[0], batches[1].shape[0], -1)
|
|
85
104
|
v_u = vmap(
|
|
86
105
|
vmap(
|
|
87
|
-
lambda
|
|
88
|
-
in_axes=(
|
|
106
|
+
lambda t_x, params_: u(t_x, params_),
|
|
107
|
+
in_axes=(0,) + vmap_axes_params,
|
|
89
108
|
),
|
|
90
|
-
in_axes=(0,
|
|
109
|
+
in_axes=(0,) + vmap_axes_params,
|
|
91
110
|
)
|
|
92
|
-
res = v_u(
|
|
93
|
-
|
|
111
|
+
res = v_u(batches, params)
|
|
112
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
113
|
+
# For all times t, we perform an integration. Then we average the
|
|
114
|
+
# losses over times.
|
|
94
115
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
95
|
-
jnp.abs(jnp.mean(res, axis
|
|
116
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
|
|
96
117
|
)
|
|
97
118
|
elif isinstance(u, SPINN):
|
|
98
119
|
if len(batches) == 1:
|
|
99
120
|
res = u(*batches, params)
|
|
121
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
100
122
|
mse_norm_loss = (
|
|
101
123
|
loss_weight
|
|
102
124
|
* jnp.abs(
|
|
103
125
|
jnp.mean(
|
|
104
|
-
|
|
105
|
-
axis=tuple(range(res.ndim - 1)),
|
|
126
|
+
res.squeeze(),
|
|
106
127
|
)
|
|
107
|
-
*
|
|
128
|
+
* norm_weights
|
|
108
129
|
- 1
|
|
109
130
|
)
|
|
110
131
|
** 2
|
|
@@ -112,15 +133,21 @@ def normalization_loss_apply(
|
|
|
112
133
|
else:
|
|
113
134
|
assert batches[1].shape[0] % batches[0].shape[0] == 0
|
|
114
135
|
rep_t = batches[1].shape[0] // batches[0].shape[0]
|
|
115
|
-
res = u(
|
|
136
|
+
res = u(
|
|
137
|
+
jnp.concatenate(
|
|
138
|
+
[jnp.repeat(batches[0], rep_t, axis=0), batches[1]], axis=-1
|
|
139
|
+
),
|
|
140
|
+
params,
|
|
141
|
+
)
|
|
142
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
116
143
|
# the outer mean() below is for the times stamps
|
|
117
144
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
118
145
|
jnp.abs(
|
|
119
146
|
jnp.mean(
|
|
120
|
-
|
|
147
|
+
res.squeeze(),
|
|
121
148
|
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
122
149
|
)
|
|
123
|
-
*
|
|
150
|
+
* norm_weights
|
|
124
151
|
- 1
|
|
125
152
|
)
|
|
126
153
|
** 2
|
|
@@ -140,23 +167,16 @@ def boundary_condition_apply(
|
|
|
140
167
|
omega_boundary_dim: int,
|
|
141
168
|
loss_weight: float | Float[Array, "boundary_cond_dim"],
|
|
142
169
|
) -> float:
|
|
170
|
+
|
|
171
|
+
vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
172
|
+
|
|
143
173
|
if isinstance(omega_boundary_fun, dict):
|
|
144
174
|
# We must create the facet tree dictionary as we do not have the
|
|
145
175
|
# enumerate from the for loop to pass the id integer
|
|
146
|
-
if
|
|
147
|
-
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
|
|
148
|
-
) or (
|
|
149
|
-
isinstance(batch, PDENonStatioBatch)
|
|
150
|
-
and batch.times_x_border_batch.shape[-1] == 2
|
|
151
|
-
):
|
|
176
|
+
if batch.border_batch.shape[-1] == 2:
|
|
152
177
|
# 1D
|
|
153
178
|
facet_tree = {"xmin": 0, "xmax": 1}
|
|
154
|
-
elif
|
|
155
|
-
isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
|
|
156
|
-
) or (
|
|
157
|
-
isinstance(batch, PDENonStatioBatch)
|
|
158
|
-
and batch.times_x_border_batch.shape[-1] == 4
|
|
159
|
-
):
|
|
179
|
+
elif batch.border_batch.shape[-1] == 4:
|
|
160
180
|
# 2D
|
|
161
181
|
facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
|
|
162
182
|
else:
|
|
@@ -166,7 +186,10 @@ def boundary_condition_apply(
|
|
|
166
186
|
None
|
|
167
187
|
if c is None
|
|
168
188
|
else jnp.mean(
|
|
169
|
-
loss_weight
|
|
189
|
+
loss_weight
|
|
190
|
+
* _compute_boundary_loss(
|
|
191
|
+
c, f, batch, u, params, fa, d, vmap_in_axes
|
|
192
|
+
)
|
|
170
193
|
)
|
|
171
194
|
),
|
|
172
195
|
omega_boundary_condition,
|
|
@@ -180,10 +203,7 @@ def boundary_condition_apply(
|
|
|
180
203
|
# Note that to keep the behaviour given in the comment above we neede
|
|
181
204
|
# to specify is_leaf according to the note in the release of 0.4.29
|
|
182
205
|
else:
|
|
183
|
-
|
|
184
|
-
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
185
|
-
else:
|
|
186
|
-
facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
|
|
206
|
+
facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
|
|
187
207
|
b_losses_by_facet = jax.tree_util.tree_map(
|
|
188
208
|
lambda fa: jnp.mean(
|
|
189
209
|
loss_weight
|
|
@@ -195,6 +215,7 @@ def boundary_condition_apply(
|
|
|
195
215
|
params,
|
|
196
216
|
fa,
|
|
197
217
|
omega_boundary_dim,
|
|
218
|
+
vmap_in_axes,
|
|
198
219
|
)
|
|
199
220
|
),
|
|
200
221
|
facet_tuple,
|
|
@@ -215,7 +236,7 @@ def observations_loss_apply(
|
|
|
215
236
|
obs_slice: slice,
|
|
216
237
|
) -> float:
|
|
217
238
|
# TODO implement for SPINN
|
|
218
|
-
if isinstance(u, (PINN,
|
|
239
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
219
240
|
v_u = vmap(
|
|
220
241
|
lambda *args: u(*args)[u.slice_solution],
|
|
221
242
|
vmap_axes,
|
|
@@ -225,8 +246,10 @@ def observations_loss_apply(
|
|
|
225
246
|
mse_observation_loss = jnp.mean(
|
|
226
247
|
jnp.sum(
|
|
227
248
|
loss_weight
|
|
228
|
-
* (
|
|
229
|
-
|
|
249
|
+
* _subtract_with_check(
|
|
250
|
+
observed_values, val, cause="user defined observed_values"
|
|
251
|
+
)
|
|
252
|
+
** 2,
|
|
230
253
|
axis=-1,
|
|
231
254
|
)
|
|
232
255
|
)
|
|
@@ -243,33 +266,38 @@ def initial_condition_apply(
|
|
|
243
266
|
params: Params | ParamsDict,
|
|
244
267
|
vmap_axes: tuple[int | None, ...],
|
|
245
268
|
initial_condition_fun: Callable,
|
|
246
|
-
n: int,
|
|
247
269
|
loss_weight: float | Float[Array, "initial_condition_dimension"],
|
|
248
270
|
) -> float:
|
|
249
|
-
|
|
271
|
+
n = omega_batch.shape[0]
|
|
272
|
+
t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
|
|
273
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
250
274
|
v_u_t0 = vmap(
|
|
251
|
-
lambda
|
|
275
|
+
lambda t0_x, params: _subtract_with_check(
|
|
276
|
+
initial_condition_fun(t0_x[1:]),
|
|
277
|
+
u(t0_x, params),
|
|
278
|
+
cause="Output of initial_condition_fun",
|
|
279
|
+
),
|
|
252
280
|
vmap_axes,
|
|
253
281
|
0,
|
|
254
282
|
)
|
|
255
|
-
res = v_u_t0(
|
|
283
|
+
res = v_u_t0(t0_omega_batch, params) # NOTE take the tiled
|
|
256
284
|
# omega_batch (ie omega_batch_) to have the same batch
|
|
257
285
|
# dimension as params to be able to vmap.
|
|
258
286
|
# Recall that by convention:
|
|
259
287
|
# param_batch_dict = times_batch_size * omega_batch_size
|
|
260
288
|
mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
|
|
261
289
|
elif isinstance(u, SPINN):
|
|
262
|
-
values = lambda
|
|
263
|
-
|
|
264
|
-
x,
|
|
290
|
+
values = lambda t_x: u(
|
|
291
|
+
t_x,
|
|
265
292
|
params,
|
|
266
293
|
)[0]
|
|
267
|
-
omega_batch_grid =
|
|
268
|
-
v_ini = values(
|
|
269
|
-
|
|
270
|
-
initial_condition_fun(omega_batch_grid),
|
|
294
|
+
omega_batch_grid = get_grid(omega_batch)
|
|
295
|
+
v_ini = values(t0_omega_batch)
|
|
296
|
+
res = _subtract_with_check(
|
|
297
|
+
initial_condition_fun(omega_batch_grid),
|
|
298
|
+
v_ini,
|
|
299
|
+
cause="Output of initial_condition_fun",
|
|
271
300
|
)
|
|
272
|
-
res = ini - v_ini
|
|
273
301
|
mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
|
|
274
302
|
else:
|
|
275
303
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|