jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/utils/_pinn.py DELETED
@@ -1,324 +0,0 @@
1
- """
2
- Implements utility function to create PINNs
3
- """
4
-
5
- from typing import Callable, Literal
6
- from dataclasses import InitVar
7
- import jax
8
- import jax.numpy as jnp
9
- import equinox as eqx
10
-
11
- from jaxtyping import Array, Key, PyTree, Float
12
-
13
- from jinns.parameters._params import Params, ParamsDict
14
-
15
-
16
- class _MLP(eqx.Module):
17
- """
18
- Class to construct an equinox module from a key and a eqx_list. To be used
19
- in pair with the function `create_PINN`.
20
-
21
- Parameters
22
- ----------
23
- key : InitVar[Key]
24
- A jax random key for the layer initializations.
25
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
26
- A tuple of tuples of successive equinox modules and activation functions to
27
- describe the PINN architecture. The inner tuples must have the eqx module or
28
- activation function as first item, other items represents arguments
29
- that could be required (eg. the size of the layer).
30
- The `key` argument need not be given.
31
- Thus typical example is `eqx_list=
32
- ((eqx.nn.Linear, 2, 20),
33
- jax.nn.tanh,
34
- (eqx.nn.Linear, 20, 20),
35
- jax.nn.tanh,
36
- (eqx.nn.Linear, 20, 20),
37
- jax.nn.tanh,
38
- (eqx.nn.Linear, 20, 1)
39
- )`.
40
- """
41
-
42
- key: InitVar[Key] = eqx.field(kw_only=True)
43
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
44
- kw_only=True
45
- )
46
-
47
- # NOTE that the following should NOT be declared as static otherwise the
48
- # eqx.partition that we use in the PINN module will misbehave
49
- layers: list[eqx.Module] = eqx.field(init=False)
50
-
51
- def __post_init__(self, key, eqx_list):
52
- self.layers = []
53
- for l in eqx_list:
54
- if len(l) == 1:
55
- self.layers.append(l[0])
56
- else:
57
- key, subkey = jax.random.split(key, 2)
58
- self.layers.append(l[0](*l[1:], key=subkey))
59
-
60
- def __call__(self, t: Float[Array, "input_dim"]) -> Float[Array, "output_dim"]:
61
- for layer in self.layers:
62
- t = layer(t)
63
- return t
64
-
65
-
66
- class PINN(eqx.Module):
67
- r"""
68
- A PINN object, i.e., a neural network compatible with the rest of jinns.
69
- This is typically created with `create_PINN` which creates iternally a
70
- `_MLP` object. However, a user could directly creates their PINN using this
71
- class by passing a eqx.Module (for argument `mlp`)
72
- that plays the role of the NN and that is
73
- already instanciated.
74
-
75
- Parameters
76
- ----------
77
- slice_solution : slice
78
- A jnp.s\_ object which indicates which axis of the PINN output is
79
- dedicated to the actual equation solution. Default None
80
- means that slice_solution = the whole PINN output. This argument is useful
81
- when the PINN is also used to output equation parameters for example
82
- Note that it must be a slice and not an integer (a preprocessing of the
83
- user provided argument takes care of it).
84
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
85
- A string with three possibilities.
86
- "ODE": the PINN is called with one input `t`.
87
- "statio_PDE": the PINN is called with one input `x`, `x`
88
- can be high dimensional.
89
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
90
- can be high dimensional.
91
- **Note**: the input dimension as given in eqx_list has to match the sum
92
- of the dimension of `t` + the dimension of `x` or the output dimension
93
- after the `input_transform` function.
94
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
95
- A function that will be called before entering the PINN. Its output(s)
96
- must match the PINN inputs (except for the parameters).
97
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
98
- and the parameters. Default is no operation.
99
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
100
- A function with arguments begin the same input as the PINN, the PINN
101
- output and the parameter. This function will be called after exiting the PINN.
102
- Default is no operation.
103
- output_slice : slice, default=None
104
- A jnp.s\_[] to determine the different dimension for the PINN.
105
- See `shared_pinn_outputs` argument of `create_PINN`.
106
- mlp : eqx.Module
107
- The actual neural network instanciated as an eqx.Module.
108
- """
109
-
110
- slice_solution: slice = eqx.field(static=True, kw_only=True)
111
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"] = eqx.field(
112
- static=True, kw_only=True
113
- )
114
- input_transform: Callable[
115
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
116
- ] = eqx.field(static=True, kw_only=True)
117
- output_transform: Callable[
118
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
119
- Float[Array, "output_dim"],
120
- ] = eqx.field(static=True, kw_only=True)
121
- output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
122
-
123
- mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
124
-
125
- params: PyTree = eqx.field(init=False)
126
- static: PyTree = eqx.field(init=False, static=True)
127
-
128
- def __post_init__(self, mlp):
129
- self.params, self.static = eqx.partition(mlp, eqx.is_inexact_array)
130
-
131
- @property
132
- def init_params(self) -> PyTree:
133
- """
134
- Returns an initial set of parameters
135
- """
136
- return self.params
137
-
138
- def __call__(
139
- self,
140
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
141
- params: Params | ParamsDict | PyTree,
142
- ) -> Float[Array, "output_dim"]:
143
- """
144
- Evaluate the PINN on some inputs with some params.
145
- """
146
- if len(inputs.shape) == 0:
147
- # This can happen often when the user directly provides some
148
- # collocation points (eg for plotting, whithout using
149
- # DataGenerators)
150
- inputs = inputs[None]
151
-
152
- try:
153
- model = eqx.combine(params.nn_params, self.static)
154
- except (KeyError, AttributeError, TypeError) as e: # give more flexibility
155
- model = eqx.combine(params, self.static)
156
- res = self.output_transform(
157
- inputs, model(self.input_transform(inputs, params)).squeeze(), params
158
- )
159
-
160
- if self.output_slice is not None:
161
- res = res[self.output_slice]
162
-
163
- ## force (1,) output for non vectorial solution (consistency)
164
- if not res.shape:
165
- return jnp.expand_dims(res, axis=-1)
166
- return res
167
-
168
-
169
- def create_PINN(
170
- key: Key,
171
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
172
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
173
- dim_x: int = 0,
174
- input_transform: Callable[
175
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
176
- ] = None,
177
- output_transform: Callable[
178
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
179
- Float[Array, "output_dim"],
180
- ] = None,
181
- shared_pinn_outputs: tuple[slice] = None,
182
- slice_solution: slice = None,
183
- ) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
184
- r"""
185
- Utility function to create a standard PINN neural network with the equinox
186
- library.
187
-
188
- Parameters
189
- ----------
190
- key
191
- A JAX random key that will be used to initialize the network
192
- parameters.
193
- eqx_list
194
- A tuple of tuples of successive equinox modules and activation
195
- functions to describe the PINN architecture. The inner tuples must have
196
- the eqx module or activation function as first item, other items
197
- represent arguments that could be required (eg. the size of the layer).
198
-
199
- The `key` argument do not need to be given.
200
-
201
- A typical example is `eqx_list = (
202
- (eqx.nn.Linear, input_dim, 20),
203
- (jax.nn.tanh,),
204
- (eqx.nn.Linear, 20, 20),
205
- (jax.nn.tanh,),
206
- (eqx.nn.Linear, 20, 20),
207
- (jax.nn.tanh,),
208
- (eqx.nn.Linear, 20, output_dim)
209
- )`.
210
- eq_type
211
- A string with three possibilities.
212
- "ODE": the PINN is called with one input `t`.
213
- "statio_PDE": the PINN is called with one input `x`, `x`
214
- can be high dimensional.
215
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
216
- can be high dimensional.
217
- **Note**: the input dimension as given in eqx_list has to match the sum
218
- of the dimension of `t` + the dimension of `x` or the output dimension
219
- after the `input_transform` function.
220
- dim_x
221
- An integer. The dimension of `x`. Default `0`.
222
- input_transform
223
- A function that will be called before entering the PINN. Its output(s)
224
- must match the PINN inputs (except for the parameters).
225
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
226
- and the parameters. Default is no operation.
227
- output_transform
228
- A function with arguments begin the same input as the PINN, the PINN
229
- output and the parameter. This function will be called after exiting
230
- the PINN.
231
- Default is no operation.
232
- shared_pinn_outputs
233
- Default is None, for a stantard PINN.
234
- A tuple of jnp.s\_[] (slices) to determine the different output for each
235
- network. In this case we return a list of PINNs, one for each output in
236
- shared_pinn_outputs. This is useful to create PINNs that share the
237
- same network and same parameters; **the user must then use the same
238
- parameter set in their manipulation**.
239
- See the notebook 2D Navier Stokes in pipeflow with metamodel for an
240
- example using this option.
241
- slice_solution
242
- A jnp.s\_ object which indicates which axis of the PINN output is
243
- dedicated to the actual equation solution. Default None
244
- means that slice_solution = the whole PINN output. This argument is
245
- useful when the PINN is also used to output equation parameters for
246
- example Note that it must be a slice and not an integer (a
247
- preprocessing of the user provided argument takes care of it).
248
-
249
-
250
- Returns
251
- -------
252
- pinn
253
- A PINN instance or, when `shared_pinn_ouput` is not None,
254
- a list of PINN instances with the same structure is returned,
255
- only differing by there final slicing of the network output.
256
- pinn.init_params
257
- An initial set of parameters for the PINN or a list of the latter
258
- when `shared_pinn_ouput` is not None.
259
-
260
- Raises
261
- ------
262
- RuntimeError
263
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
264
- "nonstatio_PDE"]`
265
- RuntimeError
266
- If we have a `dim_x > 0` and `eq_type == "ODE"`
267
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
268
- """
269
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
270
- raise RuntimeError("Wrong parameter value for eq_type")
271
-
272
- if eq_type == "ODE" and dim_x != 0:
273
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
274
-
275
- if eq_type != "ODE" and dim_x == 0:
276
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
277
-
278
- try:
279
- nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
280
- # last layer
281
- except IndexError:
282
- nb_outputs_declared = eqx_list[-2][2]
283
-
284
- if slice_solution is None:
285
- slice_solution = jnp.s_[0:nb_outputs_declared]
286
- if isinstance(slice_solution, int):
287
- # rewrite it as a slice to ensure that axis does not disappear when
288
- # indexing
289
- slice_solution = jnp.s_[slice_solution : slice_solution + 1]
290
-
291
- if input_transform is None:
292
-
293
- def input_transform(_in, _params):
294
- return _in
295
-
296
- if output_transform is None:
297
-
298
- def output_transform(_in_pinn, _out_pinn, _params):
299
- return _out_pinn
300
-
301
- mlp = _MLP(key=key, eqx_list=eqx_list)
302
-
303
- if shared_pinn_outputs is not None:
304
- pinns = []
305
- for output_slice in shared_pinn_outputs:
306
- pinn = PINN(
307
- mlp=mlp,
308
- slice_solution=slice_solution,
309
- eq_type=eq_type,
310
- input_transform=input_transform,
311
- output_transform=output_transform,
312
- output_slice=output_slice,
313
- )
314
- pinns.append(pinn)
315
- return pinns, [p.init_params for p in pinns]
316
- pinn = PINN(
317
- mlp=mlp,
318
- slice_solution=slice_solution,
319
- eq_type=eq_type,
320
- input_transform=input_transform,
321
- output_transform=output_transform,
322
- output_slice=None,
323
- )
324
- return pinn, pinn.init_params
jinns/utils/_ppinn.py DELETED
@@ -1,227 +0,0 @@
1
- """
2
- Implements utility function to create PINNs
3
- """
4
-
5
- from typing import Callable, Literal
6
- from dataclasses import InitVar
7
- import jax
8
- import jax.numpy as jnp
9
- import equinox as eqx
10
-
11
- from jaxtyping import Array, Key, PyTree, Float
12
-
13
- from jinns.parameters._params import Params, ParamsDict
14
-
15
- from jinns.utils._pinn import PINN, _MLP
16
-
17
-
18
- class PPINN(PINN):
19
- r"""
20
- A PPINN (Parallel PINN) object which mimicks the PFNN architecture from
21
- DeepXDE. This is in fact a PINN that encompasses several PINNs internally.
22
-
23
- Parameters
24
- ----------
25
- slice_solution : slice
26
- A jnp.s\_ object which indicates which axis of the PINN output is
27
- dedicated to the actual equation solution. Default None
28
- means that slice_solution = the whole PINN output. This argument is useful
29
- when the PINN is also used to output equation parameters for example
30
- Note that it must be a slice and not an integer (a preprocessing of the
31
- user provided argument takes care of it).
32
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
33
- A string with three possibilities.
34
- "ODE": the PINN is called with one input `t`.
35
- "statio_PDE": the PINN is called with one input `x`, `x`
36
- can be high dimensional.
37
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
38
- can be high dimensional.
39
- **Note**: the input dimension as given in eqx_list has to match the sum
40
- of the dimension of `t` + the dimension of `x` or the output dimension
41
- after the `input_transform` function.
42
- input_transform : Callable[[Float[Array, "input_dim"], Params], Float[Array, "output_dim"]]
43
- A function that will be called before entering the PINN. Its output(s)
44
- must match the PINN inputs (except for the parameters).
45
- Its inputs are the PINN inputs (`t` and/or `x` concatenated together)
46
- and the parameters. Default is no operation.
47
- output_transform : Callable[[Float[Array, "input_dim"], Float[Array, "output_dim"], Params], Float[Array, "output_dim"]]
48
- A function with arguments begin the same input as the PINN, the PINN
49
- output and the parameter. This function will be called after exiting the PINN.
50
- Default is no operation.
51
- mlp_list : list[eqx.Module]
52
- The actual neural networks instanciated as eqx.Modules
53
- """
54
-
55
- slice_solution: slice = eqx.field(static=True, kw_only=True)
56
- output_slice: slice = eqx.field(static=True, kw_only=True, default=None)
57
-
58
- mlp_list: InitVar[list[eqx.Module]] = eqx.field(kw_only=True)
59
-
60
- params: PyTree = eqx.field(init=False)
61
- static: PyTree = eqx.field(init=False, static=True)
62
-
63
- def __post_init__(self, mlp, mlp_list):
64
- super().__post_init__(
65
- mlp=mlp_list[0],
66
- )
67
- self.params, self.static = (), ()
68
- for mlp in mlp_list:
69
- params, static = eqx.partition(mlp, eqx.is_inexact_array)
70
- self.params = self.params + (params,)
71
- self.static = self.static + (static,)
72
-
73
- @property
74
- def init_params(self) -> PyTree:
75
- """
76
- Returns an initial set of parameters
77
- """
78
- return self.params
79
-
80
- def __call__(
81
- self,
82
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
83
- params: PyTree,
84
- ) -> Float[Array, "output_dim"]:
85
- """
86
- Evaluate the PPINN on some inputs with some params.
87
- """
88
- if len(inputs.shape) == 0:
89
- # This can happen often when the user directly provides some
90
- # collocation points (eg for plotting, whithout using
91
- # DataGenerators)
92
- inputs = inputs[None]
93
- transformed_inputs = self.input_transform(inputs, params)
94
-
95
- outs = []
96
- for params_, static in zip(params.nn_params, self.static):
97
- model = eqx.combine(params_, static)
98
- outs += [model(transformed_inputs)]
99
- # Note that below is then a global output transform
100
- res = self.output_transform(inputs, jnp.concatenate(outs, axis=0), params)
101
-
102
- ## force (1,) output for non vectorial solution (consistency)
103
- if not res.shape:
104
- return jnp.expand_dims(res, axis=-1)
105
- return res
106
-
107
-
108
- def create_PPINN(
109
- key: Key,
110
- eqx_list_list: list[tuple[tuple[Callable, int, int] | Callable, ...]],
111
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
112
- dim_x: int = 0,
113
- input_transform: Callable[
114
- [Float[Array, "input_dim"], Params], Float[Array, "output_dim"]
115
- ] = None,
116
- output_transform: Callable[
117
- [Float[Array, "input_dim"], Float[Array, "output_dim"], Params],
118
- Float[Array, "output_dim"],
119
- ] = None,
120
- slice_solution: slice = None,
121
- ) -> tuple[PINN | list[PINN], PyTree | list[PyTree]]:
122
- r"""
123
- Utility function to create a standard PINN neural network with the equinox
124
- library.
125
-
126
- Parameters
127
- ----------
128
- key
129
- A JAX random key that will be used to initialize the network
130
- parameters.
131
- eqx_list_list
132
- A list of `eqx_list` (see `create_PINN`). The input dimension must be the
133
- same for each sub-`eqx_list`. Then the parallel subnetworks can be
134
- different. Their respective outputs are concatenated.
135
- eq_type
136
- A string with three possibilities.
137
- "ODE": the PPINN is called with one input `t`.
138
- "statio_PDE": the PPINN is called with one input `x`, `x`
139
- can be high dimensional.
140
- "nonstatio_PDE": the PPINN is called with two inputs `t` and `x`, `x`
141
- can be high dimensional.
142
- **Note**: the input dimension as given in eqx_list has to match the sum
143
- of the dimension of `t` + the dimension of `x` or the output dimension
144
- after the `input_transform` function.
145
- dim_x
146
- An integer. The dimension of `x`. Default `0`.
147
- input_transform
148
- A function that will be called before entering the PPINN. Its output(s)
149
- must match the PPINN inputs (except for the parameters).
150
- Its inputs are the PPINN inputs (`t` and/or `x` concatenated together)
151
- and the parameters. Default is no operation.
152
- output_transform
153
- This function will be called after exiting
154
- the PPINN, i.e., on the concatenated outputs of all parallel networks
155
- Default is no operation.
156
- slice_solution
157
- A jnp.s\_ object which indicates which axis of the PPINN output is
158
- dedicated to the actual equation solution. Default None
159
- means that slice_solution = the whole PPINN output. This argument is
160
- useful when the PPINN is also used to output equation parameters for
161
- example Note that it must be a slice and not an integer (a
162
- preprocessing of the user provided argument takes care of it).
163
-
164
-
165
- Returns
166
- -------
167
- ppinn
168
- A PPINN instance
169
- ppinn.init_params
170
- An initial set of parameters for the PPINN
171
-
172
- Raises
173
- ------
174
- RuntimeError
175
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
176
- "nonstatio_PDE"]`
177
- RuntimeError
178
- If we have a `dim_x > 0` and `eq_type == "ODE"`
179
- or if we have a `dim_x = 0` and `eq_type != "ODE"`
180
- """
181
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
182
- raise RuntimeError("Wrong parameter value for eq_type")
183
-
184
- if eq_type == "ODE" and dim_x != 0:
185
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
186
-
187
- if eq_type != "ODE" and dim_x == 0:
188
- raise RuntimeError("Wrong parameter combination eq_type and dim_x")
189
-
190
- nb_outputs_declared = 0
191
- for eqx_list in eqx_list_list:
192
- try:
193
- nb_outputs_declared += eqx_list[-1][2] # normally we look for 3rd ele of
194
- # last layer
195
- except IndexError:
196
- nb_outputs_declared += eqx_list[-2][2]
197
-
198
- if slice_solution is None:
199
- slice_solution = jnp.s_[0:nb_outputs_declared]
200
- if isinstance(slice_solution, int):
201
- # rewrite it as a slice to ensure that axis does not disappear when
202
- # indexing
203
- slice_solution = jnp.s_[slice_solution : slice_solution + 1]
204
-
205
- if input_transform is None:
206
-
207
- def input_transform(_in, _params):
208
- return _in
209
-
210
- if output_transform is None:
211
-
212
- def output_transform(_in_pinn, _out_pinn, _params):
213
- return _out_pinn
214
-
215
- mlp_list = []
216
- for eqx_list in eqx_list_list:
217
- mlp_list.append(_MLP(key=key, eqx_list=eqx_list))
218
-
219
- ppinn = PPINN(
220
- mlp=None,
221
- mlp_list=mlp_list,
222
- slice_solution=slice_solution,
223
- eq_type=eq_type,
224
- input_transform=input_transform,
225
- output_transform=output_transform,
226
- )
227
- return ppinn, ppinn.init_params