jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py CHANGED
@@ -8,56 +8,76 @@ from __future__ import (
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
10
  import time
11
- from typing import TYPE_CHECKING, NamedTuple, Dict, Union
11
+ from typing import TYPE_CHECKING, Any, TypeAlias, Callable
12
12
  from functools import partial
13
13
  import optax
14
14
  import jax
15
15
  from jax import jit
16
16
  import jax.numpy as jnp
17
- from jaxtyping import Int, Bool, Float, Array
17
+ from jaxtyping import Float, Array, PyTree, Key
18
+ import equinox as eqx
18
19
  from jinns.solver._rar import init_rar, trigger_rar
19
20
  from jinns.utils._utils import _check_nan_in_pytree
20
21
  from jinns.solver._utils import _check_batch_size
21
- from jinns.utils._containers import *
22
- from jinns.data._DataGenerators import (
23
- DataGeneratorODE,
24
- CubicMeshPDEStatio,
25
- CubicMeshPDENonStatio,
26
- append_obs_batch,
27
- append_param_batch,
22
+ from jinns.utils._containers import (
23
+ DataGeneratorContainer,
24
+ OptimizationContainer,
25
+ OptimizationExtraContainer,
26
+ LossContainer,
27
+ StoredObjectContainer,
28
28
  )
29
+ from jinns.data._utils import append_param_batch, append_obs_batch
29
30
 
30
31
  if TYPE_CHECKING:
31
- from jinns.utils._types import *
32
+ from jinns.parameters._params import Params
33
+ from jinns.utils._types import AnyBatch
34
+ from jinns.loss._abstract_loss import AbstractLoss
35
+ from jinns.validation._validation import AbstractValidationModule
36
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
37
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
38
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
39
+
40
+ main_carry: TypeAlias = tuple[
41
+ int,
42
+ AbstractLoss,
43
+ OptimizationContainer,
44
+ OptimizationExtraContainer,
45
+ DataGeneratorContainer,
46
+ AbstractValidationModule | None,
47
+ LossContainer,
48
+ StoredObjectContainer,
49
+ Float[Array, " n_iter"] | None,
50
+ Key | None,
51
+ ]
32
52
 
33
53
 
34
54
  def solve(
35
- n_iter: Int,
36
- init_params: AnyParams,
37
- data: AnyDataGenerator,
38
- loss: AnyLoss,
55
+ n_iter: int,
56
+ init_params: Params[Array],
57
+ data: AbstractDataGenerator,
58
+ loss: AbstractLoss,
39
59
  optimizer: optax.GradientTransformation,
40
- print_loss_every: Int = 1000,
41
- opt_state: Union[NamedTuple, None] = None,
42
- tracked_params: Params | ParamsDict | None = None,
60
+ print_loss_every: int = 1000,
61
+ opt_state: optax.OptState | None = None,
62
+ tracked_params: Params[Any | None] | None = None,
43
63
  param_data: DataGeneratorParameter | None = None,
44
- obs_data: (
45
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
46
- ) = None,
64
+ obs_data: DataGeneratorObservations | None = None,
47
65
  validation: AbstractValidationModule | None = None,
48
66
  obs_batch_sharding: jax.sharding.Sharding | None = None,
49
- verbose: Bool = True,
50
- ahead_of_time: Bool = True,
67
+ verbose: bool = True,
68
+ ahead_of_time: bool = True,
69
+ key: Key = None,
51
70
  ) -> tuple[
52
- Params | ParamsDict,
53
- Float[Array, "n_iter"],
54
- Dict[str, Float[Array, "n_iter"]],
55
- AnyDataGenerator,
56
- AnyLoss,
57
- NamedTuple,
58
- AnyParams,
59
- Float[Array, "n_iter"],
60
- AnyParams,
71
+ Params[Array],
72
+ Float[Array, " n_iter"],
73
+ PyTree,
74
+ AbstractDataGenerator,
75
+ AbstractLoss,
76
+ optax.OptState,
77
+ Params[Array | None],
78
+ PyTree,
79
+ Float[Array, " n_iter"] | None,
80
+ Params[Array],
61
81
  ]:
62
82
  """
63
83
  Performs the optimization process via stochastic gradient descent
@@ -94,8 +114,7 @@ def solve(
94
114
  Default None. A DataGeneratorParameter object which can be used to
95
115
  sample equation parameters.
96
116
  obs_data
97
- Default None. A DataGeneratorObservations or
98
- DataGeneratorObservationsMultiPINNs
117
+ Default None. A DataGeneratorObservations
99
118
  object which can be used to sample minibatches of observations.
100
119
  validation
101
120
  Default None. Otherwise, a callable ``eqx.Module`` which implements a
@@ -127,6 +146,9 @@ def solve(
127
146
  transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
128
147
  When False, jinns does not provide any timing information (which would
129
148
  be nonsense in a JIT transformed `solve()` function).
149
+ key
150
+ Default None. A JAX random key that can be used for diverse purpose in
151
+ the main iteration loop.
130
152
 
131
153
  Returns
132
154
  -------
@@ -136,8 +158,8 @@ def solve(
136
158
  total_loss_values
137
159
  An array of the total loss term along the gradient steps
138
160
  stored_loss_terms
139
- A dictionary. At each key an array of the values of a given loss
140
- term is stored
161
+ A PyTree with attributes being arrays of all the values for each loss
162
+ term
141
163
  data
142
164
  The input data object
143
165
  loss
@@ -147,11 +169,19 @@ def solve(
147
169
  stored_params
148
170
  A Params objects with the stored values of the desired parameters (as
149
171
  signified in tracked_params argument)
172
+ stored_weights_terms
173
+ A PyTree with attributes being arrays of all the values for each loss
174
+ weight. Note that if Loss.update_weight_method is None, we return None,
175
+ because loss weights are never updated and we can then save some
176
+ computations
150
177
  validation_crit_values
151
178
  An array containing the validation criterion values of the training
152
179
  best_val_params
153
180
  The best parameters according to the validation criterion
154
181
  """
182
+ if n_iter < 1:
183
+ raise ValueError("Cannot run jinns.solve for n_iter<1")
184
+
155
185
  if param_data is not None:
156
186
  if param_data.param_batch_size is not None:
157
187
  # We need to check that batch sizes will all be compliant for
@@ -171,11 +201,21 @@ def solve(
171
201
  _check_batch_size(obs_data, param_data, "n")
172
202
 
173
203
  if opt_state is None:
174
- opt_state = optimizer.init(init_params)
204
+ opt_state = optimizer.init(init_params) # type: ignore
205
+ # our Params are eqx.Module (dataclass + PyTree), PyTree is
206
+ # compatible with optax transform but not dataclass, this leads to a
207
+ # type hint error: we could prevent this by ensuring with the eqx.filter that
208
+ # we have only floating points optimizable params given to optax
209
+ # see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
210
+ # opt_state = optimizer.init(
211
+ # eqx.filter(init_params, eqx.is_inexact_array)
212
+ # )
213
+ # but this seems like a hack and there is no better way
214
+ # https://github.com/google-deepmind/optax/issues/384
175
215
 
176
216
  # RAR sampling init (ouside scanned function to avoid dynamic slice error)
177
217
  # If RAR is not used the _rar_step_*() are juste None and data is unchanged
178
- data, _rar_step_true, _rar_step_false = init_rar(data)
218
+ data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
179
219
 
180
220
  # Seq2seq
181
221
  curr_seq = 0
@@ -207,11 +247,38 @@ def solve(
207
247
  # being a complex data structure
208
248
  )
209
249
 
210
- # initialize the dict for stored loss values
250
+ # initialize the PyTree for stored loss values
211
251
  stored_loss_terms = jax.tree_util.tree_map(
212
252
  lambda _: jnp.zeros((n_iter)), loss_terms
213
253
  )
214
254
 
255
+ # initialize the PyTree for stored loss weights values
256
+ if loss.update_weight_method is not None:
257
+ stored_weights_terms = eqx.tree_at(
258
+ lambda pt: jax.tree.leaves(
259
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
260
+ ),
261
+ loss.loss_weights,
262
+ tuple(
263
+ jnp.zeros((n_iter))
264
+ for n in range(
265
+ len(
266
+ jax.tree.leaves(
267
+ loss.loss_weights,
268
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
269
+ )
270
+ )
271
+ )
272
+ ),
273
+ )
274
+ else:
275
+ stored_weights_terms = None
276
+ if loss.update_weight_method is not None and key is None:
277
+ raise ValueError(
278
+ "`key` argument must be passed to jinns.solve when"
279
+ " `loss.update_weight_method` is not None"
280
+ )
281
+
215
282
  train_data = DataGeneratorContainer(
216
283
  data=data, param_data=param_data, obs_data=obs_data
217
284
  )
@@ -228,6 +295,7 @@ def solve(
228
295
  )
229
296
  loss_container = LossContainer(
230
297
  stored_loss_terms=stored_loss_terms,
298
+ stored_weights_terms=stored_weights_terms,
231
299
  train_loss_values=train_loss_values,
232
300
  )
233
301
  stored_objects = StoredObjectContainer(
@@ -252,6 +320,7 @@ def solve(
252
320
  loss_container,
253
321
  stored_objects,
254
322
  validation_crit_values,
323
+ key,
255
324
  )
256
325
 
257
326
  def _one_iteration(carry: main_carry) -> main_carry:
@@ -265,24 +334,47 @@ def solve(
265
334
  loss_container,
266
335
  stored_objects,
267
336
  validation_crit_values,
337
+ key,
268
338
  ) = carry
269
339
 
270
340
  batch, data, param_data, obs_data = get_batch(
271
341
  train_data.data, train_data.param_data, train_data.obs_data
272
342
  )
273
343
 
274
- # Gradient step
344
+ # ---------------------------------------------------------------------
345
+ # The following part is the equivalent of a
346
+ # > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
347
+ # but it is decomposed on individual loss terms so that we can use it
348
+ # if needed for updating loss weights.
349
+ # Since the total loss is a weighted sum of individual loss terms, so
350
+ # are its total gradients.
351
+
352
+ # Compute individual losses and individual gradients
353
+ loss_terms, grad_terms = loss.evaluate_by_terms(optimization.params, batch)
354
+
355
+ if loss.update_weight_method is not None:
356
+ key, subkey = jax.random.split(key) # type: ignore because key can
357
+ # still be None currently
358
+ # avoid computations of tree_at if no updates
359
+ loss = loss.update_weights(
360
+ i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
361
+ )
362
+
363
+ # total grad
364
+ grads = loss.ponderate_and_sum_gradient(grad_terms)
365
+
366
+ # total loss
367
+ train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
368
+ # ---------------------------------------------------------------------
369
+
370
+ # gradient step
275
371
  (
276
- loss,
277
- train_loss_value,
278
- loss_terms,
279
372
  params,
280
373
  opt_state,
281
374
  last_non_nan_params,
282
375
  ) = _gradient_step(
283
- loss,
376
+ grads,
284
377
  optimizer,
285
- batch,
286
378
  optimization.params,
287
379
  optimization.opt_state,
288
380
  optimization.last_non_nan_params,
@@ -292,7 +384,7 @@ def solve(
292
384
  if verbose:
293
385
  _print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
294
386
 
295
- if validation is not None:
387
+ if validation is not None and validation_crit_values is not None:
296
388
  # there is a jax.lax.cond because we do not necesarily call the
297
389
  # validation step every iteration
298
390
  (
@@ -306,7 +398,7 @@ def solve(
306
398
  lambda operands: (
307
399
  operands[0],
308
400
  False,
309
- validation_crit_values[i - 1],
401
+ validation_crit_values[i - 1], # type: ignore don't know why it can still be None
310
402
  False,
311
403
  ),
312
404
  (
@@ -350,14 +442,14 @@ def solve(
350
442
  )
351
443
 
352
444
  # save loss value and selected parameters
353
- stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
445
+ stored_objects, loss_container = _store_loss_and_params(
354
446
  i,
355
447
  params,
356
448
  stored_objects.stored_params,
357
- loss_container.stored_loss_terms,
358
- loss_container.train_loss_values,
449
+ loss_container,
359
450
  train_loss_value,
360
451
  loss_terms,
452
+ loss.loss_weights,
361
453
  tracked_params,
362
454
  )
363
455
 
@@ -377,9 +469,10 @@ def solve(
377
469
  ),
378
470
  DataGeneratorContainer(data, param_data, obs_data),
379
471
  validation,
380
- LossContainer(stored_loss_terms, train_loss_values),
381
- StoredObjectContainer(stored_params),
472
+ loss_container,
473
+ stored_objects,
382
474
  validation_crit_values,
475
+ key,
383
476
  )
384
477
 
385
478
  # Main optimization loop. We use the LAX while loop (fully jitted) version
@@ -419,6 +512,7 @@ def solve(
419
512
  loss_container,
420
513
  stored_objects,
421
514
  validation_crit_values,
515
+ key,
422
516
  ) = carry
423
517
 
424
518
  if verbose:
@@ -431,7 +525,7 @@ def solve(
431
525
  # get ready to return the parameters at last iteration...
432
526
  # (by default arbitrary choice, this could be None)
433
527
  validation_parameters = optimization.last_non_nan_params
434
- if validation is not None:
528
+ if validation is not None and validation_crit_values is not None:
435
529
  jax.debug.print(
436
530
  "validation loss value = {validation_loss_val}",
437
531
  validation_loss_val=validation_crit_values[i - 1],
@@ -456,6 +550,7 @@ def solve(
456
550
  loss, # return the Loss if needed (no-inplace modif)
457
551
  optimization.opt_state,
458
552
  stored_objects.stored_params,
553
+ loss_container.stored_weights_terms,
459
554
  validation_crit_values if validation is not None else None,
460
555
  validation_parameters,
461
556
  )
@@ -463,27 +558,26 @@ def solve(
463
558
 
464
559
  @partial(jit, static_argnames=["optimizer"])
465
560
  def _gradient_step(
466
- loss: AnyLoss,
561
+ grads: Params[Array],
467
562
  optimizer: optax.GradientTransformation,
468
- batch: AnyBatch,
469
- params: AnyParams,
470
- opt_state: NamedTuple,
471
- last_non_nan_params: AnyParams,
563
+ params: Params[Array],
564
+ opt_state: optax.OptState,
565
+ last_non_nan_params: Params[Array],
472
566
  ) -> tuple[
473
- AnyLoss,
474
- float,
475
- Dict[str, float],
476
- AnyParams,
477
- NamedTuple,
478
- AnyParams,
567
+ Params[Array],
568
+ optax.OptState,
569
+ Params[Array],
479
570
  ]:
480
571
  """
481
572
  optimizer cannot be jit-ted.
482
573
  """
483
- value_grad_loss = jax.value_and_grad(loss, has_aux=True)
484
- (loss_val, loss_terms), grads = value_grad_loss(params, batch)
485
- updates, opt_state = optimizer.update(grads, opt_state, params)
486
- params = optax.apply_updates(params, updates)
574
+
575
+ updates, opt_state = optimizer.update(
576
+ grads, # type: ignore
577
+ opt_state,
578
+ params, # type: ignore
579
+ ) # see optimizer.init for explaination for the ignore(s) here
580
+ params = optax.apply_updates(params, updates) # type: ignore
487
581
 
488
582
  # check if any of the parameters is NaN
489
583
  last_non_nan_params = jax.lax.cond(
@@ -494,9 +588,6 @@ def _gradient_step(
494
588
  )
495
589
 
496
590
  return (
497
- loss,
498
- loss_val,
499
- loss_terms,
500
591
  params,
501
592
  opt_state,
502
593
  last_non_nan_params,
@@ -504,7 +595,7 @@ def _gradient_step(
504
595
 
505
596
 
506
597
  @partial(jit, static_argnames=["prefix"])
507
- def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
598
+ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
508
599
  # note that if the following is not jitted in the main lor loop, it is
509
600
  # super slow
510
601
  _ = jax.lax.cond(
@@ -521,17 +612,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
521
612
 
522
613
  @jit
523
614
  def _store_loss_and_params(
524
- i: Int,
525
- params: AnyParams,
526
- stored_params: AnyParams,
527
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
528
- train_loss_values: Float[Array, "n_iter"],
615
+ i: int,
616
+ params: Params[Array],
617
+ stored_params: Params[Array | None],
618
+ loss_container: LossContainer,
529
619
  train_loss_val: float,
530
- loss_terms: Dict[str, float],
531
- tracked_params: AnyParams,
532
- ) -> tuple[
533
- Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
534
- ]:
620
+ loss_terms: PyTree[Array],
621
+ weight_terms: PyTree[Array],
622
+ tracked_params: Params,
623
+ ) -> tuple[StoredObjectContainer, LossContainer]:
535
624
  stored_params = jax.tree_util.tree_map(
536
625
  lambda stored_value, param, tracked_param: (
537
626
  None
@@ -550,15 +639,41 @@ def _store_loss_and_params(
550
639
  )
551
640
  stored_loss_terms = jax.tree_util.tree_map(
552
641
  lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
553
- stored_loss_terms,
642
+ loss_container.stored_loss_terms,
554
643
  loss_terms,
555
644
  )
556
645
 
557
- train_loss_values = train_loss_values.at[i].set(train_loss_val)
558
- return (stored_params, stored_loss_terms, train_loss_values)
646
+ if loss_container.stored_weights_terms is not None:
647
+ stored_weights_terms = jax.tree_util.tree_map(
648
+ lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
649
+ jax.tree.leaves(
650
+ loss_container.stored_weights_terms,
651
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
652
+ ),
653
+ jax.tree.leaves(
654
+ weight_terms,
655
+ is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
656
+ ),
657
+ )
658
+ stored_weights_terms = eqx.tree_at(
659
+ lambda pt: jax.tree.leaves(
660
+ pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
661
+ ),
662
+ loss_container.stored_weights_terms,
663
+ stored_weights_terms,
664
+ )
665
+ else:
666
+ stored_weights_terms = None
667
+
668
+ train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
669
+ loss_container = LossContainer(
670
+ stored_loss_terms, stored_weights_terms, train_loss_values
671
+ )
672
+ stored_objects = StoredObjectContainer(stored_params)
673
+ return stored_objects, loss_container
559
674
 
560
675
 
561
- def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
676
+ def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
562
677
  """
563
678
  Wrapper to get the break_fun with appropriate `n_iter`.
564
679
  The verbose argument is here to control printing (or not) when exiting
@@ -586,7 +701,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
586
701
  def continue_while_loop(_):
587
702
  return True
588
703
 
589
- (i, _, optimization, optimization_extra, _, _, _, _, _) = carry
704
+ (i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
590
705
 
591
706
  # Condition 1
592
707
  bool_max_iter = jax.lax.cond(
@@ -599,7 +714,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
599
714
  bool_nan_in_params = jax.lax.cond(
600
715
  _check_nan_in_pytree(optimization.params),
601
716
  lambda _: stop_while_loop(
602
- "NaN values in parameters " "(returning last non NaN values)"
717
+ "NaN values in parameters (returning last non NaN values)"
603
718
  ),
604
719
  continue_while_loop,
605
720
  None,
@@ -622,18 +737,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
622
737
 
623
738
 
624
739
  def _get_get_batch(
625
- obs_batch_sharding: jax.sharding.Sharding,
740
+ obs_batch_sharding: jax.sharding.Sharding | None,
626
741
  ) -> Callable[
627
742
  [
628
- AnyDataGenerator,
743
+ AbstractDataGenerator,
629
744
  DataGeneratorParameter | None,
630
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
745
+ DataGeneratorObservations | None,
631
746
  ],
632
747
  tuple[
633
748
  AnyBatch,
634
- AnyDataGenerator,
749
+ AbstractDataGenerator,
635
750
  DataGeneratorParameter | None,
636
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
751
+ DataGeneratorObservations | None,
637
752
  ],
638
753
  ]:
639
754
  """
jinns/solver/_utils.py CHANGED
@@ -1,9 +1,7 @@
1
- from jinns.data._DataGenerators import (
2
- DataGeneratorODE,
3
- CubicMeshPDEStatio,
4
- CubicMeshPDENonStatio,
5
- DataGeneratorParameter,
6
- )
1
+ from jinns.data._DataGeneratorODE import DataGeneratorODE
2
+ from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
3
+ from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
7
5
 
8
6
 
9
7
  def _check_batch_size(other_data, main_data, attr_name):
jinns/utils/__init__.py CHANGED
@@ -1 +1,3 @@
1
1
  from ._utils import get_grid
2
+
3
+ __all__ = ["get_grid"]
@@ -6,28 +6,31 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Dict
9
+ from typing import TYPE_CHECKING
10
10
  from jaxtyping import PyTree, Array, Float, Bool
11
11
  from optax import OptState
12
12
  import equinox as eqx
13
13
 
14
+ from jinns.parameters._params import Params
15
+
14
16
  if TYPE_CHECKING:
15
- from jinns.utils._types import *
17
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
18
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
19
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
20
+ from jinns.utils._types import AnyLoss
16
21
 
17
22
 
18
23
  class DataGeneratorContainer(eqx.Module):
19
- data: AnyDataGenerator
24
+ data: AbstractDataGenerator
20
25
  param_data: DataGeneratorParameter | None = None
21
- obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
22
- None
23
- )
26
+ obs_data: DataGeneratorObservations | None = None
24
27
 
25
28
 
26
29
  class ValidationContainer(eqx.Module):
27
30
  loss: AnyLoss | None
28
31
  data: DataGeneratorContainer
29
32
  hyperparams: PyTree = None
30
- loss_values: Float[Array, "n_iter"] | None = None
33
+ loss_values: Float[Array, " n_iter"] | None = None
31
34
 
32
35
 
33
36
  class OptimizationContainer(eqx.Module):
@@ -45,9 +48,12 @@ class OptimizationExtraContainer(eqx.Module):
45
48
 
46
49
 
47
50
  class LossContainer(eqx.Module):
48
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
49
- train_loss_values: Float[Array, "n_iter"]
51
+ # PyTree below refers to ODEComponents or PDEStatioComponents or
52
+ # PDENonStatioComponents
53
+ stored_loss_terms: PyTree[Float[Array, " n_iter"]]
54
+ stored_weights_terms: PyTree[Float[Array, " n_iter"]]
55
+ train_loss_values: Float[Array, " n_iter"]
50
56
 
51
57
 
52
58
  class StoredObjectContainer(eqx.Module):
53
- stored_params: list | None
59
+ stored_params: Params[Array | None]