jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py CHANGED
@@ -3,26 +3,33 @@ This modules implements the main `solve()` function of jinns which
3
3
  handles the optimization process
4
4
  """
5
5
 
6
- import copy
6
+ from __future__ import (
7
+ annotations,
8
+ ) # https://docs.python.org/3/library/typing.html#constant
9
+
10
+ from typing import TYPE_CHECKING, NamedTuple, Dict, Union
7
11
  from functools import partial
8
12
  import optax
9
13
  import jax
10
14
  from jax import jit
11
15
  import jax.numpy as jnp
12
- from jinns.solver._seq2seq import trigger_seq2seq, initialize_seq2seq
16
+ from jaxtyping import Int, Bool, Float, Array
13
17
  from jinns.solver._rar import init_rar, trigger_rar
14
- from jinns.utils._utils import _check_nan_in_pytree, _tracked_parameters
18
+ from jinns.utils._utils import _check_nan_in_pytree
19
+ from jinns.utils._containers import *
15
20
  from jinns.data._DataGenerators import (
16
21
  DataGeneratorODE,
17
22
  CubicMeshPDEStatio,
18
23
  CubicMeshPDENonStatio,
19
- append_param_batch,
20
24
  append_obs_batch,
25
+ append_param_batch,
21
26
  )
22
- from jinns.utils._containers import *
23
27
 
28
+ if TYPE_CHECKING:
29
+ from jinns.utils._types import *
24
30
 
25
- def check_batch_size(other_data, main_data, attr_name):
31
+
32
+ def _check_batch_size(other_data, main_data, attr_name):
26
33
  if (
27
34
  (
28
35
  isinstance(main_data, DataGeneratorODE)
@@ -48,21 +55,32 @@ def check_batch_size(other_data, main_data, attr_name):
48
55
 
49
56
 
50
57
  def solve(
51
- n_iter,
52
- init_params,
53
- data,
54
- loss,
55
- optimizer,
56
- print_loss_every=1000,
57
- opt_state=None,
58
- seq2seq=None,
59
- tracked_params_key_list=None,
60
- param_data=None,
61
- obs_data=None,
62
- validation=None,
63
- obs_batch_sharding=None,
64
- verbose=True,
65
- ):
58
+ n_iter: Int,
59
+ init_params: AnyParams,
60
+ data: AnyDataGenerator,
61
+ loss: AnyLoss,
62
+ optimizer: optax.GradientTransformation,
63
+ print_loss_every: Int = 1000,
64
+ opt_state: Union[NamedTuple, None] = None,
65
+ tracked_params: Params | ParamsDict | None = None,
66
+ param_data: DataGeneratorParameter | None = None,
67
+ obs_data: (
68
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
69
+ ) = None,
70
+ validation: AbstractValidationModule | None = None,
71
+ obs_batch_sharding: jax.sharding.Sharding | None = None,
72
+ verbose: Bool = True,
73
+ ) -> tuple[
74
+ Params | ParamsDict,
75
+ Float[Array, "n_iter"],
76
+ Dict[str, Float[Array, "n_iter"]],
77
+ AnyDataGenerator,
78
+ AnyLoss,
79
+ NamedTuple,
80
+ AnyParams,
81
+ Float[Array, "n_iter"],
82
+ AnyParams,
83
+ ]:
66
84
  """
67
85
  Performs the optimization process via stochastic gradient descent
68
86
  algorithm. We minimize the function defined `loss.evaluate()` with
@@ -73,52 +91,39 @@ def solve(
73
91
  Parameters
74
92
  ----------
75
93
  n_iter
76
- The number of iterations in the optimization
94
+ The maximum number of iterations in the optimization.
77
95
  init_params
78
- The initial dictionary of parameters. Typically, it is a dictionary of
79
- dictionaries: `eq_params` and `nn_params``, respectively the
80
- differential equation parameters and the neural network parameter
96
+ The initial jinns.parameters.Params object.
81
97
  data
82
- A DataGenerator object which implements a `get_batch()`
83
- method which returns a 3-tuple with (omega_grid, omega_border, time grid).
84
- It must be jittable (e.g. implements via a pytree
85
- registration)
98
+ A DataGenerator object to retrieve batches of collocation points.
86
99
  loss
87
- A loss object (e.g. a LossODE, SystemLossODE, LossPDEStatio [...]
88
- object). It must be jittable (e.g. implements via a pytree
89
- registration)
100
+ The loss function to minimize.
90
101
  optimizer
91
- An `optax` optimizer (e.g. `optax.adam`).
102
+ An optax optimizer.
92
103
  print_loss_every
93
- Integer. Default 100. The rate at which we print the loss value in the
104
+ Default 1000. The rate at which we print the loss value in the
94
105
  gradient step loop.
95
106
  opt_state
96
- Default None. Provide an optional initial optional state to the
97
- optimizer. Not valid for all optimizers.
98
- seq2seq
99
- Default None. A dictionary with keys 'times_steps'
100
- and 'iter_steps' which mush have same length. The first represents
101
- the time steps which represents the different time interval upon
102
- which we perform the incremental learning. The second represents
103
- the number of iteration we perform in each time interval.
104
- The seq2seq approach we reimplements is defined in
105
- "Characterizing possible failure modes in physics-informed neural
106
- networks", A. S. Krishnapriyan, NeurIPS 2021
107
- tracked_params_key_list
108
- Default None. Otherwise it is a list of list of strings
109
- to access a leaf in params. Each selected leaf will be tracked
110
- and stored at each iteration and returned by the solve function
107
+ Provide an optional initial state to the optimizer.
108
+ tracked_params
109
+ Default None. An eqx.Module of type Params with non-None values for
110
+ parameters that needs to be tracked along the iterations.
111
+ None values in tracked_params will not be traversed. Thus
112
+ the user can provide something like `tracked_params = jinns.parameters.Params(
113
+ nn_params=None, eq_params={"nu": True})` while init_params.nn_params
114
+ being a complex data structure.
111
115
  param_data
112
116
  Default None. A DataGeneratorParameter object which can be used to
113
117
  sample equation parameters.
114
118
  obs_data
115
- Default None. A DataGeneratorObservations object which can be used to
116
- sample minibatches of observations
119
+ Default None. A DataGeneratorObservations or
120
+ DataGeneratorObservationsMultiPINNs
121
+ object which can be used to sample minibatches of observations.
117
122
  validation
118
123
  Default None. Otherwise, a callable ``eqx.Module`` which implements a
119
- validation strategy. See documentation of :obj:`~jinns.validation.
124
+ validation strategy. See documentation of `jinns.validation.
120
125
  _validation.AbstractValidationModule` for the general interface, and
121
- :obj:`~jinns.validation._validation.ValidationLoss` for a practical
126
+ `jinns.validation._validation.ValidationLoss` for a practical
122
127
  implementation of a vanilla validation stategy on a validation set of
123
128
  collocation points.
124
129
 
@@ -132,15 +137,15 @@ def solve(
132
137
  Default None. An optional sharding object to constraint the obs_batch.
133
138
  Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
134
139
  created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
135
- loading on GPU huge datasets of observations
136
- verbose:
137
- Boolean, default True. If False, no std output (loss or cause of
140
+ loading on GPU huge datasets of observations.
141
+ verbose
142
+ Default True. If False, no std output (loss or cause of
138
143
  exiting the optimization loop) will be produced.
139
144
 
140
145
  Returns
141
146
  -------
142
147
  params
143
- The last non NaN value of the dictionaries of parameters at then end of the
148
+ The last non NaN value of the params at then end of the
144
149
  optimization process
145
150
  total_loss_values
146
151
  An array of the total loss term along the gradient steps
@@ -154,14 +159,18 @@ def solve(
154
159
  opt_state
155
160
  The final optimized state
156
161
  stored_params
157
- A dictionary. At each key an array of the values of the parameters
158
- given in tracked_params_key_list is stored
162
+ A Params objects with the stored values of the desired parameters (as
163
+ signified in tracked_params argument)
164
+ validation_crit_values
165
+ An array containing the validation criterion values of the training
166
+ best_val_params
167
+ The best parameters according to the validation criterion
159
168
  """
160
169
  if param_data is not None:
161
- check_batch_size(param_data, data, "param_batch_size")
170
+ _check_batch_size(param_data, data, "param_batch_size")
162
171
 
163
172
  if obs_data is not None:
164
- check_batch_size(obs_data, data, "obs_batch_size")
173
+ _check_batch_size(obs_data, data, "obs_batch_size")
165
174
 
166
175
  if opt_state is None:
167
176
  opt_state = optimizer.init(init_params)
@@ -172,30 +181,32 @@ def solve(
172
181
 
173
182
  # Seq2seq
174
183
  curr_seq = 0
175
- if seq2seq is not None:
176
- assert (
177
- data.method == "uniform"
178
- ), "data.method must be uniform if using seq2seq learning !"
179
- data, opt_state = initialize_seq2seq(loss, data, seq2seq, opt_state)
180
184
 
181
185
  train_loss_values = jnp.zeros((n_iter))
182
186
  # depending on obs_batch_sharding we will get the simple get_batch or the
183
187
  # get_batch with device_put, the latter is not jittable
184
- get_batch = get_get_batch(obs_batch_sharding)
188
+ get_batch = _get_get_batch(obs_batch_sharding)
185
189
 
186
190
  # initialize the dict for stored parameter values
187
191
  # we need to get a loss_term to init stuff
188
192
  batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
189
193
  _, loss_terms = loss(init_params, batch_ini)
190
- if tracked_params_key_list is None:
191
- tracked_params_key_list = []
192
- tracked_params = _tracked_parameters(init_params, tracked_params_key_list)
194
+
195
+ # initialize parameter tracking
196
+ if tracked_params is None:
197
+ tracked_params = jax.tree.map(lambda p: None, init_params)
193
198
  stored_params = jax.tree_util.tree_map(
194
199
  lambda tracked_param, param: (
195
- jnp.zeros((n_iter,) + param.shape) if tracked_param else None
200
+ jnp.zeros((n_iter,) + jnp.asarray(param).shape)
201
+ if tracked_param is not None
202
+ else None
196
203
  ),
197
204
  tracked_params,
198
205
  init_params,
206
+ is_leaf=lambda x: x is None, # None values in tracked_params will not
207
+ # be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
208
+ # nn_params=None, eq_params={"nu": True})` while init_params.nn_params
209
+ # being a complex data structure
199
210
  )
200
211
 
201
212
  # initialize the dict for stored loss values
@@ -208,13 +219,12 @@ def solve(
208
219
  )
209
220
  optimization = OptimizationContainer(
210
221
  params=init_params,
211
- last_non_nan_params=init_params.copy(),
222
+ last_non_nan_params=init_params,
212
223
  opt_state=opt_state,
213
224
  )
214
225
  optimization_extra = OptimizationExtraContainer(
215
226
  curr_seq=curr_seq,
216
- seq2seq=seq2seq,
217
- best_val_params=init_params.copy(),
227
+ best_val_params=init_params,
218
228
  )
219
229
  loss_container = LossContainer(
220
230
  stored_loss_terms=stored_loss_terms,
@@ -229,7 +239,7 @@ def solve(
229
239
  else:
230
240
  validation_crit_values = None
231
241
 
232
- break_fun = get_break_fun(n_iter, verbose)
242
+ break_fun = _get_break_fun(n_iter, verbose)
233
243
 
234
244
  iteration = 0
235
245
  carry = (
@@ -244,7 +254,7 @@ def solve(
244
254
  validation_crit_values,
245
255
  )
246
256
 
247
- def one_iteration(carry):
257
+ def _one_iteration(carry: main_carry) -> main_carry:
248
258
  (
249
259
  i,
250
260
  loss,
@@ -269,7 +279,7 @@ def solve(
269
279
  params,
270
280
  opt_state,
271
281
  last_non_nan_params,
272
- ) = gradient_step(
282
+ ) = _gradient_step(
273
283
  loss,
274
284
  optimizer,
275
285
  batch,
@@ -280,7 +290,7 @@ def solve(
280
290
 
281
291
  # Print train loss value during optimization
282
292
  if verbose:
283
- print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
293
+ _print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
284
294
 
285
295
  if validation is not None:
286
296
  # there is a jax.lax.cond because we do not necesarily call the
@@ -306,7 +316,7 @@ def solve(
306
316
  )
307
317
  # Print validation loss value during optimization
308
318
  if verbose:
309
- print_fn(
319
+ _print_fn(
310
320
  i, validation_criterion, print_loss_every, prefix="[validation] "
311
321
  )
312
322
  validation_crit_values = validation_crit_values.at[i].set(
@@ -329,19 +339,8 @@ def solve(
329
339
  i, loss, params, data, _rar_step_true, _rar_step_false
330
340
  )
331
341
 
332
- # Trigger seq2seq
333
- loss, params, data, opt_state, curr_seq, seq2seq = trigger_seq2seq(
334
- i,
335
- loss,
336
- params,
337
- data,
338
- opt_state,
339
- optimization_extra.curr_seq,
340
- optimization_extra.seq2seq,
341
- )
342
-
343
342
  # save loss value and selected parameters
344
- stored_params, stored_loss_terms, train_loss_values = store_loss_and_params(
343
+ stored_params, stored_loss_terms, train_loss_values = _store_loss_and_params(
345
344
  i,
346
345
  params,
347
346
  stored_objects.stored_params,
@@ -359,9 +358,7 @@ def solve(
359
358
  i,
360
359
  loss,
361
360
  OptimizationContainer(params, last_non_nan_params, opt_state),
362
- OptimizationExtraContainer(
363
- curr_seq, seq2seq, best_val_params, early_stopping
364
- ),
361
+ OptimizationExtraContainer(curr_seq, best_val_params, early_stopping),
365
362
  DataGeneratorContainer(data, param_data, obs_data),
366
363
  validation,
367
364
  LossContainer(stored_loss_terms, train_loss_values),
@@ -374,9 +371,9 @@ def solve(
374
371
  # concern obs_batch, but it could lead to more complex scheme in the future
375
372
  if obs_batch_sharding is not None:
376
373
  while break_fun(carry):
377
- carry = one_iteration(carry)
374
+ carry = _one_iteration(carry)
378
375
  else:
379
- carry = jax.lax.while_loop(break_fun, one_iteration, carry)
376
+ carry = jax.lax.while_loop(break_fun, _one_iteration, carry)
380
377
 
381
378
  (
382
379
  i,
@@ -416,7 +413,21 @@ def solve(
416
413
 
417
414
 
418
415
  @partial(jit, static_argnames=["optimizer"])
419
- def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
416
+ def _gradient_step(
417
+ loss: AnyLoss,
418
+ optimizer: optax.GradientTransformation,
419
+ batch: AnyBatch,
420
+ params: AnyParams,
421
+ opt_state: NamedTuple,
422
+ last_non_nan_params: AnyParams,
423
+ ) -> tuple[
424
+ AnyLoss,
425
+ float,
426
+ Dict[str, float],
427
+ AnyParams,
428
+ NamedTuple,
429
+ AnyParams,
430
+ ]:
420
431
  """
421
432
  optimizer cannot be jit-ted.
422
433
  """
@@ -444,7 +455,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
444
455
 
445
456
 
446
457
  @partial(jit, static_argnames=["prefix"])
447
- def print_fn(i, loss_val, print_loss_every, prefix=""):
458
+ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
448
459
  # note that if the following is not jitted in the main lor loop, it is
449
460
  # super slow
450
461
  _ = jax.lax.cond(
@@ -460,16 +471,18 @@ def print_fn(i, loss_val, print_loss_every, prefix=""):
460
471
 
461
472
 
462
473
  @jit
463
- def store_loss_and_params(
464
- i,
465
- params,
466
- stored_params,
467
- stored_loss_terms,
468
- train_loss_values,
469
- train_loss_val,
470
- loss_terms,
471
- tracked_params,
472
- ):
474
+ def _store_loss_and_params(
475
+ i: Int,
476
+ params: AnyParams,
477
+ stored_params: AnyParams,
478
+ stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
479
+ train_loss_values: Float[Array, "n_iter"],
480
+ train_loss_val: float,
481
+ loss_terms: Dict[str, float],
482
+ tracked_params: AnyParams,
483
+ ) -> tuple[
484
+ Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
485
+ ]:
473
486
  stored_params = jax.tree_util.tree_map(
474
487
  lambda stored_value, param, tracked_param: (
475
488
  None
@@ -496,7 +509,7 @@ def store_loss_and_params(
496
509
  return (stored_params, stored_loss_terms, train_loss_values)
497
510
 
498
511
 
499
- def get_break_fun(n_iter, verbose: str):
512
+ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
500
513
  """
501
514
  Wrapper to get the break_fun with appropriate `n_iter`.
502
515
  The verbose argument is here to control printing (or not) when exiting
@@ -507,8 +520,8 @@ def get_break_fun(n_iter, verbose: str):
507
520
  @jit
508
521
  def break_fun(carry: tuple):
509
522
  """
510
- Function to break from the main optimization loop
511
- We the following conditions : maximum number of iterations, NaN
523
+ Function to break from the main optimization loop whe the following
524
+ conditions are met : maximum number of iterations, NaN
512
525
  appearing in the parameters, and early stopping criterion.
513
526
  """
514
527
 
@@ -559,43 +572,57 @@ def get_break_fun(n_iter, verbose: str):
559
572
  return break_fun
560
573
 
561
574
 
562
- def get_get_batch(obs_batch_sharding):
575
+ def _get_get_batch(
576
+ obs_batch_sharding: jax.sharding.Sharding,
577
+ ) -> Callable[
578
+ [
579
+ AnyDataGenerator,
580
+ DataGeneratorParameter | None,
581
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
582
+ ],
583
+ tuple[
584
+ AnyBatch,
585
+ AnyDataGenerator,
586
+ DataGeneratorParameter | None,
587
+ DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
588
+ ],
589
+ ]:
563
590
  """
564
591
  Return the get_batch function that will be used either the jittable one or
565
- the non-jittable one with sharding
592
+ the non-jittable one with sharding using jax.device.put()
566
593
  """
567
594
 
568
595
  def get_batch_sharding(data, param_data, obs_data):
569
596
  """
570
597
  This function is used at each loop but it cannot be jitted because of
571
598
  device_put
572
-
573
- Note: return all that's modified or unwanted dirty undefined behaviour
574
599
  """
575
- batch = data.get_batch()
600
+ data, batch = data.get_batch()
576
601
  if param_data is not None:
577
- batch = append_param_batch(batch, param_data.get_batch())
602
+ param_data, param_batch = param_data.get_batch()
603
+ batch = append_param_batch(batch, param_batch)
578
604
  if obs_data is not None:
579
605
  # This is the part that motivated the transition from scan to for loop
580
606
  # Indeed we need to be transit obs_batch from CPU to GPU when we have
581
607
  # huge observations that cannot fit on GPU. Such transfer wasn't meant
582
608
  # to be jitted, i.e. in a scan loop
583
- obs_batch = jax.device_put(obs_data.get_batch(), obs_batch_sharding)
609
+ obs_data, obs_batch = obs_data.get_batch()
610
+ obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
584
611
  batch = append_obs_batch(batch, obs_batch)
585
612
  return batch, data, param_data, obs_data
586
613
 
587
614
  @jit
588
615
  def get_batch(data, param_data, obs_data):
589
616
  """
590
- Original get_batch with not sharding
591
-
592
- Note: return all that's modified or unwanted dirty undefined behaviour
617
+ Original get_batch with no sharding
593
618
  """
594
- batch = data.get_batch()
619
+ data, batch = data.get_batch()
595
620
  if param_data is not None:
596
- batch = append_param_batch(batch, param_data.get_batch())
621
+ param_data, param_batch = param_data.get_batch()
622
+ batch = append_param_batch(batch, param_batch)
597
623
  if obs_data is not None:
598
- batch = append_obs_batch(batch, obs_data.get_batch())
624
+ obs_data, obs_batch = obs_data.get_batch()
625
+ batch = append_obs_batch(batch, obs_batch)
599
626
  return batch, data, param_data, obs_data
600
627
 
601
628
  if obs_batch_sharding is not None:
jinns/utils/__init__.py CHANGED
@@ -1,10 +1,4 @@
1
- from ._utils import (
2
- euler_maruyama,
3
- euler_maruyama_density,
4
- log_euler_maruyama_density,
5
- )
6
- from ._pinn import create_PINN
7
- from ._spinn import create_SPINN
8
- from ._hyperpinn import create_HYPERPINN
9
- from ._optim import alternate_optimizer, delayed_optimizer
1
+ from ._pinn import create_PINN, PINN
2
+ from ._spinn import create_SPINN, SPINN
3
+ from ._hyperpinn import create_HYPERPINN, HYPERPINN
10
4
  from ._save_load import save_pinn, load_pinn
@@ -1,58 +1,51 @@
1
1
  """
2
- NamedTuples definition
2
+ equinox Modules used as containers
3
3
  """
4
4
 
5
- from typing import Union, NamedTuple
6
- from jaxtyping import PyTree
7
- from jax.typing import ArrayLike
8
- import optax
9
- import jax.numpy as jnp
10
- from jinns.loss._LossODE import LossODE, SystemLossODE
11
- from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
12
- from jinns.data._DataGenerators import (
13
- DataGeneratorODE,
14
- CubicMeshPDEStatio,
15
- CubicMeshPDENonStatio,
16
- DataGeneratorParameter,
17
- DataGeneratorObservations,
18
- DataGeneratorObservationsMultiPINNs,
19
- )
20
-
21
-
22
- class DataGeneratorContainer(NamedTuple):
23
- data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
24
- param_data: Union[DataGeneratorParameter, None] = None
25
- obs_data: Union[
26
- DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
27
- ] = None
28
-
29
-
30
- class ValidationContainer(NamedTuple):
31
- loss: Union[
32
- LossODE, SystemLossODE, LossPDEStatio, LossPDENonStatio, SystemLossPDE, None
33
- ]
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Dict
10
+ from jaxtyping import PyTree, Array, Float, Bool
11
+ from optax import OptState
12
+ import equinox as eqx
13
+
14
+ if TYPE_CHECKING:
15
+ from jinns.utils._types import *
16
+
17
+
18
+ class DataGeneratorContainer(eqx.Module):
19
+ data: AnyDataGenerator
20
+ param_data: DataGeneratorParameter | None = None
21
+ obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
22
+ None
23
+ )
24
+
25
+
26
+ class ValidationContainer(eqx.Module):
27
+ loss: AnyLoss | None
34
28
  data: DataGeneratorContainer
35
29
  hyperparams: PyTree = None
36
- loss_values: Union[ArrayLike, None] = None
30
+ loss_values: Float[Array, "n_iter"] | None = None
37
31
 
38
32
 
39
- class OptimizationContainer(NamedTuple):
40
- params: dict
41
- last_non_nan_params: dict
42
- opt_state: optax.OptState
33
+ class OptimizationContainer(eqx.Module):
34
+ params: Params
35
+ last_non_nan_params: Params
36
+ opt_state: OptState
43
37
 
44
38
 
45
- class OptimizationExtraContainer(NamedTuple):
39
+ class OptimizationExtraContainer(eqx.Module):
46
40
  curr_seq: int
47
- seq2seq: Union[dict, None]
48
- best_val_params: dict
49
- early_stopping: bool = False
41
+ best_val_params: Params
42
+ early_stopping: Bool = False
50
43
 
51
44
 
52
- class LossContainer(NamedTuple):
53
- stored_loss_terms: dict
54
- train_loss_values: ArrayLike
45
+ class LossContainer(eqx.Module):
46
+ stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
47
+ train_loss_values: Float[Array, "n_iter"]
55
48
 
56
49
 
57
- class StoredObjectContainer(NamedTuple):
58
- stored_params: Union[list, None]
50
+ class StoredObjectContainer(eqx.Module):
51
+ stored_params: list | None