jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/utils/_spinn.py DELETED
@@ -1,249 +0,0 @@
1
- """
2
- Implements utility function to create Separable PINNs
3
- https://arxiv.org/abs/2211.08761
4
- """
5
-
6
- from dataclasses import InitVar
7
- from typing import Callable, Literal
8
- import jax
9
- import jax.numpy as jnp
10
- import equinox as eqx
11
- from jaxtyping import Key, Array, Float, PyTree
12
-
13
- from jinns.parameters._params import Params, ParamsDict
14
-
15
-
16
- class _SPINN(eqx.Module):
17
- """
18
- Construct a Separable PINN as proposed in
19
- Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
20
-
21
- Parameters
22
- ----------
23
- key : InitVar[Key]
24
- A jax random key for the layer initializations.
25
- d : int
26
- The number of dimensions to treat separately, including time `t` if
27
- used for non-stationnary equations.
28
- eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
29
- A tuple of tuples of successive equinox modules and activation functions to
30
- describe the PINN architecture. The inner tuples must have the eqx module or
31
- activation function as first item, other items represents arguments
32
- that could be required (eg. the size of the layer).
33
- The `key` argument need not be given.
34
- Thus typical example is `eqx_list=
35
- ((eqx.nn.Linear, 2, 20),
36
- jax.nn.tanh,
37
- (eqx.nn.Linear, 20, 20),
38
- jax.nn.tanh,
39
- (eqx.nn.Linear, 20, 20),
40
- jax.nn.tanh,
41
- (eqx.nn.Linear, 20, 1)
42
- )`.
43
- """
44
-
45
- d: int = eqx.field(static=True, kw_only=True)
46
-
47
- key: InitVar[Key] = eqx.field(kw_only=True)
48
- eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
49
- kw_only=True
50
- )
51
-
52
- layers: list = eqx.field(init=False)
53
- separated_mlp: list = eqx.field(init=False)
54
-
55
- def __post_init__(self, key, eqx_list):
56
- self.separated_mlp = []
57
- for _ in range(self.d):
58
- self.layers = []
59
- for l in eqx_list:
60
- if len(l) == 1:
61
- self.layers.append(l[0])
62
- else:
63
- key, subkey = jax.random.split(key, 2)
64
- self.layers.append(l[0](*l[1:], key=subkey))
65
- self.separated_mlp.append(self.layers)
66
-
67
- def __call__(
68
- self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
69
- ) -> Float[Array, "d embed_dim*output_dim"]:
70
- outputs = []
71
- for d in range(self.d):
72
- t_ = inputs[d : d + 1]
73
- for layer in self.separated_mlp[d]:
74
- t_ = layer(t_)
75
- outputs += [t_]
76
- return jnp.asarray(outputs)
77
-
78
-
79
- class SPINN(eqx.Module):
80
- """
81
- A SPINN object compatible with the rest of jinns.
82
- This is typically created with `create_SPINN`.
83
-
84
- Parameters
85
- ----------
86
- d : int
87
- The number of dimensions to treat separately, including time `t` if
88
- used for non-stationnary equations.
89
-
90
- """
91
-
92
- d: int = eqx.field(static=True, kw_only=True)
93
- r: int = eqx.field(static=True, kw_only=True)
94
- eq_type: str = eqx.field(static=True, kw_only=True)
95
- m: int = eqx.field(static=True, kw_only=True)
96
-
97
- spinn_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
98
-
99
- params: PyTree = eqx.field(init=False)
100
- static: PyTree = eqx.field(init=False, static=True)
101
-
102
- def __post_init__(self, spinn_mlp):
103
- self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
104
-
105
- @property
106
- def init_params(self) -> PyTree:
107
- """
108
- Returns an initial set of parameters
109
- """
110
- return self.params
111
-
112
- def __call__(
113
- self,
114
- t_x: Float[Array, "batch_size 1+dim"],
115
- params: Params | ParamsDict | PyTree,
116
- ) -> Float[Array, "output_dim"]:
117
- """
118
- Evaluate the SPINN on some inputs with some params.
119
- """
120
- try:
121
- spinn = eqx.combine(params.nn_params, self.static)
122
- except (KeyError, AttributeError, TypeError) as e:
123
- spinn = eqx.combine(params, self.static)
124
- v_model = jax.vmap(spinn)
125
- res = v_model(t_x)
126
-
127
- a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
128
- b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
129
- res = jnp.stack(
130
- [
131
- jnp.einsum(
132
- f"{a} -> {b}",
133
- *(
134
- res[:, d, m * self.r : (m + 1) * self.r]
135
- for d in range(res.shape[1])
136
- ),
137
- )
138
- for m in range(self.m)
139
- ],
140
- axis=-1,
141
- ) # compute each output dimension
142
-
143
- # force (1,) output for non vectorial solution (consistency)
144
- if len(res.shape) == self.d:
145
- return jnp.expand_dims(res, axis=-1)
146
- return res
147
-
148
-
149
- def create_SPINN(
150
- key: Key,
151
- d: int,
152
- r: int,
153
- eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
154
- eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
155
- m: int = 1,
156
- ) -> tuple[SPINN, PyTree]:
157
- """
158
- Utility function to create a SPINN neural network with the equinox
159
- library.
160
-
161
- *Note* that a SPINN is not vmapped and expects the
162
- same batch size for each of its input axis. It directly outputs a solution
163
- of shape `(batchsize, batchsize)`. See the paper for more details.
164
-
165
- Parameters
166
- ----------
167
- key
168
- A JAX random key that will be used to initialize the network parameters
169
- d
170
- The number of dimensions to treat separately.
171
- r
172
- An integer. The dimension of the embedding.
173
- eqx_list
174
- A tuple of tuples of successive equinox modules and activation functions to
175
- describe the PINN architecture. The inner tuples must have the eqx module or
176
- activation function as first item, other items represents arguments
177
- that could be required (eg. the size of the layer).
178
- The `key` argument need not be given.
179
- Thus typical example is
180
- `eqx_list=((eqx.nn.Linear, 2, 20),
181
- jax.nn.tanh,
182
- (eqx.nn.Linear, 20, 20),
183
- jax.nn.tanh,
184
- (eqx.nn.Linear, 20, 20),
185
- jax.nn.tanh,
186
- (eqx.nn.Linear, 20, 1)
187
- )`.
188
- eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
189
- A string with three possibilities.
190
- "ODE": the PINN is called with one input `t`.
191
- "statio_PDE": the PINN is called with one input `x`, `x`
192
- can be high dimensional.
193
- "nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
194
- can be high dimensional.
195
- **Note**: the input dimension as given in eqx_list has to match the sum
196
- of the dimension of `t` + the dimension of `x` or the output dimension
197
- after the `input_transform` function.
198
- m
199
- The output dimension of the neural network. According to
200
- the SPINN article, a total embedding dimension of `r*m` is defined. We
201
- then sum groups of `r` embedding dimensions to compute each output.
202
- Default is 1.
203
-
204
-
205
-
206
- Returns
207
- -------
208
- spinn
209
- An instanciated SPINN
210
- spinn.init_params
211
- The initial set of parameters of the model
212
-
213
- Raises
214
- ------
215
- RuntimeError
216
- If the parameter value for eq_type is not in `["ODE", "statio_PDE",
217
- "nonstatio_PDE"]` and for various failing checks
218
- """
219
-
220
- if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
221
- raise RuntimeError("Wrong parameter value for eq_type")
222
-
223
- try:
224
- nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
225
- except IndexError:
226
- nb_inputs_declared = eqx_list[1][
227
- 1
228
- ] # but we can have, eg, a flatten first layer
229
- if nb_inputs_declared != 1:
230
- raise ValueError("Input dim must be set to 1 in SPINN!")
231
-
232
- try:
233
- nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
234
- # last layer
235
- except IndexError:
236
- nb_outputs_declared = eqx_list[-2][2]
237
- # but we can have, eg, a `jnp.exp` last layer
238
- if nb_outputs_declared != r * m:
239
- raise ValueError("Output dim must be set to r * m in SPINN!")
240
-
241
- if d > 24:
242
- raise ValueError(
243
- "Too many dimensions, not enough letters available in jnp.einsum"
244
- )
245
-
246
- spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
247
- spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
248
-
249
- return spinn, spinn.init_params
@@ -1,41 +0,0 @@
1
- jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
2
- jinns/data/_Batchs.py,sha256=oc7-N1wEbsEvbe9fjVFKG2OPoZJVEjzPm8uj_icACf4,817
3
- jinns/data/_DataGenerators.py,sha256=BWNF28jM7EN6wyJ3X2X1qjG4sjs6vHLgIVVY6aqhKS8,65099
4
- jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
5
- jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
6
- jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
7
- jinns/loss/_DynamicLoss.py,sha256=x1pJLZ_4P2iHhf5hjRhkWi1v5Q2dAH-v7Gctv4ax73E,25819
8
- jinns/loss/_DynamicLossAbstract.py,sha256=bqmPxyrcvZh_dL74DTpj-TGiFxchvG8qC6KhuGeyOoA,12006
9
- jinns/loss/_LossODE.py,sha256=lTJ3b2EnlYk1eAcldiVyFRh_XQrk83eYRQBuqWSiILg,23418
10
- jinns/loss/_LossPDE.py,sha256=ufECwKdQXFEy8h5fdDA2uFTjiO_sVdqi9nfP7NTyu60,48775
11
- jinns/loss/__init__.py,sha256=PRiJV9fd2GSwaCBVCPyh6pFc6pdA40jfb_T1YvO8ERc,712
12
- jinns/loss/_boundary_conditions.py,sha256=KL_UUajQWLtRFqmBO_lfEcxrUsY-398ySTgpfaHdeYk,12277
13
- jinns/loss/_loss_utils.py,sha256=RXu20VYeAelpuRIAJ67k3h7jdLsu2VIdjzmLVdiMjns,14223
14
- jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
15
- jinns/loss/_operators.py,sha256=Hf5zWUQCGRSm1SS8f6IZLqC8bL7suC-R4ckXNbXqoAI,21790
16
- jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
17
- jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
18
- jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
19
- jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
20
- jinns/plot/_plot.py,sha256=jS0fdI9VHCd6HL4K8ZzDjmMr10EJFqU6a0S4iWpEzFk,11433
21
- jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- jinns/solver/_rar.py,sha256=BCvUdNLicdU1YizHKcSaJLmkUkCybyGmTRjHnjoiQKA,9730
23
- jinns/solver/_solve.py,sha256=TR-RuH2tYIFq6YhbZ-Aa7r8OYy7CstjEYNm2RMI3xck,22641
24
- jinns/solver/_utils.py,sha256=b2zYvwZY_fU0NMNWvUEMvHez9s7hwcxfpGzQlz5F6HA,5762
25
- jinns/utils/__init__.py,sha256=nx7wi4RfgeHr5wLM_c1BAG8UQteTUTK6B0t9enmTn1E,243
26
- jinns/utils/_containers.py,sha256=a7A-iUApnjc1YVc7bdt9tKUvHHPDOKMB9OfdrDZGWN8,1450
27
- jinns/utils/_hyperpinn.py,sha256=hcdHRHXer03I3Twupok5moPR1vRGgPQUCjFHAZMCMdc,16872
28
- jinns/utils/_pinn.py,sha256=8wDliUrDbBUKWllJgllapNX5BLLHMWF7Mz2CaEA-6vk,12667
29
- jinns/utils/_ppinn.py,sha256=K4xuGaT-4XddIxsi1wOYO24G9l1RAY2m4bZqP7C3klI,8690
30
- jinns/utils/_save_load.py,sha256=BngwUjkm5jibctnUCB8klMypYkfE6DKLaoWehwWnTNU,8585
31
- jinns/utils/_spinn.py,sha256=tjly3qNZj9C2JntFKuwrLfFAbVfvy7Z5HPjj0ymfnLk,8337
32
- jinns/utils/_types.py,sha256=EQXxUAipwy7lVk39hYw4NDf_a66OP01w_T3ziPITxpU,1901
33
- jinns/utils/_utils.py,sha256=hoRcJqcTuQi_Ip40oI4EbxW46E1rp2C01_HfuCpwKRM,2932
34
- jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
35
- jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
36
- jinns-1.2.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
37
- jinns-1.2.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
38
- jinns-1.2.0.dist-info/METADATA,sha256=PxuzXslHG4YzOcuwDY3ulo9zJaF95ChjK9y584eMnN0,4662
39
- jinns-1.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
40
- jinns-1.2.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
41
- jinns-1.2.0.dist-info/RECORD,,