jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -3,9 +3,9 @@ This modules implements the main `solve()` function of jinns which
3
3
  handles the optimization process
4
4
  """
5
5
 
6
+ import copy
7
+ from functools import partial
6
8
  import optax
7
- from tqdm import tqdm
8
- from jax_tqdm import scan_tqdm
9
9
  import jax
10
10
  from jax import jit
11
11
  import jax.numpy as jnp
@@ -19,8 +19,32 @@ from jinns.data._DataGenerators import (
19
19
  append_param_batch,
20
20
  append_obs_batch,
21
21
  )
22
+ from jinns.utils._containers import *
22
23
 
23
- from functools import partial
24
+
25
+ def check_batch_size(other_data, main_data, attr_name):
26
+ if (
27
+ (
28
+ isinstance(main_data, DataGeneratorODE)
29
+ and getattr(other_data, attr_name) != main_data.temporal_batch_size
30
+ )
31
+ or (
32
+ isinstance(main_data, CubicMeshPDEStatio)
33
+ and not isinstance(main_data, CubicMeshPDENonStatio)
34
+ and getattr(other_data, attr_name) != main_data.omega_batch_size
35
+ )
36
+ or (
37
+ isinstance(main_data, CubicMeshPDENonStatio)
38
+ and getattr(other_data, attr_name)
39
+ != main_data.omega_batch_size * main_data.temporal_batch_size
40
+ )
41
+ ):
42
+ raise ValueError(
43
+ "Optional other_data.param_batch_size must be"
44
+ " equal to main_data.temporal_batch_size or main_data.omega_batch_size or"
45
+ " the product of both dependeing on the type of the main"
46
+ " datagenerator"
47
+ )
24
48
 
25
49
 
26
50
  def solve(
@@ -35,6 +59,7 @@ def solve(
35
59
  tracked_params_key_list=None,
36
60
  param_data=None,
37
61
  obs_data=None,
62
+ validation=None,
38
63
  obs_batch_sharding=None,
39
64
  ):
40
65
  """
@@ -88,6 +113,20 @@ def solve(
88
113
  obs_data
89
114
  Default None. A DataGeneratorObservations object which can be used to
90
115
  sample minibatches of observations
116
+ validation
117
+ Default None. Otherwise, a callable ``eqx.Module`` which implements a
118
+ validation strategy. See documentation of :obj:`~jinns.validation.
119
+ _validation.AbstractValidationModule` for the general interface, and
120
+ :obj:`~jinns.validation._validation.ValidationLoss` for a practical
121
+ implementation of a vanilla validation stategy on a validation set of
122
+ collocation points.
123
+
124
+ **Note**: The ``__call__(self, params)`` method should have
125
+ the latter prescribed signature and return ``(validation [eqx.Module],
126
+ early_stop [bool], validation_criterion [Array])``. It is called every
127
+ ``validation.call_every`` iteration. Users are free to design any
128
+ validation strategy of their choice, and to decide on the early
129
+ stopping criterion.
91
130
  obs_batch_sharding
92
131
  Default None. An optional sharding object to constraint the obs_batch.
93
132
  Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
@@ -114,39 +153,14 @@ def solve(
114
153
  A dictionary. At each key an array of the values of the parameters
115
154
  given in tracked_params_key_list is stored
116
155
  """
117
- params = init_params
118
- last_non_nan_params = init_params.copy()
119
-
120
156
  if param_data is not None:
121
- if (
122
- (
123
- isinstance(data, DataGeneratorODE)
124
- and param_data.param_batch_size != data.temporal_batch_size
125
- and obs_data.obs_batch_size != data.temporal_batch_size
126
- )
127
- or (
128
- isinstance(data, CubicMeshPDEStatio)
129
- and not isinstance(data, CubicMeshPDENonStatio)
130
- and param_data.param_batch_size != data.omega_batch_size
131
- and obs_data.obs_batch_size != data.omega_batch_size
132
- )
133
- or (
134
- isinstance(data, CubicMeshPDENonStatio)
135
- and param_data.param_batch_size
136
- != data.omega_batch_size * data.temporal_batch_size
137
- and obs_data.obs_batch_size
138
- != data.omega_batch_size * data.temporal_batch_size
139
- )
140
- ):
141
- raise ValueError(
142
- "Optional param_data.param_batch_size must be"
143
- " equal to data.temporal_batch_size or data.omega_batch_size or"
144
- " the product of both dependeing on the type of the main"
145
- " datagenerator"
146
- )
157
+ check_batch_size(param_data, data, "param_batch_size")
158
+
159
+ if obs_data is not None:
160
+ check_batch_size(obs_data, data, "obs_batch_size")
147
161
 
148
162
  if opt_state is None:
149
- opt_state = optimizer.init(params)
163
+ opt_state = optimizer.init(init_params)
150
164
 
151
165
  # RAR sampling init (ouside scanned function to avoid dynamic slice error)
152
166
  # If RAR is not used the _rar_step_*() are juste None and data is unchanged
@@ -160,7 +174,7 @@ def solve(
160
174
  ), "data.method must be uniform if using seq2seq learning !"
161
175
  data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
162
176
 
163
- total_loss_values = jnp.zeros((n_iter))
177
+ train_loss_values = jnp.zeros((n_iter))
164
178
  # depending on obs_batch_sharding we will get the simple get_batch or the
165
179
  # get_batch with device_put, the latter is not jittable
166
180
  get_batch = get_get_batch(obs_batch_sharding)
@@ -168,68 +182,125 @@ def solve(
168
182
  # initialize the dict for stored parameter values
169
183
  # we need to get a loss_term to init stuff
170
184
  batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
171
- _, loss_terms = loss(params, batch_ini)
185
+ _, loss_terms = loss(init_params, batch_ini)
172
186
  if tracked_params_key_list is None:
173
187
  tracked_params_key_list = []
174
- tracked_params = _tracked_parameters(params, tracked_params_key_list)
188
+ tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
175
189
  stored_params = jax.tree_util.tree_map(
176
190
  lambda tracked_param, param: (
177
191
  jnp.zeros((n_iter,) + param.shape) if tracked_param else None
178
192
  ),
179
193
  tracked_params,
180
- params,
194
+ init_params,
181
195
  )
182
196
 
183
197
  # initialize the dict for stored loss values
184
198
  stored_loss_terms = jax.tree_util.tree_map(
185
- lambda x: jnp.zeros((n_iter)), loss_terms
199
+ lambda _: jnp.zeros((n_iter)), loss_terms
200
+ )
201
+
202
+ train_data = DataGeneratorContainer(
203
+ data=data, param_data=param_data, obs_data=obs_data
204
+ )
205
+ optimization = OptimizationContainer(
206
+ params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
207
+ )
208
+ optimization_extra = OptimizationExtraContainer(
209
+ curr_seq=curr_seq,
210
+ seq2seq=seq2seq,
186
211
  )
212
+ loss_container = LossContainer(
213
+ stored_loss_terms=stored_loss_terms,
214
+ train_loss_values=train_loss_values,
215
+ )
216
+ stored_objects = StoredObjectContainer(
217
+ stored_params=stored_params,
218
+ )
219
+
220
+ if validation is not None:
221
+ validation_crit_values = jnp.zeros(n_iter)
222
+ else:
223
+ validation_crit_values = None
224
+
225
+ break_fun = get_break_fun(n_iter)
187
226
 
227
+ iteration = 0
188
228
  carry = (
189
- init_params,
190
- init_params.copy(),
191
- data,
192
- curr_seq,
193
- seq2seq,
194
- stored_params,
195
- stored_loss_terms,
229
+ iteration,
196
230
  loss,
197
- param_data,
198
- obs_data,
199
- opt_state,
200
- total_loss_values,
231
+ optimization,
232
+ optimization_extra,
233
+ train_data,
234
+ validation,
235
+ loss_container,
236
+ stored_objects,
237
+ validation_crit_values,
201
238
  )
202
239
 
203
- def one_iteration(carry, i):
240
+ def one_iteration(carry):
204
241
  (
205
- params,
206
- last_non_nan_params,
207
- data,
208
- curr_seq,
209
- seq2seq,
210
- stored_params,
211
- stored_loss_terms,
242
+ i,
212
243
  loss,
213
- param_data,
214
- obs_data,
215
- opt_state,
216
- total_loss_values,
244
+ optimization,
245
+ optimization_extra,
246
+ train_data,
247
+ validation,
248
+ loss_container,
249
+ stored_objects,
250
+ validation_crit_values,
217
251
  ) = carry
218
- batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
219
252
 
253
+ batch, data, param_data, obs_data = get_batch(
254
+ train_data.data, train_data.param_data, train_data.obs_data
255
+ )
256
+
257
+ # Gradient step
220
258
  (
221
259
  loss,
222
- loss_val,
260
+ train_loss_value,
223
261
  loss_terms,
224
262
  params,
225
263
  opt_state,
226
264
  last_non_nan_params,
227
265
  ) = gradient_step(
228
- loss, optimizer, batch, params, opt_state, last_non_nan_params
266
+ loss,
267
+ optimizer,
268
+ batch,
269
+ optimization.params,
270
+ optimization.opt_state,
271
+ optimization.last_non_nan_params,
229
272
  )
230
273
 
231
- # Print loss during optimization
232
- print_fn(i, loss_val, print_loss_every)
274
+ # Print train loss value during optimization
275
+ print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
276
+
277
+ if validation is not None:
278
+ # there is a jax.lax.cond because we do not necesarily call the
279
+ # validation step every iteration
280
+ (
281
+ validation, # always return `validation` for in-place mutation
282
+ early_stopping,
283
+ validation_criterion,
284
+ ) = jax.lax.cond(
285
+ i % validation.call_every == 0,
286
+ lambda operands: operands[0](*operands[1:]), # validation.__call__()
287
+ lambda operands: (
288
+ operands[0],
289
+ False,
290
+ validation_crit_values[i - 1],
291
+ ),
292
+ (
293
+ validation, # validation must be in operands
294
+ params,
295
+ ),
296
+ )
297
+ # Print validation loss value during optimization
298
+ print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
299
+ validation_crit_values = validation_crit_values.at[i].set(
300
+ validation_criterion
301
+ )
302
+ else:
303
+ early_stopping = False
233
304
 
234
305
  # Trigger RAR
235
306
  loss, params, data = trigger_rar(
@@ -238,84 +309,98 @@ def solve(
238
309
 
239
310
  # Trigger seq2seq
240
311
  loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
241
- i, loss, params, data, opt_state, curr_seq, seq2seq
312
+ i,
313
+ loss,
314
+ params,
315
+ data,
316
+ opt_state,
317
+ optimization_extra.curr_seq,
318
+ optimization_extra.seq2seq,
242
319
  )
243
320
 
244
321
  # save loss value and selected parameters
245
- stored_params, stored_loss_terms, total_loss_values = store_loss_and_params(
322
+ stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
246
323
  i,
247
324
  params,
248
- stored_params,
249
- stored_loss_terms,
250
- total_loss_values,
251
- loss_val,
325
+ stored_objects.stored_params,
326
+ loss_container.stored_loss_terms,
327
+ loss_container.train_loss_values,
328
+ train_loss_value,
252
329
  loss_terms,
253
330
  tracked_params,
254
331
  )
332
+ i += 1
333
+
255
334
  return (
256
- params,
257
- last_non_nan_params,
258
- data,
259
- curr_seq,
260
- seq2seq,
261
- stored_params,
262
- stored_loss_terms,
335
+ i,
263
336
  loss,
264
- param_data,
265
- obs_data,
266
- opt_state,
267
- total_loss_values,
268
- ), None
337
+ OptimizationContainer(params, last_non_nan_params, opt_state),
338
+ OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
339
+ DataGeneratorContainer(data, param_data, obs_data),
340
+ validation,
341
+ LossContainer(stored_loss_terms, train_loss_values),
342
+ StoredObjectContainer(stored_params),
343
+ validation_crit_values,
344
+ )
269
345
 
270
- # Main optimization loop. We use the fully scanned (fully jitted) version
271
- # if no mixing devices. Otherwise we use the for loop. Here devices only
346
+ # Main optimization loop. We use the LAX while loop (fully jitted) version
347
+ # if no mixing devices. Otherwise we use the standard while loop. Here devices only
272
348
  # concern obs_batch, but it could lead to more complex scheme in the future
273
349
  if obs_batch_sharding is not None:
274
- for i in tqdm(range(n_iter)):
275
- carry, _ = one_iteration(carry, i)
350
+ while break_fun(carry):
351
+ carry = one_iteration(carry)
276
352
  else:
277
- carry, _ = jax.lax.scan(
278
- scan_tqdm(n_iter)(one_iteration),
279
- carry,
280
- jnp.arange(n_iter),
281
- )
353
+ carry = jax.lax.while_loop(break_fun, one_iteration, carry)
282
354
 
283
355
  (
284
- init_params,
285
- last_non_nan_params,
286
- data,
287
- curr_seq,
288
- seq2seq,
289
- stored_params,
290
- stored_loss_terms,
356
+ i,
291
357
  loss,
292
- param_data,
293
- obs_data,
294
- opt_state,
295
- total_loss_values,
358
+ optimization,
359
+ optimization_extra,
360
+ train_data,
361
+ validation,
362
+ loss_container,
363
+ stored_objects,
364
+ validation_crit_values,
296
365
  ) = carry
297
366
 
298
367
  jax.debug.print(
299
- "Iteration {i}: loss value = {total_loss_val}",
300
- i=n_iter,
301
- total_loss_val=total_loss_values[-1],
368
+ "Final iteration {i}: train loss value = {train_loss_val}",
369
+ i=i,
370
+ train_loss_val=loss_container.train_loss_values[i - 1],
302
371
  )
372
+ if validation is not None:
373
+ jax.debug.print(
374
+ "validation loss value = {validation_loss_val}",
375
+ validation_loss_val=validation_crit_values[i - 1],
376
+ )
303
377
 
378
+ if validation is None:
379
+ return (
380
+ optimization.last_non_nan_params,
381
+ loss_container.train_loss_values,
382
+ loss_container.stored_loss_terms,
383
+ train_data.data,
384
+ loss,
385
+ optimization.opt_state,
386
+ stored_objects.stored_params,
387
+ )
304
388
  return (
305
- last_non_nan_params,
306
- total_loss_values,
307
- stored_loss_terms,
308
- data,
389
+ optimization.last_non_nan_params,
390
+ loss_container.train_loss_values,
391
+ loss_container.stored_loss_terms,
392
+ train_data.data,
309
393
  loss,
310
- opt_state,
311
- stored_params,
394
+ optimization.opt_state,
395
+ stored_objects.stored_params,
396
+ validation_crit_values,
312
397
  )
313
398
 
314
399
 
315
400
  @partial(jit, static_argnames=["optimizer"])
316
401
  def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
317
402
  """
318
- loss and optimizer cannot be jit-ted.
403
+ optimizer cannot be jit-ted.
319
404
  """
320
405
  value_grad_loss = jax.value_and_grad(loss, has_aux=True)
321
406
  (loss_val, loss_terms), grads = value_grad_loss(params, batch)
@@ -340,14 +425,14 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
340
425
  )
341
426
 
342
427
 
343
- @jit
344
- def print_fn(i, loss_val, print_loss_every):
428
+ @partial(jit, static_argnames=["prefix"])
429
+ def print_fn(i, loss_val, print_loss_every, prefix=""):
345
430
  # note that if the following is not jitted in the main lor loop, it is
346
431
  # super slow
347
432
  _ = jax.lax.cond(
348
433
  i % print_loss_every == 0,
349
434
  lambda _: jax.debug.print(
350
- "Iteration {i}: loss value = {loss_val}",
435
+ prefix + "Iteration {i}: loss value = {loss_val}",
351
436
  i=i,
352
437
  loss_val=loss_val,
353
438
  ),
@@ -362,8 +447,8 @@ def store_loss_and_params(
362
447
  params,
363
448
  stored_params,
364
449
  stored_loss_terms,
365
- total_loss_values,
366
- loss_val,
450
+ train_loss_values,
451
+ train_loss_val,
367
452
  loss_terms,
368
453
  tracked_params,
369
454
  ):
@@ -384,8 +469,66 @@ def store_loss_and_params(
384
469
  loss_terms,
385
470
  )
386
471
 
387
- total_loss_values = total_loss_values.at[i].set(loss_val)
388
- return stored_params, stored_loss_terms, total_loss_values
472
+ train_loss_values = train_loss_values.at[i].set(train_loss_val)
473
+ return (stored_params, stored_loss_terms, train_loss_values)
474
+
475
+
476
+ def get_break_fun(n_iter):
477
+ """
478
+ Wrapper to get the break_fun with appropriate `n_iter`
479
+ """
480
+
481
+ @jit
482
+ def break_fun(carry):
483
+ """
484
+ Function to break from the main optimization loop
485
+ We check several conditions
486
+ """
487
+
488
+ def stop_while_loop(msg):
489
+ """
490
+ Note that the message is wrapped in the jax.lax.cond because a
491
+ string is not a valid JAX type that can be fed into the operands
492
+ """
493
+ jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
494
+ return False
495
+
496
+ def continue_while_loop(_):
497
+ return True
498
+
499
+ (i, _, optimization, optimization_extra, _, _, _, _, _) = carry
500
+
501
+ # Condition 1
502
+ bool_max_iter = jax.lax.cond(
503
+ i >= n_iter,
504
+ lambda _: stop_while_loop("max iteration is reached"),
505
+ continue_while_loop,
506
+ None,
507
+ )
508
+ # Condition 2
509
+ bool_nan_in_params = jax.lax.cond(
510
+ _check_nan_in_pytree(optimization.params),
511
+ lambda _: stop_while_loop(
512
+ "NaN values in parameters " "(returning last non NaN values)"
513
+ ),
514
+ continue_while_loop,
515
+ None,
516
+ )
517
+ # Condition 3
518
+ bool_early_stopping = jax.lax.cond(
519
+ optimization_extra.early_stopping,
520
+ lambda _: stop_while_loop("early stopping"),
521
+ continue_while_loop,
522
+ _,
523
+ )
524
+
525
+ # stop when one of the cond to continue is False
526
+ return jax.tree_util.tree_reduce(
527
+ lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
528
+ (bool_max_iter, bool_nan_in_params, bool_early_stopping),
529
+ )
530
+
531
+ return break_fun
389
532
 
390
533
 
391
534
  def get_get_batch(obs_batch_sharding):
@@ -0,0 +1,57 @@
1
+ """
2
+ NamedTuples definition
3
+ """
4
+
5
+ from typing import Union, NamedTuple
6
+ from jaxtyping import PyTree
7
+ from jax.typing import ArrayLike
8
+ import optax
9
+ import jax.numpy as jnp
10
+ from jinns.loss._LossODE import LossODE, SystemLossODE
11
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
12
+ from jinns.data._DataGenerators import (
13
+ DataGeneratorODE,
14
+ CubicMeshPDEStatio,
15
+ CubicMeshPDENonStatio,
16
+ DataGeneratorParameter,
17
+ DataGeneratorObservations,
18
+ DataGeneratorObservationsMultiPINNs,
19
+ )
20
+
21
+
22
+ class DataGeneratorContainer(NamedTuple):
23
+ data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
24
+ param_data: Union[DataGeneratorParameter, None] = None
25
+ obs_data: Union[
26
+ DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
27
+ ] = None
28
+
29
+
30
+ class ValidationContainer(NamedTuple):
31
+ loss: Union[
32
+ LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
33
+ ]
34
+ data: DataGeneratorContainer
35
+ hyperparams: PyTree = None
36
+ loss_values: Union[ArrayLike, None] = None
37
+
38
+
39
+ class OptimizationContainer(NamedTuple):
40
+ params: dict
41
+ last_non_nan_params: dict
42
+ opt_state: optax.OptState
43
+
44
+
45
+ class OptimizationExtraContainer(NamedTuple):
46
+ curr_seq: int
47
+ seq2seq: Union[dict, None]
48
+ early_stopping: bool = False
49
+
50
+
51
+ class LossContainer(NamedTuple):
52
+ stored_loss_terms: dict
53
+ train_loss_values: ArrayLike
54
+
55
+
56
+ class StoredObjectContainer(NamedTuple):
57
+ stored_params: Union[list, None]
@@ -0,0 +1 @@
1
+ from ._validation import AbstractValidationModule, ValidationLoss