jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/nn/_pinn.py ADDED
@@ -0,0 +1,190 @@
1
+ """
2
+ Implement abstract class for PINN architectures
3
+ """
4
+
5
+ from typing import Literal, Callable, Union, Any
6
+ from dataclasses import InitVar
7
+ import equinox as eqx
8
+ from jaxtyping import Float, Array, PyTree
9
+ import jax.numpy as jnp
10
+ from jinns.parameters._params import Params, ParamsDict
11
+
12
+
13
+ class PINN(eqx.Module):
14
+ r"""
15
+ Base class for PINN objects. It can be seen as a wrapper on
16
+ an `eqx.Module` which actually implement the NN architectures, with extra
17
+ arguments handling the "physics-informed" aspect.
18
+
19
+ !!! Note
20
+ We use the `eqx.partition` and `eqx.combine` strategy of Equinox: a
21
+ `filter_spec` is applied on the PyTree and splits it into two PyTree with
22
+ the same structure: a static one (invisible to JAX transform such as JIT,
23
+ grad, etc.) and dynamic one. By convention, anything not static is
24
+ considered a parameter in Jinns.
25
+
26
+ For compatibility with jinns, we require that a `PINN` architecture:
27
+
28
+ 1) has an eqx.Module (`eqx_network`) InitVar passed to __post_init__
29
+ representing the network architecture.
30
+ 2) calls `eqx.partition` in __post_init__ in order to store the
31
+ static part of the model and the initial parameters.
32
+ 3) has a `eq_type` argument, used for handling internal operations in
33
+ jinns.
34
+ 4) has a `slice_solution` argument. It is a `jnp.s\_` object which
35
+ indicates which axis of the PINN output is dedicated to the actual equation
36
+ solution. Default None means that slice_solution = the whole PINN output.
37
+ For example, this argument is useful when the PINN is also used to output
38
+ equation parameters. Note that it must be a slice and not an integer (a
39
+ preprocessing of the user provided argument takes care of it).
40
+
41
+ Parameters
42
+ ----------
43
+ slice_solution : slice
44
+ Default is jnp.s\_[...]. A jnp.s\_ object which indicates which axis of the PINN output is
45
+ dedicated to the actual equation solution. Default None
46
+ means that slice_solution = the whole PINN output. This argument is useful
47
+ when the PINN is also used to output equation parameters for example
48
+ Note that it must be a slice and not an integer (a preprocessing of the
49
+ user provided argument takes care of it).
50
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
51
+ A string with three possibilities.
52
+ "ODE": the PINN is called with one input `t`.
53
+ "statio_PDE": the PINN is called with one input `x`, `x`
54
+ can be high dimensional.
55
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
56
+ can be high dimensional.
57
+ **Note**: the input dimension as given in eqx_list has to match the sum
58
+ of the dimension of `t` + the dimension of `x` or the output dimension
59
+ after the `input_transform` function.
60
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
61
+ A function that will be called before entering the PINN. Its output(s)
62
+ must match the PINN inputs (except for the parameters).
63
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
64
+ and the parameters. Default is no operation.
65
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
66
+ A function with arguments begin the same input as the PINN, the PINN
67
+ output and the parameter. This function will be called after exiting the PINN.
68
+ Default is no operation.
69
+ eqx_network : eqx.Module
70
+ The actual neural network instanciated as an eqx.Module.
71
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
72
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
73
+ a trainable parameter. Quoting from equinox documentation:
74
+ a PyTree whose structure should be a prefix of the structure of pytree.
75
+ Each of its leaves should either be 1) True, in which case the leaf or
76
+ subtree is kept; 2) False, in which case the leaf or subtree is
77
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
78
+
79
+
80
+ Raises
81
+ ------
82
+ RuntimeError
83
+ If the parameter value for eq_type is not in `["ODE", "statio_PDE",
84
+ "nonstatio_PDE"]`
85
+ """
86
+
87
+ slice_solution: slice = eqx.field(static=True, kw_only=True, default=None)
88
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
89
+ static=True, kw_only=True
90
+ )
91
+ input_transform: Callable[
92
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
93
+ ] = eqx.field(static=True, kw_only=True, default=None)
94
+ output_transform: Callable[
95
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
96
+ Float[Array, "output_dim"],
97
+ ] = eqx.field(static=True, kw_only=True, default=None)
98
+
99
+ eqx_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
100
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
101
+ static=True, kw_only=True, default=eqx.is_inexact_array
102
+ )
103
+
104
+ init_params: PyTree = eqx.field(init=False)
105
+ static: PyTree = eqx.field(init=False, static=True)
106
+
107
+ def __post_init__(self, eqx_network):
108
+
109
+ if self.eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
110
+ raise RuntimeError("Wrong parameter value for eq_type")
111
+ # saving the static part of the model and initial parameters
112
+
113
+ if self.filter_spec is None:
114
+ self.filter_spec = eqx.is_inexact_array
115
+
116
+ self.init_params, self.static = eqx.partition(eqx_network, self.filter_spec)
117
+
118
+ if self.input_transform is None:
119
+ self.input_transform = lambda _in, _params: _in
120
+
121
+ if self.output_transform is None:
122
+ self.output_transform = lambda _in_pinn, _out_pinn, _params: _out_pinn
123
+
124
+ if self.slice_solution is None:
125
+ self.slice_solution = jnp.s_[:]
126
+
127
+ if isinstance(self.slice_solution, int):
128
+ # rewrite it as a slice to ensure that axis does not disappear when
129
+ # indexing
130
+ self.slice_solution = jnp.s_[self.slice_solution : self.slice_solution + 1]
131
+
132
+ def eval(self, network, inputs, *args, **kwargs):
133
+ """How to call your Equinox module `network`. The purpose of this method
134
+ is to give more flexibility : user should re-implement `eval`
135
+ when inheriting from `PINN` if they desire more flexibility on how to
136
+ evaluate the network.
137
+
138
+ Defaults to using `network.__call__(inputs)` but it could be more refined *e.g.* `network.anymethod(inputs)`.
139
+
140
+ Parameters
141
+ ----------
142
+ network : eqx.Module
143
+ Your neural network with the parameters set, usually returned by
144
+ `eqx.combine(self.static, current_params)`.
145
+ inputs : Array
146
+ The inputs, evetually transformed by `self.input_transformed` if
147
+ specified by the user.
148
+
149
+ Returns
150
+ -------
151
+ Array
152
+ The output
153
+ """
154
+
155
+ return network(inputs)
156
+
157
+ def __call__(
158
+ self,
159
+ inputs: Float[Array, "input_dim"],
160
+ params: Params | ParamsDict | PyTree,
161
+ *args,
162
+ **kwargs,
163
+ ) -> Float[Array, "output_dim"]:
164
+ """
165
+ A proper __call__ implementation performs an eqx.combine here with
166
+ `params` and `self.static` to recreate the callable eqx.Module
167
+ architecture. The rest of the content of this function is dependent on
168
+ the network.
169
+ """
170
+
171
+ if len(inputs.shape) == 0:
172
+ # This can happen often when the user directly provides some
173
+ # collocation points (eg for plotting, whithout using
174
+ # DataGenerators)
175
+ inputs = inputs[None]
176
+
177
+ try:
178
+ model = eqx.combine(params.nn_params, self.static)
179
+ except (KeyError, AttributeError, TypeError) as e: # give more flexibility
180
+ model = eqx.combine(params, self.static)
181
+
182
+ # evaluate the model
183
+ res = self.eval(model, self.input_transform(inputs, params), *args, **kwargs)
184
+
185
+ res = self.output_transform(inputs, res.squeeze(), params)
186
+
187
+ # force (1,) output for non vectorial solution (consistency)
188
+ if not res.shape:
189
+ return jnp.expand_dims(res, axis=-1)
190
+ return res
jinns/nn/_ppinn.py ADDED
@@ -0,0 +1,203 @@
1
+ """
2
+ Implements utility function to create PINNs
3
+ """
4
+
5
+ from typing import Callable, Literal, Self
6
+ from dataclasses import InitVar
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import equinox as eqx
10
+
11
+ from jaxtyping import Array, Key, PyTree, Float
12
+
13
+ from jinns.parameters._params import Params, ParamsDict
14
+ from jinns.nn._pinn import PINN
15
+ from jinns.nn._mlp import MLP
16
+
17
+
18
+ class PPINN_MLP(PINN):
19
+ r"""
20
+ A PPINN MLP (Parallel PINN with MLPs) object which mimicks the PFNN architecture from
21
+ DeepXDE. This is in fact a PINN MLP that encompasses several PINN MLPs internally.
22
+
23
+ Parameters
24
+ ----------
25
+ slice_solution : slice
26
+ A jnp.s\_ object which indicates which axis of the PPINN output is
27
+ dedicated to the actual equation solution. Default None
28
+ means that slice_solution = the whole PPINN output. This argument is useful
29
+ when the PINN is also used to output equation parameters for example
30
+ Note that it must be a slice and not an integer (a preprocessing of the
31
+ user provided argument takes care of it).
32
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
33
+ A string with three possibilities.
34
+ "ODE": the PPINN is called with one input `t`.
35
+ "statio_PDE": the PPINN is called with one input `x`, `x`
36
+ can be high dimensional.
37
+ "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
38
+ can be high dimensional.
39
+ **Note**: the input dimension as given in eqx_list has to match the sum
40
+ of the dimension of `t` + the dimension of `x` or the output dimension
41
+ after the `input_transform` function.
42
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
43
+ A function that will be called before entering the PPINN. Its output(s)
44
+ must match the PPINN inputs (except for the parameters).
45
+ Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
46
+ and the parameters. Default is no operation.
47
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
48
+ A function with arguments begin the same input as the PPINN, the PPINN
49
+ output and the parameter. This function will be called after exiting
50
+ the PPINN.
51
+ Default is no operation.
52
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
53
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
54
+ a trainable parameter. Quoting from equinox documentation:
55
+ a PyTree whose structure should be a prefix of the structure of pytree.
56
+ Each of its leaves should either be 1) True, in which case the leaf or
57
+ subtree is kept; 2) False, in which case the leaf or subtree is
58
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
59
+ eqx_network_list
60
+ A list of eqx.nn.MLP objects with same input
61
+ dimensions. They represent the parallel subnetworks of the PPIN MLP.
62
+ Their respective outputs are concatenated.
63
+ """
64
+
65
+ eqx_network_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
66
+
67
+ def __post_init__(self, eqx_network, eqx_network_list):
68
+ super().__post_init__(
69
+ eqx_network=eqx_network_list[0], # this is not used since it is
70
+ # overwritten just below
71
+ )
72
+ self.init_params, self.static = (), ()
73
+ for eqx_network_ in eqx_network_list:
74
+ params, static = eqx.partition(eqx_network_, self.filter_spec)
75
+ self.init_params = self.init_params + (params,)
76
+ self.static = self.static + (static,)
77
+
78
+ def __call__(
79
+ self,
80
+ inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
81
+ params: PyTree,
82
+ ) -> Float[Array, "output_dim"]:
83
+ """
84
+ Evaluate the PPINN on some inputs with some params.
85
+ """
86
+ if len(inputs.shape) == 0:
87
+ # This can happen often when the user directly provides some
88
+ # collocation points (eg for plotting, whithout using
89
+ # DataGenerators)
90
+ inputs = inputs[None]
91
+ transformed_inputs = self.input_transform(inputs, params)
92
+
93
+ outs = []
94
+
95
+ try:
96
+ for params_, static in zip(params.nn_params, self.static):
97
+ model = eqx.combine(params_, static)
98
+ outs += [model(transformed_inputs)]
99
+ except (KeyError, AttributeError, TypeError) as e:
100
+ for params_, static in zip(params, self.static):
101
+ model = eqx.combine(params_, static)
102
+ outs += [model(transformed_inputs)]
103
+ # Note that below is then a global output transform
104
+ res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
105
+
106
+ ## force (1,) output for non vectorial solution (consistency)
107
+ if not res.shape:
108
+ return jnp.expand_dims(res, axis=-1)
109
+ return res
110
+
111
+ @classmethod
112
+ def create(
113
+ cls,
114
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
115
+ eqx_network_list: list[eqx.nn.MLP] = None,
116
+ key: Key = None,
117
+ eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]] = None,
118
+ input_transform: Callable[
119
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
120
+ ] = None,
121
+ output_transform: Callable[
122
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
123
+ Float[Array, "output_dim"],
124
+ ] = None,
125
+ slice_solution: slice = None,
126
+ ) -> tuple[Self, PyTree]:
127
+ r"""
128
+ Utility function to create a Parrallel PINN neural network for Jinns.
129
+
130
+ Parameters
131
+ ----------
132
+ eq_type
133
+ A string with three possibilities.
134
+ "ODE": the PPINN MLP is called with one input `t`.
135
+ "statio_PDE": the PPINN MLP is called with one input `x`, `x`
136
+ can be high dimensional.
137
+ "nonstatio_PDE": the PPINN MLP is called with two inputs `t` and `x`, `x`
138
+ can be high dimensional.
139
+ **Note**: the input dimension as given in eqx_list has to match the sum
140
+ of the dimension of `t` + the dimension of `x` or the output dimension
141
+ after the `input_transform` function.
142
+ eqx_network_list
143
+ Default is None. A list of eqx.nn.MLP objects with same input
144
+ dimensions. They represent the parallel subnetworks of the PPIN MLP.
145
+ Their respective outputs are concatenated.
146
+ key
147
+ Default is None. Must be provided with `eqx_list_list` if
148
+ `eqx_network_list` is not provided. A JAX random key that will be used
149
+ to initialize the networks parameters.
150
+ eqx_list_list
151
+ Default is None. Must be provided if `eqx_network_list` is not
152
+ provided. A list of `eqx_list` (see `PINN_MLP.create()`). The input dimension must be the
153
+ same for each sub-`eqx_list`. Then the parallel subnetworks can be
154
+ different. Their respective outputs are concatenated.
155
+ input_transform
156
+ A function that will be called before entering the PPINN MLP. Its output(s)
157
+ must match the PPINN MLP inputs (except for the parameters).
158
+ Its inputs are the PPINN MLP inputs (`t` and/or `x` concatenated together)
159
+ and the parameters. Default is no operation.
160
+ output_transform
161
+ This function will be called after exiting
162
+ the PPINN MLP, i.e., on the concatenated outputs of all parallel networks
163
+ Default is no operation.
164
+ slice_solution
165
+ A jnp.s\_ object which indicates which axis of the PPINN MLP output is
166
+ dedicated to the actual equation solution. Default None
167
+ means that slice_solution = the whole PPINN MLP output. This argument is
168
+ useful when the PPINN MLP is also used to output equation parameters for
169
+ example Note that it must be a slice and not an integer (a
170
+ preprocessing of the user provided argument takes care of it).
171
+
172
+
173
+ Returns
174
+ -------
175
+ ppinn
176
+ A PPINN MLP instance
177
+ ppinn.init_params
178
+ An initial set of parameters for the PPINN MLP
179
+
180
+ """
181
+
182
+ if eqx_network_list is None:
183
+ if eqx_list_list is None or key is None:
184
+ raise ValueError(
185
+ "If eqx_network_list is None, then key and eqx_list_list"
186
+ " must be provided"
187
+ )
188
+
189
+ eqx_network_list = []
190
+ for eqx_list in eqx_list_list:
191
+ key, subkey = jax.random.split(key, 2)
192
+ print(subkey)
193
+ eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
194
+
195
+ ppinn = cls(
196
+ eqx_network=None,
197
+ eqx_network_list=eqx_network_list,
198
+ slice_solution=slice_solution,
199
+ eq_type=eq_type,
200
+ input_transform=input_transform,
201
+ output_transform=output_transform,
202
+ )
203
+ return ppinn, ppinn.init_params
@@ -7,14 +7,16 @@ import pickle
7
7
  import jax
8
8
  import equinox as eqx
9
9
 
10
- from jinns.utils._pinn import create_PINN, PINN
11
- from jinns.utils._spinn import create_SPINN, SPINN
12
- from jinns.utils._hyperpinn import create_HYPERPINN, HYPERPINN
10
+ from jinns.nn._pinn import PINN
11
+ from jinns.nn._spinn import SPINN
12
+ from jinns.nn._mlp import PINN_MLP
13
+ from jinns.nn._spinn_mlp import SPINN_MLP
14
+ from jinns.nn._hyperpinn import HyperPINN
13
15
  from jinns.parameters._params import Params, ParamsDict
14
16
 
15
17
 
16
18
  def function_to_string(
17
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...]
19
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
18
20
  ) -> tuple[tuple[str, int, int] | str, ...]:
19
21
  """
20
22
  We need this transformation for eqx_list to be pickled
@@ -40,7 +42,7 @@ def function_to_string(
40
42
 
41
43
 
42
44
  def string_to_function(
43
- eqx_list_with_string: tuple[tuple[str, int, int] | str, ...]
45
+ eqx_list_with_string: tuple[tuple[str, int, int] | str, ...],
44
46
  ) -> tuple[tuple[Callable, int, int] | Callable, ...]:
45
47
  """
46
48
  We need this transformation for eqx_list at the loading ("unpickling")
@@ -84,7 +86,7 @@ def string_to_function(
84
86
 
85
87
  def save_pinn(
86
88
  filename: str,
87
- u: PINN | HYPERPINN | SPINN,
89
+ u: PINN | HyperPINN | SPINN,
88
90
  params: Params | ParamsDict,
89
91
  kwargs_creation,
90
92
  ):
@@ -103,7 +105,7 @@ def save_pinn(
103
105
  tree_serialise_leaves`).
104
106
 
105
107
  Equation parameters are saved apart because the initial type of attribute
106
- `params` in PINN / HYPERPINN / SPINN is not `Params` nor `ParamsDict`
108
+ `params` in PINN / HyperPINN / SPINN is not `Params` nor `ParamsDict`
107
109
  but `PyTree` as inherited from `eqx.partition`.
108
110
  Therefore, if we want to ensure a proper serialization/deserialization:
109
111
  - we cannot save a `Params` object at this
@@ -111,7 +113,7 @@ def save_pinn(
111
113
  (type `PyTree`) and `Params.eq_params` (type `dict`).
112
114
  - in the case of a `ParamsDict` we cannot save `ParamsDict.nn_params` at
113
115
  the attribute field `params` because it is not a `PyTree` (as expected in
114
- the PINN / HYPERPINN / SPINN signature) but it is still a dictionary.
116
+ the PINN / HyperPINN / SPINN signature) but it is still a dictionary.
115
117
 
116
118
  Parameters
117
119
  ----------
@@ -126,18 +128,18 @@ def save_pinn(
126
128
  the layers list, O/PDE type, etc.
127
129
  """
128
130
  if isinstance(params, Params):
129
- if isinstance(u, HYPERPINN):
130
- u = eqx.tree_at(lambda m: m.params_hyper, u, params)
131
+ if isinstance(u, HyperPINN):
132
+ u = eqx.tree_at(lambda m: m.init_params_hyper, u, params)
131
133
  elif isinstance(u, (PINN, SPINN)):
132
- u = eqx.tree_at(lambda m: m.params, u, params)
134
+ u = eqx.tree_at(lambda m: m.init_params, u, params)
133
135
  eqx.tree_serialise_leaves(filename + "-module.eqx", u)
134
136
 
135
137
  elif isinstance(params, ParamsDict):
136
138
  for key, params_ in params.nn_params.items():
137
- if isinstance(u, HYPERPINN):
138
- u = eqx.tree_at(lambda m: m.params_hyper, u, params_)
139
+ if isinstance(u, HyperPINN):
140
+ u = eqx.tree_at(lambda m: m.init_params_hyper, u, params_)
139
141
  elif isinstance(u, (PINN, SPINN)):
140
- u = eqx.tree_at(lambda m: m.params, u, params_)
142
+ u = eqx.tree_at(lambda m: m.init_params, u, params_)
141
143
  eqx.tree_serialise_leaves(filename + f"-module_{key}.eqx", u)
142
144
 
143
145
  else:
@@ -167,7 +169,7 @@ def save_pinn(
167
169
 
168
170
  def load_pinn(
169
171
  filename: str,
170
- type_: Literal["pinn", "hyperpinn", "spinn"],
172
+ type_: Literal["pinn_mlp", "hyperpinn", "spinn_mlp"],
171
173
  key_list_for_paramsdict: list[str] = None,
172
174
  ) -> tuple[eqx.Module, Params | ParamsDict]:
173
175
  """
@@ -187,7 +189,7 @@ def load_pinn(
187
189
  filename
188
190
  Filename (prefix) without extension.
189
191
  type_
190
- Type of model to load. Must be in ["pinn", "hyperpinn", "spinn"].
192
+ Type of model to load. Must be in ["pinn_mlp", "hyperpinn", "spinn"].
191
193
  key_list_for_paramsdict
192
194
  Pass the name of the keys of the dictionnary `ParamsDict.nn_params`. Default is None. In this case, we expect to retrieve a ParamsDict.
193
195
 
@@ -207,18 +209,22 @@ def load_pinn(
207
209
  eq_params_reloaded = {}
208
210
  print("No pickle file for equation parameters found!")
209
211
  kwargs_reloaded["eqx_list"] = string_to_function(kwargs_reloaded["eqx_list"])
210
- if type_ == "pinn":
212
+ if type_ == "pinn_mlp":
211
213
  # next line creates a shallow model, the jax arrays are just shapes and
212
214
  # not populated, this just recreates the correct pytree structure
213
- u_reloaded_shallow, _ = eqx.filter_eval_shape(create_PINN, **kwargs_reloaded)
214
- elif type_ == "spinn":
215
- u_reloaded_shallow, _ = eqx.filter_eval_shape(create_SPINN, **kwargs_reloaded)
215
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(
216
+ PINN_MLP.create, **kwargs_reloaded
217
+ )
218
+ elif type_ == "spinn_mlp":
219
+ u_reloaded_shallow, _ = eqx.filter_eval_shape(
220
+ SPINN_MLP.create, **kwargs_reloaded
221
+ )
216
222
  elif type_ == "hyperpinn":
217
223
  kwargs_reloaded["eqx_list_hyper"] = string_to_function(
218
224
  kwargs_reloaded["eqx_list_hyper"]
219
225
  )
220
226
  u_reloaded_shallow, _ = eqx.filter_eval_shape(
221
- create_HYPERPINN, **kwargs_reloaded
227
+ HyperPINN.create, **kwargs_reloaded
222
228
  )
223
229
  else:
224
230
  raise ValueError(f"{type_} is not valid")
@@ -228,13 +234,23 @@ def load_pinn(
228
234
  u_reloaded = eqx.tree_deserialise_leaves(
229
235
  filename + "-module.eqx", u_reloaded_shallow
230
236
  )
231
- params = Params(nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded)
237
+ if isinstance(u_reloaded, HyperPINN):
238
+ params = Params(
239
+ nn_params=u_reloaded.init_params_hyper, eq_params=eq_params_reloaded
240
+ )
241
+ elif isinstance(u_reloaded, (PINN, SPINN)):
242
+ params = Params(
243
+ nn_params=u_reloaded.init_params, eq_params=eq_params_reloaded
244
+ )
232
245
  else:
233
246
  nn_params_dict = {}
234
247
  for key in key_list_for_paramsdict:
235
248
  u_reloaded = eqx.tree_deserialise_leaves(
236
249
  filename + f"-module_{key}.eqx", u_reloaded_shallow
237
250
  )
238
- nn_params_dict[key] = u_reloaded.init_params
251
+ if isinstance(u_reloaded, HyperPINN):
252
+ nn_params_dict[key] = u_reloaded.init_params_hyper
253
+ elif isinstance(u_reloaded, (PINN, SPINN)):
254
+ nn_params_dict[key] = u_reloaded.init_params
239
255
  params = ParamsDict(nn_params=nn_params_dict, eq_params=eq_params_reloaded)
240
256
  return u_reloaded, params
jinns/nn/_spinn.py ADDED
@@ -0,0 +1,106 @@
1
+ from typing import Union, Callable, Any
2
+ from dataclasses import InitVar
3
+ from jaxtyping import PyTree, Float, Array
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import equinox as eqx
7
+
8
+ from jinns.parameters._params import Params, ParamsDict
9
+
10
+
11
+ class SPINN(eqx.Module):
12
+ """
13
+ A Separable PINN object compatible with the rest of jinns.
14
+
15
+ Parameters
16
+ ----------
17
+ d : int
18
+ The number of dimensions to treat separately, including time `t` if
19
+ used for non-stationnary equations.
20
+ r : int
21
+ An integer. The dimension of the embedding.
22
+ eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
23
+ A string with three possibilities.
24
+ "ODE": the PINN is called with one input `t`.
25
+ "statio_PDE": the PINN is called with one input `x`, `x`
26
+ can be high dimensional.
27
+ "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
28
+ can be high dimensional.
29
+ **Note**: the input dimension as given in eqx_list has to match the sum
30
+ of the dimension of `t` + the dimension of `x`.
31
+ m : int
32
+ The output dimension of the neural network. According to
33
+ the SPINN article, a total embedding dimension of `r*m` is defined. We
34
+ then sum groups of `r` embedding dimensions to compute each output.
35
+ Default is 1.
36
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
37
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
38
+ a trainable parameter. Quoting from equinox documentation:
39
+ a PyTree whose structure should be a prefix of the structure of pytree.
40
+ Each of its leaves should either be 1) True, in which case the leaf or
41
+ subtree is kept; 2) False, in which case the leaf or subtree is
42
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
43
+ eqx_spinn_network : eqx.Module
44
+ The actual neural network instanciated as an eqx.Module. It should be
45
+ an architecture taking `d` inputs and returning `d` times an embedding
46
+ of dimension `r`*`m`. See the Separable PINN paper for more details.
47
+
48
+ """
49
+
50
+ d: int = eqx.field(static=True, kw_only=True)
51
+ r: int = eqx.field(static=True, kw_only=True)
52
+ eq_type: str = eqx.field(static=True, kw_only=True)
53
+ m: int = eqx.field(static=True, kw_only=True, default=1)
54
+
55
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = eqx.field(
56
+ static=True, kw_only=True, default=None
57
+ )
58
+ eqx_spinn_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
59
+
60
+ init_params: PyTree = eqx.field(init=False)
61
+ static: PyTree = eqx.field(init=False, static=True)
62
+
63
+ def __post_init__(self, eqx_spinn_network):
64
+
65
+ if self.filter_spec is None:
66
+ self.filter_spec = eqx.is_inexact_array
67
+
68
+ self.init_params, self.static = eqx.partition(
69
+ eqx_spinn_network, self.filter_spec
70
+ )
71
+
72
+ def __call__(
73
+ self,
74
+ t_x: Float[Array, "batch_size 1+dim"],
75
+ params: Params | ParamsDict | PyTree,
76
+ ) -> Float[Array, "output_dim"]:
77
+ """
78
+ Evaluate the SPINN on some inputs with some params.
79
+ """
80
+ try:
81
+ spinn = eqx.combine(params.nn_params, self.static)
82
+ except (KeyError, AttributeError, TypeError) as e:
83
+ spinn = eqx.combine(params, self.static)
84
+ v_model = jax.vmap(spinn)
85
+ res = v_model(t_x)
86
+
87
+ a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
88
+ b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
89
+ res = jnp.stack(
90
+ [
91
+ jnp.einsum(
92
+ f"{a} -> {b}",
93
+ *(
94
+ res[:, d, m * self.r : (m + 1) * self.r]
95
+ for d in range(res.shape[1])
96
+ ),
97
+ )
98
+ for m in range(self.m)
99
+ ],
100
+ axis=-1,
101
+ ) # compute each output dimension
102
+
103
+ # force (1,) output for non vectorial solution (consistency)
104
+ if len(res.shape) == self.d:
105
+ return jnp.expand_dims(res, axis=-1)
106
+ return res