jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/loss/_LossODE.py +6 -0
- jinns/loss/_LossPDE.py +18 -18
- jinns/solver/_solve.py +264 -121
- jinns/utils/_containers.py +57 -0
- jinns/validation/__init__.py +1 -0
- jinns/validation/_validation.py +214 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/METADATA +1 -1
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/RECORD +16 -11
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/LICENSE +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/WHEEL +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -3,9 +3,9 @@ This modules implements the main `solve()` function of jinns which
|
|
|
3
3
|
handles the optimization process
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import copy
|
|
7
|
+
from functools import partial
|
|
6
8
|
import optax
|
|
7
|
-
from tqdm import tqdm
|
|
8
|
-
from jax_tqdm import scan_tqdm
|
|
9
9
|
import jax
|
|
10
10
|
from jax import jit
|
|
11
11
|
import jax.numpy as jnp
|
|
@@ -19,8 +19,32 @@ from jinns.data._DataGenerators import (
|
|
|
19
19
|
append_param_batch,
|
|
20
20
|
append_obs_batch,
|
|
21
21
|
)
|
|
22
|
+
from jinns.utils._containers import *
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
def check_batch_size(other_data, main_data, attr_name):
|
|
26
|
+
if (
|
|
27
|
+
(
|
|
28
|
+
isinstance(main_data, DataGeneratorODE)
|
|
29
|
+
and getattr(other_data, attr_name) != main_data.temporal_batch_size
|
|
30
|
+
)
|
|
31
|
+
or (
|
|
32
|
+
isinstance(main_data, CubicMeshPDEStatio)
|
|
33
|
+
and not isinstance(main_data, CubicMeshPDENonStatio)
|
|
34
|
+
and getattr(other_data, attr_name) != main_data.omega_batch_size
|
|
35
|
+
)
|
|
36
|
+
or (
|
|
37
|
+
isinstance(main_data, CubicMeshPDENonStatio)
|
|
38
|
+
and getattr(other_data, attr_name)
|
|
39
|
+
!= main_data.omega_batch_size * main_data.temporal_batch_size
|
|
40
|
+
)
|
|
41
|
+
):
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"Optional other_data.param_batch_size must be"
|
|
44
|
+
" equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
|
|
45
|
+
" the product of both dependeing on the type of the main"
|
|
46
|
+
" datagenerator"
|
|
47
|
+
)
|
|
24
48
|
|
|
25
49
|
|
|
26
50
|
def solve(
|
|
@@ -35,6 +59,7 @@ def solve(
|
|
|
35
59
|
tracked_params_key_list=None,
|
|
36
60
|
param_data=None,
|
|
37
61
|
obs_data=None,
|
|
62
|
+
validation=None,
|
|
38
63
|
obs_batch_sharding=None,
|
|
39
64
|
):
|
|
40
65
|
"""
|
|
@@ -88,6 +113,20 @@ def solve(
|
|
|
88
113
|
obs_data
|
|
89
114
|
Default None. A DataGeneratorObservations object which can be used to
|
|
90
115
|
sample minibatches of observations
|
|
116
|
+
validation
|
|
117
|
+
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
118
|
+
validation strategy. See documentation of :obj:`~jinns.validation.
|
|
119
|
+
_validation.AbstractValidationModule` for the general interface, and
|
|
120
|
+
:obj:`~jinns.validation._validation.ValidationLoss` for a practical
|
|
121
|
+
implementation of a vanilla validation stategy on a validation set of
|
|
122
|
+
collocation points.
|
|
123
|
+
|
|
124
|
+
**Note**: The ``__call__(self, params)`` method should have
|
|
125
|
+
the latter prescribed signature and return ``(validation [eqx.Module],
|
|
126
|
+
early_stop [bool], validation_criterion [Array])``. It is called every
|
|
127
|
+
``validation.call_every`` iteration. Users are free to design any
|
|
128
|
+
validation strategy of their choice, and to decide on the early
|
|
129
|
+
stopping criterion.
|
|
91
130
|
obs_batch_sharding
|
|
92
131
|
Default None. An optional sharding object to constraint the obs_batch.
|
|
93
132
|
Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
|
|
@@ -114,39 +153,14 @@ def solve(
|
|
|
114
153
|
A dictionary. At each key an array of the values of the parameters
|
|
115
154
|
given in tracked_params_key_list is stored
|
|
116
155
|
"""
|
|
117
|
-
params = init_params
|
|
118
|
-
last_non_nan_params = init_params.copy()
|
|
119
|
-
|
|
120
156
|
if param_data is not None:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
and obs_data.obs_batch_size != data.temporal_batch_size
|
|
126
|
-
)
|
|
127
|
-
or (
|
|
128
|
-
isinstance(data, CubicMeshPDEStatio)
|
|
129
|
-
and not isinstance(data, CubicMeshPDENonStatio)
|
|
130
|
-
and param_data.param_batch_size != data.omega_batch_size
|
|
131
|
-
and obs_data.obs_batch_size != data.omega_batch_size
|
|
132
|
-
)
|
|
133
|
-
or (
|
|
134
|
-
isinstance(data, CubicMeshPDENonStatio)
|
|
135
|
-
and param_data.param_batch_size
|
|
136
|
-
!= data.omega_batch_size * data.temporal_batch_size
|
|
137
|
-
and obs_data.obs_batch_size
|
|
138
|
-
!= data.omega_batch_size * data.temporal_batch_size
|
|
139
|
-
)
|
|
140
|
-
):
|
|
141
|
-
raise ValueError(
|
|
142
|
-
"Optional param_data.param_batch_size must be"
|
|
143
|
-
" equal to data.temporal_batch_size or data.omega_batch_size or"
|
|
144
|
-
" the product of both dependeing on the type of the main"
|
|
145
|
-
" datagenerator"
|
|
146
|
-
)
|
|
157
|
+
check_batch_size(param_data, data, "param_batch_size")
|
|
158
|
+
|
|
159
|
+
if obs_data is not None:
|
|
160
|
+
check_batch_size(obs_data, data, "obs_batch_size")
|
|
147
161
|
|
|
148
162
|
if opt_state is None:
|
|
149
|
-
opt_state = optimizer.init(
|
|
163
|
+
opt_state = optimizer.init(init_params)
|
|
150
164
|
|
|
151
165
|
# RAR sampling init (ouside scanned function to avoid dynamic slice error)
|
|
152
166
|
# If RAR is not used the _rar_step_*() are juste None and data is unchanged
|
|
@@ -160,7 +174,7 @@ def solve(
|
|
|
160
174
|
), "data.method must be uniform if using seq2seq learning !"
|
|
161
175
|
data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
|
|
162
176
|
|
|
163
|
-
|
|
177
|
+
train_loss_values = jnp.zeros((n_iter))
|
|
164
178
|
# depending on obs_batch_sharding we will get the simple get_batch or the
|
|
165
179
|
# get_batch with device_put, the latter is not jittable
|
|
166
180
|
get_batch = get_get_batch(obs_batch_sharding)
|
|
@@ -168,68 +182,125 @@ def solve(
|
|
|
168
182
|
# initialize the dict for stored parameter values
|
|
169
183
|
# we need to get a loss_term to init stuff
|
|
170
184
|
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
171
|
-
_, loss_terms = loss(
|
|
185
|
+
_, loss_terms = loss(init_params, batch_ini)
|
|
172
186
|
if tracked_params_key_list is None:
|
|
173
187
|
tracked_params_key_list = []
|
|
174
|
-
tracked_params = _tracked_parameters(
|
|
188
|
+
tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
|
|
175
189
|
stored_params = jax.tree_util.tree_map(
|
|
176
190
|
lambda tracked_param, param: (
|
|
177
191
|
jnp.zeros((n_iter,) + param.shape) if tracked_param else None
|
|
178
192
|
),
|
|
179
193
|
tracked_params,
|
|
180
|
-
|
|
194
|
+
init_params,
|
|
181
195
|
)
|
|
182
196
|
|
|
183
197
|
# initialize the dict for stored loss values
|
|
184
198
|
stored_loss_terms = jax.tree_util.tree_map(
|
|
185
|
-
lambda
|
|
199
|
+
lambda _: jnp.zeros((n_iter)), loss_terms
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
train_data = DataGeneratorContainer(
|
|
203
|
+
data=data, param_data=param_data, obs_data=obs_data
|
|
204
|
+
)
|
|
205
|
+
optimization = OptimizationContainer(
|
|
206
|
+
params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
|
|
207
|
+
)
|
|
208
|
+
optimization_extra = OptimizationExtraContainer(
|
|
209
|
+
curr_seq=curr_seq,
|
|
210
|
+
seq2seq=seq2seq,
|
|
186
211
|
)
|
|
212
|
+
loss_container = LossContainer(
|
|
213
|
+
stored_loss_terms=stored_loss_terms,
|
|
214
|
+
train_loss_values=train_loss_values,
|
|
215
|
+
)
|
|
216
|
+
stored_objects = StoredObjectContainer(
|
|
217
|
+
stored_params=stored_params,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if validation is not None:
|
|
221
|
+
validation_crit_values = jnp.zeros(n_iter)
|
|
222
|
+
else:
|
|
223
|
+
validation_crit_values = None
|
|
224
|
+
|
|
225
|
+
break_fun = get_break_fun(n_iter)
|
|
187
226
|
|
|
227
|
+
iteration = 0
|
|
188
228
|
carry = (
|
|
189
|
-
|
|
190
|
-
init_params.copy(),
|
|
191
|
-
data,
|
|
192
|
-
curr_seq,
|
|
193
|
-
seq2seq,
|
|
194
|
-
stored_params,
|
|
195
|
-
stored_loss_terms,
|
|
229
|
+
iteration,
|
|
196
230
|
loss,
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
231
|
+
optimization,
|
|
232
|
+
optimization_extra,
|
|
233
|
+
train_data,
|
|
234
|
+
validation,
|
|
235
|
+
loss_container,
|
|
236
|
+
stored_objects,
|
|
237
|
+
validation_crit_values,
|
|
201
238
|
)
|
|
202
239
|
|
|
203
|
-
def one_iteration(carry
|
|
240
|
+
def one_iteration(carry):
|
|
204
241
|
(
|
|
205
|
-
|
|
206
|
-
last_non_nan_params,
|
|
207
|
-
data,
|
|
208
|
-
curr_seq,
|
|
209
|
-
seq2seq,
|
|
210
|
-
stored_params,
|
|
211
|
-
stored_loss_terms,
|
|
242
|
+
i,
|
|
212
243
|
loss,
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
244
|
+
optimization,
|
|
245
|
+
optimization_extra,
|
|
246
|
+
train_data,
|
|
247
|
+
validation,
|
|
248
|
+
loss_container,
|
|
249
|
+
stored_objects,
|
|
250
|
+
validation_crit_values,
|
|
217
251
|
) = carry
|
|
218
|
-
batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
219
252
|
|
|
253
|
+
batch, data, param_data, obs_data = get_batch(
|
|
254
|
+
train_data.data, train_data.param_data, train_data.obs_data
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Gradient step
|
|
220
258
|
(
|
|
221
259
|
loss,
|
|
222
|
-
|
|
260
|
+
train_loss_value,
|
|
223
261
|
loss_terms,
|
|
224
262
|
params,
|
|
225
263
|
opt_state,
|
|
226
264
|
last_non_nan_params,
|
|
227
265
|
) = gradient_step(
|
|
228
|
-
loss,
|
|
266
|
+
loss,
|
|
267
|
+
optimizer,
|
|
268
|
+
batch,
|
|
269
|
+
optimization.params,
|
|
270
|
+
optimization.opt_state,
|
|
271
|
+
optimization.last_non_nan_params,
|
|
229
272
|
)
|
|
230
273
|
|
|
231
|
-
# Print loss during optimization
|
|
232
|
-
print_fn(i,
|
|
274
|
+
# Print train loss value during optimization
|
|
275
|
+
print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
276
|
+
|
|
277
|
+
if validation is not None:
|
|
278
|
+
# there is a jax.lax.cond because we do not necesarily call the
|
|
279
|
+
# validation step every iteration
|
|
280
|
+
(
|
|
281
|
+
validation, # always return `validation` for in-place mutation
|
|
282
|
+
early_stopping,
|
|
283
|
+
validation_criterion,
|
|
284
|
+
) = jax.lax.cond(
|
|
285
|
+
i % validation.call_every == 0,
|
|
286
|
+
lambda operands: operands[0](*operands[1:]), # validation.__call__()
|
|
287
|
+
lambda operands: (
|
|
288
|
+
operands[0],
|
|
289
|
+
False,
|
|
290
|
+
validation_crit_values[i - 1],
|
|
291
|
+
),
|
|
292
|
+
(
|
|
293
|
+
validation, # validation must be in operands
|
|
294
|
+
params,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
297
|
+
# Print validation loss value during optimization
|
|
298
|
+
print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
|
|
299
|
+
validation_crit_values = validation_crit_values.at[i].set(
|
|
300
|
+
validation_criterion
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
early_stopping = False
|
|
233
304
|
|
|
234
305
|
# Trigger RAR
|
|
235
306
|
loss, params, data = trigger_rar(
|
|
@@ -238,84 +309,98 @@ def solve(
|
|
|
238
309
|
|
|
239
310
|
# Trigger seq2seq
|
|
240
311
|
loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
|
|
241
|
-
i,
|
|
312
|
+
i,
|
|
313
|
+
loss,
|
|
314
|
+
params,
|
|
315
|
+
data,
|
|
316
|
+
opt_state,
|
|
317
|
+
optimization_extra.curr_seq,
|
|
318
|
+
optimization_extra.seq2seq,
|
|
242
319
|
)
|
|
243
320
|
|
|
244
321
|
# save loss value and selected parameters
|
|
245
|
-
stored_params, stored_loss_terms,
|
|
322
|
+
stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
|
|
246
323
|
i,
|
|
247
324
|
params,
|
|
248
|
-
stored_params,
|
|
249
|
-
stored_loss_terms,
|
|
250
|
-
|
|
251
|
-
|
|
325
|
+
stored_objects.stored_params,
|
|
326
|
+
loss_container.stored_loss_terms,
|
|
327
|
+
loss_container.train_loss_values,
|
|
328
|
+
train_loss_value,
|
|
252
329
|
loss_terms,
|
|
253
330
|
tracked_params,
|
|
254
331
|
)
|
|
332
|
+
i += 1
|
|
333
|
+
|
|
255
334
|
return (
|
|
256
|
-
|
|
257
|
-
last_non_nan_params,
|
|
258
|
-
data,
|
|
259
|
-
curr_seq,
|
|
260
|
-
seq2seq,
|
|
261
|
-
stored_params,
|
|
262
|
-
stored_loss_terms,
|
|
335
|
+
i,
|
|
263
336
|
loss,
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
337
|
+
OptimizationContainer(params, last_non_nan_params, opt_state),
|
|
338
|
+
OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
|
|
339
|
+
DataGeneratorContainer(data, param_data, obs_data),
|
|
340
|
+
validation,
|
|
341
|
+
LossContainer(stored_loss_terms, train_loss_values),
|
|
342
|
+
StoredObjectContainer(stored_params),
|
|
343
|
+
validation_crit_values,
|
|
344
|
+
)
|
|
269
345
|
|
|
270
|
-
# Main optimization loop. We use the
|
|
271
|
-
# if no mixing devices. Otherwise we use the
|
|
346
|
+
# Main optimization loop. We use the LAX while loop (fully jitted) version
|
|
347
|
+
# if no mixing devices. Otherwise we use the standard while loop. Here devices only
|
|
272
348
|
# concern obs_batch, but it could lead to more complex scheme in the future
|
|
273
349
|
if obs_batch_sharding is not None:
|
|
274
|
-
|
|
275
|
-
carry
|
|
350
|
+
while break_fun(carry):
|
|
351
|
+
carry = one_iteration(carry)
|
|
276
352
|
else:
|
|
277
|
-
carry
|
|
278
|
-
scan_tqdm(n_iter)(one_iteration),
|
|
279
|
-
carry,
|
|
280
|
-
jnp.arange(n_iter),
|
|
281
|
-
)
|
|
353
|
+
carry = jax.lax.while_loop(break_fun, one_iteration, carry)
|
|
282
354
|
|
|
283
355
|
(
|
|
284
|
-
|
|
285
|
-
last_non_nan_params,
|
|
286
|
-
data,
|
|
287
|
-
curr_seq,
|
|
288
|
-
seq2seq,
|
|
289
|
-
stored_params,
|
|
290
|
-
stored_loss_terms,
|
|
356
|
+
i,
|
|
291
357
|
loss,
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
358
|
+
optimization,
|
|
359
|
+
optimization_extra,
|
|
360
|
+
train_data,
|
|
361
|
+
validation,
|
|
362
|
+
loss_container,
|
|
363
|
+
stored_objects,
|
|
364
|
+
validation_crit_values,
|
|
296
365
|
) = carry
|
|
297
366
|
|
|
298
367
|
jax.debug.print(
|
|
299
|
-
"
|
|
300
|
-
i=
|
|
301
|
-
|
|
368
|
+
"Final iteration {i}: train loss value = {train_loss_val}",
|
|
369
|
+
i=i,
|
|
370
|
+
train_loss_val=loss_container.train_loss_values[i - 1],
|
|
302
371
|
)
|
|
372
|
+
if validation is not None:
|
|
373
|
+
jax.debug.print(
|
|
374
|
+
"validation loss value = {validation_loss_val}",
|
|
375
|
+
validation_loss_val=validation_crit_values[i - 1],
|
|
376
|
+
)
|
|
303
377
|
|
|
378
|
+
if validation is None:
|
|
379
|
+
return (
|
|
380
|
+
optimization.last_non_nan_params,
|
|
381
|
+
loss_container.train_loss_values,
|
|
382
|
+
loss_container.stored_loss_terms,
|
|
383
|
+
train_data.data,
|
|
384
|
+
loss,
|
|
385
|
+
optimization.opt_state,
|
|
386
|
+
stored_objects.stored_params,
|
|
387
|
+
)
|
|
304
388
|
return (
|
|
305
|
-
last_non_nan_params,
|
|
306
|
-
|
|
307
|
-
stored_loss_terms,
|
|
308
|
-
data,
|
|
389
|
+
optimization.last_non_nan_params,
|
|
390
|
+
loss_container.train_loss_values,
|
|
391
|
+
loss_container.stored_loss_terms,
|
|
392
|
+
train_data.data,
|
|
309
393
|
loss,
|
|
310
|
-
opt_state,
|
|
311
|
-
stored_params,
|
|
394
|
+
optimization.opt_state,
|
|
395
|
+
stored_objects.stored_params,
|
|
396
|
+
validation_crit_values,
|
|
312
397
|
)
|
|
313
398
|
|
|
314
399
|
|
|
315
400
|
@partial(jit, static_argnames=["optimizer"])
|
|
316
401
|
def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
|
|
317
402
|
"""
|
|
318
|
-
|
|
403
|
+
optimizer cannot be jit-ted.
|
|
319
404
|
"""
|
|
320
405
|
value_grad_loss = jax.value_and_grad(loss, has_aux=True)
|
|
321
406
|
(loss_val, loss_terms), grads = value_grad_loss(params, batch)
|
|
@@ -340,14 +425,14 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
340
425
|
)
|
|
341
426
|
|
|
342
427
|
|
|
343
|
-
@jit
|
|
344
|
-
def print_fn(i, loss_val, print_loss_every):
|
|
428
|
+
@partial(jit, static_argnames=["prefix"])
|
|
429
|
+
def print_fn(i, loss_val, print_loss_every, prefix=""):
|
|
345
430
|
# note that if the following is not jitted in the main lor loop, it is
|
|
346
431
|
# super slow
|
|
347
432
|
_ = jax.lax.cond(
|
|
348
433
|
i % print_loss_every == 0,
|
|
349
434
|
lambda _: jax.debug.print(
|
|
350
|
-
"Iteration {i}: loss value = {loss_val}",
|
|
435
|
+
prefix + "Iteration {i}: loss value = {loss_val}",
|
|
351
436
|
i=i,
|
|
352
437
|
loss_val=loss_val,
|
|
353
438
|
),
|
|
@@ -362,8 +447,8 @@ def store_loss_and_params(
|
|
|
362
447
|
params,
|
|
363
448
|
stored_params,
|
|
364
449
|
stored_loss_terms,
|
|
365
|
-
|
|
366
|
-
|
|
450
|
+
train_loss_values,
|
|
451
|
+
train_loss_val,
|
|
367
452
|
loss_terms,
|
|
368
453
|
tracked_params,
|
|
369
454
|
):
|
|
@@ -384,8 +469,66 @@ def store_loss_and_params(
|
|
|
384
469
|
loss_terms,
|
|
385
470
|
)
|
|
386
471
|
|
|
387
|
-
|
|
388
|
-
return stored_params, stored_loss_terms,
|
|
472
|
+
train_loss_values = train_loss_values.at[i].set(train_loss_val)
|
|
473
|
+
return (stored_params, stored_loss_terms, train_loss_values)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def get_break_fun(n_iter):
|
|
477
|
+
"""
|
|
478
|
+
Wrapper to get the break_fun with appropriate `n_iter`
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
@jit
|
|
482
|
+
def break_fun(carry):
|
|
483
|
+
"""
|
|
484
|
+
Function to break from the main optimization loop
|
|
485
|
+
We check several conditions
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
def stop_while_loop(msg):
|
|
489
|
+
"""
|
|
490
|
+
Note that the message is wrapped in the jax.lax.cond because a
|
|
491
|
+
string is not a valid JAX type that can be fed into the operands
|
|
492
|
+
"""
|
|
493
|
+
jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
|
|
494
|
+
return False
|
|
495
|
+
|
|
496
|
+
def continue_while_loop(_):
|
|
497
|
+
return True
|
|
498
|
+
|
|
499
|
+
(i, _, optimization, optimization_extra, _, _, _, _, _) = carry
|
|
500
|
+
|
|
501
|
+
# Condition 1
|
|
502
|
+
bool_max_iter = jax.lax.cond(
|
|
503
|
+
i >= n_iter,
|
|
504
|
+
lambda _: stop_while_loop("max iteration is reached"),
|
|
505
|
+
continue_while_loop,
|
|
506
|
+
None,
|
|
507
|
+
)
|
|
508
|
+
# Condition 2
|
|
509
|
+
bool_nan_in_params = jax.lax.cond(
|
|
510
|
+
_check_nan_in_pytree(optimization.params),
|
|
511
|
+
lambda _: stop_while_loop(
|
|
512
|
+
"NaN values in parameters " "(returning last non NaN values)"
|
|
513
|
+
),
|
|
514
|
+
continue_while_loop,
|
|
515
|
+
None,
|
|
516
|
+
)
|
|
517
|
+
# Condition 3
|
|
518
|
+
bool_early_stopping = jax.lax.cond(
|
|
519
|
+
optimization_extra.early_stopping,
|
|
520
|
+
lambda _: stop_while_loop("early stopping"),
|
|
521
|
+
continue_while_loop,
|
|
522
|
+
_,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# stop when one of the cond to continue is False
|
|
526
|
+
return jax.tree_util.tree_reduce(
|
|
527
|
+
lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
|
|
528
|
+
(bool_max_iter, bool_nan_in_params, bool_early_stopping),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return break_fun
|
|
389
532
|
|
|
390
533
|
|
|
391
534
|
def get_get_batch(obs_batch_sharding):
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NamedTuples definition
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Union, NamedTuple
|
|
6
|
+
from jaxtyping import PyTree
|
|
7
|
+
from jax.typing import ArrayLike
|
|
8
|
+
import optax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jinns.loss._LossODE import LossODE, SystemLossODE
|
|
11
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
12
|
+
from jinns.data._DataGenerators import (
|
|
13
|
+
DataGeneratorODE,
|
|
14
|
+
CubicMeshPDEStatio,
|
|
15
|
+
CubicMeshPDENonStatio,
|
|
16
|
+
DataGeneratorParameter,
|
|
17
|
+
DataGeneratorObservations,
|
|
18
|
+
DataGeneratorObservationsMultiPINNs,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataGeneratorContainer(NamedTuple):
|
|
23
|
+
data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
|
|
24
|
+
param_data: Union[DataGeneratorParameter, None] = None
|
|
25
|
+
obs_data: Union[
|
|
26
|
+
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
27
|
+
] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ValidationContainer(NamedTuple):
|
|
31
|
+
loss: Union[
|
|
32
|
+
LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
|
|
33
|
+
]
|
|
34
|
+
data: DataGeneratorContainer
|
|
35
|
+
hyperparams: PyTree = None
|
|
36
|
+
loss_values: Union[ArrayLike, None] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OptimizationContainer(NamedTuple):
|
|
40
|
+
params: dict
|
|
41
|
+
last_non_nan_params: dict
|
|
42
|
+
opt_state: optax.OptState
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OptimizationExtraContainer(NamedTuple):
|
|
46
|
+
curr_seq: int
|
|
47
|
+
seq2seq: Union[dict, None]
|
|
48
|
+
early_stopping: bool = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class LossContainer(NamedTuple):
|
|
52
|
+
stored_loss_terms: dict
|
|
53
|
+
train_loss_values: ArrayLike
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class StoredObjectContainer(NamedTuple):
|
|
57
|
+
stored_params: Union[list, None]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._validation import AbstractValidationModule, ValidationLoss
|