jinns 0.8.10__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +953 -1182
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +290 -307
  10. jinns/loss/_LossPDE.py +628 -1040
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +95 -96
  13. jinns/loss/{_Losses.py → _loss_utils.py} +104 -46
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +193 -45
  22. jinns/solver/_solve.py +199 -144
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -43
  25. jinns/utils/_hyperpinn.py +226 -127
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +117 -84
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +52 -144
  32. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/METADATA +5 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.8.10.dist-info/RECORD +0 -36
  41. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.8.10.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py CHANGED
@@ -3,26 +3,33 @@ This modules implements the main `solve()` function of jinns which
3
3
  handles the optimization process
4
4
  """
5
5
 
6
- import copy
6
+ from __future__ import (
7
+ annotations,
8
+ ) # https://docs.python.org/3/library/typing.html#constant
9
+
10
+ from typing import TYPE_CHECKING, NamedTuple, Dict, Union
7
11
  from functools import partial
8
12
  import optax
9
13
  import jax
10
14
  from jax import jit
11
15
  import jax.numpy as jnp
12
- from jinns.solver._seq2seq import trigger_seq2seq, initialize_seq2seq
16
+ from jaxtyping import Int, Bool, Float, Array
13
17
  from jinns.solver._rar import init_rar, trigger_rar
14
- from jinns.utils._utils import _check_nan_in_pytree, _tracked_parameters
18
+ from jinns.utils._utils import _check_nan_in_pytree
19
+ from jinns.utils._containers import *
15
20
  from jinns.data._DataGenerators import (
16
21
  DataGeneratorODE,
17
22
  CubicMeshPDEStatio,
18
23
  CubicMeshPDENonStatio,
19
- append_param_batch,
20
24
  append_obs_batch,
25
+ append_param_batch,
21
26
  )
22
- from jinns.utils._containers import *
23
27
 
28
+ if TYPE_CHECKING:
29
+ from jinns.utils._types import *
24
30
 
25
- def check_batch_size(other_data, main_data, attr_name):
31
+
32
+ def _check_batch_size(other_data, main_data, attr_name):
26
33
  if (
27
34
  (
28
35
  isinstance(main_data, DataGeneratorODE)
@@ -48,20 +55,32 @@ def check_batch_size(other_data, main_data, attr_name):
48
55
 
49
56
 
50
57
  def solve(
51
- n_iter,
52
- init_params,
53
- data,
54
- loss,
55
- optimizer,
56
- print_loss_every=1000,
57
- opt_state=None,
58
- seq2seq=None,
59
- tracked_params_key_list=None,
60
- param_data=None,
61
- obs_data=None,
62
- validation=None,
63
- obs_batch_sharding=None,
64
- ):
58
+ n_iter: Int,
59
+ init_params: AnyParams,
60
+ data: AnyDataGenerator,
61
+ loss: AnyLoss,
62
+ optimizer: optax.GradientTransformation,
63
+ print_loss_every: Int = 1000,
64
+ opt_state: Union[NamedTuple, None] = None,
65
+ tracked_params: Params | ParamsDict | None = None,
66
+ param_data: DataGeneratorParameter | None = None,
67
+ obs_data: (
68
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
69
+ ) = None,
70
+ validation: AbstractValidationModule | None = None,
71
+ obs_batch_sharding: jax.sharding.Sharding | None = None,
72
+ verbose: Bool = True,
73
+ ) -> tuple[
74
+ Params | ParamsDict,
75
+ Float[Array, "n_iter"],
76
+ Dict[str, Float[Array, "n_iter"]],
77
+ AnyDataGenerator,
78
+ AnyLoss,
79
+ NamedTuple,
80
+ AnyParams,
81
+ Float[Array, "n_iter"],
82
+ AnyParams,
83
+ ]:
65
84
  """
66
85
  Performs the optimization process via stochastic gradient descent
67
86
  algorithm. We minimize the function defined `loss.evaluate()` with
@@ -72,52 +91,39 @@ def solve(
72
91
  Parameters
73
92
  ----------
74
93
  n_iter
75
- The number of iterations in the optimization
94
+ The maximum number of iterations in the optimization.
76
95
  init_params
77
- The initial dictionary of parameters. Typically, it is a dictionary of
78
- dictionaries: `eq_params` and `nn_params``, respectively the
79
- differential equation parameters and the neural network parameter
96
+ The initial jinns.parameters.Params object.
80
97
  data
81
- A DataGenerator object which implements a `get_batch()`
82
- method which returns a 3-tuple with (omega_grid, omega_border, time grid).
83
- It must be jittable (e.g. implements via a pytree
84
- registration)
98
+ A DataGenerator object to retrieve batches of collocation points.
85
99
  loss
86
- A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
87
- object). It must be jittable (e.g. implements via a pytree
88
- registration)
100
+ The loss function to minimize.
89
101
  optimizer
90
- An `optax` optimizer (e.g. `optax.adam`).
102
+ An optax optimizer.
91
103
  print_loss_every
92
- Integer. Default 100. The rate at which we print the loss value in the
104
+ Default 1000. The rate at which we print the loss value in the
93
105
  gradient step loop.
94
106
  opt_state
95
- Default None. Provide an optional initial optional state to the
96
- optimizer. Not valid for all optimizers.
97
- seq2seq
98
- Default None. A dictionary with keys 'times_steps'
99
- and 'iter_steps' which mush have same length. The first represents
100
- the time steps which represents the different time interval upon
101
- which we perform the incremental learning. The second represents
102
- the number of iteration we perform in each time interval.
103
- The seq2seq approach we reimplements is defined in
104
- "Characterizing possible failure modes in physics-informed neural
105
- networks", A. S. Krishnapriyan, NeurIPS 2021
106
- tracked_params_key_list
107
- Default None. Otherwise it is a list of list of strings
108
- to access a leaf in params. Each selected leaf will be tracked
109
- and stored at each iteration and returned by the solve function
107
+ Provide an optional initial state to the optimizer.
108
+ tracked_params
109
+ Default None. An eqx.Module of type Params with non-None values for
110
+ parameters that needs to be tracked along the iterations.
111
+ None values in tracked_params will not be traversed. Thus
112
+ the user can provide something like `tracked_params = jinns.parameters.Params(
113
+ nn_params=None, eq_params={"nu": True})` while init_params.nn_params
114
+ being a complex data structure.
110
115
  param_data
111
116
  Default None. A DataGeneratorParameter object which can be used to
112
117
  sample equation parameters.
113
118
  obs_data
114
- Default None. A DataGeneratorObservations object which can be used to
115
- sample minibatches of observations
119
+ Default None. A DataGeneratorObservations or
120
+ DataGeneratorObservationsMultiPINNs
121
+ object which can be used to sample minibatches of observations.
116
122
  validation
117
123
  Default None. Otherwise, a callable ``eqx.Module`` which implements a
118
- validation strategy. See documentation of :obj:`~jinns.validation.
124
+ validation strategy. See documentation of `jinns.validation.
119
125
  _validation.AbstractValidationModule` for the general interface, and
120
- :obj:`~jinns.validation._validation.ValidationLoss` for a practical
126
+ `jinns.validation._validation.ValidationLoss` for a practical
121
127
  implementation of a vanilla validation stategy on a validation set of
122
128
  collocation points.
123
129
 
@@ -131,12 +137,15 @@ def solve(
131
137
  Default None. An optional sharding object to constraint the obs_batch.
132
138
  Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
133
139
  created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
134
- loading on GPU huge datasets of observations
140
+ loading on GPU huge datasets of observations.
141
+ verbose
142
+ Default True. If False, no std output (loss or cause of
143
+ exiting the optimization loop) will be produced.
135
144
 
136
145
  Returns
137
146
  -------
138
147
  params
139
- The last non NaN value of the dictionaries of parameters at then end of the
148
+ The last non NaN value of the params at then end of the
140
149
  optimization process
141
150
  total_loss_values
142
151
  An array of the total loss term along the gradient steps
@@ -150,14 +159,18 @@ def solve(
150
159
  opt_state
151
160
  The final optimized state
152
161
  stored_params
153
- A dictionary. At each key an array of the values of the parameters
154
- given in tracked_params_key_list is stored
162
+ A Params objects with the stored values of the desired parameters (as
163
+ signified in tracked_params argument)
164
+ validation_crit_values
165
+ An array containing the validation criterion values of the training
166
+ best_val_params
167
+ The best parameters according to the validation criterion
155
168
  """
156
169
  if param_data is not None:
157
- check_batch_size(param_data, data, "param_batch_size")
170
+ _check_batch_size(param_data, data, "param_batch_size")
158
171
 
159
172
  if obs_data is not None:
160
- check_batch_size(obs_data, data, "obs_batch_size")
173
+ _check_batch_size(obs_data, data, "obs_batch_size")
161
174
 
162
175
  if opt_state is None:
163
176
  opt_state = optimizer.init(init_params)
@@ -168,30 +181,32 @@ def solve(
168
181
 
169
182
  # Seq2seq
170
183
  curr_seq = 0
171
- if seq2seq is not None:
172
- assert (
173
- data.method == "uniform"
174
- ), "data.method must be uniform if using seq2seq learning !"
175
- data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
176
184
 
177
185
  train_loss_values = jnp.zeros((n_iter))
178
186
  # depending on obs_batch_sharding we will get the simple get_batch or the
179
187
  # get_batch with device_put, the latter is not jittable
180
- get_batch = get_get_batch(obs_batch_sharding)
188
+ get_batch = _get_get_batch(obs_batch_sharding)
181
189
 
182
190
  # initialize the dict for stored parameter values
183
191
  # we need to get a loss_term to init stuff
184
192
  batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
185
193
  _, loss_terms = loss(init_params, batch_ini)
186
- if tracked_params_key_list is None:
187
- tracked_params_key_list = []
188
- tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
194
+
195
+ # initialize parameter tracking
196
+ if tracked_params is None:
197
+ tracked_params = jax.tree.map(lambda p: None, init_params)
189
198
  stored_params = jax.tree_util.tree_map(
190
199
  lambda tracked_param, param: (
191
- jnp.zeros((n_iter,) + param.shape) if tracked_param else None
200
+ jnp.zeros((n_iter,) + jnp.asarray(param).shape)
201
+ if tracked_param is not None
202
+ else None
192
203
  ),
193
204
  tracked_params,
194
205
  init_params,
206
+ is_leaf=lambda x: x is None, # None values in tracked_params will not
207
+ # be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
208
+ # nn_params=None, eq_params={"nu": True})` while init_params.nn_params
209
+ # being a complex data structure
195
210
  )
196
211
 
197
212
  # initialize the dict for stored loss values
@@ -203,11 +218,13 @@ def solve(
203
218
  data=data, param_data=param_data, obs_data=obs_data
204
219
  )
205
220
  optimization = OptimizationContainer(
206
- params=init_params, last_non_nan_params=init_params.copy(), opt_state=opt_state
221
+ params=init_params,
222
+ last_non_nan_params=init_params,
223
+ opt_state=opt_state,
207
224
  )
208
225
  optimization_extra = OptimizationExtraContainer(
209
226
  curr_seq=curr_seq,
210
- seq2seq=seq2seq,
227
+ best_val_params=init_params,
211
228
  )
212
229
  loss_container = LossContainer(
213
230
  stored_loss_terms=stored_loss_terms,
@@ -222,7 +239,7 @@ def solve(
222
239
  else:
223
240
  validation_crit_values = None
224
241
 
225
- break_fun = get_break_fun(n_iter)
242
+ break_fun = _get_break_fun(n_iter, verbose)
226
243
 
227
244
  iteration = 0
228
245
  carry = (
@@ -237,7 +254,7 @@ def solve(
237
254
  validation_crit_values,
238
255
  )
239
256
 
240
- def one_iteration(carry):
257
+ def _one_iteration(carry: main_carry) -> main_carry:
241
258
  (
242
259
  i,
243
260
  loss,
@@ -262,7 +279,7 @@ def solve(
262
279
  params,
263
280
  opt_state,
264
281
  last_non_nan_params,
265
- ) = gradient_step(
282
+ ) = _gradient_step(
266
283
  loss,
267
284
  optimizer,
268
285
  batch,
@@ -272,7 +289,8 @@ def solve(
272
289
  )
273
290
 
274
291
  # Print train loss value during optimization
275
- print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
292
+ if verbose:
293
+ _print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
276
294
 
277
295
  if validation is not None:
278
296
  # there is a jax.lax.cond because we do not necesarily call the
@@ -281,6 +299,7 @@ def solve(
281
299
  validation, # always return `validation` for in-place mutation
282
300
  early_stopping,
283
301
  validation_criterion,
302
+ update_best_params,
284
303
  ) = jax.lax.cond(
285
304
  i % validation.call_every == 0,
286
305
  lambda operands: operands[0](*operands[1:]), # validation.__call__()
@@ -288,6 +307,7 @@ def solve(
288
307
  operands[0],
289
308
  False,
290
309
  validation_crit_values[i - 1],
310
+ False,
291
311
  ),
292
312
  (
293
313
  validation, # validation must be in operands
@@ -295,31 +315,32 @@ def solve(
295
315
  ),
296
316
  )
297
317
  # Print validation loss value during optimization
298
- print_fn(i, validation_criterion, print_loss_every, prefix="[validation] ")
318
+ if verbose:
319
+ _print_fn(
320
+ i, validation_criterion, print_loss_every, prefix="[validation] "
321
+ )
299
322
  validation_crit_values = validation_crit_values.at[i].set(
300
323
  validation_criterion
301
324
  )
325
+
326
+ # update best_val_params w.r.t val_loss if needed
327
+ best_val_params = jax.lax.cond(
328
+ update_best_params,
329
+ lambda _: params, # update with current value
330
+ lambda operands: operands[0].best_val_params, # unchanged
331
+ (optimization_extra,),
332
+ )
302
333
  else:
303
334
  early_stopping = False
335
+ best_val_params = params
304
336
 
305
337
  # Trigger RAR
306
338
  loss, params, data = trigger_rar(
307
339
  i, loss, params, data, _rar_step_true, _rar_step_false
308
340
  )
309
341
 
310
- # Trigger seq2seq
311
- loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
312
- i,
313
- loss,
314
- params,
315
- data,
316
- opt_state,
317
- optimization_extra.curr_seq,
318
- optimization_extra.seq2seq,
319
- )
320
-
321
342
  # save loss value and selected parameters
322
- stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
343
+ stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
323
344
  i,
324
345
  params,
325
346
  stored_objects.stored_params,
@@ -329,13 +350,15 @@ def solve(
329
350
  loss_terms,
330
351
  tracked_params,
331
352
  )
353
+
354
+ # increment iteration number
332
355
  i += 1
333
356
 
334
357
  return (
335
358
  i,
336
359
  loss,
337
360
  OptimizationContainer(params, last_non_nan_params, opt_state),
338
- OptimizationExtraContainer(curr_seq, seq2seq, early_stopping),
361
+ OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
339
362
  DataGeneratorContainer(data, param_data, obs_data),
340
363
  validation,
341
364
  LossContainer(stored_loss_terms, train_loss_values),
@@ -348,9 +371,9 @@ def solve(
348
371
  # concern obs_batch, but it could lead to more complex scheme in the future
349
372
  if obs_batch_sharding is not None:
350
373
  while break_fun(carry):
351
- carry = one_iteration(carry)
374
+ carry = _one_iteration(carry)
352
375
  else:
353
- carry = jax.lax.while_loop(break_fun, one_iteration, carry)
376
+ carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
354
377
 
355
378
  (
356
379
  i,
@@ -364,41 +387,47 @@ def solve(
364
387
  validation_crit_values,
365
388
  ) = carry
366
389
 
367
- jax.debug.print(
368
- "Final iteration {i}: train loss value = {train_loss_val}",
369
- i=i,
370
- train_loss_val=loss_container.train_loss_values[i - 1],
371
- )
390
+ if verbose:
391
+ jax.debug.print(
392
+ "Final iteration {i}: train loss value = {train_loss_val}",
393
+ i=i,
394
+ train_loss_val=loss_container.train_loss_values[i - 1],
395
+ )
372
396
  if validation is not None:
373
397
  jax.debug.print(
374
398
  "validation loss value = {validation_loss_val}",
375
399
  validation_loss_val=validation_crit_values[i - 1],
376
400
  )
377
401
 
378
- if validation is None:
379
- return (
380
- optimization.last_non_nan_params,
381
- loss_container.train_loss_values,
382
- loss_container.stored_loss_terms,
383
- train_data.data,
384
- loss,
385
- optimization.opt_state,
386
- stored_objects.stored_params,
387
- )
388
402
  return (
389
403
  optimization.last_non_nan_params,
390
404
  loss_container.train_loss_values,
391
405
  loss_container.stored_loss_terms,
392
- train_data.data,
393
- loss,
406
+ train_data.data, # return the DataGenerator if needed (no in-place modif)
407
+ loss, # return the Loss if needed (no-inplace modif)
394
408
  optimization.opt_state,
395
409
  stored_objects.stored_params,
396
- validation_crit_values,
410
+ validation_crit_values if validation is not None else None,
411
+ optimization_extra.best_val_params if validation is not None else None,
397
412
  )
398
413
 
399
414
 
400
415
  @partial(jit, static_argnames=["optimizer"])
401
- def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
416
+ def _gradient_step(
417
+ loss: AnyLoss,
418
+ optimizer: optax.GradientTransformation,
419
+ batch: AnyBatch,
420
+ params: AnyParams,
421
+ opt_state: NamedTuple,
422
+ last_non_nan_params: AnyParams,
423
+ ) -> tuple[
424
+ AnyLoss,
425
+ float,
426
+ Dict[str, float],
427
+ AnyParams,
428
+ NamedTuple,
429
+ AnyParams,
430
+ ]:
402
431
  """
403
432
  optimizer cannot be jit-ted.
404
433
  """
@@ -426,7 +455,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
426
455
 
427
456
 
428
457
  @partial(jit, static_argnames=["prefix"])
429
- def print_fn(i, loss_val, print_loss_every, prefix=""):
458
+ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
430
459
  # note that if the following is not jitted in the main lor loop, it is
431
460
  # super slow
432
461
  _ = jax.lax.cond(
@@ -442,26 +471,33 @@ def print_fn(i, loss_val, print_loss_every, prefix=""):
442
471
 
443
472
 
444
473
  @jit
445
- def store_loss_and_params(
446
- i,
447
- params,
448
- stored_params,
449
- stored_loss_terms,
450
- train_loss_values,
451
- train_loss_val,
452
- loss_terms,
453
- tracked_params,
454
- ):
474
+ def _store_loss_and_params(
475
+ i: Int,
476
+ params: AnyParams,
477
+ stored_params: AnyParams,
478
+ stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
479
+ train_loss_values: Float[Array, "n_iter"],
480
+ train_loss_val: float,
481
+ loss_terms: Dict[str, float],
482
+ tracked_params: AnyParams,
483
+ ) -> tuple[
484
+ Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
485
+ ]:
455
486
  stored_params = jax.tree_util.tree_map(
456
- lambda stored_value, param, tracked_param: jax.lax.cond(
457
- tracked_param,
458
- lambda ope: ope[0].at[i].set(ope[1]),
459
- lambda ope: ope[0],
460
- (stored_value, param),
487
+ lambda stored_value, param, tracked_param: (
488
+ None
489
+ if stored_value is None
490
+ else jax.lax.cond(
491
+ tracked_param,
492
+ lambda ope: ope[0].at[i].set(ope[1]),
493
+ lambda ope: ope[0],
494
+ (stored_value, param),
495
+ )
461
496
  ),
462
497
  stored_params,
463
498
  params,
464
499
  tracked_params,
500
+ is_leaf=lambda x: x is None,
465
501
  )
466
502
  stored_loss_terms = jax.tree_util.tree_map(
467
503
  lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
@@ -473,16 +509,20 @@ def store_loss_and_params(
473
509
  return (stored_params, stored_loss_terms, train_loss_values)
474
510
 
475
511
 
476
- def get_break_fun(n_iter):
512
+ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
477
513
  """
478
- Wrapper to get the break_fun with appropriate `n_iter`
514
+ Wrapper to get the break_fun with appropriate `n_iter`.
515
+ The verbose argument is here to control printing (or not) when exiting
516
+ the optimisation loop. It can be convenient is jinns.solve is itself
517
+ called in a loop and user want to avoid std output.
479
518
  """
480
519
 
481
520
  @jit
482
- def break_fun(carry):
521
+ def break_fun(carry: tuple):
483
522
  """
484
- Function to break from the main optimization loop
485
- We check several conditions
523
+ Function to break from the main optimization loop whe the following
524
+ conditions are met : maximum number of iterations, NaN
525
+ appearing in the parameters, and early stopping criterion.
486
526
  """
487
527
 
488
528
  def stop_while_loop(msg):
@@ -490,7 +530,8 @@ def get_break_fun(n_iter):
490
530
  Note that the message is wrapped in the jax.lax.cond because a
491
531
  string is not a valid JAX type that can be fed into the operands
492
532
  """
493
- jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
533
+ if verbose:
534
+ jax.debug.print(f"Stopping main optimization loop, cause: {msg}")
494
535
  return False
495
536
 
496
537
  def continue_while_loop(_):
@@ -531,43 +572,57 @@ def get_break_fun(n_iter):
531
572
  return break_fun
532
573
 
533
574
 
534
- def get_get_batch(obs_batch_sharding):
575
+ def _get_get_batch(
576
+ obs_batch_sharding: jax.sharding.Sharding,
577
+ ) -> Callable[
578
+ [
579
+ AnyDataGenerator,
580
+ DataGeneratorParameter | None,
581
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
582
+ ],
583
+ tuple[
584
+ AnyBatch,
585
+ AnyDataGenerator,
586
+ DataGeneratorParameter | None,
587
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
588
+ ],
589
+ ]:
535
590
  """
536
591
  Return the get_batch function that will be used either the jittable one or
537
- the non-jittable one with sharding
592
+ the non-jittable one with sharding using jax.device.put()
538
593
  """
539
594
 
540
595
  def get_batch_sharding(data, param_data, obs_data):
541
596
  """
542
597
  This function is used at each loop but it cannot be jitted because of
543
598
  device_put
544
-
545
- Note: return all that's modified or unwanted dirty undefined behaviour
546
599
  """
547
- batch = data.get_batch()
600
+ data, batch = data.get_batch()
548
601
  if param_data is not None:
549
- batch = append_param_batch(batch, param_data.get_batch())
602
+ param_data, param_batch = param_data.get_batch()
603
+ batch = append_param_batch(batch, param_batch)
550
604
  if obs_data is not None:
551
605
  # This is the part that motivated the transition from scan to for loop
552
606
  # Indeed we need to be transit obs_batch from CPU to GPU when we have
553
607
  # huge observations that cannot fit on GPU. Such transfer wasn't meant
554
608
  # to be jitted, i.e. in a scan loop
555
- obs_batch = jax.device_put(obs_data.get_batch(), obs_batch_sharding)
609
+ obs_data, obs_batch = obs_data.get_batch()
610
+ obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
556
611
  batch = append_obs_batch(batch, obs_batch)
557
612
  return batch, data, param_data, obs_data
558
613
 
559
614
  @jit
560
615
  def get_batch(data, param_data, obs_data):
561
616
  """
562
- Original get_batch with not sharding
563
-
564
- Note: return all that's modified or unwanted dirty undefined behaviour
617
+ Original get_batch with no sharding
565
618
  """
566
- batch = data.get_batch()
619
+ data, batch = data.get_batch()
567
620
  if param_data is not None:
568
- batch = append_param_batch(batch, param_data.get_batch())
621
+ param_data, param_batch = param_data.get_batch()
622
+ batch = append_param_batch(batch, param_batch)
569
623
  if obs_data is not None:
570
- batch = append_obs_batch(batch, obs_data.get_batch())
624
+ obs_data, obs_batch = obs_data.get_batch()
625
+ batch = append_obs_batch(batch, obs_batch)
571
626
  return batch, data, param_data, obs_data
572
627
 
573
628
  if obs_batch_sharding is not None:
jinns/utils/__init__.py CHANGED
@@ -1,10 +1,4 @@
1
- from ._utils import (
2
- euler_maruyama,
3
- euler_maruyama_density,
4
- log_euler_maruyama_density,
5
- )
6
- from ._pinn import create_PINN
7
- from ._spinn import create_SPINN
8
- from ._hyperpinn import create_HYPERPINN
9
- from ._optim import alternate_optimizer, delayed_optimizer
1
+ from ._pinn import create_PINN, PINN
2
+ from ._spinn import create_SPINN, SPINN
3
+ from ._hyperpinn import create_HYPERPINN, HYPERPINN
10
4
  from ._save_load import save_pinn, load_pinn