jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -0
- jinns/data/_Batchs.py +27 -0
- jinns/data/_DataGenerators.py +904 -1203
- jinns/data/__init__.py +4 -8
- jinns/experimental/__init__.py +0 -2
- jinns/experimental/_diffrax_solver.py +5 -5
- jinns/loss/_DynamicLoss.py +282 -305
- jinns/loss/_DynamicLossAbstract.py +322 -167
- jinns/loss/_LossODE.py +324 -322
- jinns/loss/_LossPDE.py +652 -1027
- jinns/loss/__init__.py +21 -5
- jinns/loss/_boundary_conditions.py +87 -41
- jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
- jinns/loss/_loss_weights.py +59 -0
- jinns/loss/_operators.py +78 -72
- jinns/parameters/__init__.py +6 -0
- jinns/parameters/_derivative_keys.py +521 -0
- jinns/parameters/_params.py +115 -0
- jinns/plot/__init__.py +5 -0
- jinns/{data/_display.py → plot/_plot.py} +98 -75
- jinns/solver/_rar.py +183 -39
- jinns/solver/_solve.py +151 -124
- jinns/utils/__init__.py +3 -9
- jinns/utils/_containers.py +37 -44
- jinns/utils/_hyperpinn.py +224 -119
- jinns/utils/_pinn.py +183 -111
- jinns/utils/_save_load.py +121 -56
- jinns/utils/_spinn.py +113 -86
- jinns/utils/_types.py +64 -0
- jinns/utils/_utils.py +6 -160
- jinns/validation/_validation.py +48 -140
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
- jinns-1.1.0.dist-info/RECORD +39 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
- jinns/experimental/_sinuspinn.py +0 -135
- jinns/experimental/_spectralpinn.py +0 -87
- jinns/solver/_seq2seq.py +0 -157
- jinns/utils/_optim.py +0 -147
- jinns/utils/_utils_uspinn.py +0 -727
- jinns-0.9.0.dist-info/RECORD +0 -36
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/validation/_validation.py
CHANGED
|
@@ -2,27 +2,24 @@
|
|
|
2
2
|
Implements some validation functions and their associated hyperparameter
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from __future__ import (
|
|
6
|
+
annotations,
|
|
7
|
+
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
+
|
|
6
9
|
import abc
|
|
7
|
-
from typing import Union
|
|
10
|
+
from typing import TYPE_CHECKING, Union
|
|
8
11
|
import equinox as eqx
|
|
9
12
|
import jax
|
|
10
13
|
import jax.numpy as jnp
|
|
11
|
-
from jaxtyping import Array
|
|
12
|
-
|
|
13
|
-
import jinns.data
|
|
14
|
-
from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
|
|
14
|
+
from jaxtyping import Array
|
|
15
|
+
|
|
15
16
|
from jinns.data._DataGenerators import (
|
|
16
|
-
DataGeneratorODE,
|
|
17
|
-
CubicMeshPDEStatio,
|
|
18
|
-
CubicMeshPDENonStatio,
|
|
19
|
-
DataGeneratorParameter,
|
|
20
|
-
DataGeneratorObservations,
|
|
21
|
-
DataGeneratorObservationsMultiPINNs,
|
|
22
17
|
append_obs_batch,
|
|
23
18
|
append_param_batch,
|
|
24
19
|
)
|
|
25
|
-
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from jinns.utils._types import *
|
|
26
23
|
|
|
27
24
|
# Using eqx Module for the DataClass + Pytree inheritance
|
|
28
25
|
# Abstract class and abstract/final pattern is used
|
|
@@ -35,14 +32,16 @@ class AbstractValidationModule(eqx.Module):
|
|
|
35
32
|
2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
|
|
36
33
|
"""
|
|
37
34
|
|
|
38
|
-
call_every: eqx.AbstractVar[
|
|
35
|
+
call_every: eqx.AbstractVar[int] = eqx.field(
|
|
36
|
+
kw_only=True
|
|
37
|
+
) # Mandatory for all validation step,
|
|
39
38
|
# it tells that the validation step is performed every call_every
|
|
40
39
|
# iterations.
|
|
41
40
|
|
|
42
41
|
@abc.abstractmethod
|
|
43
42
|
def __call__(
|
|
44
|
-
self, params:
|
|
45
|
-
) -> tuple["AbstractValidationModule",
|
|
43
|
+
self, params: Params | ParamsDict
|
|
44
|
+
) -> tuple["AbstractValidationModule", bool, Array, bool]:
|
|
46
45
|
raise NotImplementedError
|
|
47
46
|
|
|
48
47
|
|
|
@@ -53,37 +52,44 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
53
52
|
for more complicated validation strategy.
|
|
54
53
|
"""
|
|
55
54
|
|
|
56
|
-
loss:
|
|
57
|
-
|
|
55
|
+
loss: AnyLoss = eqx.field(kw_only=True) # NOTE that
|
|
56
|
+
# there used to be a deepcopy here which has been suppressed. 1) No need
|
|
57
|
+
# because loss are now eqx.Module (immutable) so no risk of in-place
|
|
58
|
+
# modification. 2) deepcopy is buggy with equinox, InitVar etc. (see issue
|
|
59
|
+
# #857 on equinox github)
|
|
60
|
+
validation_data: Union[AnyDataGenerator] = eqx.field(kw_only=True)
|
|
61
|
+
validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
|
|
62
|
+
kw_only=True, default=None
|
|
58
63
|
)
|
|
59
|
-
validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
|
|
60
|
-
validation_param_data: Union[DataGeneratorParameter, None] = None
|
|
61
64
|
validation_obs_data: Union[
|
|
62
65
|
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
63
|
-
] = None
|
|
64
|
-
call_every:
|
|
65
|
-
early_stopping:
|
|
66
|
+
] = eqx.field(kw_only=True, default=None)
|
|
67
|
+
call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
|
|
68
|
+
early_stopping: bool = eqx.field(
|
|
69
|
+
kw_only=True, default=True
|
|
70
|
+
) # globally control if early stopping happens
|
|
66
71
|
|
|
67
|
-
patience: Union[
|
|
72
|
+
patience: Union[int] = eqx.field(kw_only=True, default=10)
|
|
68
73
|
best_val_loss: Array = eqx.field(
|
|
69
|
-
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
|
|
74
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
|
|
70
75
|
)
|
|
71
76
|
|
|
72
77
|
counter: Array = eqx.field(
|
|
73
|
-
converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
|
|
78
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(0.0), kw_only=True
|
|
74
79
|
)
|
|
75
80
|
|
|
76
|
-
def __call__(
|
|
81
|
+
def __call__(
|
|
82
|
+
self, params: AnyParams
|
|
83
|
+
) -> tuple["ValidationLoss", bool, float, AnyParams]:
|
|
77
84
|
# do in-place mutation
|
|
78
|
-
|
|
85
|
+
|
|
86
|
+
validation_data, val_batch = self.validation_data.get_batch()
|
|
79
87
|
if self.validation_param_data is not None:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
)
|
|
88
|
+
validation_param_data, param_batch = self.validation_param_data.get_batch()
|
|
89
|
+
val_batch = append_param_batch(val_batch, param_batch)
|
|
83
90
|
if self.validation_obs_data is not None:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
)
|
|
91
|
+
validation_obs_data, obs_batch = self.validation_obs_data.get_batch()
|
|
92
|
+
val_batch = append_obs_batch(val_batch, obs_batch)
|
|
87
93
|
|
|
88
94
|
validation_loss_value, _ = self.loss(params, val_batch)
|
|
89
95
|
(counter, best_val_loss, update_best_params) = jax.lax.cond(
|
|
@@ -93,9 +99,14 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
93
99
|
(self.counter, self.best_val_loss),
|
|
94
100
|
)
|
|
95
101
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
102
|
+
new = eqx.tree_at(lambda t: t.validation_data, self, validation_data)
|
|
103
|
+
if self.validation_param_data is not None:
|
|
104
|
+
new = eqx.tree_at(
|
|
105
|
+
lambda t: t.validation_param_data, new, validation_param_data
|
|
106
|
+
)
|
|
107
|
+
if self.validation_obs_data is not None:
|
|
108
|
+
new = eqx.tree_at(lambda t: t.validation_obs_data, new, validation_obs_data)
|
|
109
|
+
new = eqx.tree_at(lambda t: t.counter, new, counter)
|
|
99
110
|
new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
|
|
100
111
|
|
|
101
112
|
bool_early_stopping = jax.lax.cond(
|
|
@@ -109,106 +120,3 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
109
120
|
)
|
|
110
121
|
# return `new` cause no in-place modification of the eqx.Module
|
|
111
122
|
return (new, bool_early_stopping, validation_loss_value, update_best_params)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if __name__ == "__main__":
|
|
115
|
-
import jax
|
|
116
|
-
import jax.numpy as jnp
|
|
117
|
-
import jax.random as random
|
|
118
|
-
from jinns.loss import BurgerEquation
|
|
119
|
-
|
|
120
|
-
key = random.PRNGKey(1)
|
|
121
|
-
key, subkey = random.split(key)
|
|
122
|
-
|
|
123
|
-
n = 50
|
|
124
|
-
nb = 2 * 2 * 10
|
|
125
|
-
nt = 10
|
|
126
|
-
omega_batch_size = 10
|
|
127
|
-
omega_border_batch_size = 10
|
|
128
|
-
temporal_batch_size = 4
|
|
129
|
-
dim = 1
|
|
130
|
-
xmin = 0
|
|
131
|
-
xmax = 1
|
|
132
|
-
tmin, tmax = 0, 1
|
|
133
|
-
method = "uniform"
|
|
134
|
-
|
|
135
|
-
val_data = jinns.data.CubicMeshPDENonStatio(
|
|
136
|
-
subkey,
|
|
137
|
-
n,
|
|
138
|
-
nb,
|
|
139
|
-
nt,
|
|
140
|
-
omega_batch_size,
|
|
141
|
-
omega_border_batch_size,
|
|
142
|
-
temporal_batch_size,
|
|
143
|
-
dim,
|
|
144
|
-
(xmin,),
|
|
145
|
-
(xmax,),
|
|
146
|
-
tmin,
|
|
147
|
-
tmax,
|
|
148
|
-
method,
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
eqx_list = [
|
|
152
|
-
[eqx.nn.Linear, 2, 50],
|
|
153
|
-
[jax.nn.tanh],
|
|
154
|
-
[eqx.nn.Linear, 50, 50],
|
|
155
|
-
[jax.nn.tanh],
|
|
156
|
-
[eqx.nn.Linear, 50, 50],
|
|
157
|
-
[jax.nn.tanh],
|
|
158
|
-
[eqx.nn.Linear, 50, 50],
|
|
159
|
-
[jax.nn.tanh],
|
|
160
|
-
[eqx.nn.Linear, 50, 50],
|
|
161
|
-
[jax.nn.tanh],
|
|
162
|
-
[eqx.nn.Linear, 50, 2],
|
|
163
|
-
]
|
|
164
|
-
|
|
165
|
-
key, subkey = random.split(key)
|
|
166
|
-
u = jinns.utils.create_PINN(
|
|
167
|
-
subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
|
|
168
|
-
)
|
|
169
|
-
init_nn_params = u.init_params()
|
|
170
|
-
|
|
171
|
-
dyn_loss = BurgerEquation()
|
|
172
|
-
loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
|
|
173
|
-
|
|
174
|
-
key, subkey = random.split(key)
|
|
175
|
-
loss = jinns.loss.LossPDENonStatio(
|
|
176
|
-
u=u,
|
|
177
|
-
loss_weights=loss_weights,
|
|
178
|
-
dynamic_loss=dyn_loss,
|
|
179
|
-
norm_key=subkey,
|
|
180
|
-
norm_borders=(-1, 1),
|
|
181
|
-
)
|
|
182
|
-
print(id(loss))
|
|
183
|
-
validation = ValidationLoss(
|
|
184
|
-
call_every=250,
|
|
185
|
-
early_stopping=True,
|
|
186
|
-
patience=1000,
|
|
187
|
-
loss=loss,
|
|
188
|
-
validation_data=val_data,
|
|
189
|
-
validation_param_data=None,
|
|
190
|
-
)
|
|
191
|
-
print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
|
|
192
|
-
|
|
193
|
-
init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
|
|
194
|
-
|
|
195
|
-
print(validation.loss is loss)
|
|
196
|
-
loss.evaluate(init_params, val_data.get_batch())
|
|
197
|
-
print(loss.norm_key)
|
|
198
|
-
print("Call validation once")
|
|
199
|
-
validation, _, _ = validation(init_params)
|
|
200
|
-
print(validation.loss is loss)
|
|
201
|
-
print(validation.loss.norm_key == loss.norm_key)
|
|
202
|
-
print("Crate new pytree from validation and call it once")
|
|
203
|
-
new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
|
|
204
|
-
print(validation.loss is new_val.loss) # FALSE
|
|
205
|
-
# test if attribute have been modified
|
|
206
|
-
new_val, _, _ = new_val(init_params)
|
|
207
|
-
print(f"{new_val.loss is loss=}")
|
|
208
|
-
print(f"{loss.norm_key=}")
|
|
209
|
-
print(f"{validation.loss.norm_key=}")
|
|
210
|
-
print(f"{new_val.loss.norm_key=}")
|
|
211
|
-
print(f"{new_val.loss.norm_key == loss.norm_key=}")
|
|
212
|
-
print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
|
|
213
|
-
print(new_val.counter)
|
|
214
|
-
print(validation.counter)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -10,20 +10,21 @@ Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
|
|
|
10
10
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
11
|
Classifier: Development Status :: 4 - Beta
|
|
12
12
|
Classifier: Programming Language :: Python
|
|
13
|
-
Requires-Python: >=3.
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
+
License-File: AUTHORS
|
|
16
17
|
Requires-Dist: numpy
|
|
17
18
|
Requires-Dist: jax
|
|
18
19
|
Requires-Dist: jaxopt
|
|
19
20
|
Requires-Dist: optax
|
|
20
|
-
Requires-Dist: equinox
|
|
21
|
+
Requires-Dist: equinox >0.11.3
|
|
21
22
|
Requires-Dist: jax-tqdm
|
|
22
23
|
Requires-Dist: diffrax
|
|
23
24
|
Requires-Dist: matplotlib
|
|
24
25
|
Provides-Extra: notebook
|
|
25
26
|
Requires-Dist: jupyter ; extra == 'notebook'
|
|
26
|
-
Requires-Dist:
|
|
27
|
+
Requires-Dist: seaborn ; extra == 'notebook'
|
|
27
28
|
|
|
28
29
|
jinns
|
|
29
30
|
=====
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
|
|
2
|
+
jinns/data/_Batchs.py,sha256=BLxTDiFb6o9M6Irc2_HKmpr8IgA159u_kJIbCBZ490E,926
|
|
3
|
+
jinns/data/_DataGenerators.py,sha256=TuQDPI8NGx4WnCGtUp8v7o7GcnDfVZvL5E34JI-9Lmw,58455
|
|
4
|
+
jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
|
|
5
|
+
jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
|
|
6
|
+
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
7
|
+
jinns/loss/_DynamicLoss.py,sha256=WGbAuWnNfsbzUlWEiW_ARd4kI3jmHwdqPjxLC-wCA6s,25753
|
|
8
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=Xyt28Oej_zlhcV3f6cw2vnAKyRJhXBiA63CsdL3PihU,13767
|
|
9
|
+
jinns/loss/_LossODE.py,sha256=ThWPse6Gn5crM3_tzwZCBx-usoD0xWu6y1n0GVl2dpI,23422
|
|
10
|
+
jinns/loss/_LossPDE.py,sha256=R9kNQiaFbFx2eCMdjB7ie7UJ9pJW7PmvHijioNgu-bs,49117
|
|
11
|
+
jinns/loss/__init__.py,sha256=Fm4QAHaVmp0CA7HSwb7KUctwdXnNZ9v5KmTqpeoYPaE,669
|
|
12
|
+
jinns/loss/_boundary_conditions.py,sha256=O0D8eWsFfvNNeO20PQ0rUKBI_MDqaBvqChfXaztZoL4,16679
|
|
13
|
+
jinns/loss/_loss_utils.py,sha256=44J-VF6dxT_o5BcNWFOiLpY40c35YnAxxZkoNtdtcZc,13689
|
|
14
|
+
jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
|
|
15
|
+
jinns/loss/_operators.py,sha256=o-Ljp_9_HXB9Mhm-ANh6ouNw4_PsqLJAha7dFDGl_nQ,10781
|
|
16
|
+
jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
|
|
17
|
+
jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
|
|
18
|
+
jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
|
|
19
|
+
jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
|
|
20
|
+
jinns/plot/_plot.py,sha256=ZGIJdGwEd3NlHRTq_2sOfEH_CtOkvPwdgCMct-nQlJE,11691
|
|
21
|
+
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
jinns/solver/_rar.py,sha256=BAr0qHNtCQ11XRhxRDed8JZPSWBm4E5FT3VOGQPrm4g,22728
|
|
23
|
+
jinns/solver/_solve.py,sha256=iUkvZ9f-0bxO2eJlNv7wLk5oES69LNbDxzOgBAjbHTg,21106
|
|
24
|
+
jinns/utils/__init__.py,sha256=CNxcb_AYzA2aeDsYwLfIuT1zy8NU7LElU160eHCj2oA,174
|
|
25
|
+
jinns/utils/_containers.py,sha256=_PLIkJHY-jE3nigjuAiYE5USJr7rpXCDsNULRKZtqnU,1213
|
|
26
|
+
jinns/utils/_hyperpinn.py,sha256=7fAy6Xv7CkIeIYCw7u9RB0gh4aITMfxz6JCmJ685KB8,16351
|
|
27
|
+
jinns/utils/_pinn.py,sha256=eiw_i72D7-CxZXXOBpgOgDndwGd5sSnFZNk-Rg-6xy8,12977
|
|
28
|
+
jinns/utils/_save_load.py,sha256=kVDeQrpPf7j7kYUy1gWGCr4_QyaBplRbPSitkWTQnQA,8574
|
|
29
|
+
jinns/utils/_spinn.py,sha256=lesBjXwoj3UUzfecPdh1wBFCY0BlA-q0Wb1pDv0RVYA,9211
|
|
30
|
+
jinns/utils/_types.py,sha256=P_dS0odrHbyalYJ0FjS6q0tkXAGr-4GArsiyJYrB1ho,1878
|
|
31
|
+
jinns/utils/_utils.py,sha256=Ow8xB516E7yHDZatokVJHHFNPDu6fXr9-NmraUXjjyw,1819
|
|
32
|
+
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
33
|
+
jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
|
|
34
|
+
jinns-1.1.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
35
|
+
jinns-1.1.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
36
|
+
jinns-1.1.0.dist-info/METADATA,sha256=3Qk885oguf6S_WPHd9KCFVIWP21nJqX9zFWoS9ZI-T0,2536
|
|
37
|
+
jinns-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
38
|
+
jinns-1.1.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
39
|
+
jinns-1.1.0.dist-info/RECORD,,
|
jinns/experimental/_sinuspinn.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import equinox as eqx
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
from jinns.utils._pinn import PINN
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
-
out, in_ = weight.shape
|
|
9
|
-
stddev = 1e-2
|
|
10
|
-
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class _SinusPINN(eqx.Module):
|
|
14
|
-
"""
|
|
15
|
-
A specific PINN whose layers are x_sin2x functions whose frequencies are
|
|
16
|
-
determined by an other network
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
layers_pinn: list
|
|
20
|
-
layers_aux_nn: list
|
|
21
|
-
|
|
22
|
-
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
23
|
-
"""
|
|
24
|
-
Parameters
|
|
25
|
-
----------
|
|
26
|
-
key
|
|
27
|
-
A jax random key
|
|
28
|
-
list_layers_pinn
|
|
29
|
-
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
-
list_layers_aux_nn
|
|
31
|
-
A list as eqx_list in jinns' PINN utility for the network which outputs
|
|
32
|
-
the PINN's activation frequencies
|
|
33
|
-
"""
|
|
34
|
-
self.layers_pinn = []
|
|
35
|
-
for l in list_layers_pinn:
|
|
36
|
-
if len(l) == 1:
|
|
37
|
-
self.layers_pinn.append(l[0])
|
|
38
|
-
else:
|
|
39
|
-
key, subkey = jax.random.split(key, 2)
|
|
40
|
-
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
-
self.layers_aux_nn = []
|
|
42
|
-
for idx, l in enumerate(list_layers_aux_nn):
|
|
43
|
-
if len(l) == 1:
|
|
44
|
-
self.layers_aux_nn.append(l[0])
|
|
45
|
-
else:
|
|
46
|
-
key, subkey = jax.random.split(key, 2)
|
|
47
|
-
linear_layer = l[0](*l[1:], key=subkey)
|
|
48
|
-
key, subkey = jax.random.split(key, 2)
|
|
49
|
-
linear_layer = eqx.tree_at(
|
|
50
|
-
lambda l: l.weight,
|
|
51
|
-
linear_layer,
|
|
52
|
-
almost_zero_init(linear_layer.weight, subkey),
|
|
53
|
-
)
|
|
54
|
-
if (idx == len(list_layers_aux_nn) - 1) or (
|
|
55
|
-
idx == len(list_layers_aux_nn) - 2
|
|
56
|
-
):
|
|
57
|
-
# for the last layer: almost 0 weights and 0.5 bias
|
|
58
|
-
linear_layer = eqx.tree_at(
|
|
59
|
-
lambda l: l.bias,
|
|
60
|
-
linear_layer,
|
|
61
|
-
0.5 * jnp.ones(linear_layer.bias.shape),
|
|
62
|
-
)
|
|
63
|
-
else:
|
|
64
|
-
# for the other previous layers:
|
|
65
|
-
# almost 0 weight and 0 bias
|
|
66
|
-
linear_layer = eqx.tree_at(
|
|
67
|
-
lambda l: l.bias,
|
|
68
|
-
linear_layer,
|
|
69
|
-
jnp.zeros(linear_layer.bias.shape),
|
|
70
|
-
)
|
|
71
|
-
self.layers_aux_nn.append(linear_layer)
|
|
72
|
-
|
|
73
|
-
## init to zero the frequency network except last biases
|
|
74
|
-
# key, subkey = jax.random.split(key, 2)
|
|
75
|
-
# _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
|
|
76
|
-
# key, subkey = jax.random.split(key, 2)
|
|
77
|
-
# _pinn = init_linear_bias(_pinn, zero_init, subkey)
|
|
78
|
-
# print(_pinn)
|
|
79
|
-
# print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
|
|
80
|
-
# p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
|
|
81
|
-
# _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
|
|
82
|
-
# jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
|
|
83
|
-
# #, is_leaf=lambda
|
|
84
|
-
# #p:not isinstance(p, eqx.nn.Linear))
|
|
85
|
-
|
|
86
|
-
def __call__(self, x):
|
|
87
|
-
x_ = x.copy()
|
|
88
|
-
# forward pass in the network which determines the freq
|
|
89
|
-
for layer in self.layers_aux_nn:
|
|
90
|
-
x_ = layer(x_)
|
|
91
|
-
freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
|
|
92
|
-
x_ = x.copy()
|
|
93
|
-
# forward pass through the actual PINN
|
|
94
|
-
for idx, layer in enumerate(self.layers_pinn):
|
|
95
|
-
if idx % 2 == 0:
|
|
96
|
-
# Currently: every two layer we have an activation
|
|
97
|
-
# requiring a frequency
|
|
98
|
-
x_ = layer(x_)
|
|
99
|
-
else:
|
|
100
|
-
x_ = layer(x_, freq_list[(idx - 1) // 2])
|
|
101
|
-
return x_
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
class sinusPINN(PINN):
|
|
105
|
-
"""
|
|
106
|
-
MUST inherit from PINN to pass all the checks
|
|
107
|
-
|
|
108
|
-
HOWEVER we dot not bother with reimplementing anything
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
112
|
-
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
113
|
-
key, subkey = jax.random.split(key, 2)
|
|
114
|
-
_pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
|
|
115
|
-
|
|
116
|
-
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
117
|
-
|
|
118
|
-
def init_params(self):
|
|
119
|
-
return self.params
|
|
120
|
-
|
|
121
|
-
def __call__(self, x, params):
|
|
122
|
-
try:
|
|
123
|
-
model = eqx.combine(params["nn_params"], self.static)
|
|
124
|
-
except (KeyError, TypeError) as e: # give more flexibility
|
|
125
|
-
model = eqx.combine(params, self.static)
|
|
126
|
-
res = model(x)
|
|
127
|
-
if not res.shape:
|
|
128
|
-
return jnp.expand_dims(res, axis=-1)
|
|
129
|
-
return res
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
|
|
133
|
-
""" """
|
|
134
|
-
u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
|
|
135
|
-
return u
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import equinox as eqx
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
from jinns.utils._pinn import PINN
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
-
out, in_ = weight.shape
|
|
9
|
-
stddev = 1e-2
|
|
10
|
-
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class _SpectralPINN(eqx.Module):
|
|
14
|
-
"""
|
|
15
|
-
A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
|
|
16
|
-
(Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
layers_pinn: list
|
|
20
|
-
nbands: int
|
|
21
|
-
|
|
22
|
-
def __init__(self, key, list_layers_pinn, nbands):
|
|
23
|
-
"""
|
|
24
|
-
Parameters
|
|
25
|
-
----------
|
|
26
|
-
key
|
|
27
|
-
A jax random key
|
|
28
|
-
list_layers_pinn
|
|
29
|
-
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
-
nbands
|
|
31
|
-
Number of spectral bands (i.e., neurones in the single layer of the PINN)
|
|
32
|
-
"""
|
|
33
|
-
self.nbands = nbands
|
|
34
|
-
self.layers_pinn = []
|
|
35
|
-
for l in list_layers_pinn:
|
|
36
|
-
if len(l) == 1:
|
|
37
|
-
self.layers_pinn.append(l[0])
|
|
38
|
-
else:
|
|
39
|
-
key, subkey = jax.random.split(key, 2)
|
|
40
|
-
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
-
|
|
42
|
-
def __call__(self, x):
|
|
43
|
-
# forward pass through the actual PINN
|
|
44
|
-
for layer in self.layers_pinn:
|
|
45
|
-
x = layer(x)
|
|
46
|
-
|
|
47
|
-
return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class spectralPINN(PINN):
|
|
51
|
-
"""
|
|
52
|
-
MUST inherit from PINN to pass all the checks
|
|
53
|
-
|
|
54
|
-
HOWEVER we dot not bother with reimplementing anything
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
def __init__(self, key, list_layers_pinn, nbands):
|
|
58
|
-
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
59
|
-
key, subkey = jax.random.split(key, 2)
|
|
60
|
-
_pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
|
|
61
|
-
|
|
62
|
-
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
63
|
-
|
|
64
|
-
def init_params(self):
|
|
65
|
-
return self.params
|
|
66
|
-
|
|
67
|
-
def __call__(self, x, params):
|
|
68
|
-
try:
|
|
69
|
-
model = eqx.combine(params["nn_params"], self.static)
|
|
70
|
-
except (KeyError, TypeError) as e: # give more flexibility
|
|
71
|
-
model = eqx.combine(params, self.static)
|
|
72
|
-
# model = eqx.tree_at(lambda m:
|
|
73
|
-
# m.layers_pinn[0].bias,
|
|
74
|
-
# model,
|
|
75
|
-
# model.layers_pinn[0].bias % (2 *
|
|
76
|
-
# jnp.pi)
|
|
77
|
-
# )
|
|
78
|
-
res = model(x)
|
|
79
|
-
if not res.shape:
|
|
80
|
-
return jnp.expand_dims(res, axis=-1)
|
|
81
|
-
return res
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def create_spectralPINN(key, list_layers_pinn, nbands):
|
|
85
|
-
""" """
|
|
86
|
-
u = spectralPINN(key, list_layers_pinn, nbands)
|
|
87
|
-
return u
|