jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -3,26 +3,33 @@ This modules implements the main `solve()` function of jinns which
|
|
|
3
3
|
handles the optimization process
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
from __future__ import (
|
|
7
|
+
annotations,
|
|
8
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, NamedTuple, Dict, Union
|
|
7
11
|
from functools import partial
|
|
8
12
|
import optax
|
|
9
13
|
import jax
|
|
10
14
|
from jax import jit
|
|
11
15
|
import jax.numpy as jnp
|
|
12
|
-
from
|
|
16
|
+
from jaxtyping import Int, Bool, Float, Array
|
|
13
17
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
14
|
-
from jinns.utils._utils import _check_nan_in_pytree
|
|
18
|
+
from jinns.utils._utils import _check_nan_in_pytree
|
|
19
|
+
from jinns.utils._containers import *
|
|
15
20
|
from jinns.data._DataGenerators import (
|
|
16
21
|
DataGeneratorODE,
|
|
17
22
|
CubicMeshPDEStatio,
|
|
18
23
|
CubicMeshPDENonStatio,
|
|
19
|
-
append_param_batch,
|
|
20
24
|
append_obs_batch,
|
|
25
|
+
append_param_batch,
|
|
21
26
|
)
|
|
22
|
-
from jinns.utils._containers import *
|
|
23
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from jinns.utils._types import *
|
|
24
30
|
|
|
25
|
-
|
|
31
|
+
|
|
32
|
+
def _check_batch_size(other_data, main_data, attr_name):
|
|
26
33
|
if (
|
|
27
34
|
(
|
|
28
35
|
isinstance(main_data, DataGeneratorODE)
|
|
@@ -48,21 +55,32 @@ def check_batch_size(other_data, main_data, attr_name):
|
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
def solve(
|
|
51
|
-
n_iter,
|
|
52
|
-
init_params,
|
|
53
|
-
data,
|
|
54
|
-
loss,
|
|
55
|
-
optimizer,
|
|
56
|
-
print_loss_every=1000,
|
|
57
|
-
opt_state=None,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
58
|
+
n_iter: Int,
|
|
59
|
+
init_params: AnyParams,
|
|
60
|
+
data: AnyDataGenerator,
|
|
61
|
+
loss: AnyLoss,
|
|
62
|
+
optimizer: optax.GradientTransformation,
|
|
63
|
+
print_loss_every: Int = 1000,
|
|
64
|
+
opt_state: Union[NamedTuple, None] = None,
|
|
65
|
+
tracked_params: Params | ParamsDict | None = None,
|
|
66
|
+
param_data: DataGeneratorParameter | None = None,
|
|
67
|
+
obs_data: (
|
|
68
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
|
|
69
|
+
) = None,
|
|
70
|
+
validation: AbstractValidationModule | None = None,
|
|
71
|
+
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
72
|
+
verbose: Bool = True,
|
|
73
|
+
) -> tuple[
|
|
74
|
+
Params | ParamsDict,
|
|
75
|
+
Float[Array, "n_iter"],
|
|
76
|
+
Dict[str, Float[Array, "n_iter"]],
|
|
77
|
+
AnyDataGenerator,
|
|
78
|
+
AnyLoss,
|
|
79
|
+
NamedTuple,
|
|
80
|
+
AnyParams,
|
|
81
|
+
Float[Array, "n_iter"],
|
|
82
|
+
AnyParams,
|
|
83
|
+
]:
|
|
66
84
|
"""
|
|
67
85
|
Performs the optimization process via stochastic gradient descent
|
|
68
86
|
algorithm. We minimize the function defined `loss.evaluate()` with
|
|
@@ -73,52 +91,39 @@ def solve(
|
|
|
73
91
|
Parameters
|
|
74
92
|
----------
|
|
75
93
|
n_iter
|
|
76
|
-
The number of iterations in the optimization
|
|
94
|
+
The maximum number of iterations in the optimization.
|
|
77
95
|
init_params
|
|
78
|
-
The initial
|
|
79
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
80
|
-
differential equation parameters and the neural network parameter
|
|
96
|
+
The initial jinns.parameters.Params object.
|
|
81
97
|
data
|
|
82
|
-
A DataGenerator object
|
|
83
|
-
method which returns a 3-tuple with (omega_grid, omega_border, time grid).
|
|
84
|
-
It must be jittable (e.g. implements via a pytree
|
|
85
|
-
registration)
|
|
98
|
+
A DataGenerator object to retrieve batches of collocation points.
|
|
86
99
|
loss
|
|
87
|
-
|
|
88
|
-
object). It must be jittable (e.g. implements via a pytree
|
|
89
|
-
registration)
|
|
100
|
+
The loss function to minimize.
|
|
90
101
|
optimizer
|
|
91
|
-
An
|
|
102
|
+
An optax optimizer.
|
|
92
103
|
print_loss_every
|
|
93
|
-
|
|
104
|
+
Default 1000. The rate at which we print the loss value in the
|
|
94
105
|
gradient step loop.
|
|
95
106
|
opt_state
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
the
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
The seq2seq approach we reimplements is defined in
|
|
105
|
-
"Characterizing possible failure modes in physics-informed neural
|
|
106
|
-
networks", A. S. Krishnapriyan, NeurIPS 2021
|
|
107
|
-
tracked_params_key_list
|
|
108
|
-
Default None. Otherwise it is a list of list of strings
|
|
109
|
-
to access a leaf in params. Each selected leaf will be tracked
|
|
110
|
-
and stored at each iteration and returned by the solve function
|
|
107
|
+
Provide an optional initial state to the optimizer.
|
|
108
|
+
tracked_params
|
|
109
|
+
Default None. An eqx.Module of type Params with non-None values for
|
|
110
|
+
parameters that needs to be tracked along the iterations.
|
|
111
|
+
None values in tracked_params will not be traversed. Thus
|
|
112
|
+
the user can provide something like `tracked_params = jinns.parameters.Params(
|
|
113
|
+
nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
114
|
+
being a complex data structure.
|
|
111
115
|
param_data
|
|
112
116
|
Default None. A DataGeneratorParameter object which can be used to
|
|
113
117
|
sample equation parameters.
|
|
114
118
|
obs_data
|
|
115
|
-
Default None. A DataGeneratorObservations
|
|
116
|
-
|
|
119
|
+
Default None. A DataGeneratorObservations or
|
|
120
|
+
DataGeneratorObservationsMultiPINNs
|
|
121
|
+
object which can be used to sample minibatches of observations.
|
|
117
122
|
validation
|
|
118
123
|
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
119
|
-
validation strategy. See documentation of
|
|
124
|
+
validation strategy. See documentation of `jinns.validation.
|
|
120
125
|
_validation.AbstractValidationModule` for the general interface, and
|
|
121
|
-
|
|
126
|
+
`jinns.validation._validation.ValidationLoss` for a practical
|
|
122
127
|
implementation of a vanilla validation stategy on a validation set of
|
|
123
128
|
collocation points.
|
|
124
129
|
|
|
@@ -132,15 +137,15 @@ def solve(
|
|
|
132
137
|
Default None. An optional sharding object to constraint the obs_batch.
|
|
133
138
|
Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
|
|
134
139
|
created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
|
|
135
|
-
loading on GPU huge datasets of observations
|
|
136
|
-
verbose
|
|
137
|
-
|
|
140
|
+
loading on GPU huge datasets of observations.
|
|
141
|
+
verbose
|
|
142
|
+
Default True. If False, no std output (loss or cause of
|
|
138
143
|
exiting the optimization loop) will be produced.
|
|
139
144
|
|
|
140
145
|
Returns
|
|
141
146
|
-------
|
|
142
147
|
params
|
|
143
|
-
The last non NaN value of the
|
|
148
|
+
The last non NaN value of the params at then end of the
|
|
144
149
|
optimization process
|
|
145
150
|
total_loss_values
|
|
146
151
|
An array of the total loss term along the gradient steps
|
|
@@ -154,14 +159,18 @@ def solve(
|
|
|
154
159
|
opt_state
|
|
155
160
|
The final optimized state
|
|
156
161
|
stored_params
|
|
157
|
-
A
|
|
158
|
-
|
|
162
|
+
A Params objects with the stored values of the desired parameters (as
|
|
163
|
+
signified in tracked_params argument)
|
|
164
|
+
validation_crit_values
|
|
165
|
+
An array containing the validation criterion values of the training
|
|
166
|
+
best_val_params
|
|
167
|
+
The best parameters according to the validation criterion
|
|
159
168
|
"""
|
|
160
169
|
if param_data is not None:
|
|
161
|
-
|
|
170
|
+
_check_batch_size(param_data, data, "param_batch_size")
|
|
162
171
|
|
|
163
172
|
if obs_data is not None:
|
|
164
|
-
|
|
173
|
+
_check_batch_size(obs_data, data, "obs_batch_size")
|
|
165
174
|
|
|
166
175
|
if opt_state is None:
|
|
167
176
|
opt_state = optimizer.init(init_params)
|
|
@@ -172,30 +181,32 @@ def solve(
|
|
|
172
181
|
|
|
173
182
|
# Seq2seq
|
|
174
183
|
curr_seq = 0
|
|
175
|
-
if seq2seq is not None:
|
|
176
|
-
assert (
|
|
177
|
-
data.method == "uniform"
|
|
178
|
-
), "data.method must be uniform if using seq2seq learning !"
|
|
179
|
-
data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
|
|
180
184
|
|
|
181
185
|
train_loss_values = jnp.zeros((n_iter))
|
|
182
186
|
# depending on obs_batch_sharding we will get the simple get_batch or the
|
|
183
187
|
# get_batch with device_put, the latter is not jittable
|
|
184
|
-
get_batch =
|
|
188
|
+
get_batch = _get_get_batch(obs_batch_sharding)
|
|
185
189
|
|
|
186
190
|
# initialize the dict for stored parameter values
|
|
187
191
|
# we need to get a loss_term to init stuff
|
|
188
192
|
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
189
193
|
_, loss_terms = loss(init_params, batch_ini)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
tracked_params
|
|
194
|
+
|
|
195
|
+
# initialize parameter tracking
|
|
196
|
+
if tracked_params is None:
|
|
197
|
+
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
193
198
|
stored_params = jax.tree_util.tree_map(
|
|
194
199
|
lambda tracked_param, param: (
|
|
195
|
-
jnp.zeros((n_iter,) + param.shape)
|
|
200
|
+
jnp.zeros((n_iter,) + jnp.asarray(param).shape)
|
|
201
|
+
if tracked_param is not None
|
|
202
|
+
else None
|
|
196
203
|
),
|
|
197
204
|
tracked_params,
|
|
198
205
|
init_params,
|
|
206
|
+
is_leaf=lambda x: x is None, # None values in tracked_params will not
|
|
207
|
+
# be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
|
|
208
|
+
# nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
209
|
+
# being a complex data structure
|
|
199
210
|
)
|
|
200
211
|
|
|
201
212
|
# initialize the dict for stored loss values
|
|
@@ -208,13 +219,12 @@ def solve(
|
|
|
208
219
|
)
|
|
209
220
|
optimization = OptimizationContainer(
|
|
210
221
|
params=init_params,
|
|
211
|
-
last_non_nan_params=init_params
|
|
222
|
+
last_non_nan_params=init_params,
|
|
212
223
|
opt_state=opt_state,
|
|
213
224
|
)
|
|
214
225
|
optimization_extra = OptimizationExtraContainer(
|
|
215
226
|
curr_seq=curr_seq,
|
|
216
|
-
|
|
217
|
-
best_val_params=init_params.copy(),
|
|
227
|
+
best_val_params=init_params,
|
|
218
228
|
)
|
|
219
229
|
loss_container = LossContainer(
|
|
220
230
|
stored_loss_terms=stored_loss_terms,
|
|
@@ -229,7 +239,7 @@ def solve(
|
|
|
229
239
|
else:
|
|
230
240
|
validation_crit_values = None
|
|
231
241
|
|
|
232
|
-
break_fun =
|
|
242
|
+
break_fun = _get_break_fun(n_iter, verbose)
|
|
233
243
|
|
|
234
244
|
iteration = 0
|
|
235
245
|
carry = (
|
|
@@ -244,7 +254,7 @@ def solve(
|
|
|
244
254
|
validation_crit_values,
|
|
245
255
|
)
|
|
246
256
|
|
|
247
|
-
def
|
|
257
|
+
def _one_iteration(carry: main_carry) -> main_carry:
|
|
248
258
|
(
|
|
249
259
|
i,
|
|
250
260
|
loss,
|
|
@@ -269,7 +279,7 @@ def solve(
|
|
|
269
279
|
params,
|
|
270
280
|
opt_state,
|
|
271
281
|
last_non_nan_params,
|
|
272
|
-
) =
|
|
282
|
+
) = _gradient_step(
|
|
273
283
|
loss,
|
|
274
284
|
optimizer,
|
|
275
285
|
batch,
|
|
@@ -280,7 +290,7 @@ def solve(
|
|
|
280
290
|
|
|
281
291
|
# Print train loss value during optimization
|
|
282
292
|
if verbose:
|
|
283
|
-
|
|
293
|
+
_print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
284
294
|
|
|
285
295
|
if validation is not None:
|
|
286
296
|
# there is a jax.lax.cond because we do not necesarily call the
|
|
@@ -306,7 +316,7 @@ def solve(
|
|
|
306
316
|
)
|
|
307
317
|
# Print validation loss value during optimization
|
|
308
318
|
if verbose:
|
|
309
|
-
|
|
319
|
+
_print_fn(
|
|
310
320
|
i, validation_criterion, print_loss_every, prefix="[validation] "
|
|
311
321
|
)
|
|
312
322
|
validation_crit_values = validation_crit_values.at[i].set(
|
|
@@ -329,19 +339,8 @@ def solve(
|
|
|
329
339
|
i, loss, params, data, _rar_step_true, _rar_step_false
|
|
330
340
|
)
|
|
331
341
|
|
|
332
|
-
# Trigger seq2seq
|
|
333
|
-
loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
|
|
334
|
-
i,
|
|
335
|
-
loss,
|
|
336
|
-
params,
|
|
337
|
-
data,
|
|
338
|
-
opt_state,
|
|
339
|
-
optimization_extra.curr_seq,
|
|
340
|
-
optimization_extra.seq2seq,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
342
|
# save loss value and selected parameters
|
|
344
|
-
stored_params, stored_loss_terms, train_loss_values =
|
|
343
|
+
stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
|
|
345
344
|
i,
|
|
346
345
|
params,
|
|
347
346
|
stored_objects.stored_params,
|
|
@@ -359,9 +358,7 @@ def solve(
|
|
|
359
358
|
i,
|
|
360
359
|
loss,
|
|
361
360
|
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
362
|
-
OptimizationExtraContainer(
|
|
363
|
-
curr_seq, seq2seq, best_val_params, early_stopping
|
|
364
|
-
),
|
|
361
|
+
OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
|
|
365
362
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
366
363
|
validation,
|
|
367
364
|
LossContainer(stored_loss_terms, train_loss_values),
|
|
@@ -374,9 +371,9 @@ def solve(
|
|
|
374
371
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
375
372
|
if obs_batch_sharding is not None:
|
|
376
373
|
while break_fun(carry):
|
|
377
|
-
carry =
|
|
374
|
+
carry = _one_iteration(carry)
|
|
378
375
|
else:
|
|
379
|
-
carry = jax.lax.while_loop(break_fun,
|
|
376
|
+
carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
|
|
380
377
|
|
|
381
378
|
(
|
|
382
379
|
i,
|
|
@@ -416,7 +413,21 @@ def solve(
|
|
|
416
413
|
|
|
417
414
|
|
|
418
415
|
@partial(jit, static_argnames=["optimizer"])
|
|
419
|
-
def
|
|
416
|
+
def _gradient_step(
|
|
417
|
+
loss: AnyLoss,
|
|
418
|
+
optimizer: optax.GradientTransformation,
|
|
419
|
+
batch: AnyBatch,
|
|
420
|
+
params: AnyParams,
|
|
421
|
+
opt_state: NamedTuple,
|
|
422
|
+
last_non_nan_params: AnyParams,
|
|
423
|
+
) -> tuple[
|
|
424
|
+
AnyLoss,
|
|
425
|
+
float,
|
|
426
|
+
Dict[str, float],
|
|
427
|
+
AnyParams,
|
|
428
|
+
NamedTuple,
|
|
429
|
+
AnyParams,
|
|
430
|
+
]:
|
|
420
431
|
"""
|
|
421
432
|
optimizer cannot be jit-ted.
|
|
422
433
|
"""
|
|
@@ -444,7 +455,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
444
455
|
|
|
445
456
|
|
|
446
457
|
@partial(jit, static_argnames=["prefix"])
|
|
447
|
-
def
|
|
458
|
+
def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
|
|
448
459
|
# note that if the following is not jitted in the main lor loop, it is
|
|
449
460
|
# super slow
|
|
450
461
|
_ = jax.lax.cond(
|
|
@@ -460,16 +471,18 @@ def print_fn(i, loss_val, print_loss_every, prefix=""):
|
|
|
460
471
|
|
|
461
472
|
|
|
462
473
|
@jit
|
|
463
|
-
def
|
|
464
|
-
i,
|
|
465
|
-
params,
|
|
466
|
-
stored_params,
|
|
467
|
-
stored_loss_terms,
|
|
468
|
-
train_loss_values,
|
|
469
|
-
train_loss_val,
|
|
470
|
-
loss_terms,
|
|
471
|
-
tracked_params,
|
|
472
|
-
)
|
|
474
|
+
def _store_loss_and_params(
|
|
475
|
+
i: Int,
|
|
476
|
+
params: AnyParams,
|
|
477
|
+
stored_params: AnyParams,
|
|
478
|
+
stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
|
|
479
|
+
train_loss_values: Float[Array, "n_iter"],
|
|
480
|
+
train_loss_val: float,
|
|
481
|
+
loss_terms: Dict[str, float],
|
|
482
|
+
tracked_params: AnyParams,
|
|
483
|
+
) -> tuple[
|
|
484
|
+
Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
|
|
485
|
+
]:
|
|
473
486
|
stored_params = jax.tree_util.tree_map(
|
|
474
487
|
lambda stored_value, param, tracked_param: (
|
|
475
488
|
None
|
|
@@ -496,7 +509,7 @@ def store_loss_and_params(
|
|
|
496
509
|
return (stored_params, stored_loss_terms, train_loss_values)
|
|
497
510
|
|
|
498
511
|
|
|
499
|
-
def
|
|
512
|
+
def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
500
513
|
"""
|
|
501
514
|
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
502
515
|
The verbose argument is here to control printing (or not) when exiting
|
|
@@ -507,8 +520,8 @@ def get_break_fun(n_iter, verbose: str):
|
|
|
507
520
|
@jit
|
|
508
521
|
def break_fun(carry: tuple):
|
|
509
522
|
"""
|
|
510
|
-
Function to break from the main optimization loop
|
|
511
|
-
|
|
523
|
+
Function to break from the main optimization loop whe the following
|
|
524
|
+
conditions are met : maximum number of iterations, NaN
|
|
512
525
|
appearing in the parameters, and early stopping criterion.
|
|
513
526
|
"""
|
|
514
527
|
|
|
@@ -559,43 +572,57 @@ def get_break_fun(n_iter, verbose: str):
|
|
|
559
572
|
return break_fun
|
|
560
573
|
|
|
561
574
|
|
|
562
|
-
def
|
|
575
|
+
def _get_get_batch(
|
|
576
|
+
obs_batch_sharding: jax.sharding.Sharding,
|
|
577
|
+
) -> Callable[
|
|
578
|
+
[
|
|
579
|
+
AnyDataGenerator,
|
|
580
|
+
DataGeneratorParameter | None,
|
|
581
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
|
|
582
|
+
],
|
|
583
|
+
tuple[
|
|
584
|
+
AnyBatch,
|
|
585
|
+
AnyDataGenerator,
|
|
586
|
+
DataGeneratorParameter | None,
|
|
587
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
|
|
588
|
+
],
|
|
589
|
+
]:
|
|
563
590
|
"""
|
|
564
591
|
Return the get_batch function that will be used either the jittable one or
|
|
565
|
-
the non-jittable one with sharding
|
|
592
|
+
the non-jittable one with sharding using jax.device.put()
|
|
566
593
|
"""
|
|
567
594
|
|
|
568
595
|
def get_batch_sharding(data, param_data, obs_data):
|
|
569
596
|
"""
|
|
570
597
|
This function is used at each loop but it cannot be jitted because of
|
|
571
598
|
device_put
|
|
572
|
-
|
|
573
|
-
Note: return all that's modified or unwanted dirty undefined behaviour
|
|
574
599
|
"""
|
|
575
|
-
batch = data.get_batch()
|
|
600
|
+
data, batch = data.get_batch()
|
|
576
601
|
if param_data is not None:
|
|
577
|
-
|
|
602
|
+
param_data, param_batch = param_data.get_batch()
|
|
603
|
+
batch = append_param_batch(batch, param_batch)
|
|
578
604
|
if obs_data is not None:
|
|
579
605
|
# This is the part that motivated the transition from scan to for loop
|
|
580
606
|
# Indeed we need to be transit obs_batch from CPU to GPU when we have
|
|
581
607
|
# huge observations that cannot fit on GPU. Such transfer wasn't meant
|
|
582
608
|
# to be jitted, i.e. in a scan loop
|
|
583
|
-
obs_batch =
|
|
609
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
610
|
+
obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
|
|
584
611
|
batch = append_obs_batch(batch, obs_batch)
|
|
585
612
|
return batch, data, param_data, obs_data
|
|
586
613
|
|
|
587
614
|
@jit
|
|
588
615
|
def get_batch(data, param_data, obs_data):
|
|
589
616
|
"""
|
|
590
|
-
Original get_batch with
|
|
591
|
-
|
|
592
|
-
Note: return all that's modified or unwanted dirty undefined behaviour
|
|
617
|
+
Original get_batch with no sharding
|
|
593
618
|
"""
|
|
594
|
-
batch = data.get_batch()
|
|
619
|
+
data, batch = data.get_batch()
|
|
595
620
|
if param_data is not None:
|
|
596
|
-
|
|
621
|
+
param_data, param_batch = param_data.get_batch()
|
|
622
|
+
batch = append_param_batch(batch, param_batch)
|
|
597
623
|
if obs_data is not None:
|
|
598
|
-
|
|
624
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
625
|
+
batch = append_obs_batch(batch, obs_batch)
|
|
599
626
|
return batch, data, param_data, obs_data
|
|
600
627
|
|
|
601
628
|
if obs_batch_sharding is not None:
|
jinns/utils/__init__.py
CHANGED
|
@@ -1,10 +1,4 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
log_euler_maruyama_density,
|
|
5
|
-
)
|
|
6
|
-
from ._pinn import create_PINN
|
|
7
|
-
from ._spinn import create_SPINN
|
|
8
|
-
from ._hyperpinn import create_HYPERPINN
|
|
9
|
-
from ._optim import alternate_optimizer, delayed_optimizer
|
|
1
|
+
from ._pinn import create_PINN, PINN
|
|
2
|
+
from ._spinn import create_SPINN, SPINN
|
|
3
|
+
from ._hyperpinn import create_HYPERPINN, HYPERPINN
|
|
10
4
|
from ._save_load import save_pinn, load_pinn
|
jinns/utils/_containers.py
CHANGED
|
@@ -1,58 +1,51 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
2
|
+
equinox Modules used as containers
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
import
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class ValidationContainer(NamedTuple):
|
|
31
|
-
loss: Union[
|
|
32
|
-
LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
|
|
33
|
-
]
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, Dict
|
|
10
|
+
from jaxtyping import PyTree, Array, Float, Bool
|
|
11
|
+
from optax import OptState
|
|
12
|
+
import equinox as eqx
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from jinns.utils._types import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DataGeneratorContainer(eqx.Module):
|
|
19
|
+
data: AnyDataGenerator
|
|
20
|
+
param_data: DataGeneratorParameter | None = None
|
|
21
|
+
obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
|
|
22
|
+
None
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ValidationContainer(eqx.Module):
|
|
27
|
+
loss: AnyLoss | None
|
|
34
28
|
data: DataGeneratorContainer
|
|
35
29
|
hyperparams: PyTree = None
|
|
36
|
-
loss_values:
|
|
30
|
+
loss_values: Float[Array, "n_iter"] | None = None
|
|
37
31
|
|
|
38
32
|
|
|
39
|
-
class OptimizationContainer(
|
|
40
|
-
params:
|
|
41
|
-
last_non_nan_params:
|
|
42
|
-
opt_state:
|
|
33
|
+
class OptimizationContainer(eqx.Module):
|
|
34
|
+
params: Params
|
|
35
|
+
last_non_nan_params: Params
|
|
36
|
+
opt_state: OptState
|
|
43
37
|
|
|
44
38
|
|
|
45
|
-
class OptimizationExtraContainer(
|
|
39
|
+
class OptimizationExtraContainer(eqx.Module):
|
|
46
40
|
curr_seq: int
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
early_stopping: bool = False
|
|
41
|
+
best_val_params: Params
|
|
42
|
+
early_stopping: Bool = False
|
|
50
43
|
|
|
51
44
|
|
|
52
|
-
class LossContainer(
|
|
53
|
-
stored_loss_terms:
|
|
54
|
-
train_loss_values:
|
|
45
|
+
class LossContainer(eqx.Module):
|
|
46
|
+
stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
|
|
47
|
+
train_loss_values: Float[Array, "n_iter"]
|
|
55
48
|
|
|
56
49
|
|
|
57
|
-
class StoredObjectContainer(
|
|
58
|
-
stored_params:
|
|
50
|
+
class StoredObjectContainer(eqx.Module):
|
|
51
|
+
stored_params: list | None
|