jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +116 -189
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +176 -513
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +22 -21
- jinns/loss/_loss_utils.py +98 -173
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -76
- jinns/nn/__init__.py +22 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +434 -0
- jinns/nn/_mlp.py +217 -0
- jinns/nn/_pinn.py +204 -0
- jinns/nn/_ppinn.py +239 -0
- jinns/{utils → nn}/_save_load.py +39 -53
- jinns/nn/_spinn.py +123 -0
- jinns/nn/_spinn_mlp.py +202 -0
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +38 -37
- jinns/solver/_rar.py +82 -65
- jinns/solver/_solve.py +111 -71
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -5
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -8,55 +8,71 @@ from __future__ import (
|
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
10
|
import time
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeAlias, Callable
|
|
12
12
|
from functools import partial
|
|
13
13
|
import optax
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import
|
|
17
|
+
from jaxtyping import Float, Array
|
|
18
18
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
19
19
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
20
|
from jinns.solver._utils import _check_batch_size
|
|
21
|
-
from jinns.utils._containers import
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
append_param_batch,
|
|
21
|
+
from jinns.utils._containers import (
|
|
22
|
+
DataGeneratorContainer,
|
|
23
|
+
OptimizationContainer,
|
|
24
|
+
OptimizationExtraContainer,
|
|
25
|
+
LossContainer,
|
|
26
|
+
StoredObjectContainer,
|
|
28
27
|
)
|
|
28
|
+
from jinns.data._utils import append_param_batch, append_obs_batch
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
|
-
from jinns.
|
|
31
|
+
from jinns.parameters._params import Params
|
|
32
|
+
from jinns.utils._types import AnyLoss, AnyBatch
|
|
33
|
+
from jinns.validation._validation import AbstractValidationModule
|
|
34
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
35
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
36
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
37
|
+
|
|
38
|
+
main_carry: TypeAlias = tuple[
|
|
39
|
+
int,
|
|
40
|
+
AnyLoss,
|
|
41
|
+
OptimizationContainer,
|
|
42
|
+
OptimizationExtraContainer,
|
|
43
|
+
DataGeneratorContainer,
|
|
44
|
+
AbstractValidationModule | None,
|
|
45
|
+
LossContainer,
|
|
46
|
+
StoredObjectContainer,
|
|
47
|
+
Float[Array, " n_iter"] | None,
|
|
48
|
+
]
|
|
32
49
|
|
|
33
50
|
|
|
34
51
|
def solve(
|
|
35
|
-
n_iter:
|
|
36
|
-
init_params:
|
|
37
|
-
data:
|
|
52
|
+
n_iter: int,
|
|
53
|
+
init_params: Params[Array],
|
|
54
|
+
data: AbstractDataGenerator,
|
|
38
55
|
loss: AnyLoss,
|
|
39
56
|
optimizer: optax.GradientTransformation,
|
|
40
|
-
print_loss_every:
|
|
41
|
-
opt_state:
|
|
42
|
-
tracked_params: Params |
|
|
57
|
+
print_loss_every: int = 1000,
|
|
58
|
+
opt_state: optax.OptState | None = None,
|
|
59
|
+
tracked_params: Params[Any | None] | None = None,
|
|
43
60
|
param_data: DataGeneratorParameter | None = None,
|
|
44
|
-
obs_data:
|
|
45
|
-
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
|
|
46
|
-
) = None,
|
|
61
|
+
obs_data: DataGeneratorObservations | None = None,
|
|
47
62
|
validation: AbstractValidationModule | None = None,
|
|
48
63
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
49
|
-
verbose:
|
|
64
|
+
verbose: bool = True,
|
|
65
|
+
ahead_of_time: bool = True,
|
|
50
66
|
) -> tuple[
|
|
51
|
-
Params
|
|
52
|
-
Float[Array, "n_iter"],
|
|
53
|
-
|
|
54
|
-
|
|
67
|
+
Params[Array],
|
|
68
|
+
Float[Array, " n_iter"],
|
|
69
|
+
dict[str, Float[Array, " n_iter"]],
|
|
70
|
+
AbstractDataGenerator,
|
|
55
71
|
AnyLoss,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
Float[Array, "n_iter"],
|
|
59
|
-
|
|
72
|
+
optax.OptState,
|
|
73
|
+
Params[Array | None],
|
|
74
|
+
Float[Array, " n_iter"] | None,
|
|
75
|
+
Params[Array],
|
|
60
76
|
]:
|
|
61
77
|
"""
|
|
62
78
|
Performs the optimization process via stochastic gradient descent
|
|
@@ -93,8 +109,7 @@ def solve(
|
|
|
93
109
|
Default None. A DataGeneratorParameter object which can be used to
|
|
94
110
|
sample equation parameters.
|
|
95
111
|
obs_data
|
|
96
|
-
Default None. A DataGeneratorObservations
|
|
97
|
-
DataGeneratorObservationsMultiPINNs
|
|
112
|
+
Default None. A DataGeneratorObservations
|
|
98
113
|
object which can be used to sample minibatches of observations.
|
|
99
114
|
validation
|
|
100
115
|
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
@@ -118,6 +133,14 @@ def solve(
|
|
|
118
133
|
verbose
|
|
119
134
|
Default True. If False, no std output (loss or cause of
|
|
120
135
|
exiting the optimization loop) will be produced.
|
|
136
|
+
ahead_of_time
|
|
137
|
+
Default True. Separate the compilation of the main training loop from
|
|
138
|
+
the execution to get both timings. You might need to avoid this
|
|
139
|
+
behaviour if you need to perform JAX transforms over chunks of code
|
|
140
|
+
containing `jinns.solve()` since AOT-compiled functions cannot be JAX
|
|
141
|
+
transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
|
|
142
|
+
When False, jinns does not provide any timing information (which would
|
|
143
|
+
be nonsense in a JIT transformed `solve()` function).
|
|
121
144
|
|
|
122
145
|
Returns
|
|
123
146
|
-------
|
|
@@ -162,11 +185,21 @@ def solve(
|
|
|
162
185
|
_check_batch_size(obs_data, param_data, "n")
|
|
163
186
|
|
|
164
187
|
if opt_state is None:
|
|
165
|
-
opt_state = optimizer.init(init_params)
|
|
188
|
+
opt_state = optimizer.init(init_params) # type: ignore
|
|
189
|
+
# our Params are eqx.Module (dataclass + PyTree), PyTree is
|
|
190
|
+
# compatible with optax transform but not dataclass, this leads to a
|
|
191
|
+
# type hint error: we could prevent this by ensuring with the eqx.filter that
|
|
192
|
+
# we have only floating points optimizable params given to optax
|
|
193
|
+
# see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
|
|
194
|
+
# opt_state = optimizer.init(
|
|
195
|
+
# eqx.filter(init_params, eqx.is_inexact_array)
|
|
196
|
+
# )
|
|
197
|
+
# but this seems like a hack and there is no better way
|
|
198
|
+
# https://github.com/google-deepmind/optax/issues/384
|
|
166
199
|
|
|
167
200
|
# RAR sampling init (ouside scanned function to avoid dynamic slice error)
|
|
168
201
|
# If RAR is not used the _rar_step_*() are juste None and data is unchanged
|
|
169
|
-
data, _rar_step_true, _rar_step_false = init_rar(data)
|
|
202
|
+
data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
|
|
170
203
|
|
|
171
204
|
# Seq2seq
|
|
172
205
|
curr_seq = 0
|
|
@@ -283,7 +316,7 @@ def solve(
|
|
|
283
316
|
if verbose:
|
|
284
317
|
_print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
285
318
|
|
|
286
|
-
if validation is not None:
|
|
319
|
+
if validation is not None and validation_crit_values is not None:
|
|
287
320
|
# there is a jax.lax.cond because we do not necesarily call the
|
|
288
321
|
# validation step every iteration
|
|
289
322
|
(
|
|
@@ -297,7 +330,7 @@ def solve(
|
|
|
297
330
|
lambda operands: (
|
|
298
331
|
operands[0],
|
|
299
332
|
False,
|
|
300
|
-
validation_crit_values[i - 1],
|
|
333
|
+
validation_crit_values[i - 1], # type: ignore don't know why it can still be None
|
|
301
334
|
False,
|
|
302
335
|
),
|
|
303
336
|
(
|
|
@@ -384,16 +417,21 @@ def solve(
|
|
|
384
417
|
def train_fun(carry):
|
|
385
418
|
return jax.lax.while_loop(break_fun, _one_iteration, carry)
|
|
386
419
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
420
|
+
if ahead_of_time:
|
|
421
|
+
start = time.time()
|
|
422
|
+
compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
|
|
423
|
+
end = time.time()
|
|
424
|
+
if verbose:
|
|
425
|
+
print("\nCompilation took\n", end - start, "\n")
|
|
391
426
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
427
|
+
start = time.time()
|
|
428
|
+
carry = compiled_train_fun(carry)
|
|
429
|
+
jax.block_until_ready(carry)
|
|
430
|
+
end = time.time()
|
|
431
|
+
if verbose:
|
|
432
|
+
print("\nTraining took\n", end - start, "\n")
|
|
433
|
+
else:
|
|
434
|
+
carry = train_fun(carry)
|
|
397
435
|
|
|
398
436
|
(
|
|
399
437
|
i,
|
|
@@ -417,7 +455,7 @@ def solve(
|
|
|
417
455
|
# get ready to return the parameters at last iteration...
|
|
418
456
|
# (by default arbitrary choice, this could be None)
|
|
419
457
|
validation_parameters = optimization.last_non_nan_params
|
|
420
|
-
if validation is not None:
|
|
458
|
+
if validation is not None and validation_crit_values is not None:
|
|
421
459
|
jax.debug.print(
|
|
422
460
|
"validation loss value = {validation_loss_val}",
|
|
423
461
|
validation_loss_val=validation_crit_values[i - 1],
|
|
@@ -452,24 +490,28 @@ def _gradient_step(
|
|
|
452
490
|
loss: AnyLoss,
|
|
453
491
|
optimizer: optax.GradientTransformation,
|
|
454
492
|
batch: AnyBatch,
|
|
455
|
-
params:
|
|
456
|
-
opt_state:
|
|
457
|
-
last_non_nan_params:
|
|
493
|
+
params: Params[Array],
|
|
494
|
+
opt_state: optax.OptState,
|
|
495
|
+
last_non_nan_params: Params[Array],
|
|
458
496
|
) -> tuple[
|
|
459
497
|
AnyLoss,
|
|
460
498
|
float,
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
499
|
+
dict[str, float],
|
|
500
|
+
Params[Array],
|
|
501
|
+
optax.OptState,
|
|
502
|
+
Params[Array],
|
|
465
503
|
]:
|
|
466
504
|
"""
|
|
467
505
|
optimizer cannot be jit-ted.
|
|
468
506
|
"""
|
|
469
507
|
value_grad_loss = jax.value_and_grad(loss, has_aux=True)
|
|
470
508
|
(loss_val, loss_terms), grads = value_grad_loss(params, batch)
|
|
471
|
-
updates, opt_state = optimizer.update(
|
|
472
|
-
|
|
509
|
+
updates, opt_state = optimizer.update(
|
|
510
|
+
grads,
|
|
511
|
+
opt_state,
|
|
512
|
+
params, # type: ignore
|
|
513
|
+
) # see optimizer.init for explaination
|
|
514
|
+
params = optax.apply_updates(params, updates) # type: ignore
|
|
473
515
|
|
|
474
516
|
# check if any of the parameters is NaN
|
|
475
517
|
last_non_nan_params = jax.lax.cond(
|
|
@@ -490,7 +532,7 @@ def _gradient_step(
|
|
|
490
532
|
|
|
491
533
|
|
|
492
534
|
@partial(jit, static_argnames=["prefix"])
|
|
493
|
-
def _print_fn(i:
|
|
535
|
+
def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
494
536
|
# note that if the following is not jitted in the main lor loop, it is
|
|
495
537
|
# super slow
|
|
496
538
|
_ = jax.lax.cond(
|
|
@@ -507,17 +549,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
|
|
|
507
549
|
|
|
508
550
|
@jit
|
|
509
551
|
def _store_loss_and_params(
|
|
510
|
-
i:
|
|
511
|
-
params:
|
|
512
|
-
stored_params:
|
|
513
|
-
stored_loss_terms:
|
|
514
|
-
train_loss_values: Float[Array, "n_iter"],
|
|
552
|
+
i: int,
|
|
553
|
+
params: Params[Array],
|
|
554
|
+
stored_params: Params[Array],
|
|
555
|
+
stored_loss_terms: dict[str, Float[Array, " n_iter"]],
|
|
556
|
+
train_loss_values: Float[Array, " n_iter"],
|
|
515
557
|
train_loss_val: float,
|
|
516
|
-
loss_terms:
|
|
517
|
-
tracked_params:
|
|
518
|
-
) -> tuple[
|
|
519
|
-
Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
|
|
520
|
-
]:
|
|
558
|
+
loss_terms: dict[str, float],
|
|
559
|
+
tracked_params: Params,
|
|
560
|
+
) -> tuple[Params, dict[str, Float[Array, " n_iter"]], Float[Array, " n_iter"]]:
|
|
521
561
|
stored_params = jax.tree_util.tree_map(
|
|
522
562
|
lambda stored_value, param, tracked_param: (
|
|
523
563
|
None
|
|
@@ -544,7 +584,7 @@ def _store_loss_and_params(
|
|
|
544
584
|
return (stored_params, stored_loss_terms, train_loss_values)
|
|
545
585
|
|
|
546
586
|
|
|
547
|
-
def _get_break_fun(n_iter:
|
|
587
|
+
def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
548
588
|
"""
|
|
549
589
|
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
550
590
|
The verbose argument is here to control printing (or not) when exiting
|
|
@@ -585,7 +625,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
585
625
|
bool_nan_in_params = jax.lax.cond(
|
|
586
626
|
_check_nan_in_pytree(optimization.params),
|
|
587
627
|
lambda _: stop_while_loop(
|
|
588
|
-
"NaN values in parameters
|
|
628
|
+
"NaN values in parameters (returning last non NaN values)"
|
|
589
629
|
),
|
|
590
630
|
continue_while_loop,
|
|
591
631
|
None,
|
|
@@ -608,18 +648,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
608
648
|
|
|
609
649
|
|
|
610
650
|
def _get_get_batch(
|
|
611
|
-
obs_batch_sharding: jax.sharding.Sharding,
|
|
651
|
+
obs_batch_sharding: jax.sharding.Sharding | None,
|
|
612
652
|
) -> Callable[
|
|
613
653
|
[
|
|
614
|
-
|
|
654
|
+
AbstractDataGenerator,
|
|
615
655
|
DataGeneratorParameter | None,
|
|
616
|
-
DataGeneratorObservations |
|
|
656
|
+
DataGeneratorObservations | None,
|
|
617
657
|
],
|
|
618
658
|
tuple[
|
|
619
659
|
AnyBatch,
|
|
620
|
-
|
|
660
|
+
AbstractDataGenerator,
|
|
621
661
|
DataGeneratorParameter | None,
|
|
622
|
-
DataGeneratorObservations |
|
|
662
|
+
DataGeneratorObservations | None,
|
|
623
663
|
],
|
|
624
664
|
]:
|
|
625
665
|
"""
|
jinns/solver/_utils.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
from jinns.data.
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
DataGeneratorParameter,
|
|
6
|
-
)
|
|
1
|
+
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
2
|
+
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
|
+
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
def _check_batch_size(other_data, main_data, attr_name):
|
jinns/utils/__init__.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
from ._pinn import create_PINN, PINN
|
|
2
|
-
from ._ppinn import create_PPINN, PPINN
|
|
3
|
-
from ._spinn import create_SPINN, SPINN
|
|
4
|
-
from ._hyperpinn import create_HYPERPINN, HYPERPINN
|
|
5
|
-
from ._save_load import save_pinn, load_pinn
|
|
6
1
|
from ._utils import get_grid
|
|
2
|
+
|
|
3
|
+
__all__ = ["get_grid"]
|
jinns/utils/_containers.py
CHANGED
|
@@ -11,23 +11,26 @@ from jaxtyping import PyTree, Array, Float, Bool
|
|
|
11
11
|
from optax import OptState
|
|
12
12
|
import equinox as eqx
|
|
13
13
|
|
|
14
|
+
from jinns.parameters._params import Params
|
|
15
|
+
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
|
-
from jinns.
|
|
17
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
18
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
19
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
20
|
+
from jinns.utils._types import AnyLoss
|
|
16
21
|
|
|
17
22
|
|
|
18
23
|
class DataGeneratorContainer(eqx.Module):
|
|
19
|
-
data:
|
|
24
|
+
data: AbstractDataGenerator
|
|
20
25
|
param_data: DataGeneratorParameter | None = None
|
|
21
|
-
obs_data: DataGeneratorObservations |
|
|
22
|
-
None
|
|
23
|
-
)
|
|
26
|
+
obs_data: DataGeneratorObservations | None = None
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
class ValidationContainer(eqx.Module):
|
|
27
30
|
loss: AnyLoss | None
|
|
28
31
|
data: DataGeneratorContainer
|
|
29
32
|
hyperparams: PyTree = None
|
|
30
|
-
loss_values: Float[Array, "n_iter"] | None = None
|
|
33
|
+
loss_values: Float[Array, " n_iter"] | None = None
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class OptimizationContainer(eqx.Module):
|
|
@@ -45,9 +48,9 @@ class OptimizationExtraContainer(eqx.Module):
|
|
|
45
48
|
|
|
46
49
|
|
|
47
50
|
class LossContainer(eqx.Module):
|
|
48
|
-
stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
|
|
49
|
-
train_loss_values: Float[Array, "n_iter"]
|
|
51
|
+
stored_loss_terms: Dict[str, Float[Array, " n_iter"]]
|
|
52
|
+
train_loss_values: Float[Array, " n_iter"]
|
|
50
53
|
|
|
51
54
|
|
|
52
55
|
class StoredObjectContainer(eqx.Module):
|
|
53
|
-
stored_params:
|
|
56
|
+
stored_params: Params[Array | None]
|
jinns/utils/_types.py
CHANGED
|
@@ -1,65 +1,19 @@
|
|
|
1
|
-
# pragma: exclude file
|
|
2
1
|
from __future__ import (
|
|
3
2
|
annotations,
|
|
4
3
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
5
4
|
|
|
6
|
-
from typing import TypeAlias, TYPE_CHECKING,
|
|
7
|
-
from jaxtyping import
|
|
5
|
+
from typing import TypeAlias, TYPE_CHECKING, Callable
|
|
6
|
+
from jaxtyping import Float, Array
|
|
8
7
|
|
|
9
8
|
if TYPE_CHECKING:
|
|
10
|
-
from jinns.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
SystemLossPDE,
|
|
14
|
-
)
|
|
9
|
+
from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
|
|
10
|
+
from jinns.loss._LossODE import LossODE
|
|
11
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
15
12
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
CubicMeshPDEStatio,
|
|
21
|
-
CubicMeshPDENonStatio,
|
|
22
|
-
DataGeneratorObservations,
|
|
23
|
-
DataGeneratorParameter,
|
|
24
|
-
DataGeneratorObservationsMultiPINNs,
|
|
25
|
-
)
|
|
13
|
+
# Here we define types available for the whole package
|
|
14
|
+
BoundaryConditionFun: TypeAlias = Callable[
|
|
15
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
16
|
+
]
|
|
26
17
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
from jinns.utils._pinn import PINN
|
|
30
|
-
from jinns.utils._hyperpinn import HYPERPINN
|
|
31
|
-
from jinns.utils._spinn import SPINN
|
|
32
|
-
from jinns.utils._containers import *
|
|
33
|
-
from jinns.validation._validation import AbstractValidationModule
|
|
34
|
-
|
|
35
|
-
AnyLoss: TypeAlias = (
|
|
36
|
-
LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
AnyParams: TypeAlias = Params | ParamsDict
|
|
40
|
-
|
|
41
|
-
AnyDataGenerator: TypeAlias = (
|
|
42
|
-
DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
|
|
46
|
-
|
|
47
|
-
AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
|
|
48
|
-
rar_operands = NewType(
|
|
49
|
-
"rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
main_carry = NewType(
|
|
53
|
-
"main_carry",
|
|
54
|
-
tuple[
|
|
55
|
-
Int,
|
|
56
|
-
AnyLoss,
|
|
57
|
-
OptimizationContainer,
|
|
58
|
-
OptimizationExtraContainer,
|
|
59
|
-
DataGeneratorContainer,
|
|
60
|
-
AbstractValidationModule,
|
|
61
|
-
LossContainer,
|
|
62
|
-
StoredObjectContainer,
|
|
63
|
-
Float[Array, "n_iter"],
|
|
64
|
-
],
|
|
65
|
-
)
|
|
18
|
+
AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
|
|
19
|
+
AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
|
jinns/utils/_utils.py
CHANGED
|
@@ -2,20 +2,13 @@
|
|
|
2
2
|
Implements various utility functions
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from math import prod
|
|
6
5
|
import warnings
|
|
7
6
|
import jax
|
|
8
7
|
import jax.numpy as jnp
|
|
9
|
-
from jaxtyping import PyTree, Array
|
|
8
|
+
from jaxtyping import PyTree, Array, Bool
|
|
10
9
|
|
|
11
|
-
from jinns.data._DataGenerators import (
|
|
12
|
-
DataGeneratorODE,
|
|
13
|
-
CubicMeshPDEStatio,
|
|
14
|
-
CubicMeshPDENonStatio,
|
|
15
|
-
)
|
|
16
10
|
|
|
17
|
-
|
|
18
|
-
def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
11
|
+
def _check_nan_in_pytree(pytree: PyTree) -> Bool[Array, " "]:
|
|
19
12
|
"""
|
|
20
13
|
Check if there is a NaN value anywhere is the pytree
|
|
21
14
|
|
|
@@ -55,7 +48,7 @@ def get_grid(in_array: Array) -> Array:
|
|
|
55
48
|
|
|
56
49
|
|
|
57
50
|
def _check_shape_and_type(
|
|
58
|
-
r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
51
|
+
r: Array | int | float, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
59
52
|
) -> Array | float:
|
|
60
53
|
"""
|
|
61
54
|
Ensures float type and correct shapes for broadcasting when performing a
|
|
@@ -90,7 +83,7 @@ def _check_shape_and_type(
|
|
|
90
83
|
|
|
91
84
|
|
|
92
85
|
def _subtract_with_check(
|
|
93
|
-
a: Array | int, b: Array
|
|
86
|
+
a: Array | int | float, b: Array, cause: str = ""
|
|
94
87
|
) -> Array | float:
|
|
95
88
|
a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
|
|
96
89
|
return a - b
|
jinns/validation/__init__.py
CHANGED
jinns/validation/_validation.py
CHANGED
|
@@ -7,19 +7,23 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
-
from typing import TYPE_CHECKING
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
11
|
import equinox as eqx
|
|
12
12
|
import jax
|
|
13
13
|
import jax.numpy as jnp
|
|
14
|
-
from jaxtyping import Array
|
|
14
|
+
from jaxtyping import Array, Float
|
|
15
15
|
|
|
16
|
-
from jinns.data.
|
|
16
|
+
from jinns.data._utils import (
|
|
17
17
|
append_obs_batch,
|
|
18
18
|
append_param_batch,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
|
-
from jinns.
|
|
22
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
23
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
24
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
25
|
+
from jinns.parameters._params import Params
|
|
26
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
23
27
|
|
|
24
28
|
# Using eqx Module for the DataClass + Pytree inheritance
|
|
25
29
|
# Abstract class and abstract/final pattern is used
|
|
@@ -40,8 +44,8 @@ class AbstractValidationModule(eqx.Module):
|
|
|
40
44
|
|
|
41
45
|
@abc.abstractmethod
|
|
42
46
|
def __call__(
|
|
43
|
-
self, params: Params
|
|
44
|
-
) -> tuple[
|
|
47
|
+
self, params: Params[Array]
|
|
48
|
+
) -> tuple[AbstractValidationModule, bool, Array, Params[Array]]:
|
|
45
49
|
raise NotImplementedError
|
|
46
50
|
|
|
47
51
|
|
|
@@ -52,24 +56,20 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
52
56
|
for more complicated validation strategy.
|
|
53
57
|
"""
|
|
54
58
|
|
|
55
|
-
loss:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
|
|
59
|
+
loss: AbstractLoss = eqx.field(kw_only=True)
|
|
60
|
+
validation_data: AbstractDataGenerator = eqx.field(kw_only=True)
|
|
61
|
+
validation_param_data: DataGeneratorParameter = eqx.field(
|
|
62
|
+
kw_only=True, default=None
|
|
63
|
+
)
|
|
64
|
+
validation_obs_data: DataGeneratorObservations | None = eqx.field(
|
|
62
65
|
kw_only=True, default=None
|
|
63
66
|
)
|
|
64
|
-
validation_obs_data: Union[
|
|
65
|
-
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
66
|
-
] = eqx.field(kw_only=True, default=None)
|
|
67
67
|
call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
|
|
68
68
|
early_stopping: bool = eqx.field(
|
|
69
69
|
kw_only=True, default=True
|
|
70
70
|
) # globally control if early stopping happens
|
|
71
71
|
|
|
72
|
-
patience:
|
|
72
|
+
patience: int = eqx.field(kw_only=True, default=10)
|
|
73
73
|
best_val_loss: Array = eqx.field(
|
|
74
74
|
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
|
|
75
75
|
)
|
|
@@ -79,10 +79,11 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
79
79
|
)
|
|
80
80
|
|
|
81
81
|
def __call__(
|
|
82
|
-
self, params:
|
|
83
|
-
) -> tuple[
|
|
82
|
+
self, params: Params[Array]
|
|
83
|
+
) -> tuple[ValidationLoss, bool, Float[Array, " "], Params[Array]]:
|
|
84
84
|
# do in-place mutation
|
|
85
85
|
|
|
86
|
+
# pylint / pyright complains below when using the self attributes see: https://github.com/patrick-kidger/equinox/issues/1013
|
|
86
87
|
validation_data, val_batch = self.validation_data.get_batch()
|
|
87
88
|
if self.validation_param_data is not None:
|
|
88
89
|
validation_param_data, param_batch = self.validation_param_data.get_batch()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
|
|
|
25
25
|
Provides-Extra: notebook
|
|
26
26
|
Requires-Dist: jupyter; extra == "notebook"
|
|
27
27
|
Requires-Dist: seaborn; extra == "notebook"
|
|
28
|
+
Dynamic: license-file
|
|
28
29
|
|
|
29
30
|
jinns
|
|
30
31
|
=====
|
|
@@ -99,7 +100,7 @@ Here are the contributors guidelines:
|
|
|
99
100
|
pip install -e .
|
|
100
101
|
```
|
|
101
102
|
|
|
102
|
-
3. Install pre-commit and run it.
|
|
103
|
+
3. Install pre-commit and run it. Our pre-commit hooks consist in `ruff format` and `ruff check`. You can install `ruff` simply by `pip install ruff`. We highly recommend you to check the code type hints with `pyright` even though we currently have no rule concerning type checking in the pipeline.
|
|
103
104
|
|
|
104
105
|
```bash
|
|
105
106
|
pip install pre-commit
|
|
@@ -112,16 +113,16 @@ pre-commit install
|
|
|
112
113
|
|
|
113
114
|
Don't hesitate to contribute and get your name on the list here !
|
|
114
115
|
|
|
115
|
-
**List of contributors:** Hugo Gangloff, Nicolas Jouvin
|
|
116
|
+
**List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
|
|
116
117
|
|
|
117
118
|
# Cite us
|
|
118
119
|
|
|
119
|
-
Please consider citing our work if you found it useful to yours, using
|
|
120
|
+
Please consider citing our work if you found it useful to yours, using this [ArXiV preprint](https://arxiv.org/abs/2412.14132)
|
|
120
121
|
```
|
|
121
|
-
@
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
@article{gangloff_jouvin2024jinns,
|
|
123
|
+
title={jinns: a JAX Library for Physics-Informed Neural Networks},
|
|
124
|
+
author={Gangloff, Hugo and Jouvin, Nicolas},
|
|
125
|
+
journal={arXiv preprint arXiv:2412.14132},
|
|
126
|
+
year={2024}
|
|
126
127
|
}
|
|
127
128
|
```
|