jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +74 -69
- jinns/loss/_LossODE.py +132 -348
- jinns/loss/_LossPDE.py +262 -549
- jinns/loss/__init__.py +32 -6
- jinns/loss/_abstract_loss.py +128 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_components.py +43 -0
- jinns/loss/_loss_utils.py +85 -179
- jinns/loss/_loss_weight_updates.py +202 -0
- jinns/loss/_loss_weights.py +64 -40
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +207 -92
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +16 -10
- jinns/utils/_types.py +20 -54
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
- jinns-1.5.0.dist-info/RECORD +55 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -8,56 +8,76 @@ from __future__ import (
|
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
10
|
import time
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeAlias, Callable
|
|
12
12
|
from functools import partial
|
|
13
13
|
import optax
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import
|
|
17
|
+
from jaxtyping import Float, Array, PyTree, Key
|
|
18
|
+
import equinox as eqx
|
|
18
19
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
19
20
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
21
|
from jinns.solver._utils import _check_batch_size
|
|
21
|
-
from jinns.utils._containers import
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
append_param_batch,
|
|
22
|
+
from jinns.utils._containers import (
|
|
23
|
+
DataGeneratorContainer,
|
|
24
|
+
OptimizationContainer,
|
|
25
|
+
OptimizationExtraContainer,
|
|
26
|
+
LossContainer,
|
|
27
|
+
StoredObjectContainer,
|
|
28
28
|
)
|
|
29
|
+
from jinns.data._utils import append_param_batch, append_obs_batch
|
|
29
30
|
|
|
30
31
|
if TYPE_CHECKING:
|
|
31
|
-
from jinns.
|
|
32
|
+
from jinns.parameters._params import Params
|
|
33
|
+
from jinns.utils._types import AnyBatch
|
|
34
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
35
|
+
from jinns.validation._validation import AbstractValidationModule
|
|
36
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
37
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
38
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
39
|
+
|
|
40
|
+
main_carry: TypeAlias = tuple[
|
|
41
|
+
int,
|
|
42
|
+
AbstractLoss,
|
|
43
|
+
OptimizationContainer,
|
|
44
|
+
OptimizationExtraContainer,
|
|
45
|
+
DataGeneratorContainer,
|
|
46
|
+
AbstractValidationModule | None,
|
|
47
|
+
LossContainer,
|
|
48
|
+
StoredObjectContainer,
|
|
49
|
+
Float[Array, " n_iter"] | None,
|
|
50
|
+
Key | None,
|
|
51
|
+
]
|
|
32
52
|
|
|
33
53
|
|
|
34
54
|
def solve(
|
|
35
|
-
n_iter:
|
|
36
|
-
init_params:
|
|
37
|
-
data:
|
|
38
|
-
loss:
|
|
55
|
+
n_iter: int,
|
|
56
|
+
init_params: Params[Array],
|
|
57
|
+
data: AbstractDataGenerator,
|
|
58
|
+
loss: AbstractLoss,
|
|
39
59
|
optimizer: optax.GradientTransformation,
|
|
40
|
-
print_loss_every:
|
|
41
|
-
opt_state:
|
|
42
|
-
tracked_params: Params |
|
|
60
|
+
print_loss_every: int = 1000,
|
|
61
|
+
opt_state: optax.OptState | None = None,
|
|
62
|
+
tracked_params: Params[Any | None] | None = None,
|
|
43
63
|
param_data: DataGeneratorParameter | None = None,
|
|
44
|
-
obs_data:
|
|
45
|
-
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
|
|
46
|
-
) = None,
|
|
64
|
+
obs_data: DataGeneratorObservations | None = None,
|
|
47
65
|
validation: AbstractValidationModule | None = None,
|
|
48
66
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
49
|
-
verbose:
|
|
50
|
-
ahead_of_time:
|
|
67
|
+
verbose: bool = True,
|
|
68
|
+
ahead_of_time: bool = True,
|
|
69
|
+
key: Key = None,
|
|
51
70
|
) -> tuple[
|
|
52
|
-
Params
|
|
53
|
-
Float[Array, "n_iter"],
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
71
|
+
Params[Array],
|
|
72
|
+
Float[Array, " n_iter"],
|
|
73
|
+
PyTree,
|
|
74
|
+
AbstractDataGenerator,
|
|
75
|
+
AbstractLoss,
|
|
76
|
+
optax.OptState,
|
|
77
|
+
Params[Array | None],
|
|
78
|
+
PyTree,
|
|
79
|
+
Float[Array, " n_iter"] | None,
|
|
80
|
+
Params[Array],
|
|
61
81
|
]:
|
|
62
82
|
"""
|
|
63
83
|
Performs the optimization process via stochastic gradient descent
|
|
@@ -94,8 +114,7 @@ def solve(
|
|
|
94
114
|
Default None. A DataGeneratorParameter object which can be used to
|
|
95
115
|
sample equation parameters.
|
|
96
116
|
obs_data
|
|
97
|
-
Default None. A DataGeneratorObservations
|
|
98
|
-
DataGeneratorObservationsMultiPINNs
|
|
117
|
+
Default None. A DataGeneratorObservations
|
|
99
118
|
object which can be used to sample minibatches of observations.
|
|
100
119
|
validation
|
|
101
120
|
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
@@ -127,6 +146,9 @@ def solve(
|
|
|
127
146
|
transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
|
|
128
147
|
When False, jinns does not provide any timing information (which would
|
|
129
148
|
be nonsense in a JIT transformed `solve()` function).
|
|
149
|
+
key
|
|
150
|
+
Default None. A JAX random key that can be used for diverse purpose in
|
|
151
|
+
the main iteration loop.
|
|
130
152
|
|
|
131
153
|
Returns
|
|
132
154
|
-------
|
|
@@ -136,8 +158,8 @@ def solve(
|
|
|
136
158
|
total_loss_values
|
|
137
159
|
An array of the total loss term along the gradient steps
|
|
138
160
|
stored_loss_terms
|
|
139
|
-
A
|
|
140
|
-
term
|
|
161
|
+
A PyTree with attributes being arrays of all the values for each loss
|
|
162
|
+
term
|
|
141
163
|
data
|
|
142
164
|
The input data object
|
|
143
165
|
loss
|
|
@@ -147,11 +169,19 @@ def solve(
|
|
|
147
169
|
stored_params
|
|
148
170
|
A Params objects with the stored values of the desired parameters (as
|
|
149
171
|
signified in tracked_params argument)
|
|
172
|
+
stored_weights_terms
|
|
173
|
+
A PyTree with attributes being arrays of all the values for each loss
|
|
174
|
+
weight. Note that if Loss.update_weight_method is None, we return None,
|
|
175
|
+
because loss weights are never updated and we can then save some
|
|
176
|
+
computations
|
|
150
177
|
validation_crit_values
|
|
151
178
|
An array containing the validation criterion values of the training
|
|
152
179
|
best_val_params
|
|
153
180
|
The best parameters according to the validation criterion
|
|
154
181
|
"""
|
|
182
|
+
if n_iter < 1:
|
|
183
|
+
raise ValueError("Cannot run jinns.solve for n_iter<1")
|
|
184
|
+
|
|
155
185
|
if param_data is not None:
|
|
156
186
|
if param_data.param_batch_size is not None:
|
|
157
187
|
# We need to check that batch sizes will all be compliant for
|
|
@@ -171,11 +201,21 @@ def solve(
|
|
|
171
201
|
_check_batch_size(obs_data, param_data, "n")
|
|
172
202
|
|
|
173
203
|
if opt_state is None:
|
|
174
|
-
opt_state = optimizer.init(init_params)
|
|
204
|
+
opt_state = optimizer.init(init_params) # type: ignore
|
|
205
|
+
# our Params are eqx.Module (dataclass + PyTree), PyTree is
|
|
206
|
+
# compatible with optax transform but not dataclass, this leads to a
|
|
207
|
+
# type hint error: we could prevent this by ensuring with the eqx.filter that
|
|
208
|
+
# we have only floating points optimizable params given to optax
|
|
209
|
+
# see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
|
|
210
|
+
# opt_state = optimizer.init(
|
|
211
|
+
# eqx.filter(init_params, eqx.is_inexact_array)
|
|
212
|
+
# )
|
|
213
|
+
# but this seems like a hack and there is no better way
|
|
214
|
+
# https://github.com/google-deepmind/optax/issues/384
|
|
175
215
|
|
|
176
216
|
# RAR sampling init (ouside scanned function to avoid dynamic slice error)
|
|
177
217
|
# If RAR is not used the _rar_step_*() are juste None and data is unchanged
|
|
178
|
-
data, _rar_step_true, _rar_step_false = init_rar(data)
|
|
218
|
+
data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
|
|
179
219
|
|
|
180
220
|
# Seq2seq
|
|
181
221
|
curr_seq = 0
|
|
@@ -207,11 +247,38 @@ def solve(
|
|
|
207
247
|
# being a complex data structure
|
|
208
248
|
)
|
|
209
249
|
|
|
210
|
-
# initialize the
|
|
250
|
+
# initialize the PyTree for stored loss values
|
|
211
251
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
212
252
|
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
213
253
|
)
|
|
214
254
|
|
|
255
|
+
# initialize the PyTree for stored loss weights values
|
|
256
|
+
if loss.update_weight_method is not None:
|
|
257
|
+
stored_weights_terms = eqx.tree_at(
|
|
258
|
+
lambda pt: jax.tree.leaves(
|
|
259
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
260
|
+
),
|
|
261
|
+
loss.loss_weights,
|
|
262
|
+
tuple(
|
|
263
|
+
jnp.zeros((n_iter))
|
|
264
|
+
for n in range(
|
|
265
|
+
len(
|
|
266
|
+
jax.tree.leaves(
|
|
267
|
+
loss.loss_weights,
|
|
268
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
stored_weights_terms = None
|
|
276
|
+
if loss.update_weight_method is not None and key is None:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"`key` argument must be passed to jinns.solve when"
|
|
279
|
+
" `loss.update_weight_method` is not None"
|
|
280
|
+
)
|
|
281
|
+
|
|
215
282
|
train_data = DataGeneratorContainer(
|
|
216
283
|
data=data, param_data=param_data, obs_data=obs_data
|
|
217
284
|
)
|
|
@@ -228,6 +295,7 @@ def solve(
|
|
|
228
295
|
)
|
|
229
296
|
loss_container = LossContainer(
|
|
230
297
|
stored_loss_terms=stored_loss_terms,
|
|
298
|
+
stored_weights_terms=stored_weights_terms,
|
|
231
299
|
train_loss_values=train_loss_values,
|
|
232
300
|
)
|
|
233
301
|
stored_objects = StoredObjectContainer(
|
|
@@ -252,6 +320,7 @@ def solve(
|
|
|
252
320
|
loss_container,
|
|
253
321
|
stored_objects,
|
|
254
322
|
validation_crit_values,
|
|
323
|
+
key,
|
|
255
324
|
)
|
|
256
325
|
|
|
257
326
|
def _one_iteration(carry: main_carry) -> main_carry:
|
|
@@ -265,24 +334,47 @@ def solve(
|
|
|
265
334
|
loss_container,
|
|
266
335
|
stored_objects,
|
|
267
336
|
validation_crit_values,
|
|
337
|
+
key,
|
|
268
338
|
) = carry
|
|
269
339
|
|
|
270
340
|
batch, data, param_data, obs_data = get_batch(
|
|
271
341
|
train_data.data, train_data.param_data, train_data.obs_data
|
|
272
342
|
)
|
|
273
343
|
|
|
274
|
-
#
|
|
344
|
+
# ---------------------------------------------------------------------
|
|
345
|
+
# The following part is the equivalent of a
|
|
346
|
+
# > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
|
|
347
|
+
# but it is decomposed on individual loss terms so that we can use it
|
|
348
|
+
# if needed for updating loss weights.
|
|
349
|
+
# Since the total loss is a weighted sum of individual loss terms, so
|
|
350
|
+
# are its total gradients.
|
|
351
|
+
|
|
352
|
+
# Compute individual losses and individual gradients
|
|
353
|
+
loss_terms, grad_terms = loss.evaluate_by_terms(optimization.params, batch)
|
|
354
|
+
|
|
355
|
+
if loss.update_weight_method is not None:
|
|
356
|
+
key, subkey = jax.random.split(key) # type: ignore because key can
|
|
357
|
+
# still be None currently
|
|
358
|
+
# avoid computations of tree_at if no updates
|
|
359
|
+
loss = loss.update_weights(
|
|
360
|
+
i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# total grad
|
|
364
|
+
grads = loss.ponderate_and_sum_gradient(grad_terms)
|
|
365
|
+
|
|
366
|
+
# total loss
|
|
367
|
+
train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
|
|
368
|
+
# ---------------------------------------------------------------------
|
|
369
|
+
|
|
370
|
+
# gradient step
|
|
275
371
|
(
|
|
276
|
-
loss,
|
|
277
|
-
train_loss_value,
|
|
278
|
-
loss_terms,
|
|
279
372
|
params,
|
|
280
373
|
opt_state,
|
|
281
374
|
last_non_nan_params,
|
|
282
375
|
) = _gradient_step(
|
|
283
|
-
|
|
376
|
+
grads,
|
|
284
377
|
optimizer,
|
|
285
|
-
batch,
|
|
286
378
|
optimization.params,
|
|
287
379
|
optimization.opt_state,
|
|
288
380
|
optimization.last_non_nan_params,
|
|
@@ -292,7 +384,7 @@ def solve(
|
|
|
292
384
|
if verbose:
|
|
293
385
|
_print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
294
386
|
|
|
295
|
-
if validation is not None:
|
|
387
|
+
if validation is not None and validation_crit_values is not None:
|
|
296
388
|
# there is a jax.lax.cond because we do not necesarily call the
|
|
297
389
|
# validation step every iteration
|
|
298
390
|
(
|
|
@@ -306,7 +398,7 @@ def solve(
|
|
|
306
398
|
lambda operands: (
|
|
307
399
|
operands[0],
|
|
308
400
|
False,
|
|
309
|
-
validation_crit_values[i - 1],
|
|
401
|
+
validation_crit_values[i - 1], # type: ignore don't know why it can still be None
|
|
310
402
|
False,
|
|
311
403
|
),
|
|
312
404
|
(
|
|
@@ -350,14 +442,14 @@ def solve(
|
|
|
350
442
|
)
|
|
351
443
|
|
|
352
444
|
# save loss value and selected parameters
|
|
353
|
-
|
|
445
|
+
stored_objects, loss_container = _store_loss_and_params(
|
|
354
446
|
i,
|
|
355
447
|
params,
|
|
356
448
|
stored_objects.stored_params,
|
|
357
|
-
loss_container
|
|
358
|
-
loss_container.train_loss_values,
|
|
449
|
+
loss_container,
|
|
359
450
|
train_loss_value,
|
|
360
451
|
loss_terms,
|
|
452
|
+
loss.loss_weights,
|
|
361
453
|
tracked_params,
|
|
362
454
|
)
|
|
363
455
|
|
|
@@ -377,9 +469,10 @@ def solve(
|
|
|
377
469
|
),
|
|
378
470
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
379
471
|
validation,
|
|
380
|
-
|
|
381
|
-
|
|
472
|
+
loss_container,
|
|
473
|
+
stored_objects,
|
|
382
474
|
validation_crit_values,
|
|
475
|
+
key,
|
|
383
476
|
)
|
|
384
477
|
|
|
385
478
|
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
@@ -419,6 +512,7 @@ def solve(
|
|
|
419
512
|
loss_container,
|
|
420
513
|
stored_objects,
|
|
421
514
|
validation_crit_values,
|
|
515
|
+
key,
|
|
422
516
|
) = carry
|
|
423
517
|
|
|
424
518
|
if verbose:
|
|
@@ -431,7 +525,7 @@ def solve(
|
|
|
431
525
|
# get ready to return the parameters at last iteration...
|
|
432
526
|
# (by default arbitrary choice, this could be None)
|
|
433
527
|
validation_parameters = optimization.last_non_nan_params
|
|
434
|
-
if validation is not None:
|
|
528
|
+
if validation is not None and validation_crit_values is not None:
|
|
435
529
|
jax.debug.print(
|
|
436
530
|
"validation loss value = {validation_loss_val}",
|
|
437
531
|
validation_loss_val=validation_crit_values[i - 1],
|
|
@@ -456,6 +550,7 @@ def solve(
|
|
|
456
550
|
loss, # return the Loss if needed (no-inplace modif)
|
|
457
551
|
optimization.opt_state,
|
|
458
552
|
stored_objects.stored_params,
|
|
553
|
+
loss_container.stored_weights_terms,
|
|
459
554
|
validation_crit_values if validation is not None else None,
|
|
460
555
|
validation_parameters,
|
|
461
556
|
)
|
|
@@ -463,27 +558,26 @@ def solve(
|
|
|
463
558
|
|
|
464
559
|
@partial(jit, static_argnames=["optimizer"])
|
|
465
560
|
def _gradient_step(
|
|
466
|
-
|
|
561
|
+
grads: Params[Array],
|
|
467
562
|
optimizer: optax.GradientTransformation,
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
last_non_nan_params: AnyParams,
|
|
563
|
+
params: Params[Array],
|
|
564
|
+
opt_state: optax.OptState,
|
|
565
|
+
last_non_nan_params: Params[Array],
|
|
472
566
|
) -> tuple[
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
AnyParams,
|
|
477
|
-
NamedTuple,
|
|
478
|
-
AnyParams,
|
|
567
|
+
Params[Array],
|
|
568
|
+
optax.OptState,
|
|
569
|
+
Params[Array],
|
|
479
570
|
]:
|
|
480
571
|
"""
|
|
481
572
|
optimizer cannot be jit-ted.
|
|
482
573
|
"""
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
574
|
+
|
|
575
|
+
updates, opt_state = optimizer.update(
|
|
576
|
+
grads, # type: ignore
|
|
577
|
+
opt_state,
|
|
578
|
+
params, # type: ignore
|
|
579
|
+
) # see optimizer.init for explaination for the ignore(s) here
|
|
580
|
+
params = optax.apply_updates(params, updates) # type: ignore
|
|
487
581
|
|
|
488
582
|
# check if any of the parameters is NaN
|
|
489
583
|
last_non_nan_params = jax.lax.cond(
|
|
@@ -494,9 +588,6 @@ def _gradient_step(
|
|
|
494
588
|
)
|
|
495
589
|
|
|
496
590
|
return (
|
|
497
|
-
loss,
|
|
498
|
-
loss_val,
|
|
499
|
-
loss_terms,
|
|
500
591
|
params,
|
|
501
592
|
opt_state,
|
|
502
593
|
last_non_nan_params,
|
|
@@ -504,7 +595,7 @@ def _gradient_step(
|
|
|
504
595
|
|
|
505
596
|
|
|
506
597
|
@partial(jit, static_argnames=["prefix"])
|
|
507
|
-
def _print_fn(i:
|
|
598
|
+
def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
508
599
|
# note that if the following is not jitted in the main lor loop, it is
|
|
509
600
|
# super slow
|
|
510
601
|
_ = jax.lax.cond(
|
|
@@ -521,17 +612,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
|
|
|
521
612
|
|
|
522
613
|
@jit
|
|
523
614
|
def _store_loss_and_params(
|
|
524
|
-
i:
|
|
525
|
-
params:
|
|
526
|
-
stored_params:
|
|
527
|
-
|
|
528
|
-
train_loss_values: Float[Array, "n_iter"],
|
|
615
|
+
i: int,
|
|
616
|
+
params: Params[Array],
|
|
617
|
+
stored_params: Params[Array | None],
|
|
618
|
+
loss_container: LossContainer,
|
|
529
619
|
train_loss_val: float,
|
|
530
|
-
loss_terms:
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
]:
|
|
620
|
+
loss_terms: PyTree[Array],
|
|
621
|
+
weight_terms: PyTree[Array],
|
|
622
|
+
tracked_params: Params,
|
|
623
|
+
) -> tuple[StoredObjectContainer, LossContainer]:
|
|
535
624
|
stored_params = jax.tree_util.tree_map(
|
|
536
625
|
lambda stored_value, param, tracked_param: (
|
|
537
626
|
None
|
|
@@ -550,15 +639,41 @@ def _store_loss_and_params(
|
|
|
550
639
|
)
|
|
551
640
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
552
641
|
lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
|
|
553
|
-
stored_loss_terms,
|
|
642
|
+
loss_container.stored_loss_terms,
|
|
554
643
|
loss_terms,
|
|
555
644
|
)
|
|
556
645
|
|
|
557
|
-
|
|
558
|
-
|
|
646
|
+
if loss_container.stored_weights_terms is not None:
|
|
647
|
+
stored_weights_terms = jax.tree_util.tree_map(
|
|
648
|
+
lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
|
|
649
|
+
jax.tree.leaves(
|
|
650
|
+
loss_container.stored_weights_terms,
|
|
651
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
652
|
+
),
|
|
653
|
+
jax.tree.leaves(
|
|
654
|
+
weight_terms,
|
|
655
|
+
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
656
|
+
),
|
|
657
|
+
)
|
|
658
|
+
stored_weights_terms = eqx.tree_at(
|
|
659
|
+
lambda pt: jax.tree.leaves(
|
|
660
|
+
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
661
|
+
),
|
|
662
|
+
loss_container.stored_weights_terms,
|
|
663
|
+
stored_weights_terms,
|
|
664
|
+
)
|
|
665
|
+
else:
|
|
666
|
+
stored_weights_terms = None
|
|
667
|
+
|
|
668
|
+
train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
|
|
669
|
+
loss_container = LossContainer(
|
|
670
|
+
stored_loss_terms, stored_weights_terms, train_loss_values
|
|
671
|
+
)
|
|
672
|
+
stored_objects = StoredObjectContainer(stored_params)
|
|
673
|
+
return stored_objects, loss_container
|
|
559
674
|
|
|
560
675
|
|
|
561
|
-
def _get_break_fun(n_iter:
|
|
676
|
+
def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
562
677
|
"""
|
|
563
678
|
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
564
679
|
The verbose argument is here to control printing (or not) when exiting
|
|
@@ -586,7 +701,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
586
701
|
def continue_while_loop(_):
|
|
587
702
|
return True
|
|
588
703
|
|
|
589
|
-
(i, _, optimization, optimization_extra, _, _, _, _, _) = carry
|
|
704
|
+
(i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
|
|
590
705
|
|
|
591
706
|
# Condition 1
|
|
592
707
|
bool_max_iter = jax.lax.cond(
|
|
@@ -599,7 +714,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
599
714
|
bool_nan_in_params = jax.lax.cond(
|
|
600
715
|
_check_nan_in_pytree(optimization.params),
|
|
601
716
|
lambda _: stop_while_loop(
|
|
602
|
-
"NaN values in parameters
|
|
717
|
+
"NaN values in parameters (returning last non NaN values)"
|
|
603
718
|
),
|
|
604
719
|
continue_while_loop,
|
|
605
720
|
None,
|
|
@@ -622,18 +737,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
622
737
|
|
|
623
738
|
|
|
624
739
|
def _get_get_batch(
|
|
625
|
-
obs_batch_sharding: jax.sharding.Sharding,
|
|
740
|
+
obs_batch_sharding: jax.sharding.Sharding | None,
|
|
626
741
|
) -> Callable[
|
|
627
742
|
[
|
|
628
|
-
|
|
743
|
+
AbstractDataGenerator,
|
|
629
744
|
DataGeneratorParameter | None,
|
|
630
|
-
DataGeneratorObservations |
|
|
745
|
+
DataGeneratorObservations | None,
|
|
631
746
|
],
|
|
632
747
|
tuple[
|
|
633
748
|
AnyBatch,
|
|
634
|
-
|
|
749
|
+
AbstractDataGenerator,
|
|
635
750
|
DataGeneratorParameter | None,
|
|
636
|
-
DataGeneratorObservations |
|
|
751
|
+
DataGeneratorObservations | None,
|
|
637
752
|
],
|
|
638
753
|
]:
|
|
639
754
|
"""
|
jinns/solver/_utils.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
from jinns.data.
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
DataGeneratorParameter,
|
|
6
|
-
)
|
|
1
|
+
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
2
|
+
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
|
+
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
def _check_batch_size(other_data, main_data, attr_name):
|
jinns/utils/__init__.py
CHANGED
jinns/utils/_containers.py
CHANGED
|
@@ -6,28 +6,31 @@ from __future__ import (
|
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
|
-
from typing import TYPE_CHECKING
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
10
|
from jaxtyping import PyTree, Array, Float, Bool
|
|
11
11
|
from optax import OptState
|
|
12
12
|
import equinox as eqx
|
|
13
13
|
|
|
14
|
+
from jinns.parameters._params import Params
|
|
15
|
+
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
|
-
from jinns.
|
|
17
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
18
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
19
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
20
|
+
from jinns.utils._types import AnyLoss
|
|
16
21
|
|
|
17
22
|
|
|
18
23
|
class DataGeneratorContainer(eqx.Module):
|
|
19
|
-
data:
|
|
24
|
+
data: AbstractDataGenerator
|
|
20
25
|
param_data: DataGeneratorParameter | None = None
|
|
21
|
-
obs_data: DataGeneratorObservations |
|
|
22
|
-
None
|
|
23
|
-
)
|
|
26
|
+
obs_data: DataGeneratorObservations | None = None
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
class ValidationContainer(eqx.Module):
|
|
27
30
|
loss: AnyLoss | None
|
|
28
31
|
data: DataGeneratorContainer
|
|
29
32
|
hyperparams: PyTree = None
|
|
30
|
-
loss_values: Float[Array, "n_iter"] | None = None
|
|
33
|
+
loss_values: Float[Array, " n_iter"] | None = None
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class OptimizationContainer(eqx.Module):
|
|
@@ -45,9 +48,12 @@ class OptimizationExtraContainer(eqx.Module):
|
|
|
45
48
|
|
|
46
49
|
|
|
47
50
|
class LossContainer(eqx.Module):
|
|
48
|
-
|
|
49
|
-
|
|
51
|
+
# PyTree below refers to ODEComponents or PDEStatioComponents or
|
|
52
|
+
# PDENonStatioComponents
|
|
53
|
+
stored_loss_terms: PyTree[Float[Array, " n_iter"]]
|
|
54
|
+
stored_weights_terms: PyTree[Float[Array, " n_iter"]]
|
|
55
|
+
train_loss_values: Float[Array, " n_iter"]
|
|
50
56
|
|
|
51
57
|
|
|
52
58
|
class StoredObjectContainer(eqx.Module):
|
|
53
|
-
stored_params:
|
|
59
|
+
stored_params: Params[Array | None]
|