jinns 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. jinns/data/_AbstractDataGenerator.py +1 -1
  2. jinns/data/_Batchs.py +47 -13
  3. jinns/data/_CubicMeshPDENonStatio.py +55 -34
  4. jinns/data/_CubicMeshPDEStatio.py +63 -35
  5. jinns/data/_DataGeneratorODE.py +48 -22
  6. jinns/data/_DataGeneratorObservations.py +75 -32
  7. jinns/data/_DataGeneratorParameter.py +152 -101
  8. jinns/data/__init__.py +2 -1
  9. jinns/data/_utils.py +22 -10
  10. jinns/loss/_DynamicLoss.py +21 -20
  11. jinns/loss/_DynamicLossAbstract.py +51 -36
  12. jinns/loss/_LossODE.py +139 -184
  13. jinns/loss/_LossPDE.py +440 -358
  14. jinns/loss/_abstract_loss.py +60 -25
  15. jinns/loss/_loss_components.py +4 -25
  16. jinns/loss/_loss_weight_updates.py +6 -7
  17. jinns/loss/_loss_weights.py +34 -35
  18. jinns/nn/_abstract_pinn.py +0 -2
  19. jinns/nn/_hyperpinn.py +34 -23
  20. jinns/nn/_mlp.py +5 -4
  21. jinns/nn/_pinn.py +1 -16
  22. jinns/nn/_ppinn.py +5 -16
  23. jinns/nn/_save_load.py +11 -4
  24. jinns/nn/_spinn.py +1 -16
  25. jinns/nn/_spinn_mlp.py +5 -5
  26. jinns/nn/_utils.py +33 -38
  27. jinns/parameters/__init__.py +3 -1
  28. jinns/parameters/_derivative_keys.py +99 -41
  29. jinns/parameters/_params.py +50 -25
  30. jinns/solver/_solve.py +3 -3
  31. jinns/utils/_DictToModuleMeta.py +66 -0
  32. jinns/utils/_ItemizableModule.py +19 -0
  33. jinns/utils/__init__.py +2 -1
  34. jinns/utils/_types.py +25 -15
  35. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
  36. jinns-1.6.0.dist-info/RECORD +57 -0
  37. jinns-1.5.1.dist-info/RECORD +0 -55
  38. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
  39. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
  40. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
  41. {jinns-1.5.1.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Union, Callable, Any, Literal, overload
2
+ from typing import Union, Callable, Any, Literal
3
3
  from dataclasses import InitVar
4
4
  from jaxtyping import PyTree, Float, Array
5
5
  import jax
@@ -8,7 +8,6 @@ import equinox as eqx
8
8
 
9
9
  from jinns.parameters._params import Params
10
10
  from jinns.nn._abstract_pinn import AbstractPINN
11
- from jinns.nn._utils import _PyTree_to_Params
12
11
 
13
12
 
14
13
  class SPINN(AbstractPINN):
@@ -72,17 +71,6 @@ class SPINN(AbstractPINN):
72
71
  eqx_spinn_network, self.filter_spec
73
72
  )
74
73
 
75
- @overload
76
- @_PyTree_to_Params
77
- def __call__(
78
- self,
79
- inputs: Float[Array, " input_dim"],
80
- params: PyTree,
81
- *args,
82
- **kwargs,
83
- ) -> Float[Array, " output_dim"]: ...
84
-
85
- @_PyTree_to_Params
86
74
  def __call__(
87
75
  self,
88
76
  t_x: Float[Array, " batch_size 1+dim"],
@@ -94,10 +82,7 @@ class SPINN(AbstractPINN):
94
82
  Note that that thanks to the decorator, params can also directly be the
95
83
  PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
96
84
  """
97
- # try:
98
85
  spinn = eqx.combine(params.nn_params, self.static)
99
- # except (KeyError, AttributeError, TypeError) as e:
100
- # spinn = eqx.combine(params, self.static)
101
86
  v_model = jax.vmap(spinn)
102
87
  res = v_model(t_x) # type: ignore
103
88
 
jinns/nn/_spinn_mlp.py CHANGED
@@ -8,7 +8,7 @@ from typing import Callable, Literal, Self, Union, Any, TypeGuard
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import equinox as eqx
11
- from jaxtyping import Key, Array, Float, PyTree
11
+ from jaxtyping import PRNGKeyArray, Array, Float, PyTree
12
12
 
13
13
  from jinns.nn._mlp import MLP
14
14
  from jinns.nn._spinn import SPINN
@@ -20,7 +20,7 @@ class SMLP(eqx.Module):
20
20
 
21
21
  Parameters
22
22
  ----------
23
- key : InitVar[Key]
23
+ key : InitVar[PRNGKeyArray]
24
24
  A jax random key for the layer initializations.
25
25
  d : int
26
26
  The number of dimensions to treat separately, including time `t` if
@@ -42,7 +42,7 @@ class SMLP(eqx.Module):
42
42
  )`.
43
43
  """
44
44
 
45
- key: InitVar[Key] = eqx.field(kw_only=True)
45
+ key: InitVar[PRNGKeyArray] = eqx.field(kw_only=True)
46
46
  eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
47
47
  eqx.field(kw_only=True)
48
48
  )
@@ -74,7 +74,7 @@ class SPINN_MLP(SPINN):
74
74
  @classmethod
75
75
  def create(
76
76
  cls,
77
- key: Key,
77
+ key: PRNGKeyArray,
78
78
  d: int,
79
79
  r: int,
80
80
  eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
@@ -93,7 +93,7 @@ class SPINN_MLP(SPINN):
93
93
 
94
94
  Parameters
95
95
  ----------
96
- key : Key
96
+ key : PRNGKeyArray
97
97
  A JAX random key that will be used to initialize the network parameters
98
98
  d : int
99
99
  The number of dimensions to treat separately.
jinns/nn/_utils.py CHANGED
@@ -1,38 +1,33 @@
1
- from typing import Any, ParamSpec, Callable, Concatenate
2
- from jaxtyping import PyTree, Array
3
- from jinns.parameters._params import Params
4
-
5
-
6
- P = ParamSpec("P")
7
-
8
-
9
- def _PyTree_to_Params(
10
- call_fun: Callable[
11
- Concatenate[Any, Any, PyTree | Params[Array], P],
12
- Any,
13
- ],
14
- ) -> Callable[
15
- Concatenate[Any, Any, PyTree | Params[Array], P],
16
- Any,
17
- ]:
18
- """
19
- Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
20
- authorizes the __call__ with `params` being directly be the
21
- PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
22
-
23
- This generic approach enables to cleanly handle type hints, up to the small
24
- effort required to understand type hints for decorators (ie ParamSpec).
25
- """
26
-
27
- def wrapper(
28
- self: Any,
29
- inputs: Any,
30
- params: PyTree | Params[Array],
31
- *args: P.args,
32
- **kwargs: P.kwargs,
33
- ):
34
- if isinstance(params, PyTree) and not isinstance(params, Params):
35
- params = Params(nn_params=params, eq_params={})
36
- return call_fun(self, inputs, params, *args, **kwargs)
37
-
38
- return wrapper
1
+ # P = ParamSpec("P")
2
+ #
3
+ #
4
+ # def _PyTree_to_Params(
5
+ # call_fun: Callable[
6
+ # Concatenate[Any, Any, PyTree | Params[Array], P],
7
+ # Any,
8
+ # ],
9
+ # ) -> Callable[
10
+ # Concatenate[Any, Any, PyTree | Params[Array], P],
11
+ # Any,
12
+ # ]:
13
+ # """
14
+ # Decorator to be used around __call__ functions of PINNs, SPINNs, etc. It
15
+ # authorizes the __call__ with `params` being directly be the
16
+ # PyTree (SPINN, PINN_MLP, ...) that we get out of `eqx.combine`
17
+ #
18
+ # This generic approach enables to cleanly handle type hints, up to the small
19
+ # effort required to understand type hints for decorators (ie ParamSpec).
20
+ # """
21
+ #
22
+ # def wrapper(
23
+ # self: Any,
24
+ # inputs: Any,
25
+ # params: PyTree | Params[Array],
26
+ # *args: P.args,
27
+ # **kwargs: P.kwargs,
28
+ # ):
29
+ # if isinstance(params, PyTree) and not isinstance(params, Params):
30
+ # params = Params(nn_params=params, eq_params={})
31
+ # return call_fun(self, inputs, params, *args, **kwargs)
32
+ #
33
+ # return wrapper
@@ -1,4 +1,4 @@
1
- from ._params import Params
1
+ from ._params import EqParams, Params, update_eq_params
2
2
  from ._derivative_keys import (
3
3
  DerivativeKeysODE,
4
4
  DerivativeKeysPDEStatio,
@@ -6,8 +6,10 @@ from ._derivative_keys import (
6
6
  )
7
7
 
8
8
  __all__ = [
9
+ "EqParams",
9
10
  "Params",
10
11
  "DerivativeKeysODE",
11
12
  "DerivativeKeysPDEStatio",
12
13
  "DerivativeKeysPDENonStatio",
14
+ "update_eq_params",
13
15
  ]
@@ -19,13 +19,10 @@ def _get_masked_parameters(
19
19
  """
20
20
  # start with a params object with True everywhere. We will update to False
21
21
  # for parameters wrt which we do want not to differentiate the loss
22
- diff_params = jax.tree.map(
23
- lambda x: True,
24
- params,
25
- is_leaf=lambda x: isinstance(x, eqx.Module)
26
- and not isinstance(x, Params), # do not travers nn_params, more
27
- # granularity could be imagined here, in the future
28
- )
22
+ diff_params = Params(
23
+ nn_params=True, eq_params=jax.tree.map(lambda _: True, params.eq_params)
24
+ ) # do not travers nn_params, more
25
+ # granularity could be imagined here, in the future
29
26
  if derivative_mask_str == "both":
30
27
  return diff_params
31
28
  if derivative_mask_str == "eq_params":
@@ -60,7 +57,7 @@ class DerivativeKeysODE(eqx.Module):
60
57
 
61
58
  1. For unspecified loss term, the default is to differentiate with
62
59
  respect to `"nn_params"` only.
63
- 2. No granularity inside `Params.nn_params` is currently supported.
60
+ 2. No granularity inside `Params.nn_params` is currently supported. An easy way to do freeze part of a custom PINN module is to use `jax.lax.stop_gradient` as explained [here](https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers).
64
61
  3. Note that the main Params object of the problem is mandatory if initialization via `from_str()`.
65
62
 
66
63
  A typical specification is of the form:
@@ -95,38 +92,52 @@ class DerivativeKeysODE(eqx.Module):
95
92
  infer the content of `Params.eq_params`.
96
93
  """
97
94
 
98
- dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
99
- observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
100
- initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
95
+ dyn_loss: Params[bool]
96
+ observations: Params[bool]
97
+ initial_condition: Params[bool]
101
98
 
102
- params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
99
+ params: InitVar[Params[Array] | None]
103
100
 
104
- def __post_init__(self, params: Params[Array] | None = None):
101
+ def __init__(
102
+ self,
103
+ *,
104
+ dyn_loss: Params[bool] | None = None,
105
+ observations: Params[bool] | None = None,
106
+ initial_condition: Params[bool] | None = None,
107
+ params: Params[Array] | None = None,
108
+ ):
109
+ super().__init__()
105
110
  if params is None and (
106
- self.dyn_loss is None
107
- or self.observations is None
108
- or self.initial_condition is None
111
+ dyn_loss is None or observations is None or initial_condition is None
109
112
  ):
110
113
  raise ValueError(
111
114
  "params cannot be None since at least one loss "
112
115
  "term has an undefined derivative key Params PyTree"
113
116
  )
114
- if self.dyn_loss is None:
117
+ if dyn_loss is None:
115
118
  if params is None:
116
119
  raise ValueError("self.dyn_loss is None, hence params should be passed")
117
120
  self.dyn_loss = _get_masked_parameters("nn_params", params)
118
- if self.observations is None:
121
+ else:
122
+ self.dyn_loss = dyn_loss
123
+
124
+ if observations is None:
119
125
  if params is None:
120
126
  raise ValueError(
121
127
  "self.observations is None, hence params should be passed"
122
128
  )
123
129
  self.observations = _get_masked_parameters("nn_params", params)
124
- if self.initial_condition is None:
130
+ else:
131
+ self.observations = observations
132
+
133
+ if initial_condition is None:
125
134
  if params is None:
126
135
  raise ValueError(
127
136
  "self.initial_condition is None, hence params should be passed"
128
137
  )
129
138
  self.initial_condition = _get_masked_parameters("nn_params", params)
139
+ else:
140
+ self.initial_condition = initial_condition
130
141
 
131
142
  @classmethod
132
143
  def from_str(
@@ -216,36 +227,56 @@ class DerivativeKeysPDEStatio(eqx.Module):
216
227
  content of `Params.eq_params`.
217
228
  """
218
229
 
219
- dyn_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
220
- observations: Params[bool] | None = eqx.field(kw_only=True, default=None)
221
- boundary_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
222
- norm_loss: Params[bool] | None = eqx.field(kw_only=True, default=None)
230
+ dyn_loss: Params[bool] = eqx.field(kw_only=True, default=None)
231
+ observations: Params[bool] = eqx.field(kw_only=True, default=None)
232
+ boundary_loss: Params[bool] = eqx.field(kw_only=True, default=None)
233
+ norm_loss: Params[bool] = eqx.field(kw_only=True, default=None)
223
234
 
224
235
  params: InitVar[Params[Array] | None] = eqx.field(kw_only=True, default=None)
225
236
 
226
- def __post_init__(self, params: Params[Array] | None = None):
227
- if self.dyn_loss is None:
237
+ def __init__(
238
+ self,
239
+ *,
240
+ dyn_loss: Params[bool] | None = None,
241
+ observations: Params[bool] | None = None,
242
+ boundary_loss: Params[bool] | None = None,
243
+ norm_loss: Params[bool] | None = None,
244
+ params: Params[Array] | None = None,
245
+ ):
246
+ super().__init__()
247
+ if dyn_loss is None:
228
248
  if params is None:
229
249
  raise ValueError("self.dyn_loss is None, hence params should be passed")
230
250
  self.dyn_loss = _get_masked_parameters("nn_params", params)
231
- if self.observations is None:
251
+ else:
252
+ self.dyn_loss = dyn_loss
253
+
254
+ if observations is None:
232
255
  if params is None:
233
256
  raise ValueError(
234
257
  "self.observations is None, hence params should be passed"
235
258
  )
236
259
  self.observations = _get_masked_parameters("nn_params", params)
237
- if self.boundary_loss is None:
260
+ else:
261
+ self.observations = observations
262
+
263
+ if boundary_loss is None:
238
264
  if params is None:
239
265
  raise ValueError(
240
266
  "self.boundary_loss is None, hence params should be passed"
241
267
  )
242
268
  self.boundary_loss = _get_masked_parameters("nn_params", params)
243
- if self.norm_loss is None:
269
+ else:
270
+ self.boundary_loss = boundary_loss
271
+
272
+ if norm_loss is None:
244
273
  if params is None:
245
274
  raise ValueError(
246
275
  "self.norm_loss is None, hence params should be passed"
247
276
  )
248
277
  self.norm_loss = _get_masked_parameters("nn_params", params)
278
+ else:
279
+ self.norm_loss = norm_loss
249
280
 
250
281
  @classmethod
251
282
  def from_str(
@@ -344,16 +375,33 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
344
375
  content of `Params.eq_params`.
345
376
  """
346
377
 
347
- initial_condition: Params[bool] | None = eqx.field(kw_only=True, default=None)
348
-
349
- def __post_init__(self, params: Params[Array] | None = None):
350
- super().__post_init__(params=params)
351
- if self.initial_condition is None:
378
+ initial_condition: Params[bool] = eqx.field(kw_only=True, default=None)
379
+
380
+ def __init__(
381
+ self,
382
+ *,
383
+ dyn_loss: Params[bool] | None = None,
384
+ observations: Params[bool] | None = None,
385
+ boundary_loss: Params[bool] | None = None,
386
+ norm_loss: Params[bool] | None = None,
387
+ initial_condition: Params[bool] | None = None,
388
+ params: Params[Array] | None = None,
389
+ ):
390
+ super().__init__(
391
+ dyn_loss=dyn_loss,
392
+ observations=observations,
393
+ boundary_loss=boundary_loss,
394
+ norm_loss=norm_loss,
395
+ params=params,
396
+ )
397
+ if initial_condition is None:
352
398
  if params is None:
353
399
  raise ValueError(
354
400
  "self.initial_condition is None, hence params should be passed"
355
401
  )
356
402
  self.initial_condition = _get_masked_parameters("nn_params", params)
403
+ else:
404
+ self.initial_condition = initial_condition
357
405
 
358
406
  @classmethod
359
407
  def from_str(
@@ -432,7 +480,9 @@ class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
432
480
  )
433
481
 
434
482
 
435
- def _set_derivatives(params, derivative_keys):
483
+ def _set_derivatives(
484
+ params: Params[Array], derivative_keys: Params[bool]
485
+ ) -> Params[Array]:
436
486
  """
437
487
  We construct an eqx.Module with the fields of derivative_keys, each field
438
488
  has a copy of the params with appropriate derivatives set
@@ -448,13 +498,21 @@ def _set_derivatives(params, derivative_keys):
448
498
  `Params(nn_params=True | False, eq_params={"alpha":True | False,
449
499
  "beta":True | False})`.
450
500
  """
451
- return jax.tree.map(
452
- lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
453
- params_,
454
- derivative_mask,
455
- is_leaf=lambda x: isinstance(x, eqx.Module)
456
- and not isinstance(x, Params), # do not travers nn_params, more
457
- # granularity could be imagined here, in the future
501
+
502
+ return Params(
503
+ nn_params=jax.lax.cond(
504
+ derivative_mask.nn_params,
505
+ lambda p: p,
506
+ jax.lax.stop_gradient,
507
+ params_.nn_params,
508
+ ),
509
+ eq_params=jax.tree.map(
510
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
511
+ params_.eq_params,
512
+ derivative_mask.eq_params,
513
+ ),
458
514
  )
515
+ # NOTE that currently we do not travers nn_params, more
516
+ # granularity could be imagined here, in the future
459
517
 
460
518
  return _set_derivatives_(params, derivative_keys)
@@ -2,10 +2,12 @@
2
2
  Formalize the data structure for the parameters
3
3
  """
4
4
 
5
+ from dataclasses import fields
5
6
  from typing import Generic, TypeVar
6
- import jax
7
7
  import equinox as eqx
8
- from jaxtyping import Array, PyTree, Float
8
+ from jaxtyping import Array, PyTree
9
+
10
+ from jinns.utils._DictToModuleMeta import DictToModuleMeta
9
11
 
10
12
  T = TypeVar("T") # the generic type for what is in the Params PyTree because we
11
13
  # have possibly Params of Arrays, boolean, ...
@@ -19,6 +21,16 @@ T = TypeVar("T") # the generic type for what is in the Params PyTree because we
19
21
  ### see https://github.com/patrick-kidger/equinox/pull/1043/commits/f88e62ab809140334c2f987ed13eff0d80b8be13
20
22
 
21
23
 
24
+ class EqParams(metaclass=DictToModuleMeta):
25
+ """
26
+ Note that this is exposed to the user for the particular case where the
27
+ user, during its work, wants to change the equation parameters. In this
28
+ case, the user must import EqParams and call `EqParams.clear()`
29
+ """
30
+
31
+ pass
32
+
33
+
22
34
  class Params(eqx.Module, Generic[T]):
23
35
  """
24
36
  The equinox module for the parameters
@@ -28,37 +40,47 @@ class Params(eqx.Module, Generic[T]):
28
40
  nn_params : PyTree[T]
29
41
  A PyTree of the non-static part of the PINN eqx.Module, i.e., the
30
42
  parameters of the PINN
31
- eq_params : dict[str, T]
32
- A dictionary of the equation parameters. Keys are the parameter name,
33
- values are their corresponding value
43
+ eq_params : PyTree[T]
44
+ A PyTree of the equation parameters. For retrocompatibility it us
45
+ provided as a dictionary of the equation parameters where keys are the parameter names, and values are their corresponding values. Internally,
46
+ it will be transformed to a custom instance of `EqParams`.
34
47
  """
35
48
 
36
- nn_params: PyTree[T] = eqx.field(kw_only=True, default=None)
37
- eq_params: dict[str, T] = eqx.field(kw_only=True, default=None)
49
+ nn_params: PyTree[T]
50
+ eq_params: PyTree[T]
51
+
52
+ def __init__(
53
+ self,
54
+ nn_params: PyTree[T] | None = None,
55
+ eq_params: dict[str, T] | None = None,
56
+ ):
57
+ self.nn_params = nn_params
58
+ if isinstance(eq_params, dict):
59
+ self.eq_params = EqParams(eq_params, "EqParams")
60
+ else:
61
+ self.eq_params = eq_params
38
62
 
39
63
 
40
- def _update_eq_params_dict(
64
+ def update_eq_params(
41
65
  params: Params[Array],
42
- param_batch_dict: dict[str, Float[Array, " param_batch_size dim"]],
43
- ) -> Params:
66
+ eq_param_batch: PyTree[Array] | None,
67
+ ) -> Params[Array]:
44
68
  """
45
69
  Update params.eq_params with a batch of eq_params for given key(s)
46
70
  """
47
71
 
48
- # artificially "complete" `param_batch_dict` with None to match `params`
49
- # PyTree structure
50
- param_batch_dict_ = param_batch_dict | {
51
- k: None for k in set(params.eq_params.keys()) - set(param_batch_dict.keys())
52
- }
72
+ if eq_param_batch is None:
73
+ return params
53
74
 
54
- # Replace at non None leafs
75
+ param_names_to_update = tuple(f.name for f in fields(eq_param_batch))
55
76
  params = eqx.tree_at(
56
77
  lambda p: p.eq_params,
57
78
  params,
58
- jax.tree_util.tree_map(
59
- lambda p, q: q if q is not None else p,
79
+ eqx.tree_at(
80
+ lambda pt: tuple(getattr(pt, f) for f in param_names_to_update),
60
81
  params.eq_params,
61
- param_batch_dict_,
82
+ tuple(getattr(eq_param_batch, f) for f in param_names_to_update),
83
+ is_leaf=lambda x: x is None or eqx.is_inexact_array(x),
62
84
  ),
63
85
  )
64
86
 
@@ -66,7 +88,7 @@ def _update_eq_params_dict(
66
88
 
67
89
 
68
90
  def _get_vmap_in_axes_params(
69
- eq_params_batch_dict: dict[str, Array], params: Params[Array]
91
+ eq_param_batch: eqx.Module | None, params: Params[Array]
70
92
  ) -> tuple[Params[int | None] | None]:
71
93
  """
72
94
  Return the input vmap axes when there is batch(es) of parameters to vmap
@@ -77,19 +99,22 @@ def _get_vmap_in_axes_params(
77
99
  Note that we return a Params PyTree with an integer to designate the
78
100
  vmapped axis or None if there is not
79
101
  """
80
- if eq_params_batch_dict is None:
102
+ if eq_param_batch is None:
81
103
  return (None,)
82
104
  # We use pytree indexing of vmapped axes and vmap on axis
83
105
  # 0 of the eq_parameters for which we have a batch
84
106
  # this is for a fine-grained vmaping
85
107
  # scheme over the params
108
+ param_names_to_vmap = tuple(f.name for f in fields(eq_param_batch))
109
+ vmap_axes_dict = {
110
+ k.name: (0 if k.name in param_names_to_vmap else None)
111
+ for k in fields(params.eq_params)
112
+ }
113
+ eq_param_vmap_axes = type(params.eq_params)(**vmap_axes_dict)
86
114
  vmap_in_axes_params = (
87
115
  Params(
88
116
  nn_params=None,
89
- eq_params={
90
- k: (0 if k in eq_params_batch_dict.keys() else None)
91
- for k in params.eq_params.keys()
92
- },
117
+ eq_params=eq_param_vmap_axes,
93
118
  ),
94
119
  )
95
120
  return vmap_in_axes_params
jinns/solver/_solve.py CHANGED
@@ -14,7 +14,7 @@ import optax
14
14
  import jax
15
15
  from jax import jit
16
16
  import jax.numpy as jnp
17
- from jaxtyping import Float, Array, PyTree, Key
17
+ from jaxtyping import Float, Array, PyTree, PRNGKeyArray
18
18
  import equinox as eqx
19
19
  from jinns.solver._rar import init_rar, trigger_rar
20
20
  from jinns.utils._utils import _check_nan_in_pytree
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
47
47
  LossContainer,
48
48
  StoredObjectContainer,
49
49
  Float[Array, " n_iter"] | None,
50
- Key | None,
50
+ PRNGKeyArray | None,
51
51
  ]
52
52
 
53
53
 
@@ -66,7 +66,7 @@ def solve(
66
66
  obs_batch_sharding: jax.sharding.Sharding | None = None,
67
67
  verbose: bool = True,
68
68
  ahead_of_time: bool = True,
69
- key: Key = None,
69
+ key: PRNGKeyArray | None = None,
70
70
  ) -> tuple[
71
71
  Params[Array],
72
72
  Float[Array, " n_iter"],
@@ -0,0 +1,66 @@
1
+ from typing import Any
2
+ import equinox as eqx
3
+
4
+
5
+ class DictToModuleMeta(type):
6
+ """
7
+ A Metaclass based solution to handle the fact that we only
8
+ want one type to be created for EqParams.
9
+ If we were to create a new **class type** (despite same name) each time we
10
+ create a new Params object, nothing would be broadcastable in terms of jax
11
+ tree utils operations and this would be useless. The difficulty comes from
12
+ the fact that we need to instanciate from this same class at different
13
+ moments of the jinns workflow eg: parameter creation, derivative keys
14
+ creations, tracked parameter designation, etc. (ie. each time a Params
15
+ class is instanciated whatever its usage, we need the same EqParams class
16
+ to be instanciated)
17
+
18
+ This is inspired by the Singleton pattern in Python
19
+ (https://stackoverflow.com/a/10362179)
20
+
21
+ Here we need the call of a metaclass because as explained in
22
+ https://stackoverflow.com/a/45536640). To quote from the answer
23
+ Metaclasses implement how the class will behave (not the instance). So when you look at the instance creation:
24
+ `x = Foo()`
25
+ This literally "calls" the class Foo. That's why __call__ of the metaclass
26
+ is invoked before the __new__ and
27
+ __init__ methods of your class initialize the instance.
28
+ Other viewpoint: Metaclasses,as well as classes making use of those
29
+ metaclasses, are created when the lines of code containing
30
+ the class statement body is executed
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ super(DictToModuleMeta, self).__init__(*args, **kwargs)
35
+ self._class = None
36
+
37
+ def __call__(self, d: dict[str, Any], class_name: str | None = None) -> eqx.Module:
38
+ """
39
+ Notably, once the class template is registered (after the first call to
40
+ EqParams()), all calls with different keys in `d` will fail.
41
+ """
42
+ if self._class is None and class_name is not None:
43
+ self._class = type(
44
+ class_name,
45
+ (eqx.Module,),
46
+ {"__annotations__": {k: type(v) for k, v in d.items()}},
47
+ )
48
+ try:
49
+ return self._class(**d) # type: ignore
50
+ except TypeError as _:
51
+ print(
52
+ "DictToModuleMeta has been created with the fields"
53
+ f"{tuple(k for k in self._class.__annotations__.keys())}"
54
+ f" but an instanciation is resquested with fields={tuple(k for k in d.keys())}"
55
+ " which results in an error"
56
+ )
57
+ raise ValueError
58
+
59
+ def clear(cls) -> None:
60
+ """
61
+ The current Metaclass implementation freezes the list of equation parameters inside a Python session;
62
+ only one EqParams annotation can exist at a given time. Use `EqParams.clear()` to reset.
63
+ Also useful for pytest where stuff is not complety reset after tests
64
+ Taken from https://stackoverflow.com/a/50065732
65
+ """
66
+ cls._class = None
@@ -0,0 +1,19 @@
1
+ from dataclasses import fields
2
+ from typing import Any, ItemsView
3
+ import equinox as eqx
4
+
5
+
6
+ class ItemizableModule(eqx.Module):
7
+ def items(self) -> ItemsView[str, Any]:
8
+ """
9
+ For the dataclass to be iterated like a dictionary.
10
+ Practical and retrocompatible with old code when loss components were
11
+ dictionaries
12
+
13
+ About the type hint: https://stackoverflow.com/questions/73022688/type-annotation-for-dict-items
14
+ """
15
+ return {
16
+ field.name: getattr(self, field.name)
17
+ for field in fields(self)
18
+ if getattr(self, field.name) is not None
19
+ }.items()
jinns/utils/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from ._utils import get_grid
2
+ from ._DictToModuleMeta import DictToModuleMeta
2
3
 
3
- __all__ = ["get_grid"]
4
+ __all__ = ["get_grid", "DictToModuleMeta"]