jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -7,6 +7,7 @@ from __future__ import (
|
|
|
7
7
|
annotations,
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
|
+
import time
|
|
10
11
|
from typing import TYPE_CHECKING, NamedTuple, Dict, Union
|
|
11
12
|
from functools import partial
|
|
12
13
|
import optax
|
|
@@ -16,6 +17,7 @@ import jax.numpy as jnp
|
|
|
16
17
|
from jaxtyping import Int, Bool, Float, Array
|
|
17
18
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
18
19
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
|
+
from jinns.solver._utils import _check_batch_size
|
|
19
21
|
from jinns.utils._containers import *
|
|
20
22
|
from jinns.data._DataGenerators import (
|
|
21
23
|
DataGeneratorODE,
|
|
@@ -29,31 +31,6 @@ if TYPE_CHECKING:
|
|
|
29
31
|
from jinns.utils._types import *
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
def _check_batch_size(other_data, main_data, attr_name):
|
|
33
|
-
if (
|
|
34
|
-
(
|
|
35
|
-
isinstance(main_data, DataGeneratorODE)
|
|
36
|
-
and getattr(other_data, attr_name) != main_data.temporal_batch_size
|
|
37
|
-
)
|
|
38
|
-
or (
|
|
39
|
-
isinstance(main_data, CubicMeshPDEStatio)
|
|
40
|
-
and not isinstance(main_data, CubicMeshPDENonStatio)
|
|
41
|
-
and getattr(other_data, attr_name) != main_data.omega_batch_size
|
|
42
|
-
)
|
|
43
|
-
or (
|
|
44
|
-
isinstance(main_data, CubicMeshPDENonStatio)
|
|
45
|
-
and getattr(other_data, attr_name)
|
|
46
|
-
!= main_data.omega_batch_size * main_data.temporal_batch_size
|
|
47
|
-
)
|
|
48
|
-
):
|
|
49
|
-
raise ValueError(
|
|
50
|
-
"Optional other_data.param_batch_size must be"
|
|
51
|
-
" equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
|
|
52
|
-
" the product of both dependeing on the type of the main"
|
|
53
|
-
" datagenerator"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
34
|
def solve(
|
|
58
35
|
n_iter: Int,
|
|
59
36
|
init_params: AnyParams,
|
|
@@ -167,10 +144,22 @@ def solve(
|
|
|
167
144
|
The best parameters according to the validation criterion
|
|
168
145
|
"""
|
|
169
146
|
if param_data is not None:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
147
|
+
if param_data.param_batch_size is not None:
|
|
148
|
+
# We need to check that batch sizes will all be compliant for
|
|
149
|
+
# correct vectorization
|
|
150
|
+
_check_batch_size(param_data, data, "param_batch_size")
|
|
151
|
+
else:
|
|
152
|
+
# If DataGeneratorParameter does not have a batch size we will
|
|
153
|
+
# vectorization using `n`, and the same checks must be done
|
|
154
|
+
_check_batch_size(param_data, data, "n")
|
|
155
|
+
|
|
156
|
+
if obs_data is not None and param_data is not None:
|
|
157
|
+
# obs_data batch dimensions need only to be aligned with param_data
|
|
158
|
+
# batch dimensions if the latter exist
|
|
159
|
+
if obs_data.obs_batch_size is not None:
|
|
160
|
+
_check_batch_size(obs_data, param_data, "obs_batch_size")
|
|
161
|
+
else:
|
|
162
|
+
_check_batch_size(obs_data, param_data, "n")
|
|
174
163
|
|
|
175
164
|
if opt_state is None:
|
|
176
165
|
opt_state = optimizer.init(init_params)
|
|
@@ -224,6 +213,8 @@ def solve(
|
|
|
224
213
|
)
|
|
225
214
|
optimization_extra = OptimizationExtraContainer(
|
|
226
215
|
curr_seq=curr_seq,
|
|
216
|
+
best_iter_id=0,
|
|
217
|
+
best_val_criterion=jnp.nan,
|
|
227
218
|
best_val_params=init_params,
|
|
228
219
|
)
|
|
229
220
|
loss_container = LossContainer(
|
|
@@ -323,16 +314,26 @@ def solve(
|
|
|
323
314
|
validation_criterion
|
|
324
315
|
)
|
|
325
316
|
|
|
326
|
-
# update best_val_params w.r.t val_loss if needed
|
|
327
|
-
best_val_params = jax.lax.cond(
|
|
317
|
+
# update best_val_params and best_val_criterion w.r.t val_loss if needed
|
|
318
|
+
(best_val_params, best_val_criterion, best_iter_id) = jax.lax.cond(
|
|
328
319
|
update_best_params,
|
|
329
|
-
lambda
|
|
330
|
-
|
|
320
|
+
lambda operands: (
|
|
321
|
+
params,
|
|
322
|
+
validation_criterion,
|
|
323
|
+
i,
|
|
324
|
+
), # update with current value
|
|
325
|
+
lambda operands: (
|
|
326
|
+
operands[0].best_val_params,
|
|
327
|
+
operands[0].best_val_criterion,
|
|
328
|
+
operands[0].best_iter_id,
|
|
329
|
+
), # unchanged
|
|
331
330
|
(optimization_extra,),
|
|
332
331
|
)
|
|
333
332
|
else:
|
|
334
333
|
early_stopping = False
|
|
334
|
+
best_iter_id = 0
|
|
335
335
|
best_val_params = params
|
|
336
|
+
best_val_criterion = jnp.nan
|
|
336
337
|
|
|
337
338
|
# Trigger RAR
|
|
338
339
|
loss, params, data = trigger_rar(
|
|
@@ -358,7 +359,13 @@ def solve(
|
|
|
358
359
|
i,
|
|
359
360
|
loss,
|
|
360
361
|
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
361
|
-
OptimizationExtraContainer(
|
|
362
|
+
OptimizationExtraContainer(
|
|
363
|
+
curr_seq,
|
|
364
|
+
best_iter_id,
|
|
365
|
+
best_val_criterion,
|
|
366
|
+
best_val_params,
|
|
367
|
+
early_stopping,
|
|
368
|
+
),
|
|
362
369
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
363
370
|
validation,
|
|
364
371
|
LossContainer(stored_loss_terms, train_loss_values),
|
|
@@ -373,7 +380,20 @@ def solve(
|
|
|
373
380
|
while break_fun(carry):
|
|
374
381
|
carry = _one_iteration(carry)
|
|
375
382
|
else:
|
|
376
|
-
|
|
383
|
+
|
|
384
|
+
def train_fun(carry):
|
|
385
|
+
return jax.lax.while_loop(break_fun, _one_iteration, carry)
|
|
386
|
+
|
|
387
|
+
start = time.time()
|
|
388
|
+
compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
|
|
389
|
+
end = time.time()
|
|
390
|
+
print("\nCompilation took\n", end - start, "\n")
|
|
391
|
+
|
|
392
|
+
start = time.time()
|
|
393
|
+
carry = compiled_train_fun(carry)
|
|
394
|
+
jax.block_until_ready(carry)
|
|
395
|
+
end = time.time()
|
|
396
|
+
print("\nTraining took\n", end - start, "\n")
|
|
377
397
|
|
|
378
398
|
(
|
|
379
399
|
i,
|
|
@@ -389,15 +409,30 @@ def solve(
|
|
|
389
409
|
|
|
390
410
|
if verbose:
|
|
391
411
|
jax.debug.print(
|
|
392
|
-
"
|
|
412
|
+
"\nFinal iteration {i}: train loss value = {train_loss_val}",
|
|
393
413
|
i=i,
|
|
394
414
|
train_loss_val=loss_container.train_loss_values[i - 1],
|
|
395
415
|
)
|
|
416
|
+
|
|
417
|
+
# get ready to return the parameters at last iteration...
|
|
418
|
+
# (by default arbitrary choice, this could be None)
|
|
419
|
+
validation_parameters = optimization.last_non_nan_params
|
|
396
420
|
if validation is not None:
|
|
397
421
|
jax.debug.print(
|
|
398
422
|
"validation loss value = {validation_loss_val}",
|
|
399
423
|
validation_loss_val=validation_crit_values[i - 1],
|
|
400
424
|
)
|
|
425
|
+
if optimization_extra.early_stopping:
|
|
426
|
+
jax.debug.print(
|
|
427
|
+
"\n Returning a set of best parameters from early stopping"
|
|
428
|
+
" as last argument!\n"
|
|
429
|
+
" Best parameters from iteration {best_iter_id}"
|
|
430
|
+
" with validation loss criterion = {best_val_criterion}",
|
|
431
|
+
best_iter_id=optimization_extra.best_iter_id,
|
|
432
|
+
best_val_criterion=optimization_extra.best_val_criterion,
|
|
433
|
+
)
|
|
434
|
+
# ...but if early stopping, return the parameters at the best_iter_id
|
|
435
|
+
validation_parameters = optimization_extra.best_val_params
|
|
401
436
|
|
|
402
437
|
return (
|
|
403
438
|
optimization.last_non_nan_params,
|
|
@@ -408,7 +443,7 @@ def solve(
|
|
|
408
443
|
optimization.opt_state,
|
|
409
444
|
stored_objects.stored_params,
|
|
410
445
|
validation_crit_values if validation is not None else None,
|
|
411
|
-
|
|
446
|
+
validation_parameters,
|
|
412
447
|
)
|
|
413
448
|
|
|
414
449
|
|
|
@@ -531,7 +566,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
531
566
|
string is not a valid JAX type that can be fed into the operands
|
|
532
567
|
"""
|
|
533
568
|
if verbose:
|
|
534
|
-
jax.debug.print(f"
|
|
569
|
+
jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
|
|
535
570
|
return False
|
|
536
571
|
|
|
537
572
|
def continue_while_loop(_):
|
jinns/solver/_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from jinns.data._DataGenerators import (
|
|
2
|
+
DataGeneratorODE,
|
|
3
|
+
CubicMeshPDEStatio,
|
|
4
|
+
CubicMeshPDENonStatio,
|
|
5
|
+
DataGeneratorParameter,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _check_batch_size(other_data, main_data, attr_name):
|
|
10
|
+
if isinstance(main_data, DataGeneratorODE):
|
|
11
|
+
if main_data.temporal_batch_size is not None:
|
|
12
|
+
if getattr(other_data, attr_name) != main_data.temporal_batch_size:
|
|
13
|
+
raise ValueError(
|
|
14
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
15
|
+
f" to {main_data.__class__}.temporal_batch_size for correct"
|
|
16
|
+
" vectorization"
|
|
17
|
+
)
|
|
18
|
+
else:
|
|
19
|
+
if main_data.nt is not None:
|
|
20
|
+
if getattr(other_data, attr_name) != main_data.nt:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
23
|
+
f" to {main_data.__class__}.nt for correct"
|
|
24
|
+
" vectorization"
|
|
25
|
+
)
|
|
26
|
+
if isinstance(main_data, CubicMeshPDEStatio) and not isinstance(
|
|
27
|
+
main_data, CubicMeshPDENonStatio
|
|
28
|
+
):
|
|
29
|
+
if main_data.omega_batch_size is not None:
|
|
30
|
+
if getattr(other_data, attr_name) != main_data.omega_batch_size:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
33
|
+
f" to {main_data.__class__}.omega_batch_size for correct"
|
|
34
|
+
" vectorization"
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
if main_data.n is not None:
|
|
38
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
41
|
+
f" to {main_data.__class__}.n for correct"
|
|
42
|
+
" vectorization"
|
|
43
|
+
)
|
|
44
|
+
if main_data.omega_border_batch_size is not None:
|
|
45
|
+
if getattr(other_data, attr_name) != main_data.omega_border_batch_size:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
48
|
+
f" to {main_data.__class__}.omega_border_batch_size for correct"
|
|
49
|
+
" vectorization"
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
if main_data.nb is not None:
|
|
53
|
+
if getattr(other_data, attr_name) != main_data.nb:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
56
|
+
f" to {main_data.__class__}.nb for correct"
|
|
57
|
+
" vectorization"
|
|
58
|
+
)
|
|
59
|
+
if isinstance(main_data, CubicMeshPDENonStatio):
|
|
60
|
+
if main_data.domain_batch_size is not None:
|
|
61
|
+
if getattr(other_data, attr_name) != main_data.domain_batch_size:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
64
|
+
f" to {main_data.__class__}.domain_batch_size for correct"
|
|
65
|
+
" vectorization"
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
if main_data.n is not None:
|
|
69
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
72
|
+
f" to {main_data.__class__}.n for correct"
|
|
73
|
+
" vectorization"
|
|
74
|
+
)
|
|
75
|
+
if main_data.border_batch_size is not None:
|
|
76
|
+
if getattr(other_data, attr_name) != main_data.border_batch_size:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
79
|
+
f" to {main_data.__class__}.border_batch_size for correct"
|
|
80
|
+
" vectorization"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
if main_data.nb is not None:
|
|
84
|
+
if main_data.dim > 1 and getattr(other_data, attr_name) != (
|
|
85
|
+
main_data.nb // 2**main_data.dim
|
|
86
|
+
):
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
89
|
+
f" to ({main_data.__class__}.nb // 2**{main_data.__class__}.dim)"
|
|
90
|
+
" for correct vectorization"
|
|
91
|
+
)
|
|
92
|
+
if main_data.initial_batch_size is not None:
|
|
93
|
+
if getattr(other_data, attr_name) != main_data.initial_batch_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
96
|
+
f" to {main_data.__class__}.initial_batch_size for correct"
|
|
97
|
+
" vectorization"
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
if main_data.ni is not None:
|
|
101
|
+
if getattr(other_data, attr_name) != main_data.ni:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
104
|
+
f" to {main_data.__class__}.ni for correct"
|
|
105
|
+
" vectorization"
|
|
106
|
+
)
|
|
107
|
+
if isinstance(main_data, DataGeneratorParameter):
|
|
108
|
+
if main_data.param_batch_size is not None:
|
|
109
|
+
if getattr(other_data, attr_name) != main_data.param_batch_size:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
112
|
+
f" to {main_data.__class__}.param_batch_size for correct"
|
|
113
|
+
" vectorization"
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
if main_data.n is not None:
|
|
117
|
+
if getattr(other_data, attr_name) != main_data.n:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"{other_data.__class__}.{attr_name} must be equal"
|
|
120
|
+
f" to {main_data.__class__}.n for correct"
|
|
121
|
+
" vectorization"
|
|
122
|
+
)
|
jinns/utils/__init__.py
CHANGED
jinns/utils/_containers.py
CHANGED
|
@@ -38,7 +38,9 @@ class OptimizationContainer(eqx.Module):
|
|
|
38
38
|
|
|
39
39
|
class OptimizationExtraContainer(eqx.Module):
|
|
40
40
|
curr_seq: int
|
|
41
|
-
|
|
41
|
+
best_iter_id: int # the best iteration number (that which achieves best_val_params and best_val_params)
|
|
42
|
+
best_val_criterion: float # the best validation criterion at early stopping
|
|
43
|
+
best_val_params: Params # the best parameter values at early stopping
|
|
42
44
|
early_stopping: Bool = False
|
|
43
45
|
|
|
44
46
|
|
jinns/utils/_hyperpinn.py
CHANGED
|
@@ -16,7 +16,7 @@ import equinox as eqx
|
|
|
16
16
|
import numpy as onp
|
|
17
17
|
|
|
18
18
|
from jinns.utils._pinn import PINN, _MLP
|
|
19
|
-
from jinns.parameters._params import Params
|
|
19
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _get_param_nb(
|
|
@@ -114,6 +114,7 @@ class HYPERPINN(PINN):
|
|
|
114
114
|
)
|
|
115
115
|
self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.params)
|
|
116
116
|
|
|
117
|
+
@property
|
|
117
118
|
def init_params(self) -> Params:
|
|
118
119
|
"""
|
|
119
120
|
Returns an initial set of parameters
|
|
@@ -138,14 +139,20 @@ class HYPERPINN(PINN):
|
|
|
138
139
|
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
139
140
|
)
|
|
140
141
|
|
|
141
|
-
def
|
|
142
|
+
def __call__(
|
|
142
143
|
self,
|
|
143
144
|
inputs: Float[Array, "input_dim"],
|
|
144
|
-
params: Params | PyTree,
|
|
145
|
+
params: Params | ParamsDict | PyTree,
|
|
145
146
|
) -> Float[Array, "output_dim"]:
|
|
146
147
|
"""
|
|
147
|
-
Evaluate the
|
|
148
|
+
Evaluate the HyperPINN on some inputs with some params.
|
|
148
149
|
"""
|
|
150
|
+
if len(inputs.shape) == 0:
|
|
151
|
+
# This can happen often when the user directly provides some
|
|
152
|
+
# collocation points (eg for plotting, whithout using
|
|
153
|
+
# DataGenerators)
|
|
154
|
+
inputs = inputs[None]
|
|
155
|
+
|
|
149
156
|
try:
|
|
150
157
|
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
151
158
|
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
@@ -190,7 +197,7 @@ def create_HYPERPINN(
|
|
|
190
197
|
slice_solution: slice = None,
|
|
191
198
|
shared_pinn_outputs: slice = None,
|
|
192
199
|
eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
|
|
193
|
-
) -> HYPERPINN | list[HYPERPINN]:
|
|
200
|
+
) -> tuple[HYPERPINN | list[HYPERPINN], PyTree | list[PyTree]]:
|
|
194
201
|
r"""
|
|
195
202
|
Utility function to create a standard PINN neural network with the equinox
|
|
196
203
|
library.
|
|
@@ -274,6 +281,9 @@ def create_HYPERPINN(
|
|
|
274
281
|
A HYPERPINN instance or, when `shared_pinn_ouput` is not None,
|
|
275
282
|
a list of HYPERPINN instances with the same structure is returned,
|
|
276
283
|
only differing by there final slicing of the network output.
|
|
284
|
+
hyperpinn.init_params
|
|
285
|
+
The initial set of parameters for the HyperPINN or a list of the latter
|
|
286
|
+
when `shared_pinn_ouput` is not None.
|
|
277
287
|
|
|
278
288
|
|
|
279
289
|
Raises
|
|
@@ -389,7 +399,7 @@ def create_HYPERPINN(
|
|
|
389
399
|
output_slice=output_slice,
|
|
390
400
|
)
|
|
391
401
|
hyperpinns.append(hyperpinn)
|
|
392
|
-
return hyperpinns
|
|
402
|
+
return hyperpinns, [h.init_params for h in hyperpinns]
|
|
393
403
|
with warnings.catch_warnings():
|
|
394
404
|
# Catch the equinox warning because we put the number of
|
|
395
405
|
# parameters as static while being jnp.Array. This this time
|
|
@@ -407,4 +417,4 @@ def create_HYPERPINN(
|
|
|
407
417
|
hypernet_input_size=hypernet_input_size,
|
|
408
418
|
output_slice=None,
|
|
409
419
|
)
|
|
410
|
-
return hyperpinn
|
|
420
|
+
return hyperpinn, hyperpinn.init_params
|
jinns/utils/_pinn.py
CHANGED
|
@@ -10,7 +10,7 @@ import equinox as eqx
|
|
|
10
10
|
|
|
11
11
|
from jaxtyping import Array, Key, PyTree, Float
|
|
12
12
|
|
|
13
|
-
from jinns.parameters._params import Params
|
|
13
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class _MLP(eqx.Module):
|
|
@@ -128,40 +128,27 @@ class PINN(eqx.Module):
|
|
|
128
128
|
def __post_init__(self, mlp):
|
|
129
129
|
self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
|
|
130
130
|
|
|
131
|
+
@property
|
|
131
132
|
def init_params(self) -> PyTree:
|
|
132
133
|
"""
|
|
133
134
|
Returns an initial set of parameters
|
|
134
135
|
"""
|
|
135
136
|
return self.params
|
|
136
137
|
|
|
137
|
-
def __call__(
|
|
138
|
-
"""
|
|
139
|
-
Calls `eval_nn` with rearranged arguments
|
|
140
|
-
"""
|
|
141
|
-
if self.eq_type == "ODE":
|
|
142
|
-
(t, params) = args
|
|
143
|
-
if len(t.shape) == 0:
|
|
144
|
-
t = t[..., None] # Add mandatory dimension which can be lacking
|
|
145
|
-
# (eg. for the ODE batches) but this dimension can already
|
|
146
|
-
# exists (eg. for user provided observation times)
|
|
147
|
-
return self.eval_nn(t, params)
|
|
148
|
-
if self.eq_type == "statio_PDE":
|
|
149
|
-
(x, params) = args
|
|
150
|
-
return self.eval_nn(x, params)
|
|
151
|
-
if self.eq_type == "nonstatio_PDE":
|
|
152
|
-
(t, x, params) = args
|
|
153
|
-
t_x = jnp.concatenate([t, x], axis=-1)
|
|
154
|
-
return self.eval_nn(t_x, params)
|
|
155
|
-
raise ValueError("Wrong value for self.eq_type")
|
|
156
|
-
|
|
157
|
-
def eval_nn(
|
|
138
|
+
def __call__(
|
|
158
139
|
self,
|
|
159
|
-
inputs: Float[Array, "
|
|
160
|
-
params: Params | PyTree,
|
|
140
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
141
|
+
params: Params | ParamsDict | PyTree,
|
|
161
142
|
) -> Float[Array, "output_dim"]:
|
|
162
143
|
"""
|
|
163
144
|
Evaluate the PINN on some inputs with some params.
|
|
164
145
|
"""
|
|
146
|
+
if len(inputs.shape) == 0:
|
|
147
|
+
# This can happen often when the user directly provides some
|
|
148
|
+
# collocation points (eg for plotting, whithout using
|
|
149
|
+
# DataGenerators)
|
|
150
|
+
inputs = inputs[None]
|
|
151
|
+
|
|
165
152
|
try:
|
|
166
153
|
model = eqx.combine(params.nn_params, self.static)
|
|
167
154
|
except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
@@ -193,7 +180,7 @@ def create_PINN(
|
|
|
193
180
|
] = None,
|
|
194
181
|
shared_pinn_outputs: tuple[slice] = None,
|
|
195
182
|
slice_solution: slice = None,
|
|
196
|
-
) -> PINN | list[PINN]:
|
|
183
|
+
) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
|
|
197
184
|
r"""
|
|
198
185
|
Utility function to create a standard PINN neural network with the equinox
|
|
199
186
|
library.
|
|
@@ -266,6 +253,9 @@ def create_PINN(
|
|
|
266
253
|
A PINN instance or, when `shared_pinn_ouput` is not None,
|
|
267
254
|
a list of PINN instances with the same structure is returned,
|
|
268
255
|
only differing by there final slicing of the network output.
|
|
256
|
+
pinn.init_params
|
|
257
|
+
An initial set of parameters for the PINN or a list of the latter
|
|
258
|
+
when `shared_pinn_ouput` is not None.
|
|
269
259
|
|
|
270
260
|
Raises
|
|
271
261
|
------
|
|
@@ -322,7 +312,7 @@ def create_PINN(
|
|
|
322
312
|
output_slice=output_slice,
|
|
323
313
|
)
|
|
324
314
|
pinns.append(pinn)
|
|
325
|
-
return pinns
|
|
315
|
+
return pinns, [p.init_params for p in pinns]
|
|
326
316
|
pinn = PINN(
|
|
327
317
|
mlp=mlp,
|
|
328
318
|
slice_solution=slice_solution,
|
|
@@ -331,4 +321,4 @@ def create_PINN(
|
|
|
331
321
|
output_transform=output_transform,
|
|
332
322
|
output_slice=None,
|
|
333
323
|
)
|
|
334
|
-
return pinn
|
|
324
|
+
return pinn, pinn.init_params
|