jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -8,17 +8,23 @@ from __future__ import (
|
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
10
|
import time
|
|
11
|
-
from typing import TYPE_CHECKING, Any
|
|
12
|
-
from functools import partial
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
13
12
|
import optax
|
|
14
13
|
import jax
|
|
15
|
-
from jax import jit
|
|
16
14
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import Float, Array,
|
|
18
|
-
import equinox as eqx
|
|
15
|
+
from jaxtyping import Float, Array, PRNGKeyArray
|
|
19
16
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
20
|
-
from jinns.
|
|
21
|
-
|
|
17
|
+
from jinns.solver._utils import (
|
|
18
|
+
_check_batch_size,
|
|
19
|
+
_init_stored_weights_terms,
|
|
20
|
+
_init_stored_params,
|
|
21
|
+
_get_break_fun,
|
|
22
|
+
_loss_evaluate_and_gradient_step,
|
|
23
|
+
_build_get_batch,
|
|
24
|
+
_store_loss_and_params,
|
|
25
|
+
_print_fn,
|
|
26
|
+
)
|
|
27
|
+
from jinns.parameters._params import Params
|
|
22
28
|
from jinns.utils._containers import (
|
|
23
29
|
DataGeneratorContainer,
|
|
24
30
|
OptimizationContainer,
|
|
@@ -26,32 +32,18 @@ from jinns.utils._containers import (
|
|
|
26
32
|
LossContainer,
|
|
27
33
|
StoredObjectContainer,
|
|
28
34
|
)
|
|
29
|
-
from jinns.data._utils import append_param_batch, append_obs_batch
|
|
30
35
|
|
|
31
36
|
if TYPE_CHECKING:
|
|
32
|
-
from jinns.
|
|
33
|
-
from jinns.utils._types import AnyBatch
|
|
37
|
+
from jinns.utils._types import AnyLossComponents, SolveCarry
|
|
34
38
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
35
39
|
from jinns.validation._validation import AbstractValidationModule
|
|
36
40
|
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
37
41
|
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
38
42
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
39
43
|
|
|
40
|
-
main_carry: TypeAlias = tuple[
|
|
41
|
-
int,
|
|
42
|
-
AbstractLoss,
|
|
43
|
-
OptimizationContainer,
|
|
44
|
-
OptimizationExtraContainer,
|
|
45
|
-
DataGeneratorContainer,
|
|
46
|
-
AbstractValidationModule | None,
|
|
47
|
-
LossContainer,
|
|
48
|
-
StoredObjectContainer,
|
|
49
|
-
Float[Array, " n_iter"] | None,
|
|
50
|
-
PRNGKeyArray | None,
|
|
51
|
-
]
|
|
52
|
-
|
|
53
44
|
|
|
54
45
|
def solve(
|
|
46
|
+
*,
|
|
55
47
|
n_iter: int,
|
|
56
48
|
init_params: Params[Array],
|
|
57
49
|
data: AbstractDataGenerator,
|
|
@@ -64,24 +56,27 @@ def solve(
|
|
|
64
56
|
obs_data: DataGeneratorObservations | None = None,
|
|
65
57
|
validation: AbstractValidationModule | None = None,
|
|
66
58
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
59
|
+
opt_state_field_for_acceleration: str | None = None,
|
|
67
60
|
verbose: bool = True,
|
|
68
61
|
ahead_of_time: bool = True,
|
|
69
62
|
key: PRNGKeyArray | None = None,
|
|
70
63
|
) -> tuple[
|
|
71
64
|
Params[Array],
|
|
72
65
|
Float[Array, " n_iter"],
|
|
73
|
-
|
|
66
|
+
AnyLossComponents[Float[Array, " n_iter"]],
|
|
74
67
|
AbstractDataGenerator,
|
|
75
68
|
AbstractLoss,
|
|
76
69
|
optax.OptState,
|
|
77
70
|
Params[Array | None],
|
|
78
|
-
|
|
71
|
+
AnyLossComponents[Float[Array, " n_iter"]],
|
|
72
|
+
DataGeneratorObservations | None,
|
|
73
|
+
DataGeneratorParameter | None,
|
|
79
74
|
Float[Array, " n_iter"] | None,
|
|
80
|
-
Params[Array],
|
|
75
|
+
Params[Array] | None,
|
|
81
76
|
]:
|
|
82
77
|
"""
|
|
83
78
|
Performs the optimization process via stochastic gradient descent
|
|
84
|
-
algorithm. We minimize the function defined `loss.evaluate()` with
|
|
79
|
+
algorithm. We minimize the function defined in `loss.evaluate()` with
|
|
85
80
|
respect to the learnable parameters of the problem whose initial values
|
|
86
81
|
are given in `init_params`.
|
|
87
82
|
|
|
@@ -91,9 +86,9 @@ def solve(
|
|
|
91
86
|
n_iter
|
|
92
87
|
The maximum number of iterations in the optimization.
|
|
93
88
|
init_params
|
|
94
|
-
The initial jinns.parameters.Params object.
|
|
89
|
+
The initial `jinns.parameters.Params` object.
|
|
95
90
|
data
|
|
96
|
-
A
|
|
91
|
+
A `jinns.data.AbstractDataGenerator` object to retrieve batches of collocation points.
|
|
97
92
|
loss
|
|
98
93
|
The loss function to minimize.
|
|
99
94
|
optimizer
|
|
@@ -102,22 +97,21 @@ def solve(
|
|
|
102
97
|
Default 1000. The rate at which we print the loss value in the
|
|
103
98
|
gradient step loop.
|
|
104
99
|
opt_state
|
|
105
|
-
|
|
100
|
+
Default `None`. Provides an optional initial state to the optimizer.
|
|
106
101
|
tracked_params
|
|
107
|
-
Default None
|
|
102
|
+
Default `None`. A `jinns.parameters.Params` object with non-`None` values for
|
|
108
103
|
parameters that needs to be tracked along the iterations.
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
104
|
+
The user can provide something like `tracked_params = jinns.parameters.Params(
|
|
105
|
+
nn_params=None, eq_params={"nu": True})` while `init_params.nn_params`
|
|
112
106
|
being a complex data structure.
|
|
113
107
|
param_data
|
|
114
|
-
Default None
|
|
108
|
+
Default `None`. A `jinns.data.DataGeneratorParameter` object which can be used to
|
|
115
109
|
sample equation parameters.
|
|
116
110
|
obs_data
|
|
117
|
-
Default None
|
|
111
|
+
Default `None`. A `jinns.data.DataGeneratorObservations`
|
|
118
112
|
object which can be used to sample minibatches of observations.
|
|
119
113
|
validation
|
|
120
|
-
Default None
|
|
114
|
+
Default `None`. Otherwise, a callable `eqx.Module` which implements a
|
|
121
115
|
validation strategy. See documentation of `jinns.validation.
|
|
122
116
|
_validation.AbstractValidationModule` for the general interface, and
|
|
123
117
|
`jinns.validation._validation.ValidationLoss` for a practical
|
|
@@ -131,53 +125,70 @@ def solve(
|
|
|
131
125
|
validation strategy of their choice, and to decide on the early
|
|
132
126
|
stopping criterion.
|
|
133
127
|
obs_batch_sharding
|
|
134
|
-
Default None
|
|
135
|
-
|
|
136
|
-
|
|
128
|
+
Default `None`. An optional sharding object to constraint the
|
|
129
|
+
`obs_batch`.
|
|
130
|
+
Typically, a `SingleDeviceSharding(gpu_device)` when `obs_data` has been
|
|
131
|
+
created with `sharding_device=SingleDeviceSharding(cpu_device)` to avoid
|
|
137
132
|
loading on GPU huge datasets of observations.
|
|
133
|
+
opt_state_field_for_acceleration
|
|
134
|
+
A string. Default `None`, i.e. the optimizer without acceleration.
|
|
135
|
+
Because in some optimization scheme one can have what is called
|
|
136
|
+
acceleration where the loss is computed at some accelerated parameter
|
|
137
|
+
values, different from the actual parameter values. These accelerated
|
|
138
|
+
parameter can be stored in the optimizer state as a field. If this
|
|
139
|
+
field name is passed to `opt_state_field_for_acceleration` then the
|
|
140
|
+
gradient step will be done by evaluate gradients at parameter value
|
|
141
|
+
`opt_state.opt_state_field_for_acceleration`.
|
|
138
142
|
verbose
|
|
139
|
-
Default True
|
|
143
|
+
Default `True`. If `False`, no output (loss or cause of
|
|
140
144
|
exiting the optimization loop) will be produced.
|
|
141
145
|
ahead_of_time
|
|
142
|
-
Default True
|
|
146
|
+
Default `True`. Separate the compilation of the main training loop from
|
|
143
147
|
the execution to get both timings. You might need to avoid this
|
|
144
148
|
behaviour if you need to perform JAX transforms over chunks of code
|
|
145
149
|
containing `jinns.solve()` since AOT-compiled functions cannot be JAX
|
|
146
150
|
transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
|
|
147
|
-
When False
|
|
151
|
+
When `False`, jinns does not provide any timing information (which would
|
|
148
152
|
be nonsense in a JIT transformed `solve()` function).
|
|
149
153
|
key
|
|
150
|
-
Default None
|
|
154
|
+
Default `None`. A JAX random key that can be used for diverse purpose in
|
|
151
155
|
the main iteration loop.
|
|
152
156
|
|
|
153
157
|
Returns
|
|
154
158
|
-------
|
|
155
159
|
params
|
|
156
|
-
The last non
|
|
157
|
-
optimization process
|
|
160
|
+
The last non-NaN value of the params at then end of the
|
|
161
|
+
optimization process.
|
|
158
162
|
total_loss_values
|
|
159
|
-
An array of the total loss term along the gradient steps
|
|
163
|
+
An array of the total loss term along the gradient steps.
|
|
160
164
|
stored_loss_terms
|
|
161
165
|
A PyTree with attributes being arrays of all the values for each loss
|
|
162
|
-
term
|
|
166
|
+
term.
|
|
163
167
|
data
|
|
164
|
-
The input
|
|
168
|
+
The data generator object passed as input, possibly modified.
|
|
165
169
|
loss
|
|
166
|
-
The input
|
|
170
|
+
The loss object passed as input, possibly modified.
|
|
167
171
|
opt_state
|
|
168
|
-
The final optimized state
|
|
172
|
+
The final optimized state.
|
|
169
173
|
stored_params
|
|
170
|
-
A
|
|
171
|
-
signified in tracked_params argument)
|
|
174
|
+
A object with the stored values of the desired parameters (as
|
|
175
|
+
signified in `tracked_params` argument).
|
|
172
176
|
stored_weights_terms
|
|
173
|
-
A PyTree with
|
|
174
|
-
weight. Note that if Loss.update_weight_method is None
|
|
177
|
+
A PyTree with leaves being arrays of all the values for each loss
|
|
178
|
+
weight. Note that if `Loss.update_weight_method is None`, we return
|
|
179
|
+
`None`,
|
|
175
180
|
because loss weights are never updated and we can then save some
|
|
176
|
-
computations
|
|
181
|
+
computations.
|
|
182
|
+
param_data
|
|
183
|
+
The `jinns.data.DataGeneratorParameter` object passed as input or
|
|
184
|
+
`None`.
|
|
185
|
+
obs_data
|
|
186
|
+
The `jinns.data.DataGeneratorObservations` object passed as input or
|
|
187
|
+
`None`.
|
|
177
188
|
validation_crit_values
|
|
178
|
-
An array containing the validation criterion values of the training
|
|
189
|
+
An array containing the validation criterion values of the training.
|
|
179
190
|
best_val_params
|
|
180
|
-
The best parameters according to the validation criterion
|
|
191
|
+
The best parameters according to the validation criterion.
|
|
181
192
|
"""
|
|
182
193
|
initialization_time = time.time()
|
|
183
194
|
if n_iter < 1:
|
|
@@ -224,24 +235,12 @@ def solve(
|
|
|
224
235
|
train_loss_values = jnp.zeros((n_iter))
|
|
225
236
|
# depending on obs_batch_sharding we will get the simple get_batch or the
|
|
226
237
|
# get_batch with device_put, the latter is not jittable
|
|
227
|
-
get_batch =
|
|
238
|
+
get_batch = _build_get_batch(obs_batch_sharding)
|
|
228
239
|
|
|
229
240
|
# initialize parameter tracking
|
|
230
241
|
if tracked_params is None:
|
|
231
242
|
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
232
|
-
stored_params =
|
|
233
|
-
lambda tracked_param, param: (
|
|
234
|
-
jnp.zeros((n_iter,) + jnp.asarray(param).shape)
|
|
235
|
-
if tracked_param is not None
|
|
236
|
-
else None
|
|
237
|
-
),
|
|
238
|
-
tracked_params,
|
|
239
|
-
init_params,
|
|
240
|
-
is_leaf=lambda x: x is None, # None values in tracked_params will not
|
|
241
|
-
# be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
|
|
242
|
-
# nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
243
|
-
# being a complex data structure
|
|
244
|
-
)
|
|
243
|
+
stored_params = _init_stored_params(tracked_params, init_params, n_iter)
|
|
245
244
|
|
|
246
245
|
# initialize the dict for stored parameter values
|
|
247
246
|
# we need to get a loss_term to init stuff
|
|
@@ -257,23 +256,7 @@ def solve(
|
|
|
257
256
|
|
|
258
257
|
# initialize the PyTree for stored loss weights values
|
|
259
258
|
if loss.update_weight_method is not None:
|
|
260
|
-
stored_weights_terms =
|
|
261
|
-
lambda pt: jax.tree.leaves(
|
|
262
|
-
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
263
|
-
),
|
|
264
|
-
loss.loss_weights,
|
|
265
|
-
tuple(
|
|
266
|
-
jnp.zeros((n_iter))
|
|
267
|
-
for n in range(
|
|
268
|
-
len(
|
|
269
|
-
jax.tree.leaves(
|
|
270
|
-
loss.loss_weights,
|
|
271
|
-
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
272
|
-
)
|
|
273
|
-
)
|
|
274
|
-
)
|
|
275
|
-
),
|
|
276
|
-
)
|
|
259
|
+
stored_weights_terms = _init_stored_weights_terms(loss, n_iter)
|
|
277
260
|
else:
|
|
278
261
|
stored_weights_terms = None
|
|
279
262
|
if loss.update_weight_method is not None and key is None:
|
|
@@ -289,6 +272,7 @@ def solve(
|
|
|
289
272
|
params=init_params,
|
|
290
273
|
last_non_nan_params=init_params,
|
|
291
274
|
opt_state=opt_state,
|
|
275
|
+
# params_mask=params_mask,
|
|
292
276
|
)
|
|
293
277
|
optimization_extra = OptimizationExtraContainer(
|
|
294
278
|
curr_seq=curr_seq,
|
|
@@ -326,7 +310,11 @@ def solve(
|
|
|
326
310
|
key,
|
|
327
311
|
)
|
|
328
312
|
|
|
329
|
-
def _one_iteration(carry:
|
|
313
|
+
def _one_iteration(carry: SolveCarry) -> SolveCarry:
|
|
314
|
+
# Note that optimizer are not part of the carry since
|
|
315
|
+
# the former is not tractable and the latter (while it could be
|
|
316
|
+
# hashable) must be static because of the equinox `filter_spec` (https://github.com/patrick-kidger/equinox/issues/1036)
|
|
317
|
+
|
|
330
318
|
(
|
|
331
319
|
i,
|
|
332
320
|
loss,
|
|
@@ -344,43 +332,24 @@ def solve(
|
|
|
344
332
|
train_data.data, train_data.param_data, train_data.obs_data
|
|
345
333
|
)
|
|
346
334
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
335
|
+
if key is not None:
|
|
336
|
+
key, subkey = jax.random.split(key)
|
|
337
|
+
else:
|
|
338
|
+
subkey = None
|
|
339
|
+
(train_loss_value, params, last_non_nan_params, opt_state, loss, loss_terms) = (
|
|
340
|
+
_loss_evaluate_and_gradient_step(
|
|
341
|
+
i,
|
|
342
|
+
batch,
|
|
343
|
+
loss,
|
|
344
|
+
optimization.params,
|
|
345
|
+
optimization.last_non_nan_params,
|
|
346
|
+
optimization.opt_state,
|
|
347
|
+
optimizer,
|
|
348
|
+
loss_container,
|
|
349
|
+
subkey,
|
|
350
|
+
None,
|
|
351
|
+
opt_state_field_for_acceleration,
|
|
364
352
|
)
|
|
365
|
-
|
|
366
|
-
# total grad
|
|
367
|
-
grads = loss.ponderate_and_sum_gradient(grad_terms)
|
|
368
|
-
|
|
369
|
-
# total loss
|
|
370
|
-
train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
|
|
371
|
-
# ---------------------------------------------------------------------
|
|
372
|
-
|
|
373
|
-
# gradient step
|
|
374
|
-
(
|
|
375
|
-
params,
|
|
376
|
-
opt_state,
|
|
377
|
-
last_non_nan_params,
|
|
378
|
-
) = _gradient_step(
|
|
379
|
-
grads,
|
|
380
|
-
optimizer,
|
|
381
|
-
optimization.params,
|
|
382
|
-
optimization.opt_state,
|
|
383
|
-
optimization.last_non_nan_params,
|
|
384
353
|
)
|
|
385
354
|
|
|
386
355
|
# Print train loss value during optimization
|
|
@@ -462,7 +431,9 @@ def solve(
|
|
|
462
431
|
return (
|
|
463
432
|
i,
|
|
464
433
|
loss,
|
|
465
|
-
OptimizationContainer(
|
|
434
|
+
OptimizationContainer(
|
|
435
|
+
params, last_non_nan_params, opt_state
|
|
436
|
+
), # , params_mask),
|
|
466
437
|
OptimizationExtraContainer(
|
|
467
438
|
curr_seq,
|
|
468
439
|
best_iter_id,
|
|
@@ -552,249 +523,13 @@ def solve(
|
|
|
552
523
|
optimization.last_non_nan_params,
|
|
553
524
|
loss_container.train_loss_values,
|
|
554
525
|
loss_container.stored_loss_terms,
|
|
555
|
-
train_data.data,
|
|
556
|
-
loss,
|
|
526
|
+
train_data.data,
|
|
527
|
+
loss,
|
|
557
528
|
optimization.opt_state,
|
|
558
529
|
stored_objects.stored_params,
|
|
559
530
|
loss_container.stored_weights_terms,
|
|
531
|
+
train_data.obs_data,
|
|
532
|
+
train_data.param_data,
|
|
560
533
|
validation_crit_values if validation is not None else None,
|
|
561
534
|
validation_parameters,
|
|
562
535
|
)
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
@partial(jit, static_argnames=["optimizer"])
|
|
566
|
-
def _gradient_step(
|
|
567
|
-
grads: Params[Array],
|
|
568
|
-
optimizer: optax.GradientTransformation,
|
|
569
|
-
params: Params[Array],
|
|
570
|
-
opt_state: optax.OptState,
|
|
571
|
-
last_non_nan_params: Params[Array],
|
|
572
|
-
) -> tuple[
|
|
573
|
-
Params[Array],
|
|
574
|
-
optax.OptState,
|
|
575
|
-
Params[Array],
|
|
576
|
-
]:
|
|
577
|
-
"""
|
|
578
|
-
optimizer cannot be jit-ted.
|
|
579
|
-
"""
|
|
580
|
-
|
|
581
|
-
updates, opt_state = optimizer.update(
|
|
582
|
-
grads, # type: ignore
|
|
583
|
-
opt_state,
|
|
584
|
-
params, # type: ignore
|
|
585
|
-
) # see optimizer.init for explaination for the ignore(s) here
|
|
586
|
-
params = optax.apply_updates(params, updates) # type: ignore
|
|
587
|
-
|
|
588
|
-
# check if any of the parameters is NaN
|
|
589
|
-
last_non_nan_params = jax.lax.cond(
|
|
590
|
-
_check_nan_in_pytree(params),
|
|
591
|
-
lambda _: last_non_nan_params,
|
|
592
|
-
lambda _: params,
|
|
593
|
-
None,
|
|
594
|
-
)
|
|
595
|
-
|
|
596
|
-
return (
|
|
597
|
-
params,
|
|
598
|
-
opt_state,
|
|
599
|
-
last_non_nan_params,
|
|
600
|
-
)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
@partial(jit, static_argnames=["prefix"])
|
|
604
|
-
def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
605
|
-
# note that if the following is not jitted in the main lor loop, it is
|
|
606
|
-
# super slow
|
|
607
|
-
_ = jax.lax.cond(
|
|
608
|
-
i % print_loss_every == 0,
|
|
609
|
-
lambda _: jax.debug.print(
|
|
610
|
-
prefix + "Iteration {i}: loss value = {loss_val}",
|
|
611
|
-
i=i,
|
|
612
|
-
loss_val=loss_val,
|
|
613
|
-
),
|
|
614
|
-
lambda _: None,
|
|
615
|
-
(None,),
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
@jit
|
|
620
|
-
def _store_loss_and_params(
|
|
621
|
-
i: int,
|
|
622
|
-
params: Params[Array],
|
|
623
|
-
stored_params: Params[Array | None],
|
|
624
|
-
loss_container: LossContainer,
|
|
625
|
-
train_loss_val: float,
|
|
626
|
-
loss_terms: PyTree[Array],
|
|
627
|
-
weight_terms: PyTree[Array],
|
|
628
|
-
tracked_params: Params,
|
|
629
|
-
) -> tuple[StoredObjectContainer, LossContainer]:
|
|
630
|
-
stored_params = jax.tree_util.tree_map(
|
|
631
|
-
lambda stored_value, param, tracked_param: (
|
|
632
|
-
None
|
|
633
|
-
if stored_value is None
|
|
634
|
-
else jax.lax.cond(
|
|
635
|
-
tracked_param,
|
|
636
|
-
lambda ope: ope[0].at[i].set(ope[1]),
|
|
637
|
-
lambda ope: ope[0],
|
|
638
|
-
(stored_value, param),
|
|
639
|
-
)
|
|
640
|
-
),
|
|
641
|
-
stored_params,
|
|
642
|
-
params,
|
|
643
|
-
tracked_params,
|
|
644
|
-
is_leaf=lambda x: x is None,
|
|
645
|
-
)
|
|
646
|
-
stored_loss_terms = jax.tree_util.tree_map(
|
|
647
|
-
lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
|
|
648
|
-
loss_container.stored_loss_terms,
|
|
649
|
-
loss_terms,
|
|
650
|
-
)
|
|
651
|
-
|
|
652
|
-
if loss_container.stored_weights_terms is not None:
|
|
653
|
-
stored_weights_terms = jax.tree_util.tree_map(
|
|
654
|
-
lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
|
|
655
|
-
jax.tree.leaves(
|
|
656
|
-
loss_container.stored_weights_terms,
|
|
657
|
-
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
658
|
-
),
|
|
659
|
-
jax.tree.leaves(
|
|
660
|
-
weight_terms,
|
|
661
|
-
is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
|
|
662
|
-
),
|
|
663
|
-
)
|
|
664
|
-
stored_weights_terms = eqx.tree_at(
|
|
665
|
-
lambda pt: jax.tree.leaves(
|
|
666
|
-
pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
|
|
667
|
-
),
|
|
668
|
-
loss_container.stored_weights_terms,
|
|
669
|
-
stored_weights_terms,
|
|
670
|
-
)
|
|
671
|
-
else:
|
|
672
|
-
stored_weights_terms = None
|
|
673
|
-
|
|
674
|
-
train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
|
|
675
|
-
loss_container = LossContainer(
|
|
676
|
-
stored_loss_terms, stored_weights_terms, train_loss_values
|
|
677
|
-
)
|
|
678
|
-
stored_objects = StoredObjectContainer(stored_params)
|
|
679
|
-
return stored_objects, loss_container
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
683
|
-
"""
|
|
684
|
-
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
685
|
-
The verbose argument is here to control printing (or not) when exiting
|
|
686
|
-
the optimisation loop. It can be convenient is jinns.solve is itself
|
|
687
|
-
called in a loop and user want to avoid std output.
|
|
688
|
-
"""
|
|
689
|
-
|
|
690
|
-
@jit
|
|
691
|
-
def break_fun(carry: tuple):
|
|
692
|
-
"""
|
|
693
|
-
Function to break from the main optimization loop whe the following
|
|
694
|
-
conditions are met : maximum number of iterations, NaN
|
|
695
|
-
appearing in the parameters, and early stopping criterion.
|
|
696
|
-
"""
|
|
697
|
-
|
|
698
|
-
def stop_while_loop(msg):
|
|
699
|
-
"""
|
|
700
|
-
Note that the message is wrapped in the jax.lax.cond because a
|
|
701
|
-
string is not a valid JAX type that can be fed into the operands
|
|
702
|
-
"""
|
|
703
|
-
if verbose:
|
|
704
|
-
jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
|
|
705
|
-
return False
|
|
706
|
-
|
|
707
|
-
def continue_while_loop(_):
|
|
708
|
-
return True
|
|
709
|
-
|
|
710
|
-
(i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
|
|
711
|
-
|
|
712
|
-
# Condition 1
|
|
713
|
-
bool_max_iter = jax.lax.cond(
|
|
714
|
-
i >= n_iter,
|
|
715
|
-
lambda _: stop_while_loop("max iteration is reached"),
|
|
716
|
-
continue_while_loop,
|
|
717
|
-
None,
|
|
718
|
-
)
|
|
719
|
-
# Condition 2
|
|
720
|
-
bool_nan_in_params = jax.lax.cond(
|
|
721
|
-
_check_nan_in_pytree(optimization.params),
|
|
722
|
-
lambda _: stop_while_loop(
|
|
723
|
-
"NaN values in parameters (returning last non NaN values)"
|
|
724
|
-
),
|
|
725
|
-
continue_while_loop,
|
|
726
|
-
None,
|
|
727
|
-
)
|
|
728
|
-
# Condition 3
|
|
729
|
-
bool_early_stopping = jax.lax.cond(
|
|
730
|
-
optimization_extra.early_stopping,
|
|
731
|
-
lambda _: stop_while_loop("early stopping"),
|
|
732
|
-
continue_while_loop,
|
|
733
|
-
_,
|
|
734
|
-
)
|
|
735
|
-
|
|
736
|
-
# stop when one of the cond to continue is False
|
|
737
|
-
return jax.tree_util.tree_reduce(
|
|
738
|
-
lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
|
|
739
|
-
(bool_max_iter, bool_nan_in_params, bool_early_stopping),
|
|
740
|
-
)
|
|
741
|
-
|
|
742
|
-
return break_fun
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
def _get_get_batch(
|
|
746
|
-
obs_batch_sharding: jax.sharding.Sharding | None,
|
|
747
|
-
) -> Callable[
|
|
748
|
-
[
|
|
749
|
-
AbstractDataGenerator,
|
|
750
|
-
DataGeneratorParameter | None,
|
|
751
|
-
DataGeneratorObservations | None,
|
|
752
|
-
],
|
|
753
|
-
tuple[
|
|
754
|
-
AnyBatch,
|
|
755
|
-
AbstractDataGenerator,
|
|
756
|
-
DataGeneratorParameter | None,
|
|
757
|
-
DataGeneratorObservations | None,
|
|
758
|
-
],
|
|
759
|
-
]:
|
|
760
|
-
"""
|
|
761
|
-
Return the get_batch function that will be used either the jittable one or
|
|
762
|
-
the non-jittable one with sharding using jax.device.put()
|
|
763
|
-
"""
|
|
764
|
-
|
|
765
|
-
def get_batch_sharding(data, param_data, obs_data):
|
|
766
|
-
"""
|
|
767
|
-
This function is used at each loop but it cannot be jitted because of
|
|
768
|
-
device_put
|
|
769
|
-
"""
|
|
770
|
-
data, batch = data.get_batch()
|
|
771
|
-
if param_data is not None:
|
|
772
|
-
param_data, param_batch = param_data.get_batch()
|
|
773
|
-
batch = append_param_batch(batch, param_batch)
|
|
774
|
-
if obs_data is not None:
|
|
775
|
-
# This is the part that motivated the transition from scan to for loop
|
|
776
|
-
# Indeed we need to be transit obs_batch from CPU to GPU when we have
|
|
777
|
-
# huge observations that cannot fit on GPU. Such transfer wasn't meant
|
|
778
|
-
# to be jitted, i.e. in a scan loop
|
|
779
|
-
obs_data, obs_batch = obs_data.get_batch()
|
|
780
|
-
obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
|
|
781
|
-
batch = append_obs_batch(batch, obs_batch)
|
|
782
|
-
return batch, data, param_data, obs_data
|
|
783
|
-
|
|
784
|
-
@jit
|
|
785
|
-
def get_batch(data, param_data, obs_data):
|
|
786
|
-
"""
|
|
787
|
-
Original get_batch with no sharding
|
|
788
|
-
"""
|
|
789
|
-
data, batch = data.get_batch()
|
|
790
|
-
if param_data is not None:
|
|
791
|
-
param_data, param_batch = param_data.get_batch()
|
|
792
|
-
batch = append_param_batch(batch, param_batch)
|
|
793
|
-
if obs_data is not None:
|
|
794
|
-
obs_data, obs_batch = obs_data.get_batch()
|
|
795
|
-
batch = append_obs_batch(batch, obs_batch)
|
|
796
|
-
return batch, data, param_data, obs_data
|
|
797
|
-
|
|
798
|
-
if obs_batch_sharding is not None:
|
|
799
|
-
return get_batch_sharding
|
|
800
|
-
return get_batch
|