jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/nn/_hyperpinn.py ADDED
@@ -0,0 +1,397 @@
1
+ """
2
+ Implements utility function to create HyperPINNs
3
+ https://arxiv.org/pdf/2111.01008.pdf
4
+ """
5
+
6
+ import warnings
7
+ from dataclasses import InitVar
8
+ from typing import Callable, Literal, Self, Union, Any
9
+ from math import prod
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jaxtyping import Array, Float, PyTree, Key
13
+ import equinox as eqx
14
+ import numpy as onp
15
+
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._mlp import MLP
18
+ from jinns.parameters._params import Params, ParamsDict
19
+
20
+
21
+ def _get_param_nb(
22
+ params: Params,
23
+ ) -> tuple[int, list]:
24
+ """Returns the number of parameters in a Params object and also
25
+ the cumulative sum when parsing the object.
26
+
27
+
28
+ Parameters
29
+ ----------
30
+ params :
31
+ A Params object.
32
+ """
33
+ dim_prod_all_arrays = [
34
+ prod(a.shape)
35
+ for a in jax.tree.leaves(params, is_leaf=lambda x: isinstance(x, jnp.ndarray))
36
+ ]
37
+ return (
38
+ sum(dim_prod_all_arrays),
39
+ onp.cumsum(dim_prod_all_arrays).tolist(),
40
+ )
41
+
42
+
43
+ class HyperPINN(PINN):
44
+ r"""
45
+ An HyperPINN object compatible with the rest of jinns.
46
+ Composed of a PINN and an HYPER network. The HyperPINN is typically
47
+ instanciated using with `create`.
48
+
49
+ Parameters
50
+ ----------
51
+ hyperparams: list = eqx.field(static=True)
52
+ A list of keys from Params.eq_params that will be considered as
53
+ hyperparameters for metamodeling.
54
+ hypernet_input_size: int
55
+ An integer. The input size of the MLP used for the hypernetwork. Must
56
+ be equal to the flattened concatenations for the array of parameters
57
+ designated by the `hyperparams` argument.
58
+ slice_solution : slice
59
+ A jnp.s\_ object which indicates which axis of the PINN output is
60
+ dedicated to the actual equation solution. Default None
61
+ means that slice_solution = the whole PINN output. This argument is useful
62
+ when the PINN is also used to output equation parameters for example
63
+ Note that it must be a slice and not an integer (a preprocessing of the
64
+ user provided argument takes care of it).
65
+ eq_type : str
66
+ A string with three possibilities.
67
+ "ODE": the HyperPINN is called with one input `t`.
68
+ "statio_PDE": the HyperPINN is called with one input `x`, `x`
69
+ can be high dimensional.
70
+ "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
71
+ can be high dimensional.
72
+ **Note**: the input dimension as given in eqx_list has to match the sum
73
+ of the dimension of `t` + the dimension of `x` or the output dimension
74
+ after the `input_transform` function
75
+ input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
76
+ A function that will be called before entering the PINN. Its output(s)
77
+ must match the PINN inputs (except for the parameters).
78
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
79
+ and the parameters. Default is no operation.
80
+ output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
81
+ A function with arguments begin the same input as the PINN, the PINN
82
+ output and the parameter. This function will be called after exiting the PINN.
83
+ Default is no operation.
84
+ mlp : eqx.Module
85
+ The actual neural network instanciated as an eqx.Module.
86
+ hyper_mlp : eqx.Module
87
+ The actual hyper neural network instanciated as an eqx.Module.
88
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
89
+ Default is `eqx.is_inexact_array`. This tells Jinns what to consider as
90
+ a trainable parameter. Quoting from equinox documentation:
91
+ a PyTree whose structure should be a prefix of the structure of pytree.
92
+ Each of its leaves should either be 1) True, in which case the leaf or
93
+ subtree is kept; 2) False, in which case the leaf or subtree is
94
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
95
+ """
96
+
97
+ hyperparams: list[str] = eqx.field(static=True, kw_only=True)
98
+ hypernet_input_size: int = eqx.field(kw_only=True)
99
+
100
+ eqx_hyper_network: InitVar[eqx.Module] = eqx.field(kw_only=True)
101
+
102
+ pinn_params_sum: int = eqx.field(init=False, static=True)
103
+ pinn_params_cumsum: list = eqx.field(init=False, static=True)
104
+
105
+ init_params_hyper: PyTree = eqx.field(init=False)
106
+ static_hyper: PyTree = eqx.field(init=False, static=True)
107
+
108
+ def __post_init__(self, eqx_network, eqx_hyper_network):
109
+ super().__post_init__(
110
+ eqx_network,
111
+ )
112
+ # In addition, we store the PyTree structure of the hypernetwork as well
113
+ self.init_params_hyper, self.static_hyper = eqx.partition(
114
+ eqx_hyper_network, self.filter_spec
115
+ )
116
+ self.pinn_params_sum, self.pinn_params_cumsum = _get_param_nb(self.init_params)
117
+
118
+ def _hyper_to_pinn(self, hyper_output: Float[Array, "output_dim"]) -> PyTree:
119
+ """
120
+ From the output of the hypernetwork, transform to a well formed
121
+ parameters for the pinn network (i.e. with the same PyTree structure as
122
+ `self.init_params`)
123
+ """
124
+
125
+ pinn_params_flat = eqx.tree_at(
126
+ jax.tree.leaves, # is_leaf=eqx.is_array argument for jax.tree.leaves
127
+ # is not needed in general when working
128
+ # with eqx.nn.Linear for examples: jax.tree.leaves
129
+ # already returns the array of weights and biases only, since the
130
+ # other stuff (that we do not want to be returned) is marked as
131
+ # static (in eqx.nn.Linear), hence is not part of the leaves.
132
+ # Note, that custom layers should then be properly designed to pass
133
+ # this jax.tree.leaves.
134
+ self.init_params,
135
+ jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
136
+ )
137
+
138
+ return jax.tree.map(
139
+ lambda a, b: a.reshape(b.shape),
140
+ pinn_params_flat,
141
+ self.init_params,
142
+ is_leaf=lambda x: isinstance(x, jnp.ndarray),
143
+ )
144
+
145
+ def __call__(
146
+ self,
147
+ inputs: Float[Array, "input_dim"],
148
+ params: Params | ParamsDict | PyTree,
149
+ *args,
150
+ **kwargs,
151
+ ) -> Float[Array, "output_dim"]:
152
+ """
153
+ Evaluate the HyperPINN on some inputs with some params.
154
+ """
155
+ if len(inputs.shape) == 0:
156
+ # This can happen often when the user directly provides some
157
+ # collocation points (eg for plotting, whithout using
158
+ # DataGenerators)
159
+ inputs = inputs[None]
160
+
161
+ try:
162
+ hyper = eqx.combine(params.nn_params, self.static_hyper)
163
+ except (KeyError, AttributeError, TypeError) as e: # give more flexibility
164
+ hyper = eqx.combine(params, self.static_hyper)
165
+
166
+ eq_params_batch = jnp.concatenate(
167
+ [params.eq_params[k].flatten() for k in self.hyperparams], axis=0
168
+ )
169
+
170
+ hyper_output = hyper(eq_params_batch)
171
+
172
+ pinn_params = self._hyper_to_pinn(hyper_output)
173
+
174
+ pinn = eqx.combine(pinn_params, self.static)
175
+ res = self.eval(pinn, self.input_transform(inputs, params), *args, **kwargs)
176
+
177
+ res = self.output_transform(inputs, res.squeeze(), params)
178
+
179
+ # force (1,) output for non vectorial solution (consistency)
180
+ if not res.shape:
181
+ return jnp.expand_dims(res, axis=-1)
182
+ return res
183
+
184
+ @classmethod
185
+ def create(
186
+ cls,
187
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
188
+ hyperparams: list[str],
189
+ hypernet_input_size: int,
190
+ eqx_network: eqx.nn.MLP = None,
191
+ eqx_hyper_network: eqx.nn.MLP = None,
192
+ key: Key = None,
193
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
194
+ eqx_list_hyper: tuple[tuple[Callable, int, int] | Callable, ...] = None,
195
+ input_transform: Callable[
196
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
197
+ ] = None,
198
+ output_transform: Callable[
199
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
200
+ Float[Array, "output_dim"],
201
+ ] = None,
202
+ slice_solution: slice = None,
203
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
204
+ ) -> tuple[Self, PyTree]:
205
+ r"""
206
+ Utility function to create a standard PINN neural network with the equinox
207
+ library.
208
+
209
+ Parameters
210
+ ----------
211
+ key
212
+ A JAX random key that will be used to initialize the network
213
+ parameters.
214
+ eq_type
215
+ A string with three possibilities.
216
+ "ODE": the HyperPINN is called with one input `t`.
217
+ "statio_PDE": the HyperPINN is called with one input `x`, `x`
218
+ can be high dimensional.
219
+ "nonstatio_PDE": the HyperPINN is called with two inputs `t` and `x`, `x`
220
+ can be high dimensional.
221
+ **Note**: the input dimension as given in eqx_list has to match the sum
222
+ of the dimension of `t` + the dimension of `x` or the output dimension
223
+ after the `input_transform` function
224
+ hyperparams
225
+ A list of keys from Params.eq_params that will be considered as
226
+ hyperparameters for metamodeling.
227
+ hypernet_input_size
228
+ An integer. The input size of the MLP used for the hypernetwork. Must
229
+ be equal to the flattened concatenations for the array of parameters
230
+ designated by the `hyperparams` argument.
231
+ eqx_network
232
+ Default is None. A eqx.nn.MLP for the base network that will be wrapped inside
233
+ our PINN_MLP object in order to make it easily jinns compatible.
234
+ eqx_hyper_network
235
+ Default is None. A eqx.nn.MLP for the hyper network that will be wrapped inside
236
+ our PINN_MLP object in order to make it easily jinns compatible.
237
+ key
238
+ Default is None. Must be provided with `eqx_list` and
239
+ `eqx_list_hyper` if `eqx_network` or `eqx_hyper_network`
240
+ is not provided. A JAX random key that will be used to initialize the network
241
+ parameters.
242
+ eqx_list
243
+ Default is None. Must be provided if `eqx_network` or
244
+ `eqx_hyper_network`
245
+ is not provided.
246
+ A tuple of tuples of successive equinox modules and activation functions to
247
+ describe the base network architecture. The inner tuples must have the eqx module or
248
+ activation function as first item, other items represent arguments
249
+ that could be required (eg. the size of the layer).
250
+ The `key` argument need not be given.
251
+ Thus typical example is `eqx_list=
252
+ ((eqx.nn.Linear, 2, 20),
253
+ jax.nn.tanh,
254
+ (eqx.nn.Linear, 20, 20),
255
+ jax.nn.tanh,
256
+ (eqx.nn.Linear, 20, 20),
257
+ jax.nn.tanh,
258
+ (eqx.nn.Linear, 20, 1)
259
+ )`.
260
+ eqx_list_hyper
261
+ Default is None. Must be provided if `eqx_network` or
262
+ `eqx_hyper_network`
263
+ is not provided.
264
+ A tuple of tuples of successive equinox modules and activation functions to
265
+ describe the hyper network architecture. The inner tuples must have the eqx module or
266
+ activation function as first item, other items represent arguments
267
+ that could be required (eg. the size of the layer).
268
+ The `key` argument need not be given.
269
+ Thus typical example is `eqx_list=
270
+ ((eqx.nn.Linear, 2, 20),
271
+ jax.nn.tanh,
272
+ (eqx.nn.Linear, 20, 20),
273
+ jax.nn.tanh,
274
+ (eqx.nn.Linear, 20, 20),
275
+ jax.nn.tanh,
276
+ (eqx.nn.Linear, 20, 1)
277
+ )`.
278
+ input_transform
279
+ A function that will be called before entering the PINN. Its output(s)
280
+ must match the PINN inputs (except for the parameters).
281
+ Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
282
+ and the parameters. Default is no operation.
283
+ output_transform
284
+ A function with arguments begin the same input as the PINN, the PINN
285
+ output and the parameter. This function will be called after exiting the PINN.
286
+ Default is no operation.
287
+ slice_solution
288
+ A jnp.s\_ object which indicates which axis of the PINN output is
289
+ dedicated to the actual equation solution. Default None
290
+ means that slice_solution = the whole PINN output. This argument is useful
291
+ when the PINN is also used to output equation parameters for example
292
+ Note that it must be a slice and not an integer (a preprocessing of the
293
+ user provided argument takes care of it).
294
+ eqx_list_hyper
295
+ Same as eqx_list but for the hypernetwork. Default is None, i.e., we
296
+ use the same architecture as the PINN, up to the number of inputs and
297
+ ouputs. Note that the number of inputs must be of the hypernetwork must
298
+ be equal to the flattened concatenations for the array of parameters
299
+ designated by the `hyperparams` argument;
300
+ and the number of outputs must be equal to the number
301
+ of parameters in the pinn network
302
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
303
+ Default is None which leads to `eqx.is_inexact_array` in the class
304
+ instanciation. This tells Jinns what to consider as
305
+ a trainable parameter. Quoting from equinox documentation:
306
+ a PyTree whose structure should be a prefix of the structure of pytree.
307
+ Each of its leaves should either be 1) True, in which case the leaf or
308
+ subtree is kept; 2) False, in which case the leaf or subtree is
309
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
310
+
311
+ Returns
312
+ -------
313
+ hyperpinn
314
+ A HyperPINN instance or, when `shared_pinn_ouput` is not None,
315
+ a list of HyperPINN instances with the same structure is returned,
316
+ only differing by there final slicing of the network output.
317
+ hyperpinn.init_params
318
+ The initial set of parameters for the HyperPINN or a list of the latter
319
+ when `shared_pinn_ouput` is not None.
320
+
321
+ """
322
+ if eqx_network is None or eqx_hyper_network is None:
323
+ if eqx_list is None or key is None or eqx_list_hyper is None:
324
+ raise ValueError(
325
+ "If eqx_network is None or eqx_hyper_network is None, then"
326
+ " key and eqx_list and eqx_hyper_network must be provided"
327
+ )
328
+
329
+ ### Now we finetune the hypernetwork architecture
330
+
331
+ key, subkey = jax.random.split(key, 2)
332
+ # with warnings.catch_warnings():
333
+ # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
334
+ eqx_network = MLP(key=subkey, eqx_list=eqx_list)
335
+ # quick partitioning to get the params to get the correct number of neurons
336
+ # for the last layer of hyper network
337
+ params_mlp, _ = eqx.partition(eqx_network, eqx.is_inexact_array)
338
+ pinn_params_sum, _ = _get_param_nb(params_mlp)
339
+ # the number of parameters for the pinn will be the number of ouputs
340
+ # for the hyper network
341
+ if len(eqx_list_hyper[-1]) > 1:
342
+ eqx_list_hyper = eqx_list_hyper[:-1] + (
343
+ (eqx_list_hyper[-1][:2] + (pinn_params_sum,)),
344
+ )
345
+ else:
346
+ eqx_list_hyper = (
347
+ eqx_list_hyper[:-2]
348
+ + ((eqx_list_hyper[-2][:2] + (pinn_params_sum,)),)
349
+ + eqx_list_hyper[-1]
350
+ )
351
+ if len(eqx_list_hyper[0]) > 1:
352
+ eqx_list_hyper = (
353
+ (
354
+ (eqx_list_hyper[0][0],)
355
+ + (hypernet_input_size,)
356
+ + (eqx_list_hyper[0][2],)
357
+ ),
358
+ ) + eqx_list_hyper[1:]
359
+ else:
360
+ eqx_list_hyper = (
361
+ eqx_list_hyper[0]
362
+ + (
363
+ (
364
+ (eqx_list_hyper[1][0],)
365
+ + (hypernet_input_size,)
366
+ + (eqx_list_hyper[1][2],)
367
+ ),
368
+ )
369
+ + eqx_list_hyper[2:]
370
+ )
371
+ key, subkey = jax.random.split(key, 2)
372
+ # with warnings.catch_warnings():
373
+ # warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
374
+ eqx_hyper_network = MLP(key=subkey, eqx_list=eqx_list_hyper)
375
+
376
+ ### End of finetuning the hypernetwork architecture
377
+
378
+ with warnings.catch_warnings():
379
+ # Catch the equinox warning because we put the number of
380
+ # parameters as static while being jnp.Array. This this time
381
+ # this is correct to do so, because they are used as indices
382
+ # and will never be modified
383
+ warnings.filterwarnings(
384
+ "ignore", message="A JAX array is being set as static!"
385
+ )
386
+ hyperpinn = cls(
387
+ eqx_network=eqx_network,
388
+ eqx_hyper_network=eqx_hyper_network,
389
+ slice_solution=slice_solution,
390
+ eq_type=eq_type,
391
+ input_transform=input_transform,
392
+ output_transform=output_transform,
393
+ hyperparams=hyperparams,
394
+ hypernet_input_size=hypernet_input_size,
395
+ filter_spec=filter_spec,
396
+ )
397
+ return hyperpinn, hyperpinn.init_params_hyper
jinns/nn/_mlp.py ADDED
@@ -0,0 +1,192 @@
1
+ """
2
+ Implements utility function to create PINNs
3
+ """
4
+
5
+ from typing import Callable, Literal, Self, Union, Any
6
+ from dataclasses import InitVar
7
+ import jax
8
+ import equinox as eqx
9
+
10
+ from jaxtyping import Array, Key, PyTree, Float
11
+
12
+ from jinns.parameters._params import Params
13
+ from jinns.nn._pinn import PINN
14
+
15
+
16
+ class MLP(eqx.Module):
17
+ """
18
+ Custom MLP equinox module from a key and a eqx_list
19
+
20
+ Parameters
21
+ ----------
22
+ key : InitVar[Key]
23
+ A jax random key for the layer initializations.
24
+ eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
25
+ A tuple of tuples of successive equinox modules and activation functions to
26
+ describe the PINN architecture. The inner tuples must have the eqx module or
27
+ activation function as first item, other items represents arguments
28
+ that could be required (eg. the size of the layer).
29
+ The `key` argument need not be given.
30
+ Thus typical example is `eqx_list=
31
+ ((eqx.nn.Linear, 2, 20),
32
+ jax.nn.tanh,
33
+ (eqx.nn.Linear, 20, 20),
34
+ jax.nn.tanh,
35
+ (eqx.nn.Linear, 20, 20),
36
+ jax.nn.tanh,
37
+ (eqx.nn.Linear, 20, 1)
38
+ )`.
39
+ """
40
+
41
+ key: InitVar[Key] = eqx.field(kw_only=True)
42
+ eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
43
+ kw_only=True
44
+ )
45
+
46
+ # NOTE that the following should NOT be declared as static otherwise the
47
+ # eqx.partition that we use in the PINN module will misbehave
48
+ layers: list[eqx.Module] = eqx.field(init=False)
49
+
50
+ def __post_init__(self, key, eqx_list):
51
+ self.layers = []
52
+ # nb_keys_required = sum(1 if len(l) > 1 else 0 for l in eqx_list)
53
+ # keys = jax.random.split(key, nb_keys_required)
54
+ # we need a global split
55
+ # before the loop to maintain strict equivalency with eqx.nn.MLP
56
+ # for debugging purpose
57
+ k = 0
58
+ for l in eqx_list:
59
+ if len(l) == 1:
60
+ self.layers.append(l[0])
61
+ else:
62
+ key, subkey = jax.random.split(key, 2) # nb_keys_required)
63
+ self.layers.append(l[0](*l[1:], key=subkey))
64
+ k += 1
65
+
66
+ def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
67
+ for layer in self.layers:
68
+ t = layer(t)
69
+ return t
70
+
71
+
72
+ class PINN_MLP(PINN):
73
+ """
74
+ An implementable PINN based on a MLP architecture
75
+ """
76
+
77
+ # Here we could have a more complex __call__ method that redefined the
78
+ # parent's __call__. But there is no need for the simple PINN_MLP
79
+
80
+ @classmethod
81
+ def create(
82
+ cls,
83
+ eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
84
+ eqx_network: eqx.nn.MLP = None,
85
+ key: Key = None,
86
+ eqx_list: tuple[tuple[Callable, int, int] | Callable, ...] = None,
87
+ input_transform: Callable[
88
+ [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
89
+ ] = None,
90
+ output_transform: Callable[
91
+ [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
92
+ Float[Array, "output_dim"],
93
+ ] = None,
94
+ slice_solution: slice = None,
95
+ filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
96
+ ) -> tuple[Self, PyTree]:
97
+ r"""
98
+ Instanciate standard PINN MLP object. The actual NN is either passed as
99
+ a eqx.nn.MLP (`eqx_network` argument) or constructed as a custom
100
+ jinns.nn.MLP when `key` and `eqx_list` is provided.
101
+
102
+ Parameters
103
+ ----------
104
+ eq_type
105
+ A string with three possibilities.
106
+ "ODE": the MLP is called with one input `t`.
107
+ "statio_PDE": the MLP is called with one input `x`, `x`
108
+ can be high dimensional.
109
+ "nonstatio_PDE": the MLP is called with two inputs `t` and `x`, `x`
110
+ can be high dimensional.
111
+ **Note**: the input dimension as given in eqx_list has to match the sum
112
+ of the dimension of `t` + the dimension of `x` or the output dimension
113
+ after the `input_transform` function.
114
+ eqx_network
115
+ Default is None. A eqx.nn.MLP object that will be wrapped inside
116
+ our PINN_MLP object in order to make it easily jinns compatible.
117
+ key
118
+ Default is None. Must be provided with `eqx_list` if `eqx_network`
119
+ is not provided. A JAX random key that will be used to initialize the network
120
+ parameters.
121
+ eqx_list
122
+ Default is None. Must be provided if `eqx_network`
123
+ is not provided. A tuple of tuples of successive equinox modules and activation
124
+ functions to describe the MLP architecture. The inner tuples must have
125
+ the eqx module or activation function as first item, other items
126
+ represent arguments that could be required (eg. the size of the layer).
127
+
128
+ The `key` argument do not need to be given.
129
+
130
+ A typical example is `eqx_list = (
131
+ (eqx.nn.Linear, input_dim, 20),
132
+ (jax.nn.tanh,),
133
+ (eqx.nn.Linear, 20, 20),
134
+ (jax.nn.tanh,),
135
+ (eqx.nn.Linear, 20, 20),
136
+ (jax.nn.tanh,),
137
+ (eqx.nn.Linear, 20, output_dim)
138
+ )`.
139
+ input_transform
140
+ A function that will be called before entering the MLP. Its output(s)
141
+ must match the MLP inputs (except for the parameters).
142
+ Its inputs are the MLP inputs (`t` and/or `x` concatenated together)
143
+ and the parameters. Default is no operation.
144
+ output_transform
145
+ A function with arguments begin the same input as the MLP, the MLP
146
+ output and the parameter. This function will be called after exiting
147
+ the MLP.
148
+ Default is no operation.
149
+ slice_solution
150
+ A jnp.s\_ object which indicates which axis of the MLP output is
151
+ dedicated to the actual equation solution. Default None
152
+ means that slice_solution = the whole MLP output. This argument is
153
+ useful when the MLP is also used to output equation parameters for
154
+ example Note that it must be a slice and not an integer (a
155
+ preprocessing of the user provided argument takes care of it).
156
+
157
+ filter_spec : PyTree[Union[bool, Callable[[Any], bool]]]
158
+ Default is None which leads to `eqx.is_inexact_array` in the class
159
+ instanciation. This tells Jinns what to consider as
160
+ a trainable parameter. Quoting from equinox documentation:
161
+ a PyTree whose structure should be a prefix of the structure of pytree.
162
+ Each of its leaves should either be 1) True, in which case the leaf or
163
+ subtree is kept; 2) False, in which case the leaf or subtree is
164
+ replaced with replace; 3) a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
165
+
166
+ Returns
167
+ -------
168
+ mlp
169
+ A MLP instance or, when `shared_pinn_ouput` is not None,
170
+ a list of MLP instances with the same structure is returned,
171
+ only differing by there final slicing of the network output.
172
+ mlp.init_params
173
+ An initial set of parameters for the MLP or a list of the latter
174
+ when `shared_pinn_ouput` is not None.
175
+
176
+ """
177
+ if eqx_network is None:
178
+ if eqx_list is None or key is None:
179
+ raise ValueError(
180
+ "If eqx_network is None, then key and eqx_list must be provided"
181
+ )
182
+ eqx_network = MLP(key=key, eqx_list=eqx_list)
183
+
184
+ mlp = cls(
185
+ eqx_network=eqx_network,
186
+ slice_solution=slice_solution,
187
+ eq_type=eq_type,
188
+ input_transform=input_transform,
189
+ output_transform=output_transform,
190
+ filter_spec=filter_spec,
191
+ )
192
+ return mlp, mlp.init_params