jinns 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_DataGeneratorObservations.py +13 -2
- jinns/loss/_LossODE.py +39 -12
- jinns/loss/_LossPDE.py +67 -31
- jinns/loss/_abstract_loss.py +33 -8
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_solve.py +98 -366
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +503 -0
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/METADATA +16 -14
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/RECORD +19 -18
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/WHEEL +0 -0
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.0.dist-info → jinns-1.7.0.dist-info}/top_level.txt +0 -0
jinns/solver/_utils.py
CHANGED
|
@@ -1,7 +1,510 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common functions for _solve.py and _solve_alternate.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Callable
|
|
10
|
+
from functools import partial
|
|
11
|
+
import jax
|
|
12
|
+
from jax import jit
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
import equinox as eqx
|
|
15
|
+
from jaxtyping import PyTree, Float, Array, PRNGKeyArray
|
|
16
|
+
import optax
|
|
17
|
+
|
|
18
|
+
from jinns.data._utils import append_param_batch, append_obs_batch
|
|
19
|
+
from jinns.utils._utils import _check_nan_in_pytree
|
|
1
20
|
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
2
21
|
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
22
|
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
23
|
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
24
|
+
from jinns.parameters._params import Params
|
|
25
|
+
from jinns.utils._containers import (
|
|
26
|
+
LossContainer,
|
|
27
|
+
StoredObjectContainer,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from jinns.utils._types import AnyBatch, SolveCarry, SolveAlternateCarry
|
|
32
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
33
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
34
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _init_stored_weights_terms(loss, n_iter):
|
|
38
|
+
return eqx.tree_at(
|
|
39
|
+
lambda pt: jax.tree.leaves(
|
|
40
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
41
|
+
),
|
|
42
|
+
loss.loss_weights,
|
|
43
|
+
tuple(
|
|
44
|
+
jnp.zeros((n_iter))
|
|
45
|
+
for n in range(
|
|
46
|
+
len(
|
|
47
|
+
jax.tree.leaves(
|
|
48
|
+
loss.loss_weights,
|
|
49
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _init_stored_params(tracked_params, params, n_iter):
|
|
58
|
+
return jax.tree_util.tree_map(
|
|
59
|
+
lambda tracked_param, param: (
|
|
60
|
+
jnp.zeros((n_iter,) + jnp.asarray(param).shape)
|
|
61
|
+
if tracked_param is not None
|
|
62
|
+
else None
|
|
63
|
+
),
|
|
64
|
+
tracked_params,
|
|
65
|
+
params,
|
|
66
|
+
is_leaf=lambda x: x is None, # None values in tracked_params will not
|
|
67
|
+
# be traversed. Thus the user can provide something like
|
|
68
|
+
# ```
|
|
69
|
+
# tracked_params = jinns.parameters.Params(
|
|
70
|
+
# nn_params=None,
|
|
71
|
+
# eq_params={"nu": True})
|
|
72
|
+
# ```
|
|
73
|
+
# even when init_params.nn_params is a complex data structure.
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@partial(jit, static_argnames=["optimizer", "params_mask", "with_loss_weight_update"])
|
|
78
|
+
def _loss_evaluate_and_gradient_step(
|
|
79
|
+
i,
|
|
80
|
+
batch: AnyBatch,
|
|
81
|
+
loss: AbstractLoss,
|
|
82
|
+
params: Params[Array],
|
|
83
|
+
last_non_nan_params: Params[Array],
|
|
84
|
+
state: optax.OptState,
|
|
85
|
+
optimizer: optax.GradientTransformation,
|
|
86
|
+
loss_container: LossContainer,
|
|
87
|
+
key: PRNGKeyArray,
|
|
88
|
+
params_mask: Params[bool] | None = None,
|
|
89
|
+
opt_state_field_for_acceleration: str | None = None,
|
|
90
|
+
with_loss_weight_update: bool = True,
|
|
91
|
+
):
|
|
92
|
+
"""
|
|
93
|
+
# The crux of our new approach is partitioning and recombining the parameters and optimization state according to params_mask.
|
|
94
|
+
|
|
95
|
+
params_mask:
|
|
96
|
+
A jinns.parameters.Params object with boolean as leaves, specifying
|
|
97
|
+
over which parameters optimization is enabled. This usually implies
|
|
98
|
+
important computational gains. Internally, it is used as the
|
|
99
|
+
filter_spec of a eqx.partition function on the parameters. Note that this
|
|
100
|
+
differs from (and complement) DerivativeKeys, as the latter allows
|
|
101
|
+
for more granularity by freezing some gradients with respect to
|
|
102
|
+
different loss terms, but do not subset the optimized parameters globally.
|
|
103
|
+
|
|
104
|
+
NOTE: in this function body, we change naming convention for concision:
|
|
105
|
+
* `state` refers to the general optimizer state
|
|
106
|
+
* `opt_state` refers to the unmasked optimizer state, i.e. which are
|
|
107
|
+
really involved in the parameter update as defined by `params_mask`.
|
|
108
|
+
* `non_opt_state` refers to the the optimizer state for non-optimized
|
|
109
|
+
params.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
(
|
|
113
|
+
opt_params,
|
|
114
|
+
opt_params_accel,
|
|
115
|
+
non_opt_params,
|
|
116
|
+
opt_state,
|
|
117
|
+
non_opt_state,
|
|
118
|
+
) = _get_masked_optimization_stuff(
|
|
119
|
+
params, state, opt_state_field_for_acceleration, params_mask
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# The following part is the equivalent of a
|
|
123
|
+
# > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
|
|
124
|
+
# but it is decomposed on individual loss terms so that we can use it
|
|
125
|
+
# if needed for updating loss weights.
|
|
126
|
+
# Since the total loss is a weighted sum of individual loss terms, so
|
|
127
|
+
# are its total gradients.
|
|
128
|
+
|
|
129
|
+
# 1. Compute individual losses and individual gradients
|
|
130
|
+
loss_terms, grad_terms = loss.evaluate_by_terms(
|
|
131
|
+
opt_params_accel
|
|
132
|
+
if opt_state_field_for_acceleration is not None
|
|
133
|
+
else opt_params,
|
|
134
|
+
batch,
|
|
135
|
+
non_opt_params=non_opt_params,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if loss.update_weight_method is not None and with_loss_weight_update:
|
|
139
|
+
key, subkey = jax.random.split(key) # type: ignore because key can
|
|
140
|
+
# still be None currently
|
|
141
|
+
# avoid computations of tree_at if no updates
|
|
142
|
+
loss = loss.update_weights(
|
|
143
|
+
i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# 2. total grad
|
|
147
|
+
grads = loss.ponderate_and_sum_gradient(grad_terms)
|
|
148
|
+
|
|
149
|
+
# total loss
|
|
150
|
+
train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
|
|
151
|
+
|
|
152
|
+
opt_grads, _ = grads.partition(
|
|
153
|
+
params_mask
|
|
154
|
+
) # because the update cannot be made otherwise
|
|
155
|
+
|
|
156
|
+
# Here, we only use the gradient step of the Optax optimizer on the
|
|
157
|
+
# parameters specified by params_mask. , no dummy state with filled with zero entries
|
|
158
|
+
# all other entries of the pytrees are None thanks to params_mask)
|
|
159
|
+
opt_params, opt_state = _gradient_step(
|
|
160
|
+
opt_grads,
|
|
161
|
+
optimizer,
|
|
162
|
+
opt_params, # NOTE that we never give the accelerated
|
|
163
|
+
# params here, this would be a wrong procedure
|
|
164
|
+
opt_state,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
params, state = _get_unmasked_optimization_stuff(
|
|
168
|
+
opt_params,
|
|
169
|
+
non_opt_params,
|
|
170
|
+
state,
|
|
171
|
+
opt_state,
|
|
172
|
+
non_opt_state,
|
|
173
|
+
params_mask,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# check if any of the parameters is NaN
|
|
177
|
+
last_non_nan_params = jax.lax.cond(
|
|
178
|
+
_check_nan_in_pytree(params),
|
|
179
|
+
lambda _: last_non_nan_params,
|
|
180
|
+
lambda _: params,
|
|
181
|
+
None,
|
|
182
|
+
)
|
|
183
|
+
return train_loss_value, params, last_non_nan_params, state, loss, loss_terms
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@partial(
|
|
187
|
+
jit,
|
|
188
|
+
static_argnames=["optimizer"],
|
|
189
|
+
)
|
|
190
|
+
def _gradient_step(
|
|
191
|
+
grads: Params[Array],
|
|
192
|
+
optimizer: optax.GradientTransformation,
|
|
193
|
+
params: Params[Array],
|
|
194
|
+
state: optax.OptState,
|
|
195
|
+
) -> tuple[
|
|
196
|
+
Params[Array],
|
|
197
|
+
optax.OptState,
|
|
198
|
+
]:
|
|
199
|
+
"""
|
|
200
|
+
optimizer cannot be jit-ted.
|
|
201
|
+
|
|
202
|
+
a plain old gradient step that is compatible with the new masked update
|
|
203
|
+
stuff
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
updates, state = optimizer.update(
|
|
207
|
+
grads, # type: ignore
|
|
208
|
+
state,
|
|
209
|
+
params, # type: ignore
|
|
210
|
+
) # Also see optimizer.init for explanation of type ignore
|
|
211
|
+
params = optax.apply_updates(params, updates) # type: ignore
|
|
212
|
+
|
|
213
|
+
return (
|
|
214
|
+
params,
|
|
215
|
+
state,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@partial(jit, static_argnames=["params_mask"])
|
|
220
|
+
def _get_masked_optimization_stuff(
|
|
221
|
+
params, state, state_field_for_acceleration, params_mask
|
|
222
|
+
):
|
|
223
|
+
"""
|
|
224
|
+
From the parameters `params`, the optimizer state `state`, we use the
|
|
225
|
+
parameter mask `params_mask` to retrieve the partitioned version of those
|
|
226
|
+
two objects, `opt_params` for the parameters that are optimized and
|
|
227
|
+
`non_opt_params` for those that are not optimized. Same for `state`.
|
|
228
|
+
|
|
229
|
+
The argument `state_field_for_acceleration` can correspond to a field
|
|
230
|
+
inside the `state` module. If it is not None, a `opt_params_accel` object
|
|
231
|
+
is created that is different of `opt_params`. See
|
|
232
|
+
`opt_state_field_for_acceleration` in `jinns.solve` docstring for more
|
|
233
|
+
details.
|
|
234
|
+
|
|
235
|
+
The opposite of `eqx.partition` ie, `eqx.combine` is made in the loss
|
|
236
|
+
`evaluevaluate_by_terms()` method for the computations and in
|
|
237
|
+
`_get_unmasked_optimization_stuff` to reconstruct the object after the
|
|
238
|
+
gradient step
|
|
239
|
+
"""
|
|
240
|
+
opt_params, non_opt_params = params.partition(params_mask)
|
|
241
|
+
opt_state = jax.tree.map(
|
|
242
|
+
lambda l: l.partition(params_mask)[0] if isinstance(l, Params) else l,
|
|
243
|
+
state,
|
|
244
|
+
is_leaf=lambda x: isinstance(x, Params),
|
|
245
|
+
)
|
|
246
|
+
non_opt_state = jax.tree.map(
|
|
247
|
+
lambda l: l.partition(params_mask)[1] if isinstance(l, Params) else l,
|
|
248
|
+
state,
|
|
249
|
+
is_leaf=lambda x: isinstance(x, Params),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# NOTE to enable optimization procedures with acceleration
|
|
253
|
+
if state_field_for_acceleration is not None:
|
|
254
|
+
opt_params_accel = getattr(opt_state, state_field_for_acceleration)
|
|
255
|
+
else:
|
|
256
|
+
opt_params_accel = opt_params
|
|
257
|
+
|
|
258
|
+
return (
|
|
259
|
+
opt_params,
|
|
260
|
+
opt_params_accel,
|
|
261
|
+
non_opt_params,
|
|
262
|
+
opt_state,
|
|
263
|
+
non_opt_state,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@partial(jit, static_argnames=["params_mask"])
|
|
268
|
+
def _get_unmasked_optimization_stuff(
|
|
269
|
+
opt_params, non_opt_params, state, opt_state, non_opt_state, params_mask
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Reverse operations of `_get_masked_optimization_stuff`
|
|
273
|
+
"""
|
|
274
|
+
# NOTE the combine which closes the partitioned chunck
|
|
275
|
+
if params_mask is not None:
|
|
276
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
277
|
+
state = jax.tree.map(
|
|
278
|
+
lambda a, b, c: eqx.combine(b, c) if isinstance(a, Params) else b,
|
|
279
|
+
# NOTE else b in order to take all non Params stuff from
|
|
280
|
+
# opt_state that may have been updated too
|
|
281
|
+
state,
|
|
282
|
+
opt_state,
|
|
283
|
+
non_opt_state,
|
|
284
|
+
is_leaf=lambda x: isinstance(x, Params),
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
params = opt_params
|
|
288
|
+
state = opt_state
|
|
289
|
+
|
|
290
|
+
return params, state
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@partial(jit, static_argnames=["prefix"])
|
|
294
|
+
def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
295
|
+
# note that if the following is not jitted in the main for loop, it is
|
|
296
|
+
# super slow
|
|
297
|
+
_ = jax.lax.cond(
|
|
298
|
+
i % print_loss_every == 0,
|
|
299
|
+
lambda _: jax.debug.print(
|
|
300
|
+
prefix + "Iteration {i}: loss value = {loss_val}",
|
|
301
|
+
i=i,
|
|
302
|
+
loss_val=loss_val,
|
|
303
|
+
),
|
|
304
|
+
lambda _: None,
|
|
305
|
+
(None,),
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@jit
|
|
310
|
+
def _store_loss_and_params(
|
|
311
|
+
i: int,
|
|
312
|
+
params: Params[Array],
|
|
313
|
+
stored_params: Params[Array | None],
|
|
314
|
+
loss_container: LossContainer,
|
|
315
|
+
train_loss_val: float,
|
|
316
|
+
loss_terms: PyTree[Array],
|
|
317
|
+
weight_terms: PyTree[Array],
|
|
318
|
+
tracked_params: Params,
|
|
319
|
+
) -> tuple[StoredObjectContainer, LossContainer]:
|
|
320
|
+
stored_params = jax.tree_util.tree_map(
|
|
321
|
+
lambda stored_value, param, tracked_param: (
|
|
322
|
+
None
|
|
323
|
+
if stored_value is None
|
|
324
|
+
else jax.lax.cond(
|
|
325
|
+
tracked_param,
|
|
326
|
+
lambda ope: ope[0].at[i].set(ope[1]),
|
|
327
|
+
lambda ope: ope[0],
|
|
328
|
+
(stored_value, param),
|
|
329
|
+
)
|
|
330
|
+
),
|
|
331
|
+
stored_params,
|
|
332
|
+
params,
|
|
333
|
+
tracked_params,
|
|
334
|
+
is_leaf=lambda x: x is None,
|
|
335
|
+
)
|
|
336
|
+
stored_loss_terms = jax.tree_util.tree_map(
|
|
337
|
+
lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
|
|
338
|
+
loss_container.stored_loss_terms,
|
|
339
|
+
loss_terms,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if loss_container.stored_weights_terms is not None:
|
|
343
|
+
stored_weights_terms = jax.tree_util.tree_map(
|
|
344
|
+
lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
|
|
345
|
+
jax.tree.leaves(
|
|
346
|
+
loss_container.stored_weights_terms,
|
|
347
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
348
|
+
),
|
|
349
|
+
jax.tree.leaves(
|
|
350
|
+
weight_terms,
|
|
351
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
stored_weights_terms = eqx.tree_at(
|
|
355
|
+
lambda pt: jax.tree.leaves(
|
|
356
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
357
|
+
),
|
|
358
|
+
loss_container.stored_weights_terms,
|
|
359
|
+
stored_weights_terms,
|
|
360
|
+
)
|
|
361
|
+
else:
|
|
362
|
+
stored_weights_terms = None
|
|
363
|
+
|
|
364
|
+
train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
|
|
365
|
+
loss_container = LossContainer(
|
|
366
|
+
stored_loss_terms, stored_weights_terms, train_loss_values
|
|
367
|
+
)
|
|
368
|
+
stored_objects = StoredObjectContainer(stored_params)
|
|
369
|
+
return stored_objects, loss_container
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _get_break_fun(
|
|
373
|
+
n_iter: int,
|
|
374
|
+
verbose: bool,
|
|
375
|
+
conditions_str: tuple[str, ...] = (
|
|
376
|
+
"bool_max_iter",
|
|
377
|
+
"bool_nan_in_params",
|
|
378
|
+
"bool_early_stopping",
|
|
379
|
+
),
|
|
380
|
+
) -> Callable[[SolveCarry | SolveAlternateCarry], bool]:
|
|
381
|
+
"""
|
|
382
|
+
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
383
|
+
The verbose argument is here to control printing (or not) when exiting
|
|
384
|
+
the optimisation loop. It can be convenient is jinns.solve is itself
|
|
385
|
+
called in a loop and user want to avoid std output.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
@jit
|
|
389
|
+
def break_fun(carry: tuple):
|
|
390
|
+
"""
|
|
391
|
+
Function to break from the main optimization loop whe the following
|
|
392
|
+
conditions are met : maximum number of iterations, NaN
|
|
393
|
+
appearing in the parameters, and early stopping criterion.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def stop_while_loop(msg):
|
|
397
|
+
"""
|
|
398
|
+
Note that the message is wrapped in the jax.lax.cond because a
|
|
399
|
+
string is not a valid JAX type that can be fed into the operands
|
|
400
|
+
"""
|
|
401
|
+
if verbose:
|
|
402
|
+
jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
|
|
403
|
+
return False
|
|
404
|
+
|
|
405
|
+
def continue_while_loop(_):
|
|
406
|
+
return True
|
|
407
|
+
|
|
408
|
+
i = carry[0]
|
|
409
|
+
optimization = carry[2]
|
|
410
|
+
optimization_extra = carry[3]
|
|
411
|
+
|
|
412
|
+
conditions_bool = ()
|
|
413
|
+
if "bool_max_iter" in conditions_str:
|
|
414
|
+
# Condition 1
|
|
415
|
+
bool_max_iter = jax.lax.cond(
|
|
416
|
+
i >= n_iter,
|
|
417
|
+
lambda _: stop_while_loop("max iteration is reached"),
|
|
418
|
+
continue_while_loop,
|
|
419
|
+
None,
|
|
420
|
+
)
|
|
421
|
+
conditions_bool += (bool_max_iter,)
|
|
422
|
+
if "bool_nan_in_params" in conditions_str:
|
|
423
|
+
# Condition 2
|
|
424
|
+
bool_nan_in_params = jax.lax.cond(
|
|
425
|
+
_check_nan_in_pytree(optimization.params),
|
|
426
|
+
lambda _: stop_while_loop(
|
|
427
|
+
"NaN values in parameters (returning last non NaN values)"
|
|
428
|
+
),
|
|
429
|
+
continue_while_loop,
|
|
430
|
+
None,
|
|
431
|
+
)
|
|
432
|
+
conditions_bool += (bool_nan_in_params,)
|
|
433
|
+
if "bool_early_stopping" in conditions_str:
|
|
434
|
+
# Condition 3
|
|
435
|
+
bool_early_stopping = jax.lax.cond(
|
|
436
|
+
optimization_extra.early_stopping,
|
|
437
|
+
lambda _: stop_while_loop("early stopping"),
|
|
438
|
+
continue_while_loop,
|
|
439
|
+
None,
|
|
440
|
+
)
|
|
441
|
+
conditions_bool += (bool_early_stopping,)
|
|
442
|
+
|
|
443
|
+
# stop when one of the cond to continue is False
|
|
444
|
+
return jax.tree_util.tree_reduce(
|
|
445
|
+
lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
|
|
446
|
+
conditions_bool,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
return break_fun
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _build_get_batch(
|
|
453
|
+
obs_batch_sharding: jax.sharding.Sharding | None,
|
|
454
|
+
) -> Callable[
|
|
455
|
+
[
|
|
456
|
+
AbstractDataGenerator,
|
|
457
|
+
DataGeneratorParameter | None,
|
|
458
|
+
DataGeneratorObservations | None,
|
|
459
|
+
],
|
|
460
|
+
tuple[
|
|
461
|
+
AnyBatch,
|
|
462
|
+
AbstractDataGenerator,
|
|
463
|
+
DataGeneratorParameter | None,
|
|
464
|
+
DataGeneratorObservations | None,
|
|
465
|
+
],
|
|
466
|
+
]:
|
|
467
|
+
"""
|
|
468
|
+
Return the get_batch function that will be used either the jittable one or
|
|
469
|
+
the non-jittable one with sharding using jax.device.put()
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
def get_batch_sharding(data, param_data, obs_data):
|
|
473
|
+
"""
|
|
474
|
+
This function is used at each loop but it cannot be jitted because of
|
|
475
|
+
device_put
|
|
476
|
+
"""
|
|
477
|
+
data, batch = data.get_batch()
|
|
478
|
+
if param_data is not None:
|
|
479
|
+
param_data, param_batch = param_data.get_batch()
|
|
480
|
+
batch = append_param_batch(batch, param_batch)
|
|
481
|
+
if obs_data is not None:
|
|
482
|
+
# This is the part that motivated the transition from scan to for loop
|
|
483
|
+
# Indeed we need to be transit obs_batch from CPU to GPU when we have
|
|
484
|
+
# huge observations that cannot fit on GPU. Such transfer wasn't meant
|
|
485
|
+
# to be jitted, i.e. in a scan loop
|
|
486
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
487
|
+
obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
|
|
488
|
+
batch = append_obs_batch(batch, obs_batch)
|
|
489
|
+
return batch, data, param_data, obs_data
|
|
490
|
+
|
|
491
|
+
@jit
|
|
492
|
+
def get_batch(data, param_data, obs_data):
|
|
493
|
+
"""
|
|
494
|
+
Original get_batch with no sharding
|
|
495
|
+
"""
|
|
496
|
+
data, batch = data.get_batch()
|
|
497
|
+
if param_data is not None:
|
|
498
|
+
param_data, param_batch = param_data.get_batch()
|
|
499
|
+
batch = append_param_batch(batch, param_batch)
|
|
500
|
+
if obs_data is not None:
|
|
501
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
502
|
+
batch = append_obs_batch(batch, obs_batch)
|
|
503
|
+
return batch, data, param_data, obs_data
|
|
504
|
+
|
|
505
|
+
if obs_batch_sharding is not None:
|
|
506
|
+
return get_batch_sharding
|
|
507
|
+
return get_batch
|
|
5
508
|
|
|
6
509
|
|
|
7
510
|
def _check_batch_size(other_data, main_data, attr_name):
|
jinns/utils/_DictToModuleMeta.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
import equinox as eqx
|
|
3
3
|
|
|
4
|
+
from jinns.utils._ItemizableModule import ItemizableModule
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
class DictToModuleMeta(type):
|
|
6
8
|
"""
|
|
@@ -42,7 +44,7 @@ class DictToModuleMeta(type):
|
|
|
42
44
|
if self._class is None and class_name is not None:
|
|
43
45
|
self._class = type(
|
|
44
46
|
class_name,
|
|
45
|
-
(
|
|
47
|
+
(ItemizableModule,),
|
|
46
48
|
{"__annotations__": {k: type(v) for k, v in d.items()}},
|
|
47
49
|
)
|
|
48
50
|
try:
|
jinns/utils/_containers.py
CHANGED
|
@@ -37,13 +37,17 @@ class OptimizationContainer(eqx.Module):
|
|
|
37
37
|
params: Params
|
|
38
38
|
last_non_nan_params: Params
|
|
39
39
|
opt_state: OptState
|
|
40
|
+
# params_mask: Params = eqx.field(static=True) # to make params_mask
|
|
41
|
+
# hashable JAX type. See _gradient_step docstring
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
class OptimizationExtraContainer(eqx.Module):
|
|
43
|
-
curr_seq: int
|
|
44
|
-
best_iter_id:
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
curr_seq: int | None
|
|
46
|
+
best_iter_id: (
|
|
47
|
+
int | None
|
|
48
|
+
) # the best iteration number (that which achieves best_val_params and best_val_params)
|
|
49
|
+
best_val_criterion: float | None # the best validation criterion at early stopping
|
|
50
|
+
best_val_params: Params | None # the best parameter values at early stopping
|
|
47
51
|
early_stopping: Bool = False
|
|
48
52
|
|
|
49
53
|
|
jinns/utils/_types.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import (
|
|
|
3
3
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
4
4
|
|
|
5
5
|
from typing import TypeAlias, TYPE_CHECKING, Callable, TypeVar
|
|
6
|
-
from jaxtyping import Float, Array
|
|
6
|
+
from jaxtyping import Float, Array, PRNGKeyArray
|
|
7
7
|
|
|
8
8
|
from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch, ObsBatchDict
|
|
9
9
|
from jinns.loss._loss_weights import (
|
|
@@ -11,6 +11,11 @@ from jinns.loss._loss_weights import (
|
|
|
11
11
|
LossWeightsPDEStatio,
|
|
12
12
|
LossWeightsPDENonStatio,
|
|
13
13
|
)
|
|
14
|
+
from jinns.parameters._derivative_keys import (
|
|
15
|
+
DerivativeKeysODE,
|
|
16
|
+
DerivativeKeysPDENonStatio,
|
|
17
|
+
DerivativeKeysPDEStatio,
|
|
18
|
+
)
|
|
14
19
|
from jinns.loss._loss_components import (
|
|
15
20
|
ODEComponents,
|
|
16
21
|
PDEStatioComponents,
|
|
@@ -19,6 +24,9 @@ from jinns.loss._loss_components import (
|
|
|
19
24
|
|
|
20
25
|
AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch | ObsBatchDict
|
|
21
26
|
|
|
27
|
+
AnyDerivativeKeys: TypeAlias = (
|
|
28
|
+
DerivativeKeysODE | DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
29
|
+
)
|
|
22
30
|
AnyLossWeights: TypeAlias = (
|
|
23
31
|
LossWeightsODE | LossWeightsPDEStatio | LossWeightsPDENonStatio
|
|
24
32
|
)
|
|
@@ -30,6 +38,15 @@ AnyLossComponents: TypeAlias = (
|
|
|
30
38
|
)
|
|
31
39
|
|
|
32
40
|
if TYPE_CHECKING:
|
|
41
|
+
from jinns.utils._containers import (
|
|
42
|
+
DataGeneratorContainer,
|
|
43
|
+
OptimizationContainer,
|
|
44
|
+
OptimizationExtraContainer,
|
|
45
|
+
LossContainer,
|
|
46
|
+
StoredObjectContainer,
|
|
47
|
+
)
|
|
48
|
+
from jinns.validation._validation import AbstractValidationModule
|
|
49
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
33
50
|
from jinns.loss._LossODE import LossODE
|
|
34
51
|
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
35
52
|
|
|
@@ -39,3 +56,27 @@ if TYPE_CHECKING:
|
|
|
39
56
|
]
|
|
40
57
|
|
|
41
58
|
AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
|
|
59
|
+
|
|
60
|
+
SolveCarry: TypeAlias = tuple[
|
|
61
|
+
int,
|
|
62
|
+
AbstractLoss,
|
|
63
|
+
OptimizationContainer,
|
|
64
|
+
OptimizationExtraContainer,
|
|
65
|
+
DataGeneratorContainer,
|
|
66
|
+
AbstractValidationModule | None,
|
|
67
|
+
LossContainer,
|
|
68
|
+
StoredObjectContainer,
|
|
69
|
+
Float[Array, " n_iter"] | None,
|
|
70
|
+
PRNGKeyArray | None,
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
SolveAlternateCarry: TypeAlias = tuple[
|
|
74
|
+
int,
|
|
75
|
+
AbstractLoss,
|
|
76
|
+
OptimizationContainer,
|
|
77
|
+
OptimizationExtraContainer,
|
|
78
|
+
DataGeneratorContainer,
|
|
79
|
+
LossContainer,
|
|
80
|
+
StoredObjectContainer,
|
|
81
|
+
PRNGKeyArray | None,
|
|
82
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7.0
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -8,23 +8,26 @@ License: Apache License 2.0
|
|
|
8
8
|
Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
|
|
9
9
|
Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
|
|
10
10
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
-
Classifier: Development Status ::
|
|
11
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Programming Language :: Python
|
|
13
13
|
Requires-Python: >=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
16
|
License-File: AUTHORS
|
|
17
|
-
Requires-Dist: numpy
|
|
18
|
-
Requires-Dist: jax
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist: equinox>0.11.3
|
|
22
|
-
Requires-Dist: jax-tqdm
|
|
23
|
-
Requires-Dist: diffrax
|
|
17
|
+
Requires-Dist: numpy>=2.0.0
|
|
18
|
+
Requires-Dist: jax>=0.8.1
|
|
19
|
+
Requires-Dist: optax>=0.2.6
|
|
20
|
+
Requires-Dist: equinox>=0.13.2
|
|
24
21
|
Requires-Dist: matplotlib
|
|
22
|
+
Requires-Dist: jaxtyping
|
|
25
23
|
Provides-Extra: notebook
|
|
26
24
|
Requires-Dist: jupyter; extra == "notebook"
|
|
27
25
|
Requires-Dist: seaborn; extra == "notebook"
|
|
26
|
+
Requires-Dist: pandas; extra == "notebook"
|
|
27
|
+
Requires-Dist: pytest; extra == "notebook"
|
|
28
|
+
Requires-Dist: pre-commit; extra == "notebook"
|
|
29
|
+
Requires-Dist: pyright; extra == "notebook"
|
|
30
|
+
Requires-Dist: diffrax; extra == "notebook"
|
|
28
31
|
Dynamic: license-file
|
|
29
32
|
|
|
30
33
|
jinns
|
|
@@ -32,12 +35,11 @@ jinns
|
|
|
32
35
|
|
|
33
36
|
 
|
|
34
37
|
|
|
35
|
-
Physics Informed Neural Networks with JAX. **jinns** is
|
|
38
|
+
Physics Informed Neural Networks with JAX. **jinns** is a Python package for physics-informed neural networks (PINNs) in the [JAX](https://jax.readthedocs.io/en/latest/) ecosystem. It provides an intuitive and flexible interface for
|
|
36
39
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
It can also be used for forward problems and hybrid-modeling.
|
|
40
|
+
* forward problem: learning a PDE solution.
|
|
41
|
+
* inverse problem: learning the parameters of a PDE. **New in jinns v1.7.0:** `jinns.solve_alternate()` for fine-grained and efficient inverse problems.
|
|
42
|
+
* meta-modeling: learning a family of PDE indexed by its parameters.
|
|
41
43
|
|
|
42
44
|
**jinns** specific points:
|
|
43
45
|
|