jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +534 -343
- jinns/loss/_DynamicLoss.py +152 -175
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +4 -4
- jinns/loss/_LossPDE.py +102 -74
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +150 -281
- jinns/loss/_loss_utils.py +95 -67
- jinns/loss/_operators.py +441 -186
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +47 -31
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +113 -100
- jinns/solver/_rar.py +104 -409
- jinns/solver/_solve.py +87 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +1 -4
- jinns/utils/_containers.py +3 -1
- jinns/utils/_types.py +5 -4
- jinns/utils/_utils.py +40 -12
- jinns-1.3.0.dist-info/METADATA +127 -0
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -410
- jinns/utils/_pinn.py +0 -334
- jinns/utils/_spinn.py +0 -268
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -7,6 +7,7 @@ from __future__ import (
|
|
|
7
7
|
annotations,
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
|
+
import time
|
|
10
11
|
from typing import TYPE_CHECKING, NamedTuple, Dict, Union
|
|
11
12
|
from functools import partial
|
|
12
13
|
import optax
|
|
@@ -16,6 +17,7 @@ import jax.numpy as jnp
|
|
|
16
17
|
from jaxtyping import Int, Bool, Float, Array
|
|
17
18
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
18
19
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
|
+
from jinns.solver._utils import _check_batch_size
|
|
19
21
|
from jinns.utils._containers import *
|
|
20
22
|
from jinns.data._DataGenerators import (
|
|
21
23
|
DataGeneratorODE,
|
|
@@ -29,31 +31,6 @@ if TYPE_CHECKING:
|
|
|
29
31
|
from jinns.utils._types import *
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def _check_batch_size(other_data, main_data, attr_name):
|
|
33
|
-
if (
|
|
34
|
-
(
|
|
35
|
-
isinstance(main_data, DataGeneratorODE)
|
|
36
|
-
and getattr(other_data, attr_name) != main_data.temporal_batch_size
|
|
37
|
-
)
|
|
38
|
-
or (
|
|
39
|
-
isinstance(main_data, CubicMeshPDEStatio)
|
|
40
|
-
and not isinstance(main_data, CubicMeshPDENonStatio)
|
|
41
|
-
and getattr(other_data, attr_name) != main_data.omega_batch_size
|
|
42
|
-
)
|
|
43
|
-
or (
|
|
44
|
-
isinstance(main_data, CubicMeshPDENonStatio)
|
|
45
|
-
and getattr(other_data, attr_name)
|
|
46
|
-
!= main_data.omega_batch_size * main_data.temporal_batch_size
|
|
47
|
-
)
|
|
48
|
-
):
|
|
49
|
-
raise ValueError(
|
|
50
|
-
"Optional other_data.param_batch_size must be"
|
|
51
|
-
" equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
|
|
52
|
-
" the product of both dependeing on the type of the main"
|
|
53
|
-
" datagenerator"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
34
|
def solve(
|
|
58
35
|
n_iter: Int,
|
|
59
36
|
init_params: AnyParams,
|
|
@@ -70,6 +47,7 @@ def solve(
|
|
|
70
47
|
validation: AbstractValidationModule | None = None,
|
|
71
48
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
72
49
|
verbose: Bool = True,
|
|
50
|
+
ahead_of_time: Bool = True,
|
|
73
51
|
) -> tuple[
|
|
74
52
|
Params | ParamsDict,
|
|
75
53
|
Float[Array, "n_iter"],
|
|
@@ -141,6 +119,14 @@ def solve(
|
|
|
141
119
|
verbose
|
|
142
120
|
Default True. If False, no std output (loss or cause of
|
|
143
121
|
exiting the optimization loop) will be produced.
|
|
122
|
+
ahead_of_time
|
|
123
|
+
Default True. Separate the compilation of the main training loop from
|
|
124
|
+
the execution to get both timings. You might need to avoid this
|
|
125
|
+
behaviour if you need to perform JAX transforms over chunks of code
|
|
126
|
+
containing `jinns.solve()` since AOT-compiled functions cannot be JAX
|
|
127
|
+
transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
|
|
128
|
+
When False, jinns does not provide any timing information (which would
|
|
129
|
+
be nonsense in a JIT transformed `solve()` function).
|
|
144
130
|
|
|
145
131
|
Returns
|
|
146
132
|
-------
|
|
@@ -167,10 +153,22 @@ def solve(
|
|
|
167
153
|
The best parameters according to the validation criterion
|
|
168
154
|
"""
|
|
169
155
|
if param_data is not None:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
156
|
+
if param_data.param_batch_size is not None:
|
|
157
|
+
# We need to check that batch sizes will all be compliant for
|
|
158
|
+
# correct vectorization
|
|
159
|
+
_check_batch_size(param_data, data, "param_batch_size")
|
|
160
|
+
else:
|
|
161
|
+
# If DataGeneratorParameter does not have a batch size we will
|
|
162
|
+
# vectorization using `n`, and the same checks must be done
|
|
163
|
+
_check_batch_size(param_data, data, "n")
|
|
164
|
+
|
|
165
|
+
if obs_data is not None and param_data is not None:
|
|
166
|
+
# obs_data batch dimensions need only to be aligned with param_data
|
|
167
|
+
# batch dimensions if the latter exist
|
|
168
|
+
if obs_data.obs_batch_size is not None:
|
|
169
|
+
_check_batch_size(obs_data, param_data, "obs_batch_size")
|
|
170
|
+
else:
|
|
171
|
+
_check_batch_size(obs_data, param_data, "n")
|
|
174
172
|
|
|
175
173
|
if opt_state is None:
|
|
176
174
|
opt_state = optimizer.init(init_params)
|
|
@@ -224,6 +222,8 @@ def solve(
|
|
|
224
222
|
)
|
|
225
223
|
optimization_extra = OptimizationExtraContainer(
|
|
226
224
|
curr_seq=curr_seq,
|
|
225
|
+
best_iter_id=0,
|
|
226
|
+
best_val_criterion=jnp.nan,
|
|
227
227
|
best_val_params=init_params,
|
|
228
228
|
)
|
|
229
229
|
loss_container = LossContainer(
|
|
@@ -323,16 +323,26 @@ def solve(
|
|
|
323
323
|
validation_criterion
|
|
324
324
|
)
|
|
325
325
|
|
|
326
|
-
# update best_val_params w.r.t val_loss if needed
|
|
327
|
-
best_val_params = jax.lax.cond(
|
|
326
|
+
# update best_val_params and best_val_criterion w.r.t val_loss if needed
|
|
327
|
+
(best_val_params, best_val_criterion, best_iter_id) = jax.lax.cond(
|
|
328
328
|
update_best_params,
|
|
329
|
-
lambda
|
|
330
|
-
|
|
329
|
+
lambda operands: (
|
|
330
|
+
params,
|
|
331
|
+
validation_criterion,
|
|
332
|
+
i,
|
|
333
|
+
), # update with current value
|
|
334
|
+
lambda operands: (
|
|
335
|
+
operands[0].best_val_params,
|
|
336
|
+
operands[0].best_val_criterion,
|
|
337
|
+
operands[0].best_iter_id,
|
|
338
|
+
), # unchanged
|
|
331
339
|
(optimization_extra,),
|
|
332
340
|
)
|
|
333
341
|
else:
|
|
334
342
|
early_stopping = False
|
|
343
|
+
best_iter_id = 0
|
|
335
344
|
best_val_params = params
|
|
345
|
+
best_val_criterion = jnp.nan
|
|
336
346
|
|
|
337
347
|
# Trigger RAR
|
|
338
348
|
loss, params, data = trigger_rar(
|
|
@@ -358,7 +368,13 @@ def solve(
|
|
|
358
368
|
i,
|
|
359
369
|
loss,
|
|
360
370
|
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
361
|
-
OptimizationExtraContainer(
|
|
371
|
+
OptimizationExtraContainer(
|
|
372
|
+
curr_seq,
|
|
373
|
+
best_iter_id,
|
|
374
|
+
best_val_criterion,
|
|
375
|
+
best_val_params,
|
|
376
|
+
early_stopping,
|
|
377
|
+
),
|
|
362
378
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
363
379
|
validation,
|
|
364
380
|
LossContainer(stored_loss_terms, train_loss_values),
|
|
@@ -373,7 +389,25 @@ def solve(
|
|
|
373
389
|
while break_fun(carry):
|
|
374
390
|
carry = _one_iteration(carry)
|
|
375
391
|
else:
|
|
376
|
-
|
|
392
|
+
|
|
393
|
+
def train_fun(carry):
|
|
394
|
+
return jax.lax.while_loop(break_fun, _one_iteration, carry)
|
|
395
|
+
|
|
396
|
+
if ahead_of_time:
|
|
397
|
+
start = time.time()
|
|
398
|
+
compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
|
|
399
|
+
end = time.time()
|
|
400
|
+
if verbose:
|
|
401
|
+
print("\nCompilation took\n", end - start, "\n")
|
|
402
|
+
|
|
403
|
+
start = time.time()
|
|
404
|
+
carry = compiled_train_fun(carry)
|
|
405
|
+
jax.block_until_ready(carry)
|
|
406
|
+
end = time.time()
|
|
407
|
+
if verbose:
|
|
408
|
+
print("\nTraining took\n", end - start, "\n")
|
|
409
|
+
else:
|
|
410
|
+
carry = train_fun(carry)
|
|
377
411
|
|
|
378
412
|
(
|
|
379
413
|
i,
|
|
@@ -389,15 +423,30 @@ def solve(
|
|
|
389
423
|
|
|
390
424
|
if verbose:
|
|
391
425
|
jax.debug.print(
|
|
392
|
-
"
|
|
426
|
+
"\nFinal iteration {i}: train loss value = {train_loss_val}",
|
|
393
427
|
i=i,
|
|
394
428
|
train_loss_val=loss_container.train_loss_values[i - 1],
|
|
395
429
|
)
|
|
430
|
+
|
|
431
|
+
# get ready to return the parameters at last iteration...
|
|
432
|
+
# (by default arbitrary choice, this could be None)
|
|
433
|
+
validation_parameters = optimization.last_non_nan_params
|
|
396
434
|
if validation is not None:
|
|
397
435
|
jax.debug.print(
|
|
398
436
|
"validation loss value = {validation_loss_val}",
|
|
399
437
|
validation_loss_val=validation_crit_values[i - 1],
|
|
400
438
|
)
|
|
439
|
+
if optimization_extra.early_stopping:
|
|
440
|
+
jax.debug.print(
|
|
441
|
+
"\n Returning a set of best parameters from early stopping"
|
|
442
|
+
" as last argument!\n"
|
|
443
|
+
" Best parameters from iteration {best_iter_id}"
|
|
444
|
+
" with validation loss criterion = {best_val_criterion}",
|
|
445
|
+
best_iter_id=optimization_extra.best_iter_id,
|
|
446
|
+
best_val_criterion=optimization_extra.best_val_criterion,
|
|
447
|
+
)
|
|
448
|
+
# ...but if early stopping, return the parameters at the best_iter_id
|
|
449
|
+
validation_parameters = optimization_extra.best_val_params
|
|
401
450
|
|
|
402
451
|
return (
|
|
403
452
|
optimization.last_non_nan_params,
|
|
@@ -408,7 +457,7 @@ def solve(
|
|
|
408
457
|
optimization.opt_state,
|
|
409
458
|
stored_objects.stored_params,
|
|
410
459
|
validation_crit_values if validation is not None else None,
|
|
411
|
-
|
|
460
|
+
validation_parameters,
|
|
412
461
|
)
|
|
413
462
|
|
|
414
463
|
|
|
@@ -531,7 +580,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
531
580
|
string is not a valid JAX type that can be fed into the operands
|
|
532
581
|
"""
|
|
533
582
|
if verbose:
|
|
534
|
-
jax.debug.print(f"
|
|
583
|
+
jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
|
|
535
584
|
return False
|
|
536
585
|
|
|
537
586
|
def continue_while_loop(_):
|
jinns/solver/_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from jinns.data._DataGenerators import (
|
|
2
|
+
DataGeneratorODE,
|
|
3
|
+
CubicMeshPDEStatio,
|
|
4
|
+
CubicMeshPDENonStatio,
|
|
5
|
+
DataGeneratorParameter,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _check_batch_size(other_data, main_data, attr_name):
|
|
10
|
+
if isinstance(main_data, DataGeneratorODE):
|
|
11
|
+
if main_data.temporal_batch_size is not None:
|
|
12
|
+
if getattr(other_data, attr_name) != main_data.temporal_batch_size:
|
|
13
|
+
raise ValueError(
|
|
14
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
15
|
+
f" to {main_data.__class__}.temporal_batch_size for correct"
|
|
16
|
+
" vectorization"
|
|
17
|
+
)
|
|
18
|
+
else:
|
|
19
|
+
if main_data.nt is not None:
|
|
20
|
+
if getattr(other_data, attr_name) != main_data.nt:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
23
|
+
f" to {main_data.__class__}.nt for correct"
|
|
24
|
+
" vectorization"
|
|
25
|
+
)
|
|
26
|
+
if isinstance(main_data, CubicMeshPDEStatio) and not isinstance(
|
|
27
|
+
main_data, CubicMeshPDENonStatio
|
|
28
|
+
):
|
|
29
|
+
if main_data.omega_batch_size is not None:
|
|
30
|
+
if getattr(other_data, attr_name) != main_data.omega_batch_size:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
33
|
+
f" to {main_data.__class__}.omega_batch_size for correct"
|
|
34
|
+
" vectorization"
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
if main_data.n is not None:
|
|
38
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
41
|
+
f" to {main_data.__class__}.n for correct"
|
|
42
|
+
" vectorization"
|
|
43
|
+
)
|
|
44
|
+
if main_data.omega_border_batch_size is not None:
|
|
45
|
+
if getattr(other_data, attr_name) != main_data.omega_border_batch_size:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
48
|
+
f" to {main_data.__class__}.omega_border_batch_size for correct"
|
|
49
|
+
" vectorization"
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
if main_data.nb is not None:
|
|
53
|
+
if getattr(other_data, attr_name) != main_data.nb:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
56
|
+
f" to {main_data.__class__}.nb for correct"
|
|
57
|
+
" vectorization"
|
|
58
|
+
)
|
|
59
|
+
if isinstance(main_data, CubicMeshPDENonStatio):
|
|
60
|
+
if main_data.domain_batch_size is not None:
|
|
61
|
+
if getattr(other_data, attr_name) != main_data.domain_batch_size:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
64
|
+
f" to {main_data.__class__}.domain_batch_size for correct"
|
|
65
|
+
" vectorization"
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
if main_data.n is not None:
|
|
69
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
72
|
+
f" to {main_data.__class__}.n for correct"
|
|
73
|
+
" vectorization"
|
|
74
|
+
)
|
|
75
|
+
if main_data.border_batch_size is not None:
|
|
76
|
+
if getattr(other_data, attr_name) != main_data.border_batch_size:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
79
|
+
f" to {main_data.__class__}.border_batch_size for correct"
|
|
80
|
+
" vectorization"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
if main_data.nb is not None:
|
|
84
|
+
if main_data.dim > 1 and getattr(other_data, attr_name) != (
|
|
85
|
+
main_data.nb // 2**main_data.dim
|
|
86
|
+
):
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
89
|
+
f" to ({main_data.__class__}.nb // 2**{main_data.__class__}.dim)"
|
|
90
|
+
" for correct vectorization"
|
|
91
|
+
)
|
|
92
|
+
if main_data.initial_batch_size is not None:
|
|
93
|
+
if getattr(other_data, attr_name) != main_data.initial_batch_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
96
|
+
f" to {main_data.__class__}.initial_batch_size for correct"
|
|
97
|
+
" vectorization"
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
if main_data.ni is not None:
|
|
101
|
+
if getattr(other_data, attr_name) != main_data.ni:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
104
|
+
f" to {main_data.__class__}.ni for correct"
|
|
105
|
+
" vectorization"
|
|
106
|
+
)
|
|
107
|
+
if isinstance(main_data, DataGeneratorParameter):
|
|
108
|
+
if main_data.param_batch_size is not None:
|
|
109
|
+
if getattr(other_data, attr_name) != main_data.param_batch_size:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
112
|
+
f" to {main_data.__class__}.param_batch_size for correct"
|
|
113
|
+
" vectorization"
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
if main_data.n is not None:
|
|
117
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
120
|
+
f" to {main_data.__class__}.n for correct"
|
|
121
|
+
" vectorization"
|
|
122
|
+
)
|
jinns/utils/__init__.py
CHANGED
jinns/utils/_containers.py
CHANGED
|
@@ -38,7 +38,9 @@ class OptimizationContainer(eqx.Module):
|
|
|
38
38
|
|
|
39
39
|
class OptimizationExtraContainer(eqx.Module):
|
|
40
40
|
curr_seq: int
|
|
41
|
-
|
|
41
|
+
best_iter_id: int # the best iteration number (that which achieves best_val_params and best_val_params)
|
|
42
|
+
best_val_criterion: float # the best validation criterion at early stopping
|
|
43
|
+
best_val_params: Params # the best parameter values at early stopping
|
|
42
44
|
early_stopping: Bool = False
|
|
43
45
|
|
|
44
46
|
|
jinns/utils/_types.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# pragma: exclude file
|
|
1
2
|
from __future__ import (
|
|
2
3
|
annotations,
|
|
3
4
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
@@ -25,9 +26,9 @@ if TYPE_CHECKING:
|
|
|
25
26
|
|
|
26
27
|
from jinns.loss import DynamicLoss
|
|
27
28
|
from jinns.data._Batchs import *
|
|
28
|
-
from jinns.
|
|
29
|
-
from jinns.
|
|
30
|
-
from jinns.
|
|
29
|
+
from jinns.nn._pinn import PINN
|
|
30
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
31
|
+
from jinns.nn._spinn_mlp import SPINN
|
|
31
32
|
from jinns.utils._containers import *
|
|
32
33
|
from jinns.validation._validation import AbstractValidationModule
|
|
33
34
|
|
|
@@ -41,7 +42,7 @@ if TYPE_CHECKING:
|
|
|
41
42
|
DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
|
|
42
43
|
)
|
|
43
44
|
|
|
44
|
-
AnyPINN: TypeAlias = PINN |
|
|
45
|
+
AnyPINN: TypeAlias = PINN | HyperPINN | SPINN
|
|
45
46
|
|
|
46
47
|
AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
|
|
47
48
|
rar_operands = NewType(
|
jinns/utils/_utils.py
CHANGED
|
@@ -2,13 +2,18 @@
|
|
|
2
2
|
Implements various utility functions
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
5
|
+
from math import prod
|
|
6
|
+
import warnings
|
|
8
7
|
import jax
|
|
9
8
|
import jax.numpy as jnp
|
|
10
9
|
from jaxtyping import PyTree, Array
|
|
11
10
|
|
|
11
|
+
from jinns.data._DataGenerators import (
|
|
12
|
+
DataGeneratorODE,
|
|
13
|
+
CubicMeshPDEStatio,
|
|
14
|
+
CubicMeshPDENonStatio,
|
|
15
|
+
)
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
14
19
|
"""
|
|
@@ -33,7 +38,7 @@ def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
|
33
38
|
)
|
|
34
39
|
|
|
35
40
|
|
|
36
|
-
def
|
|
41
|
+
def get_grid(in_array: Array) -> Array:
|
|
37
42
|
"""
|
|
38
43
|
From an array of shape (B, D), D > 1, get the grid array, i.e., an array of
|
|
39
44
|
shape (B, B, ...(D times)..., B, D): along the last axis we have the array
|
|
@@ -49,10 +54,14 @@ def _get_grid(in_array: Array) -> Array:
|
|
|
49
54
|
return in_array
|
|
50
55
|
|
|
51
56
|
|
|
52
|
-
def
|
|
57
|
+
def _check_shape_and_type(
|
|
58
|
+
r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
59
|
+
) -> Array | float:
|
|
53
60
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
61
|
+
Ensures float type and correct shapes for broadcasting when performing a
|
|
62
|
+
binary operation (like -, + or *) between two arrays.
|
|
63
|
+
First array is a custom user (observation data or output of initial/BC
|
|
64
|
+
functions), the expected shape is the same as the PINN's.
|
|
56
65
|
"""
|
|
57
66
|
if isinstance(r, (int, float)):
|
|
58
67
|
# if we have a scalar cast it to float
|
|
@@ -60,9 +69,28 @@ def _check_user_func_return(r: Array | int, shape: tuple) -> Array | int:
|
|
|
60
69
|
if r.shape == ():
|
|
61
70
|
# if we have a scalar inside a ndarray
|
|
62
71
|
return r.astype(float)
|
|
63
|
-
if r.shape[-1] ==
|
|
64
|
-
#
|
|
72
|
+
if r.shape[-1] == expected_shape[-1]:
|
|
73
|
+
# broadcasting will be OK
|
|
65
74
|
return r.astype(float)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
75
|
+
|
|
76
|
+
if r.shape != expected_shape:
|
|
77
|
+
# Usually, the reshape below adds a missing (1,) final axis to ensure # the PINN output and the other function (initial/boundary condition)
|
|
78
|
+
# have the correct shape, depending on how the user has coded the
|
|
79
|
+
# initial/boundary condition.
|
|
80
|
+
warnings.warn(
|
|
81
|
+
f"[{cause}] Performing operation `{binop}` between arrays"
|
|
82
|
+
f" of different shapes: got {r.shape} for the custom array and"
|
|
83
|
+
f" {expected_shape} for the PINN."
|
|
84
|
+
f" This can cause unexpected and wrong broadcasting."
|
|
85
|
+
f" Reshaping {r.shape} into {expected_shape}. Reshape your"
|
|
86
|
+
f" custom array to math the {expected_shape=} to prevent this"
|
|
87
|
+
f" warning."
|
|
88
|
+
)
|
|
89
|
+
return r.reshape(expected_shape)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _subtract_with_check(
|
|
93
|
+
a: Array | int, b: Array | int, cause: str = ""
|
|
94
|
+
) -> Array | float:
|
|
95
|
+
a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
|
|
96
|
+
return a - b
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: jinns
|
|
3
|
+
Version: 1.3.0
|
|
4
|
+
Summary: Physics Informed Neural Network with JAX
|
|
5
|
+
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
|
+
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
7
|
+
License: Apache License 2.0
|
|
8
|
+
Project-URL: Repository, https://gitlab.com/mia_jinns/jinns
|
|
9
|
+
Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Programming Language :: Python
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
License-File: AUTHORS
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: jax
|
|
19
|
+
Requires-Dist: jaxopt
|
|
20
|
+
Requires-Dist: optax
|
|
21
|
+
Requires-Dist: equinox>0.11.3
|
|
22
|
+
Requires-Dist: jax-tqdm
|
|
23
|
+
Requires-Dist: diffrax
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Provides-Extra: notebook
|
|
26
|
+
Requires-Dist: jupyter; extra == "notebook"
|
|
27
|
+
Requires-Dist: seaborn; extra == "notebook"
|
|
28
|
+
|
|
29
|
+
jinns
|
|
30
|
+
=====
|
|
31
|
+
|
|
32
|
+
 
|
|
33
|
+
|
|
34
|
+
Physics Informed Neural Networks with JAX. **jinns** is developed to estimate solutions of ODE and PDE problems using neural networks, with a strong focus on
|
|
35
|
+
|
|
36
|
+
1. inverse problems: find equation parameters given noisy/indirect observations
|
|
37
|
+
2. meta-modeling: solve for a parametric family of differential equations
|
|
38
|
+
|
|
39
|
+
It can also be used for forward problems and hybrid-modeling.
|
|
40
|
+
|
|
41
|
+
**jinns** specific points:
|
|
42
|
+
|
|
43
|
+
- **jinns uses JAX** - It is directed to JAX users: forward and backward autodiff, vmapping, jitting and more! No reinventing the wheel: it relies on the JAX ecosystem whenever possible, such as [equinox](https://github.com/patrick-kidger/equinox/) for neural networks or [optax](https://optax.readthedocs.io/) for optimization.
|
|
44
|
+
|
|
45
|
+
- **jinns is highly modular** - It gives users maximum control for defining their problems, and extending the package. The maths and computations are visible and not hidden behind layers of code!
|
|
46
|
+
|
|
47
|
+
- **jinns is efficient** - It compares favorably to other existing Python package for PINNs on the [PINNacle benchmarks](https://github.com/i207M/PINNacle/), as demonstrated in the table below. For more details on the benchmarks, checkout the [PINN multi-library benchmark](https://gitlab.com/mia_jinns/pinn-multi-library-benchmark)
|
|
48
|
+
|
|
49
|
+
- Implemented PINN architectures
|
|
50
|
+
- Vanilla Multi-Layer Perceptron popular accross the PINNs litterature.
|
|
51
|
+
|
|
52
|
+
- [Separable PINNs](https://openreview.net/pdf?id=dEySGIcDnI): allows to leverage forward-mode autodiff for computational speed.
|
|
53
|
+
|
|
54
|
+
- [Hyper PINNs](https://arxiv.org/pdf/2111.01008.pdf): useful for meta-modeling
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
- **Get started**: check out our various notebooks on the [documentation](https://mia_jinns.gitlab.io/jinns/index.html).
|
|
58
|
+
|
|
59
|
+
| | jinns | DeepXDE - JAX | DeepXDE - Pytorch | PINA | Nvidia Modulus |
|
|
60
|
+
|---|:---:|:---:|:---:|:---:|:---:|
|
|
61
|
+
| Burgers1D | **445** | 723 | 671 | 1977 | 646 |
|
|
62
|
+
| NS2d-C | **265** | 278 | 441 | 1600 | 275 |
|
|
63
|
+
| PInv | 149 | 218 | *CC* | 1509 | **135** |
|
|
64
|
+
| Diffusion-Reaction-Inv | **284** | *NI* | 3424 | 4061 | 2541 |
|
|
65
|
+
| Navier-Stokes-Inv | **175** | *NI* | 1511 | 1403 | 498 |
|
|
66
|
+
|
|
67
|
+
*Training time in seconds on an Nvidia T600 GPU. NI means problem cannot be implemented in the backend, CC means the code crashed.*
|
|
68
|
+
|
|
69
|
+

|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Installation
|
|
73
|
+
|
|
74
|
+
Install the latest version with pip
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
pip install jinns
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
# Documentation
|
|
81
|
+
|
|
82
|
+
The project's documentation is hosted on Gitlab page and available at [https://mia_jinns.gitlab.io/jinns/index.html](https://mia_jinns.gitlab.io/jinns/index.html).
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Found a bug / want a feature ?
|
|
86
|
+
|
|
87
|
+
Open an issue on the [Gitlab repo](https://gitlab.com/mia_jinns/jinns/-/issues).
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Contributing
|
|
91
|
+
|
|
92
|
+
Here are the contributors guidelines:
|
|
93
|
+
|
|
94
|
+
1. First fork the library on Gitlab.
|
|
95
|
+
|
|
96
|
+
2. Then clone and install the library in development mode with
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
pip install -e .
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
3. Install pre-commit and run it.
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install pre-commit
|
|
106
|
+
pre-commit install
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
4. Open a merge request once you are done with your changes, the review will be done via Gitlab.
|
|
110
|
+
|
|
111
|
+
# Contributors
|
|
112
|
+
|
|
113
|
+
Don't hesitate to contribute and get your name on the list here !
|
|
114
|
+
|
|
115
|
+
**List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
|
|
116
|
+
|
|
117
|
+
# Cite us
|
|
118
|
+
|
|
119
|
+
Please consider citing our work if you found it useful to yours, using this [ArXiV preprint](https://arxiv.org/abs/2412.14132)
|
|
120
|
+
```
|
|
121
|
+
@article{gangloff_jouvin2024jinns,
|
|
122
|
+
title={jinns: a JAX Library for Physics-Informed Neural Networks},
|
|
123
|
+
author={Gangloff, Hugo and Jouvin, Nicolas},
|
|
124
|
+
journal={arXiv preprint arXiv:2412.14132},
|
|
125
|
+
year={2024}
|
|
126
|
+
}
|
|
127
|
+
```
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
|
|
2
|
+
jinns/data/_Batchs.py,sha256=oc7-N1wEbsEvbe9fjVFKG2OPoZJVEjzPm8uj_icACf4,817
|
|
3
|
+
jinns/data/_DataGenerators.py,sha256=3pyUqzQ12AUBqOV-yqpt4X6K_7CqTFtUKMjg-gJE6KA,65101
|
|
4
|
+
jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
|
|
5
|
+
jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
|
|
6
|
+
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
7
|
+
jinns/loss/_DynamicLoss.py,sha256=lUpFl37_TfwxSREpoVKqUOpQEVqD3hrFXqwP2GZReWw,25817
|
|
8
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=bqmPxyrcvZh_dL74DTpj-TGiFxchvG8qC6KhuGeyOoA,12006
|
|
9
|
+
jinns/loss/_LossODE.py,sha256=QhhSyJpDbcyW4TdShX0HkxbvJQWXvnYg8lik8_wyOg4,23415
|
|
10
|
+
jinns/loss/_LossPDE.py,sha256=DZPinl7KYV2vp_CdjnhaR9M_gE-WOvyi4s8VSDEgti0,51046
|
|
11
|
+
jinns/loss/__init__.py,sha256=PRiJV9fd2GSwaCBVCPyh6pFc6pdA40jfb_T1YvO8ERc,712
|
|
12
|
+
jinns/loss/_boundary_conditions.py,sha256=kxHwNFSMsNzFso6nvAewcAdzW50yTi7IX-5Pthe65XY,12271
|
|
13
|
+
jinns/loss/_loss_utils.py,sha256=IkZAWmBumNWwk3hzeO0dh5RjHKZpt_hL4XnG5-Gpfr8,14690
|
|
14
|
+
jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
|
|
15
|
+
jinns/loss/_operators.py,sha256=qaRxwqgnZzlE_zTyUvafZGnUH5EZY1lpgjT9Vb7QJAQ,21718
|
|
16
|
+
jinns/nn/__init__.py,sha256=k9guJSKmKlHEadAjU-0HlYXJe55Tt783QrkZz6EYyO8,231
|
|
17
|
+
jinns/nn/_hyperpinn.py,sha256=nH8c9DeiiAujprEd7CVKU1chWn-kcSAY-fYLzd8_ikY,18049
|
|
18
|
+
jinns/nn/_mlp.py,sha256=AbbFLF85ayJcQ6kVwfSNdAvjP69UWBP6Z3V-1De-pI4,8028
|
|
19
|
+
jinns/nn/_pinn.py,sha256=45lXgrZQHv-7PQ3EDWWIoo8FlXRnjL1nl7mALTSJ45o,8391
|
|
20
|
+
jinns/nn/_ppinn.py,sha256=vqIH_v1DF3LoHyl3pJ1qhfnGMRMfvbfNK6m9s5LC21k,9212
|
|
21
|
+
jinns/nn/_save_load.py,sha256=VaO9LtR6dajEfo8iP7FgOvyLdQxT2IawazC2sxs97lc,9139
|
|
22
|
+
jinns/nn/_spinn.py,sha256=QmKhDZ0-ToJk3_glQ9BQWgoC0d-EEAWxMrDeHfB2slw,4191
|
|
23
|
+
jinns/nn/_spinn_mlp.py,sha256=9iU_-TIUFMVBcYv0nQmsa07ZwApIKqnXm7v4CY87PTo,7224
|
|
24
|
+
jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
|
|
25
|
+
jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
|
|
26
|
+
jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
|
|
27
|
+
jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
|
|
28
|
+
jinns/plot/_plot.py,sha256=6OqCNvOeqbat3dViOtehILbRfGIS3pnTmNRfbZYaVTA,11433
|
|
29
|
+
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
+
jinns/solver/_rar.py,sha256=JU4FgWt5w3tzgn2mNyftGi8Erxn5N0Za60-lRaL2poI,9724
|
|
31
|
+
jinns/solver/_solve.py,sha256=Bh7uplfcInJEQj1wmMquisN_vvUghARgX_uaYf7NUpw,23423
|
|
32
|
+
jinns/solver/_utils.py,sha256=b2zYvwZY_fU0NMNWvUEMvHez9s7hwcxfpGzQlz5F6HA,5762
|
|
33
|
+
jinns/utils/__init__.py,sha256=uw3I-lWT3wLabo6-H8FbKpSXI2xobzSs2W-Xno280g0,29
|
|
34
|
+
jinns/utils/_containers.py,sha256=a7A-iUApnjc1YVc7bdt9tKUvHHPDOKMB9OfdrDZGWN8,1450
|
|
35
|
+
jinns/utils/_types.py,sha256=4Qgsg6r9UPGpRwmERv4Cx2nU5ZIweehDlZQPo-FuR4Y,1896
|
|
36
|
+
jinns/utils/_utils.py,sha256=hoRcJqcTuQi_Ip40oI4EbxW46E1rp2C01_HfuCpwKRM,2932
|
|
37
|
+
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
38
|
+
jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
|
|
39
|
+
jinns-1.3.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
40
|
+
jinns-1.3.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
41
|
+
jinns-1.3.0.dist-info/METADATA,sha256=PM3iLQFd-vHDU697ECGjD2vQpgxo1vo1GTFl5AdIWoo,4744
|
|
42
|
+
jinns-1.3.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
43
|
+
jinns-1.3.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
44
|
+
jinns-1.3.0.dist-info/RECORD,,
|