jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +953 -1182
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +321 -168
- jinns/loss/_LossODE.py +290 -307
- jinns/loss/_LossPDE.py +628 -1040
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +95 -96
- jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +94 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +193 -45
- jinns/solver/_solve.py +199 -144
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -43
- jinns/utils/_hyperpinn.py +226 -127
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +117 -84
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +52 -144
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
- jinns-1.0.0.dist-info/RECORD +38 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.8.10.dist-info/RECORD +0 -36
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
- {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -3,26 +3,33 @@ This modules implements the main `solve()` function of jinns which
|
|
|
3
3
|
handles the optimization process
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
from __future__ import (
|
|
7
|
+
annotations,
|
|
8
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING, NamedTuple, Dict, Union
|
|
7
11
|
from functools import partial
|
|
8
12
|
import optax
|
|
9
13
|
import jax
|
|
10
14
|
from jax import jit
|
|
11
15
|
import jax.numpy as jnp
|
|
12
|
-
from
|
|
16
|
+
from jaxtyping import Int, Bool, Float, Array
|
|
13
17
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
14
|
-
from jinns.utils._utils import _check_nan_in_pytree
|
|
18
|
+
from jinns.utils._utils import _check_nan_in_pytree
|
|
19
|
+
from jinns.utils._containers import *
|
|
15
20
|
from jinns.data._DataGenerators import (
|
|
16
21
|
DataGeneratorODE,
|
|
17
22
|
CubicMeshPDEStatio,
|
|
18
23
|
CubicMeshPDENonStatio,
|
|
19
|
-
append_param_batch,
|
|
20
24
|
append_obs_batch,
|
|
25
|
+
append_param_batch,
|
|
21
26
|
)
|
|
22
|
-
from jinns.utils._containers import *
|
|
23
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from jinns.utils._types import *
|
|
24
30
|
|
|
25
|
-
|
|
31
|
+
|
|
32
|
+
def _check_batch_size(other_data, main_data, attr_name):
|
|
26
33
|
if (
|
|
27
34
|
(
|
|
28
35
|
isinstance(main_data, DataGeneratorODE)
|
|
@@ -48,20 +55,32 @@ def check_batch_size(other_data, main_data, attr_name):
|
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
def solve(
|
|
51
|
-
n_iter,
|
|
52
|
-
init_params,
|
|
53
|
-
data,
|
|
54
|
-
loss,
|
|
55
|
-
optimizer,
|
|
56
|
-
print_loss_every=1000,
|
|
57
|
-
opt_state=None,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
58
|
+
n_iter: Int,
|
|
59
|
+
init_params: AnyParams,
|
|
60
|
+
data: AnyDataGenerator,
|
|
61
|
+
loss: AnyLoss,
|
|
62
|
+
optimizer: optax.GradientTransformation,
|
|
63
|
+
print_loss_every: Int = 1000,
|
|
64
|
+
opt_state: Union[NamedTuple, None] = None,
|
|
65
|
+
tracked_params: Params | ParamsDict | None = None,
|
|
66
|
+
param_data: DataGeneratorParameter | None = None,
|
|
67
|
+
obs_data: (
|
|
68
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
|
|
69
|
+
) = None,
|
|
70
|
+
validation: AbstractValidationModule | None = None,
|
|
71
|
+
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
72
|
+
verbose: Bool = True,
|
|
73
|
+
) -> tuple[
|
|
74
|
+
Params | ParamsDict,
|
|
75
|
+
Float[Array, "n_iter"],
|
|
76
|
+
Dict[str, Float[Array, "n_iter"]],
|
|
77
|
+
AnyDataGenerator,
|
|
78
|
+
AnyLoss,
|
|
79
|
+
NamedTuple,
|
|
80
|
+
AnyParams,
|
|
81
|
+
Float[Array, "n_iter"],
|
|
82
|
+
AnyParams,
|
|
83
|
+
]:
|
|
65
84
|
"""
|
|
66
85
|
Performs the optimization process via stochastic gradient descent
|
|
67
86
|
algorithm. We minimize the function defined `loss.evaluate()` with
|
|
@@ -72,52 +91,39 @@ def solve(
|
|
|
72
91
|
Parameters
|
|
73
92
|
----------
|
|
74
93
|
n_iter
|
|
75
|
-
The number of iterations in the optimization
|
|
94
|
+
The maximum number of iterations in the optimization.
|
|
76
95
|
init_params
|
|
77
|
-
The initial
|
|
78
|
-
dictionaries: `eq_params` and `nn_params``, respectively the
|
|
79
|
-
differential equation parameters and the neural network parameter
|
|
96
|
+
The initial jinns.parameters.Params object.
|
|
80
97
|
data
|
|
81
|
-
A DataGenerator object
|
|
82
|
-
method which returns a 3-tuple with (omega_grid, omega_border, time grid).
|
|
83
|
-
It must be jittable (e.g. implements via a pytree
|
|
84
|
-
registration)
|
|
98
|
+
A DataGenerator object to retrieve batches of collocation points.
|
|
85
99
|
loss
|
|
86
|
-
|
|
87
|
-
object). It must be jittable (e.g. implements via a pytree
|
|
88
|
-
registration)
|
|
100
|
+
The loss function to minimize.
|
|
89
101
|
optimizer
|
|
90
|
-
An
|
|
102
|
+
An optax optimizer.
|
|
91
103
|
print_loss_every
|
|
92
|
-
|
|
104
|
+
Default 1000. The rate at which we print the loss value in the
|
|
93
105
|
gradient step loop.
|
|
94
106
|
opt_state
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
the
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
The seq2seq approach we reimplements is defined in
|
|
104
|
-
"Characterizing possible failure modes in physics-informed neural
|
|
105
|
-
networks", A. S. Krishnapriyan, NeurIPS 2021
|
|
106
|
-
tracked_params_key_list
|
|
107
|
-
Default None. Otherwise it is a list of list of strings
|
|
108
|
-
to access a leaf in params. Each selected leaf will be tracked
|
|
109
|
-
and stored at each iteration and returned by the solve function
|
|
107
|
+
Provide an optional initial state to the optimizer.
|
|
108
|
+
tracked_params
|
|
109
|
+
Default None. An eqx.Module of type Params with non-None values for
|
|
110
|
+
parameters that needs to be tracked along the iterations.
|
|
111
|
+
None values in tracked_params will not be traversed. Thus
|
|
112
|
+
the user can provide something like `tracked_params = jinns.parameters.Params(
|
|
113
|
+
nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
114
|
+
being a complex data structure.
|
|
110
115
|
param_data
|
|
111
116
|
Default None. A DataGeneratorParameter object which can be used to
|
|
112
117
|
sample equation parameters.
|
|
113
118
|
obs_data
|
|
114
|
-
Default None. A DataGeneratorObservations
|
|
115
|
-
|
|
119
|
+
Default None. A DataGeneratorObservations or
|
|
120
|
+
DataGeneratorObservationsMultiPINNs
|
|
121
|
+
object which can be used to sample minibatches of observations.
|
|
116
122
|
validation
|
|
117
123
|
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
118
|
-
validation strategy. See documentation of
|
|
124
|
+
validation strategy. See documentation of `jinns.validation.
|
|
119
125
|
_validation.AbstractValidationModule` for the general interface, and
|
|
120
|
-
|
|
126
|
+
`jinns.validation._validation.ValidationLoss` for a practical
|
|
121
127
|
implementation of a vanilla validation stategy on a validation set of
|
|
122
128
|
collocation points.
|
|
123
129
|
|
|
@@ -131,12 +137,15 @@ def solve(
|
|
|
131
137
|
Default None. An optional sharding object to constraint the obs_batch.
|
|
132
138
|
Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
|
|
133
139
|
created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
|
|
134
|
-
loading on GPU huge datasets of observations
|
|
140
|
+
loading on GPU huge datasets of observations.
|
|
141
|
+
verbose
|
|
142
|
+
Default True. If False, no std output (loss or cause of
|
|
143
|
+
exiting the optimization loop) will be produced.
|
|
135
144
|
|
|
136
145
|
Returns
|
|
137
146
|
-------
|
|
138
147
|
params
|
|
139
|
-
The last non NaN value of the
|
|
148
|
+
The last non NaN value of the params at then end of the
|
|
140
149
|
optimization process
|
|
141
150
|
total_loss_values
|
|
142
151
|
An array of the total loss term along the gradient steps
|
|
@@ -150,14 +159,18 @@ def solve(
|
|
|
150
159
|
opt_state
|
|
151
160
|
The final optimized state
|
|
152
161
|
stored_params
|
|
153
|
-
A
|
|
154
|
-
|
|
162
|
+
A Params objects with the stored values of the desired parameters (as
|
|
163
|
+
signified in tracked_params argument)
|
|
164
|
+
validation_crit_values
|
|
165
|
+
An array containing the validation criterion values of the training
|
|
166
|
+
best_val_params
|
|
167
|
+
The best parameters according to the validation criterion
|
|
155
168
|
"""
|
|
156
169
|
if param_data is not None:
|
|
157
|
-
|
|
170
|
+
_check_batch_size(param_data, data, "param_batch_size")
|
|
158
171
|
|
|
159
172
|
if obs_data is not None:
|
|
160
|
-
|
|
173
|
+
_check_batch_size(obs_data, data, "obs_batch_size")
|
|
161
174
|
|
|
162
175
|
if opt_state is None:
|
|
163
176
|
opt_state = optimizer.init(init_params)
|
|
@@ -168,30 +181,32 @@ def solve(
|
|
|
168
181
|
|
|
169
182
|
# Seq2seq
|
|
170
183
|
curr_seq = 0
|
|
171
|
-
if seq2seq is not None:
|
|
172
|
-
assert (
|
|
173
|
-
data.method == "uniform"
|
|
174
|
-
), "data.method must be uniform if using seq2seq learning !"
|
|
175
|
-
data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
|
|
176
184
|
|
|
177
185
|
train_loss_values = jnp.zeros((n_iter))
|
|
178
186
|
# depending on obs_batch_sharding we will get the simple get_batch or the
|
|
179
187
|
# get_batch with device_put, the latter is not jittable
|
|
180
|
-
get_batch =
|
|
188
|
+
get_batch = _get_get_batch(obs_batch_sharding)
|
|
181
189
|
|
|
182
190
|
# initialize the dict for stored parameter values
|
|
183
191
|
# we need to get a loss_term to init stuff
|
|
184
192
|
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
185
193
|
_, loss_terms = loss(init_params, batch_ini)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
tracked_params
|
|
194
|
+
|
|
195
|
+
# initialize parameter tracking
|
|
196
|
+
if tracked_params is None:
|
|
197
|
+
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
189
198
|
stored_params = jax.tree_util.tree_map(
|
|
190
199
|
lambda tracked_param, param: (
|
|
191
|
-
jnp.zeros((n_iter,) + param.shape)
|
|
200
|
+
jnp.zeros((n_iter,) + jnp.asarray(param).shape)
|
|
201
|
+
if tracked_param is not None
|
|
202
|
+
else None
|
|
192
203
|
),
|
|
193
204
|
tracked_params,
|
|
194
205
|
init_params,
|
|
206
|
+
is_leaf=lambda x: x is None, # None values in tracked_params will not
|
|
207
|
+
# be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
|
|
208
|
+
# nn_params=None, eq_params={"nu": True})` while init_params.nn_params
|
|
209
|
+
# being a complex data structure
|
|
195
210
|
)
|
|
196
211
|
|
|
197
212
|
# initialize the dict for stored loss values
|
|
@@ -203,11 +218,13 @@ def solve(
|
|
|
203
218
|
data=data, param_data=param_data, obs_data=obs_data
|
|
204
219
|
)
|
|
205
220
|
optimization = OptimizationContainer(
|
|
206
|
-
params=init_params,
|
|
221
|
+
params=init_params,
|
|
222
|
+
last_non_nan_params=init_params,
|
|
223
|
+
opt_state=opt_state,
|
|
207
224
|
)
|
|
208
225
|
optimization_extra = OptimizationExtraContainer(
|
|
209
226
|
curr_seq=curr_seq,
|
|
210
|
-
|
|
227
|
+
best_val_params=init_params,
|
|
211
228
|
)
|
|
212
229
|
loss_container = LossContainer(
|
|
213
230
|
stored_loss_terms=stored_loss_terms,
|
|
@@ -222,7 +239,7 @@ def solve(
|
|
|
222
239
|
else:
|
|
223
240
|
validation_crit_values = None
|
|
224
241
|
|
|
225
|
-
break_fun =
|
|
242
|
+
break_fun = _get_break_fun(n_iter, verbose)
|
|
226
243
|
|
|
227
244
|
iteration = 0
|
|
228
245
|
carry = (
|
|
@@ -237,7 +254,7 @@ def solve(
|
|
|
237
254
|
validation_crit_values,
|
|
238
255
|
)
|
|
239
256
|
|
|
240
|
-
def
|
|
257
|
+
def _one_iteration(carry: main_carry) -> main_carry:
|
|
241
258
|
(
|
|
242
259
|
i,
|
|
243
260
|
loss,
|
|
@@ -262,7 +279,7 @@ def solve(
|
|
|
262
279
|
params,
|
|
263
280
|
opt_state,
|
|
264
281
|
last_non_nan_params,
|
|
265
|
-
) =
|
|
282
|
+
) = _gradient_step(
|
|
266
283
|
loss,
|
|
267
284
|
optimizer,
|
|
268
285
|
batch,
|
|
@@ -272,7 +289,8 @@ def solve(
|
|
|
272
289
|
)
|
|
273
290
|
|
|
274
291
|
# Print train loss value during optimization
|
|
275
|
-
|
|
292
|
+
if verbose:
|
|
293
|
+
_print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
276
294
|
|
|
277
295
|
if validation is not None:
|
|
278
296
|
# there is a jax.lax.cond because we do not necesarily call the
|
|
@@ -281,6 +299,7 @@ def solve(
|
|
|
281
299
|
validation, # always return `validation` for in-place mutation
|
|
282
300
|
early_stopping,
|
|
283
301
|
validation_criterion,
|
|
302
|
+
update_best_params,
|
|
284
303
|
) = jax.lax.cond(
|
|
285
304
|
i % validation.call_every == 0,
|
|
286
305
|
lambda operands: operands[0](*operands[1:]), # validation.__call__()
|
|
@@ -288,6 +307,7 @@ def solve(
|
|
|
288
307
|
operands[0],
|
|
289
308
|
False,
|
|
290
309
|
validation_crit_values[i - 1],
|
|
310
|
+
False,
|
|
291
311
|
),
|
|
292
312
|
(
|
|
293
313
|
validation, # validation must be in operands
|
|
@@ -295,31 +315,32 @@ def solve(
|
|
|
295
315
|
),
|
|
296
316
|
)
|
|
297
317
|
# Print validation loss value during optimization
|
|
298
|
-
|
|
318
|
+
if verbose:
|
|
319
|
+
_print_fn(
|
|
320
|
+
i, validation_criterion, print_loss_every, prefix="[validation] "
|
|
321
|
+
)
|
|
299
322
|
validation_crit_values = validation_crit_values.at[i].set(
|
|
300
323
|
validation_criterion
|
|
301
324
|
)
|
|
325
|
+
|
|
326
|
+
# update best_val_params w.r.t val_loss if needed
|
|
327
|
+
best_val_params = jax.lax.cond(
|
|
328
|
+
update_best_params,
|
|
329
|
+
lambda _: params, # update with current value
|
|
330
|
+
lambda operands: operands[0].best_val_params, # unchanged
|
|
331
|
+
(optimization_extra,),
|
|
332
|
+
)
|
|
302
333
|
else:
|
|
303
334
|
early_stopping = False
|
|
335
|
+
best_val_params = params
|
|
304
336
|
|
|
305
337
|
# Trigger RAR
|
|
306
338
|
loss, params, data = trigger_rar(
|
|
307
339
|
i, loss, params, data, _rar_step_true, _rar_step_false
|
|
308
340
|
)
|
|
309
341
|
|
|
310
|
-
# Trigger seq2seq
|
|
311
|
-
loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
|
|
312
|
-
i,
|
|
313
|
-
loss,
|
|
314
|
-
params,
|
|
315
|
-
data,
|
|
316
|
-
opt_state,
|
|
317
|
-
optimization_extra.curr_seq,
|
|
318
|
-
optimization_extra.seq2seq,
|
|
319
|
-
)
|
|
320
|
-
|
|
321
342
|
# save loss value and selected parameters
|
|
322
|
-
stored_params, stored_loss_terms, train_loss_values =
|
|
343
|
+
stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
|
|
323
344
|
i,
|
|
324
345
|
params,
|
|
325
346
|
stored_objects.stored_params,
|
|
@@ -329,13 +350,15 @@ def solve(
|
|
|
329
350
|
loss_terms,
|
|
330
351
|
tracked_params,
|
|
331
352
|
)
|
|
353
|
+
|
|
354
|
+
# increment iteration number
|
|
332
355
|
i += 1
|
|
333
356
|
|
|
334
357
|
return (
|
|
335
358
|
i,
|
|
336
359
|
loss,
|
|
337
360
|
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
338
|
-
OptimizationExtraContainer(curr_seq,
|
|
361
|
+
OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
|
|
339
362
|
DataGeneratorContainer(data, param_data, obs_data),
|
|
340
363
|
validation,
|
|
341
364
|
LossContainer(stored_loss_terms, train_loss_values),
|
|
@@ -348,9 +371,9 @@ def solve(
|
|
|
348
371
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
349
372
|
if obs_batch_sharding is not None:
|
|
350
373
|
while break_fun(carry):
|
|
351
|
-
carry =
|
|
374
|
+
carry = _one_iteration(carry)
|
|
352
375
|
else:
|
|
353
|
-
carry = jax.lax.while_loop(break_fun,
|
|
376
|
+
carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
|
|
354
377
|
|
|
355
378
|
(
|
|
356
379
|
i,
|
|
@@ -364,41 +387,47 @@ def solve(
|
|
|
364
387
|
validation_crit_values,
|
|
365
388
|
) = carry
|
|
366
389
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
390
|
+
if verbose:
|
|
391
|
+
jax.debug.print(
|
|
392
|
+
"Final iteration {i}: train loss value = {train_loss_val}",
|
|
393
|
+
i=i,
|
|
394
|
+
train_loss_val=loss_container.train_loss_values[i - 1],
|
|
395
|
+
)
|
|
372
396
|
if validation is not None:
|
|
373
397
|
jax.debug.print(
|
|
374
398
|
"validation loss value = {validation_loss_val}",
|
|
375
399
|
validation_loss_val=validation_crit_values[i - 1],
|
|
376
400
|
)
|
|
377
401
|
|
|
378
|
-
if validation is None:
|
|
379
|
-
return (
|
|
380
|
-
optimization.last_non_nan_params,
|
|
381
|
-
loss_container.train_loss_values,
|
|
382
|
-
loss_container.stored_loss_terms,
|
|
383
|
-
train_data.data,
|
|
384
|
-
loss,
|
|
385
|
-
optimization.opt_state,
|
|
386
|
-
stored_objects.stored_params,
|
|
387
|
-
)
|
|
388
402
|
return (
|
|
389
403
|
optimization.last_non_nan_params,
|
|
390
404
|
loss_container.train_loss_values,
|
|
391
405
|
loss_container.stored_loss_terms,
|
|
392
|
-
train_data.data,
|
|
393
|
-
loss,
|
|
406
|
+
train_data.data, # return the DataGenerator if needed (no in-place modif)
|
|
407
|
+
loss, # return the Loss if needed (no-inplace modif)
|
|
394
408
|
optimization.opt_state,
|
|
395
409
|
stored_objects.stored_params,
|
|
396
|
-
validation_crit_values,
|
|
410
|
+
validation_crit_values if validation is not None else None,
|
|
411
|
+
optimization_extra.best_val_params if validation is not None else None,
|
|
397
412
|
)
|
|
398
413
|
|
|
399
414
|
|
|
400
415
|
@partial(jit, static_argnames=["optimizer"])
|
|
401
|
-
def
|
|
416
|
+
def _gradient_step(
|
|
417
|
+
loss: AnyLoss,
|
|
418
|
+
optimizer: optax.GradientTransformation,
|
|
419
|
+
batch: AnyBatch,
|
|
420
|
+
params: AnyParams,
|
|
421
|
+
opt_state: NamedTuple,
|
|
422
|
+
last_non_nan_params: AnyParams,
|
|
423
|
+
) -> tuple[
|
|
424
|
+
AnyLoss,
|
|
425
|
+
float,
|
|
426
|
+
Dict[str, float],
|
|
427
|
+
AnyParams,
|
|
428
|
+
NamedTuple,
|
|
429
|
+
AnyParams,
|
|
430
|
+
]:
|
|
402
431
|
"""
|
|
403
432
|
optimizer cannot be jit-ted.
|
|
404
433
|
"""
|
|
@@ -426,7 +455,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
426
455
|
|
|
427
456
|
|
|
428
457
|
@partial(jit, static_argnames=["prefix"])
|
|
429
|
-
def
|
|
458
|
+
def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
|
|
430
459
|
# note that if the following is not jitted in the main lor loop, it is
|
|
431
460
|
# super slow
|
|
432
461
|
_ = jax.lax.cond(
|
|
@@ -442,26 +471,33 @@ def print_fn(i, loss_val, print_loss_every, prefix=""):
|
|
|
442
471
|
|
|
443
472
|
|
|
444
473
|
@jit
|
|
445
|
-
def
|
|
446
|
-
i,
|
|
447
|
-
params,
|
|
448
|
-
stored_params,
|
|
449
|
-
stored_loss_terms,
|
|
450
|
-
train_loss_values,
|
|
451
|
-
train_loss_val,
|
|
452
|
-
loss_terms,
|
|
453
|
-
tracked_params,
|
|
454
|
-
)
|
|
474
|
+
def _store_loss_and_params(
|
|
475
|
+
i: Int,
|
|
476
|
+
params: AnyParams,
|
|
477
|
+
stored_params: AnyParams,
|
|
478
|
+
stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
|
|
479
|
+
train_loss_values: Float[Array, "n_iter"],
|
|
480
|
+
train_loss_val: float,
|
|
481
|
+
loss_terms: Dict[str, float],
|
|
482
|
+
tracked_params: AnyParams,
|
|
483
|
+
) -> tuple[
|
|
484
|
+
Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
|
|
485
|
+
]:
|
|
455
486
|
stored_params = jax.tree_util.tree_map(
|
|
456
|
-
lambda stored_value, param, tracked_param:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
487
|
+
lambda stored_value, param, tracked_param: (
|
|
488
|
+
None
|
|
489
|
+
if stored_value is None
|
|
490
|
+
else jax.lax.cond(
|
|
491
|
+
tracked_param,
|
|
492
|
+
lambda ope: ope[0].at[i].set(ope[1]),
|
|
493
|
+
lambda ope: ope[0],
|
|
494
|
+
(stored_value, param),
|
|
495
|
+
)
|
|
461
496
|
),
|
|
462
497
|
stored_params,
|
|
463
498
|
params,
|
|
464
499
|
tracked_params,
|
|
500
|
+
is_leaf=lambda x: x is None,
|
|
465
501
|
)
|
|
466
502
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
467
503
|
lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
|
|
@@ -473,16 +509,20 @@ def store_loss_and_params(
|
|
|
473
509
|
return (stored_params, stored_loss_terms, train_loss_values)
|
|
474
510
|
|
|
475
511
|
|
|
476
|
-
def
|
|
512
|
+
def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
477
513
|
"""
|
|
478
|
-
Wrapper to get the break_fun with appropriate `n_iter
|
|
514
|
+
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
515
|
+
The verbose argument is here to control printing (or not) when exiting
|
|
516
|
+
the optimisation loop. It can be convenient is jinns.solve is itself
|
|
517
|
+
called in a loop and user want to avoid std output.
|
|
479
518
|
"""
|
|
480
519
|
|
|
481
520
|
@jit
|
|
482
|
-
def break_fun(carry):
|
|
521
|
+
def break_fun(carry: tuple):
|
|
483
522
|
"""
|
|
484
|
-
Function to break from the main optimization loop
|
|
485
|
-
|
|
523
|
+
Function to break from the main optimization loop whe the following
|
|
524
|
+
conditions are met : maximum number of iterations, NaN
|
|
525
|
+
appearing in the parameters, and early stopping criterion.
|
|
486
526
|
"""
|
|
487
527
|
|
|
488
528
|
def stop_while_loop(msg):
|
|
@@ -490,7 +530,8 @@ def get_break_fun(n_iter):
|
|
|
490
530
|
Note that the message is wrapped in the jax.lax.cond because a
|
|
491
531
|
string is not a valid JAX type that can be fed into the operands
|
|
492
532
|
"""
|
|
493
|
-
|
|
533
|
+
if verbose:
|
|
534
|
+
jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
|
|
494
535
|
return False
|
|
495
536
|
|
|
496
537
|
def continue_while_loop(_):
|
|
@@ -531,43 +572,57 @@ def get_break_fun(n_iter):
|
|
|
531
572
|
return break_fun
|
|
532
573
|
|
|
533
574
|
|
|
534
|
-
def
|
|
575
|
+
def _get_get_batch(
|
|
576
|
+
obs_batch_sharding: jax.sharding.Sharding,
|
|
577
|
+
) -> Callable[
|
|
578
|
+
[
|
|
579
|
+
AnyDataGenerator,
|
|
580
|
+
DataGeneratorParameter | None,
|
|
581
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
|
|
582
|
+
],
|
|
583
|
+
tuple[
|
|
584
|
+
AnyBatch,
|
|
585
|
+
AnyDataGenerator,
|
|
586
|
+
DataGeneratorParameter | None,
|
|
587
|
+
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
|
|
588
|
+
],
|
|
589
|
+
]:
|
|
535
590
|
"""
|
|
536
591
|
Return the get_batch function that will be used either the jittable one or
|
|
537
|
-
the non-jittable one with sharding
|
|
592
|
+
the non-jittable one with sharding using jax.device.put()
|
|
538
593
|
"""
|
|
539
594
|
|
|
540
595
|
def get_batch_sharding(data, param_data, obs_data):
|
|
541
596
|
"""
|
|
542
597
|
This function is used at each loop but it cannot be jitted because of
|
|
543
598
|
device_put
|
|
544
|
-
|
|
545
|
-
Note: return all that's modified or unwanted dirty undefined behaviour
|
|
546
599
|
"""
|
|
547
|
-
batch = data.get_batch()
|
|
600
|
+
data, batch = data.get_batch()
|
|
548
601
|
if param_data is not None:
|
|
549
|
-
|
|
602
|
+
param_data, param_batch = param_data.get_batch()
|
|
603
|
+
batch = append_param_batch(batch, param_batch)
|
|
550
604
|
if obs_data is not None:
|
|
551
605
|
# This is the part that motivated the transition from scan to for loop
|
|
552
606
|
# Indeed we need to be transit obs_batch from CPU to GPU when we have
|
|
553
607
|
# huge observations that cannot fit on GPU. Such transfer wasn't meant
|
|
554
608
|
# to be jitted, i.e. in a scan loop
|
|
555
|
-
obs_batch =
|
|
609
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
610
|
+
obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
|
|
556
611
|
batch = append_obs_batch(batch, obs_batch)
|
|
557
612
|
return batch, data, param_data, obs_data
|
|
558
613
|
|
|
559
614
|
@jit
|
|
560
615
|
def get_batch(data, param_data, obs_data):
|
|
561
616
|
"""
|
|
562
|
-
Original get_batch with
|
|
563
|
-
|
|
564
|
-
Note: return all that's modified or unwanted dirty undefined behaviour
|
|
617
|
+
Original get_batch with no sharding
|
|
565
618
|
"""
|
|
566
|
-
batch = data.get_batch()
|
|
619
|
+
data, batch = data.get_batch()
|
|
567
620
|
if param_data is not None:
|
|
568
|
-
|
|
621
|
+
param_data, param_batch = param_data.get_batch()
|
|
622
|
+
batch = append_param_batch(batch, param_batch)
|
|
569
623
|
if obs_data is not None:
|
|
570
|
-
|
|
624
|
+
obs_data, obs_batch = obs_data.get_batch()
|
|
625
|
+
batch = append_obs_batch(batch, obs_batch)
|
|
571
626
|
return batch, data, param_data, obs_data
|
|
572
627
|
|
|
573
628
|
if obs_batch_sharding is not None:
|
jinns/utils/__init__.py
CHANGED
|
@@ -1,10 +1,4 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
log_euler_maruyama_density,
|
|
5
|
-
)
|
|
6
|
-
from ._pinn import create_PINN
|
|
7
|
-
from ._spinn import create_SPINN
|
|
8
|
-
from ._hyperpinn import create_HYPERPINN
|
|
9
|
-
from ._optim import alternate_optimizer, delayed_optimizer
|
|
1
|
+
from ._pinn import create_PINN, PINN
|
|
2
|
+
from ._spinn import create_SPINN, SPINN
|
|
3
|
+
from ._hyperpinn import create_HYPERPINN, HYPERPINN
|
|
10
4
|
from ._save_load import save_pinn, load_pinn
|