jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +75 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  36. jinns-1.6.0.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,77 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
- from typing import TYPE_CHECKING, Self, Literal, Callable
5
- from jaxtyping import Array, PyTree, Key
4
+ from typing import Self, Literal, Callable, TypeVar, Generic, Any
5
+ from jaxtyping import PRNGKeyArray, Array, PyTree, Float
6
6
  import equinox as eqx
7
7
  import jax
8
8
  import jax.numpy as jnp
9
9
  import optax
10
- from jinns.loss._loss_weights import AbstractLossWeights
11
10
  from jinns.parameters._params import Params
12
11
  from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
12
+ from jinns.utils._types import AnyLossComponents, AnyBatch, AnyLossWeights
13
13
 
14
- if TYPE_CHECKING:
15
- from jinns.utils._types import AnyLossComponents, AnyBatch
14
+ L = TypeVar(
15
+ "L", bound=AnyLossWeights
16
+ ) # we want to be able to use one of the element of AnyLossWeights
17
+ # that is https://stackoverflow.com/a/79534258 via `bound`
16
18
 
19
+ B = TypeVar(
20
+ "B", bound=AnyBatch
21
+ ) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
22
+ # We then do the same TypeVar to be able to use one of the element of AnyBatch
23
+ # in the evaluate_by_terms methods of child classes.
24
+ C = TypeVar(
25
+ "C", bound=AnyLossComponents[Array | None]
26
+ ) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
17
27
 
18
- class AbstractLoss(eqx.Module):
28
+ # In the cases above, without the bound, we could not have covariance on
29
+ # the type because it would break LSP. Note that covariance on the return type
30
+ # is authorized in LSP hence we do not need the same TypeVar instruction for
31
+ # the return types of evaluate_by_terms for example!
32
+
33
+
34
+ class AbstractLoss(eqx.Module, Generic[L, B, C]):
19
35
  """
20
36
  About the call:
21
37
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
22
38
  """
23
39
 
24
- loss_weights: AbstractLossWeights
40
+ loss_weights: eqx.AbstractVar[L]
25
41
  update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
26
42
  eqx.field(kw_only=True, default=None, static=True)
27
43
  )
28
44
 
29
- @abc.abstractmethod
30
- def __call__(self, *_, **__) -> Array:
31
- pass
45
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
46
+ return self.evaluate(*args, **kwargs)
32
47
 
33
48
  @abc.abstractmethod
34
- def evaluate_by_terms(
35
- self, params: Params[Array], batch: AnyBatch
36
- ) -> tuple[AnyLossComponents, AnyLossComponents]:
49
+ def evaluate_by_terms(self, params: Params[Array], batch: B) -> tuple[C, C]:
37
50
  pass
38
51
 
52
+ def evaluate(self, params: Params[Array], batch: B) -> tuple[Float[Array, " "], C]:
53
+ """
54
+ Evaluate the loss function at a batch of points for given parameters.
55
+
56
+ We retrieve the total value itself and a PyTree with loss values for each term
57
+
58
+ Parameters
59
+ ---------
60
+ params
61
+ Parameters at which the loss is evaluated
62
+ batch
63
+ Composed of a batch of points in the
64
+ domain, a batch of points in the domain
65
+ border and an optional additional batch of parameters (eg. for
66
+ metamodeling) and an optional additional batch of observed
67
+ inputs/outputs/parameters
68
+ """
69
+ loss_terms, _ = self.evaluate_by_terms(params, batch)
70
+
71
+ loss_val = self.ponderate_and_sum_loss(loss_terms)
72
+
73
+ return loss_val, loss_terms
74
+
39
75
  def get_gradients(
40
76
  self, fun: Callable[[Params[Array]], Array], params: Params[Array]
41
77
  ) -> tuple[Array, Array]:
@@ -48,7 +84,7 @@ class AbstractLoss(eqx.Module):
48
84
  loss_val, grads = value_grad_loss(params)
49
85
  return loss_val, grads
50
86
 
51
- def ponderate_and_sum_loss(self, terms):
87
+ def ponderate_and_sum_loss(self, terms: C) -> Array:
52
88
  """
53
89
  Get total loss from individual loss terms and weights
54
90
 
@@ -58,19 +94,18 @@ class AbstractLoss(eqx.Module):
58
94
  self.loss_weights,
59
95
  is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
60
96
  )
61
- terms = jax.tree.leaves(
97
+ terms_list = jax.tree.leaves(
62
98
  terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
63
99
  )
64
- if len(weights) == len(terms):
65
- return jnp.sum(jnp.array(weights) * jnp.array(terms))
66
- else:
67
- raise ValueError(
68
- "The numbers of declared loss weights and "
69
- "declared loss terms do not concord "
70
- f" got {len(weights)} and {len(terms)}"
71
- )
100
+ if len(weights) == len(terms_list):
101
+ return jnp.sum(jnp.array(weights) * jnp.array(terms_list))
102
+ raise ValueError(
103
+ "The numbers of declared loss weights and "
104
+ "declared loss terms do not concord "
105
+ f" got {len(weights)} and {len(terms_list)}"
106
+ )
72
107
 
73
- def ponderate_and_sum_gradient(self, terms):
108
+ def ponderate_and_sum_gradient(self, terms: C) -> C:
74
109
  """
75
110
  Get total gradients from individual loss gradients and weights
76
111
  for each parameter
@@ -102,7 +137,7 @@ class AbstractLoss(eqx.Module):
102
137
  loss_terms: PyTree,
103
138
  stored_loss_terms: PyTree,
104
139
  grad_terms: PyTree,
105
- key: Key,
140
+ key: PRNGKeyArray,
106
141
  ) -> Self:
107
142
  """
108
143
  Update the loss weights according to a predefined scheme
@@ -1,38 +1,17 @@
1
1
  from typing import TypeVar, Generic
2
- from dataclasses import fields
3
- import equinox as eqx
2
+
3
+ from jinns.utils._ItemizableModule import ItemizableModule
4
4
 
5
5
  T = TypeVar("T")
6
6
 
7
7
 
8
- class XDEComponentsAbstract(eqx.Module, Generic[T]):
9
- """
10
- Provides a template for ODE components with generic types.
11
- One can inherit to specialize and add methods and attributes
12
- We do not enforce keyword only to avoid being to verbose (this then can
13
- work like a tuple)
14
- """
15
-
16
- def items(self):
17
- """
18
- For the dataclass to be iterated like a dictionary.
19
- Practical and retrocompatible with old code when loss components were
20
- dictionaries
21
- """
22
- return {
23
- field.name: getattr(self, field.name)
24
- for field in fields(self)
25
- if getattr(self, field.name) is not None
26
- }.items()
27
-
28
-
29
- class ODEComponents(XDEComponentsAbstract[T]):
8
+ class ODEComponents(ItemizableModule, Generic[T]):
30
9
  dyn_loss: T
31
10
  initial_condition: T
32
11
  observations: T
33
12
 
34
13
 
35
- class PDEStatioComponents(XDEComponentsAbstract[T]):
14
+ class PDEStatioComponents(ItemizableModule, Generic[T]):
36
15
  dyn_loss: T
37
16
  norm_loss: T
38
17
  boundary_loss: T
@@ -4,18 +4,17 @@ A collection of specific weight update schemes in jinns
4
4
 
5
5
  from __future__ import annotations
6
6
  from typing import TYPE_CHECKING
7
- from jaxtyping import Array, Key
7
+ from jaxtyping import Array, PRNGKeyArray
8
8
  import jax.numpy as jnp
9
9
  import jax
10
10
  import equinox as eqx
11
11
 
12
12
  if TYPE_CHECKING:
13
- from jinns.loss._loss_weights import AbstractLossWeights
14
- from jinns.utils._types import AnyLossComponents
13
+ from jinns.utils._types import AnyLossComponents, AnyLossWeights
15
14
 
16
15
 
17
16
  def soft_adapt(
18
- loss_weights: AbstractLossWeights,
17
+ loss_weights: AnyLossWeights,
19
18
  iteration_nb: int,
20
19
  loss_terms: AnyLossComponents,
21
20
  stored_loss_terms: AnyLossComponents,
@@ -58,11 +57,11 @@ def soft_adapt(
58
57
 
59
58
 
60
59
  def ReLoBRaLo(
61
- loss_weights: AbstractLossWeights,
60
+ loss_weights: AnyLossWeights,
62
61
  iteration_nb: int,
63
62
  loss_terms: AnyLossComponents,
64
63
  stored_loss_terms: AnyLossComponents,
65
- key: Key,
64
+ key: PRNGKeyArray,
66
65
  decay_factor: float = 0.9,
67
66
  tau: float = 1, ## referred to as temperature in the article
68
67
  p: float = 0.9,
@@ -146,7 +145,7 @@ def ReLoBRaLo(
146
145
 
147
146
 
148
147
  def lr_annealing(
149
- loss_weights: AbstractLossWeights,
148
+ loss_weights: AnyLossWeights,
150
149
  grad_terms: AnyLossComponents,
151
150
  decay_factor: float = 0.9, # 0.9 is the recommended value from the article
152
151
  eps: float = 1e-6,
@@ -3,81 +3,80 @@ Formalize the loss weights data structure
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from dataclasses import fields
7
6
 
8
7
  from jaxtyping import Array
9
8
  import jax.numpy as jnp
10
9
  import equinox as eqx
11
10
 
11
+ from jinns.loss._loss_components import (
12
+ ODEComponents,
13
+ PDEStatioComponents,
14
+ PDENonStatioComponents,
15
+ )
12
16
 
13
- def lw_converter(x):
17
+
18
+ def lw_converter(x: Array | None) -> Array | None:
14
19
  if x is None:
15
20
  return x
16
21
  else:
17
22
  return jnp.asarray(x)
18
23
 
19
24
 
20
- class AbstractLossWeights(eqx.Module):
25
+ class LossWeightsODE(ODEComponents[Array | None]):
21
26
  """
22
- An abstract class, currently only useful for type hints
23
-
24
- TODO in the future maybe loss weights could be subclasses of
25
- XDEComponentsAbstract?
27
+ Value given at initialization is converted to a jnp.array orunmodified if None.
28
+ This means that at initialization, the user can pass a float or int
26
29
  """
27
30
 
28
- def items(self):
29
- """
30
- For the dataclass to be iterated like a dictionary.
31
- Practical and retrocompatible with old code when loss components were
32
- dictionaries
33
- """
34
- return {
35
- field.name: getattr(self, field.name)
36
- for field in fields(self)
37
- if getattr(self, field.name) is not None
38
- }.items()
39
-
40
-
41
- class LossWeightsODE(AbstractLossWeights):
42
- dyn_loss: Array | float | None = eqx.field(
31
+ dyn_loss: Array | None = eqx.field(
43
32
  kw_only=True, default=None, converter=lw_converter
44
33
  )
45
- initial_condition: Array | float | None = eqx.field(
34
+ initial_condition: Array | None = eqx.field(
46
35
  kw_only=True, default=None, converter=lw_converter
47
36
  )
48
- observations: Array | float | None = eqx.field(
37
+ observations: Array | None = eqx.field(
49
38
  kw_only=True, default=None, converter=lw_converter
50
39
  )
51
40
 
52
41
 
53
- class LossWeightsPDEStatio(AbstractLossWeights):
54
- dyn_loss: Array | float | None = eqx.field(
42
+ class LossWeightsPDEStatio(PDEStatioComponents[Array | None]):
43
+ """
44
+ Value given at initialization is converted to a jnp.array orunmodified if None.
45
+ This means that at initialization, the user can pass a float or int
46
+ """
47
+
48
+ dyn_loss: Array | None = eqx.field(
55
49
  kw_only=True, default=None, converter=lw_converter
56
50
  )
57
- norm_loss: Array | float | None = eqx.field(
51
+ norm_loss: Array | None = eqx.field(
58
52
  kw_only=True, default=None, converter=lw_converter
59
53
  )
60
- boundary_loss: Array | float | None = eqx.field(
54
+ boundary_loss: Array | None = eqx.field(
61
55
  kw_only=True, default=None, converter=lw_converter
62
56
  )
63
- observations: Array | float | None = eqx.field(
57
+ observations: Array | None = eqx.field(
64
58
  kw_only=True, default=None, converter=lw_converter
65
59
  )
66
60
 
67
61
 
68
- class LossWeightsPDENonStatio(AbstractLossWeights):
69
- dyn_loss: Array | float | None = eqx.field(
62
+ class LossWeightsPDENonStatio(PDENonStatioComponents[Array | None]):
63
+ """
64
+ Value given at initialization is converted to a jnp.array orunmodified if None.
65
+ This means that at initialization, the user can pass a float or int
66
+ """
67
+
68
+ dyn_loss: Array | None = eqx.field(
70
69
  kw_only=True, default=None, converter=lw_converter
71
70
  )
72
- norm_loss: Array | float | None = eqx.field(
71
+ norm_loss: Array | None = eqx.field(
73
72
  kw_only=True, default=None, converter=lw_converter
74
73
  )
75
- boundary_loss: Array | float | None = eqx.field(
74
+ boundary_loss: Array | None = eqx.field(
76
75
  kw_only=True, default=None, converter=lw_converter
77
76
  )
78
- observations: Array | float | None = eqx.field(
77
+ observations: Array | None = eqx.field(
79
78
  kw_only=True, default=None, converter=lw_converter
80
79
  )
81
- initial_condition: Array | float | None = eqx.field(
80
+ initial_condition: Array | None = eqx.field(
82
81
  kw_only=True, default=None, converter=lw_converter
83
82
  )
@@ -3,7 +3,6 @@ from typing import Literal, Any
3
3
  from jaxtyping import Array
4
4
  import equinox as eqx
5
5
 
6
- from jinns.nn._utils import _PyTree_to_Params
7
6
  from jinns.parameters._params import Params
8
7
 
9
8
 
@@ -17,6 +16,5 @@ class AbstractPINN(eqx.Module):
17
16
  eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
18
17
 
19
18
  @abc.abstractmethod
20
- @_PyTree_to_Params
21
19
  def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
22
20
  pass
jinns/nn/_hyperpinn.py CHANGED
@@ -7,18 +7,17 @@ from __future__ import annotations
7
7
 
8
8
  import warnings
9
9
  from dataclasses import InitVar
10
- from typing import Callable, Literal, Self, Union, Any, cast, overload
10
+ from typing import Callable, Literal, Self, Union, Any, cast
11
11
  from math import prod
12
12
  import jax
13
13
  import jax.numpy as jnp
14
- from jaxtyping import Array, Float, PyTree, Key
14
+ from jaxtyping import PRNGKeyArray, Array, Float, PyTree
15
15
  import equinox as eqx
16
16
  import numpy as onp
17
17
 
18
18
  from jinns.nn._pinn import PINN
19
19
  from jinns.nn._mlp import MLP
20
20
  from jinns.parameters._params import Params
21
- from jinns.nn._utils import _PyTree_to_Params
22
21
 
23
22
 
24
23
  def _get_param_nb(
@@ -138,6 +137,32 @@ class HyperPINN(PINN):
138
137
  jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
139
138
  )
140
139
 
140
+ # For the record. We exhibited that the jnp.split was a serious time
141
+ # bottleneck. However none of the approaches below improved the speed.
142
+ # Moreover, this operation is not well implemented by a triton kernel
143
+ # apparently so such an optim is not an option.
144
+ # 1)
145
+ # pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
146
+ # jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
147
+ # )
148
+ # 2)
149
+ # pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
150
+ # [jax.lax.slice(hyper_output, (s,), (e,)).reshape(r) for s, e, r in
151
+ # zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum,
152
+ # self.pinn_params_shapes)]
153
+ # )
154
+ # 3)
155
+ # pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
156
+ # [hyper_output[s:e].reshape(r) for s, e, r in
157
+ # zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum,
158
+ # self.pinn_params_shapes)]
159
+ # )
160
+ # 4)
161
+ # pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
162
+ # [jax.lax.dynamic_slice(hyper_output, (s,), (size,)) for s, size in
163
+ # zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum_size)]
164
+ # )
165
+
141
166
  return jax.tree.map(
142
167
  lambda a, b: a.reshape(b.shape),
143
168
  pinn_params_flat,
@@ -145,17 +170,6 @@ class HyperPINN(PINN):
145
170
  is_leaf=lambda x: isinstance(x, jnp.ndarray),
146
171
  )
147
172
 
148
- @overload
149
- @_PyTree_to_Params
150
- def __call__(
151
- self,
152
- inputs: Float[Array, " input_dim"],
153
- params: PyTree,
154
- *args,
155
- **kwargs,
156
- ) -> Float[Array, " output_dim"]: ...
157
-
158
- @_PyTree_to_Params
159
173
  def __call__(
160
174
  self,
161
175
  inputs: Float[Array, " input_dim"],
@@ -175,13 +189,10 @@ class HyperPINN(PINN):
175
189
  # DataGenerators)
176
190
  inputs = inputs[None]
177
191
 
178
- # try:
179
192
  hyper = eqx.combine(params.nn_params, self.static_hyper)
180
- # except (KeyError, AttributeError, TypeError) as e: # give more flexibility
181
- # hyper = eqx.combine(params, self.static_hyper)
182
193
 
183
194
  eq_params_batch = jnp.concatenate(
184
- [params.eq_params[k].flatten() for k in self.hyperparams],
195
+ [getattr(params.eq_params, k).flatten() for k in self.hyperparams],
185
196
  axis=0,
186
197
  )
187
198
 
@@ -202,12 +213,13 @@ class HyperPINN(PINN):
202
213
  @classmethod
203
214
  def create(
204
215
  cls,
216
+ *,
205
217
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
206
218
  hyperparams: list[str],
207
219
  hypernet_input_size: int,
220
+ key: PRNGKeyArray | None = None,
208
221
  eqx_network: eqx.nn.MLP | MLP | None = None,
209
222
  eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
210
- key: Key = None,
211
223
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
212
224
  eqx_list_hyper: (
213
225
  tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
@@ -359,10 +371,10 @@ class HyperPINN(PINN):
359
371
 
360
372
  ### Now we finetune the hypernetwork architecture
361
373
 
362
- key, subkey = jax.random.split(key, 2)
374
+ subkey1, subkey2 = jax.random.split(key, 2)
363
375
  # with warnings.catch_warnings():
364
376
  # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
365
- eqx_network = MLP(key=subkey, eqx_list=eqx_list)
377
+ eqx_network = MLP(key=subkey1, eqx_list=eqx_list)
366
378
  # quick partitioning to get the params to get the correct number of neurons
367
379
  # for the last layer of hyper network
368
380
  params_mlp, _ = eqx.partition(eqx_network, eqx.is_inexact_array)
@@ -405,10 +417,9 @@ class HyperPINN(PINN):
405
417
  + eqx_list_hyper[2:]
406
418
  ),
407
419
  )
408
- key, subkey = jax.random.split(key, 2)
409
420
  # with warnings.catch_warnings():
410
421
  # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
411
- eqx_hyper_network = cast(MLP, MLP(key=subkey, eqx_list=eqx_list_hyper))
422
+ eqx_hyper_network = cast(MLP, MLP(key=subkey2, eqx_list=eqx_list_hyper))
412
423
 
413
424
  ### End of finetuning the hypernetwork architecture
414
425
 
jinns/nn/_mlp.py CHANGED
@@ -9,7 +9,7 @@ from dataclasses import InitVar
9
9
  import jax
10
10
  import equinox as eqx
11
11
  from typing import Protocol
12
- from jaxtyping import Array, Key, PyTree, Float
12
+ from jaxtyping import Array, PRNGKeyArray, PyTree, Float
13
13
 
14
14
  from jinns.parameters._params import Params
15
15
  from jinns.nn._pinn import PINN
@@ -33,7 +33,7 @@ class MLP(eqx.Module):
33
33
 
34
34
  Parameters
35
35
  ----------
36
- key : InitVar[Key]
36
+ key : InitVar[PRNGKeyArray]
37
37
  A jax random key for the layer initializations.
38
38
  eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
39
39
  A tuple of tuples of successive equinox modules and activation functions to
@@ -52,7 +52,7 @@ class MLP(eqx.Module):
52
52
  )`.
53
53
  """
54
54
 
55
- key: InitVar[Key] = eqx.field(kw_only=True)
55
+ key: InitVar[PRNGKeyArray] = eqx.field(kw_only=True)
56
56
  eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
57
57
  eqx.field(kw_only=True)
58
58
  )
@@ -94,9 +94,10 @@ class PINN_MLP(PINN):
94
94
  @classmethod
95
95
  def create(
96
96
  cls,
97
+ *,
97
98
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
99
+ key: PRNGKeyArray | None = None,
98
100
  eqx_network: eqx.nn.MLP | MLP | None = None,
99
- key: Key = None,
100
101
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
101
102
  input_transform: (
102
103
  Callable[
jinns/nn/_pinn.py CHANGED
@@ -4,14 +4,13 @@ Implement abstract class for PINN architectures
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable, Union, Any, Literal, overload
7
+ from typing import Callable, Union, Any, Literal
8
8
  from dataclasses import InitVar
9
9
  import equinox as eqx
10
10
  from jaxtyping import Float, Array, PyTree
11
11
  import jax.numpy as jnp
12
12
  from jinns.parameters._params import Params
13
13
  from jinns.nn._abstract_pinn import AbstractPINN
14
- from jinns.nn._utils import _PyTree_to_Params
15
14
 
16
15
 
17
16
  class PINN(AbstractPINN):
@@ -157,17 +156,6 @@ class PINN(AbstractPINN):
157
156
 
158
157
  return network(inputs)
159
158
 
160
- @overload
161
- @_PyTree_to_Params
162
- def __call__(
163
- self,
164
- inputs: Float[Array, " input_dim"],
165
- params: PyTree,
166
- *args,
167
- **kwargs,
168
- ) -> Float[Array, " output_dim"]: ...
169
-
170
- @_PyTree_to_Params
171
159
  def __call__(
172
160
  self,
173
161
  inputs: Float[Array, " input_dim"],
@@ -180,9 +168,6 @@ class PINN(AbstractPINN):
180
168
  `params` and `self.static` to recreate the callable eqx.Module
181
169
  architecture. The rest of the content of this function is dependent on
182
170
  the network.
183
-
184
- Note that that thanks to the decorator, params can also directly be the
185
- PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
186
171
  """
187
172
 
188
173
  if len(inputs.shape) == 0:
jinns/nn/_ppinn.py CHANGED
@@ -4,18 +4,17 @@ Implements utility function to create PINNs
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable, Literal, Self, cast, overload
7
+ from typing import Callable, Literal, Self, cast
8
8
  from dataclasses import InitVar
9
9
  import jax
10
10
  import jax.numpy as jnp
11
11
  import equinox as eqx
12
12
 
13
- from jaxtyping import Array, Key, Float, PyTree
13
+ from jaxtyping import Array, Float, PRNGKeyArray
14
14
 
15
15
  from jinns.parameters._params import Params
16
16
  from jinns.nn._pinn import PINN
17
17
  from jinns.nn._mlp import MLP
18
- from jinns.nn._utils import _PyTree_to_Params
19
18
 
20
19
 
21
20
  class PPINN_MLP(PINN):
@@ -85,17 +84,6 @@ class PPINN_MLP(PINN):
85
84
  self.init_params = self.init_params + (params,)
86
85
  self.static = self.static + (static,)
87
86
 
88
- @overload
89
- @_PyTree_to_Params
90
- def __call__(
91
- self,
92
- inputs: Float[Array, " input_dim"],
93
- params: PyTree,
94
- *args,
95
- **kwargs,
96
- ) -> Float[Array, " output_dim"]: ...
97
-
98
- @_PyTree_to_Params
99
87
  def __call__(
100
88
  self,
101
89
  inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
@@ -135,9 +123,10 @@ class PPINN_MLP(PINN):
135
123
  @classmethod
136
124
  def create(
137
125
  cls,
126
+ *,
127
+ key: PRNGKeyArray | None = None,
138
128
  eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
139
129
  eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
140
- key: Key = None,
141
130
  eqx_list_list: (
142
131
  list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
143
132
  ) = None,
@@ -225,7 +214,7 @@ class PPINN_MLP(PINN):
225
214
 
226
215
  eqx_network_list = []
227
216
  for eqx_list in eqx_list_list:
228
- key, subkey = jax.random.split(key, 2)
217
+ key, subkey = jax.random.split(key, 2) # type: ignore
229
218
  eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
230
219
 
231
220
  ppinn = cls(
jinns/nn/_save_load.py CHANGED
@@ -3,6 +3,7 @@ Implements save and load functions
3
3
  """
4
4
 
5
5
  from typing import Callable, Literal
6
+ from dataclasses import fields
6
7
  import pickle
7
8
  import jax
8
9
  import equinox as eqx
@@ -130,8 +131,14 @@ def save_pinn(
130
131
  u = eqx.tree_at(lambda m: m.init_params, u, params)
131
132
  eqx.tree_serialise_leaves(filename + "-module.eqx", u)
132
133
 
134
+ # The class EqParams is malformed for pickling, hence we pickle it under
135
+ # its dictionary form
136
+ eq_params_as_dict = {
137
+ k.name: getattr(params.eq_params, k.name) for k in fields(params.eq_params)
138
+ }
139
+
133
140
  with open(filename + "-eq_params.pkl", "wb") as f:
134
- pickle.dump(params.eq_params, f)
141
+ pickle.dump(eq_params_as_dict, f)
135
142
 
136
143
  kwargs_creation = kwargs_creation.copy() # avoid side-effect that would be
137
144
  # very probably harmless anyway
@@ -187,9 +194,9 @@ def load_pinn(
187
194
  try:
188
195
  with open(filename + "-eq_params.pkl", "rb") as f:
189
196
  eq_params_reloaded = pickle.load(f)
190
- except FileNotFoundError:
191
- eq_params_reloaded = {}
192
- print("No pickle file for equation parameters found!")
197
+ except FileNotFoundError as e:
198
+ raise e
199
+
193
200
  kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
194
201
  if type_ == "pinn_mlp":
195
202
  # next line creates a shallow model, the jax arrays are just shapes and