jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/loss/_LossODE.py +6 -0
- jinns/loss/_LossPDE.py +18 -18
- jinns/solver/_solve.py +264 -121
- jinns/utils/_containers.py +57 -0
- jinns/validation/__init__.py +1 -0
- jinns/validation/_validation.py +214 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/METADATA +1 -1
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/RECORD +16 -11
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/LICENSE +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/WHEEL +0 -0
- {jinns-0.8.6.dist-info → jinns-0.8.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements some validation functions and their associated hyperparameter
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import abc
|
|
7
|
+
from typing import Union
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jaxtyping import Array, Bool, PyTree, Int
|
|
12
|
+
import jinns
|
|
13
|
+
import jinns.data
|
|
14
|
+
from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
|
|
15
|
+
from jinns.data._DataGenerators import (
|
|
16
|
+
DataGeneratorODE,
|
|
17
|
+
CubicMeshPDEStatio,
|
|
18
|
+
CubicMeshPDENonStatio,
|
|
19
|
+
DataGeneratorParameter,
|
|
20
|
+
DataGeneratorObservations,
|
|
21
|
+
DataGeneratorObservationsMultiPINNs,
|
|
22
|
+
append_obs_batch,
|
|
23
|
+
append_param_batch,
|
|
24
|
+
)
|
|
25
|
+
import jinns.loss
|
|
26
|
+
|
|
27
|
+
# Using eqx Module for the DataClass + Pytree inheritance
|
|
28
|
+
# Abstract class and abstract/final pattern is used
|
|
29
|
+
# see : https://docs.kidger.site/equinox/pattern/
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AbstractValidationModule(eqx.Module):
|
|
33
|
+
"""Abstract class representing interface for any validation module. It must
|
|
34
|
+
1. have a ``call_every`` attribute.
|
|
35
|
+
2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
call_every: eqx.AbstractVar[Int] # Mandatory for all validation step,
|
|
39
|
+
# it tells that the validation step is performed every call_every
|
|
40
|
+
# iterations.
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def __call__(
|
|
44
|
+
self, params: PyTree
|
|
45
|
+
) -> tuple["AbstractValidationModule", Bool, Array]:
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ValidationLoss(AbstractValidationModule):
|
|
50
|
+
"""
|
|
51
|
+
Implementation of a vanilla validation module returning the PINN loss
|
|
52
|
+
on a validation set of collocation points. This can be used as a baseline
|
|
53
|
+
for more complicated validation strategy.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
loss: Union[callable, LossODE, LossPDEStatio, LossPDENonStatio] = eqx.field(
|
|
57
|
+
converter=copy.deepcopy
|
|
58
|
+
)
|
|
59
|
+
validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
|
|
60
|
+
validation_param_data: Union[DataGeneratorParameter, None] = None
|
|
61
|
+
validation_obs_data: Union[
|
|
62
|
+
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
63
|
+
] = None
|
|
64
|
+
call_every: Int = 250 # concrete typing
|
|
65
|
+
early_stopping: Bool = True # globally control if early stopping happens
|
|
66
|
+
|
|
67
|
+
patience: Union[Int] = 10
|
|
68
|
+
best_val_loss: Array = eqx.field(
|
|
69
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
counter: Array = eqx.field(
|
|
73
|
+
converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def __call__(self, params) -> tuple["ValidationLoss", Bool, Array]:
|
|
77
|
+
# do in-place mutation
|
|
78
|
+
val_batch = self.validation_data.get_batch()
|
|
79
|
+
if self.validation_param_data is not None:
|
|
80
|
+
val_batch = append_param_batch(
|
|
81
|
+
val_batch, self.validation_param_data.get_batch()
|
|
82
|
+
)
|
|
83
|
+
if self.validation_obs_data is not None:
|
|
84
|
+
val_batch = append_obs_batch(
|
|
85
|
+
val_batch, self.validation_obs_data.get_batch()
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
validation_loss_value, _ = self.loss(params, val_batch)
|
|
89
|
+
(counter, best_val_loss) = jax.lax.cond(
|
|
90
|
+
validation_loss_value < self.best_val_loss,
|
|
91
|
+
lambda _: (jnp.array(0.0), validation_loss_value), # reset
|
|
92
|
+
lambda operands: (operands[0] + 1, operands[1]), # increment
|
|
93
|
+
(self.counter, self.best_val_loss),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# use eqx.tree_at to update attributes
|
|
97
|
+
# (https://github.com/patrick-kidger/equinox/issues/396)
|
|
98
|
+
new = eqx.tree_at(lambda t: t.counter, self, counter)
|
|
99
|
+
new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
|
|
100
|
+
|
|
101
|
+
bool_early_stopping = jax.lax.cond(
|
|
102
|
+
jnp.logical_and(
|
|
103
|
+
jnp.array(self.counter == self.patience),
|
|
104
|
+
jnp.array(self.early_stopping),
|
|
105
|
+
),
|
|
106
|
+
lambda _: True,
|
|
107
|
+
lambda _: False,
|
|
108
|
+
None,
|
|
109
|
+
)
|
|
110
|
+
# return `new` cause no in-place modification of the eqx.Module
|
|
111
|
+
return (new, bool_early_stopping, validation_loss_value)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if __name__ == "__main__":
|
|
115
|
+
import jax
|
|
116
|
+
import jax.numpy as jnp
|
|
117
|
+
import jax.random as random
|
|
118
|
+
from jinns.loss import BurgerEquation
|
|
119
|
+
|
|
120
|
+
key = random.PRNGKey(1)
|
|
121
|
+
key, subkey = random.split(key)
|
|
122
|
+
|
|
123
|
+
n = 50
|
|
124
|
+
nb = 2 * 2 * 10
|
|
125
|
+
nt = 10
|
|
126
|
+
omega_batch_size = 10
|
|
127
|
+
omega_border_batch_size = 10
|
|
128
|
+
temporal_batch_size = 4
|
|
129
|
+
dim = 1
|
|
130
|
+
xmin = 0
|
|
131
|
+
xmax = 1
|
|
132
|
+
tmin, tmax = 0, 1
|
|
133
|
+
method = "uniform"
|
|
134
|
+
|
|
135
|
+
val_data = jinns.data.CubicMeshPDENonStatio(
|
|
136
|
+
subkey,
|
|
137
|
+
n,
|
|
138
|
+
nb,
|
|
139
|
+
nt,
|
|
140
|
+
omega_batch_size,
|
|
141
|
+
omega_border_batch_size,
|
|
142
|
+
temporal_batch_size,
|
|
143
|
+
dim,
|
|
144
|
+
(xmin,),
|
|
145
|
+
(xmax,),
|
|
146
|
+
tmin,
|
|
147
|
+
tmax,
|
|
148
|
+
method,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
eqx_list = [
|
|
152
|
+
[eqx.nn.Linear, 2, 50],
|
|
153
|
+
[jax.nn.tanh],
|
|
154
|
+
[eqx.nn.Linear, 50, 50],
|
|
155
|
+
[jax.nn.tanh],
|
|
156
|
+
[eqx.nn.Linear, 50, 50],
|
|
157
|
+
[jax.nn.tanh],
|
|
158
|
+
[eqx.nn.Linear, 50, 50],
|
|
159
|
+
[jax.nn.tanh],
|
|
160
|
+
[eqx.nn.Linear, 50, 50],
|
|
161
|
+
[jax.nn.tanh],
|
|
162
|
+
[eqx.nn.Linear, 50, 2],
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
key, subkey = random.split(key)
|
|
166
|
+
u = jinns.utils.create_PINN(
|
|
167
|
+
subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
|
|
168
|
+
)
|
|
169
|
+
init_nn_params = u.init_params()
|
|
170
|
+
|
|
171
|
+
dyn_loss = BurgerEquation()
|
|
172
|
+
loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
|
|
173
|
+
|
|
174
|
+
key, subkey = random.split(key)
|
|
175
|
+
loss = jinns.loss.LossPDENonStatio(
|
|
176
|
+
u=u,
|
|
177
|
+
loss_weights=loss_weights,
|
|
178
|
+
dynamic_loss=dyn_loss,
|
|
179
|
+
norm_key=subkey,
|
|
180
|
+
norm_borders=(-1, 1),
|
|
181
|
+
)
|
|
182
|
+
print(id(loss))
|
|
183
|
+
validation = ValidationLoss(
|
|
184
|
+
call_every=250,
|
|
185
|
+
early_stopping=True,
|
|
186
|
+
patience=1000,
|
|
187
|
+
loss=loss,
|
|
188
|
+
validation_data=val_data,
|
|
189
|
+
validation_param_data=None,
|
|
190
|
+
)
|
|
191
|
+
print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
|
|
192
|
+
|
|
193
|
+
init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
|
|
194
|
+
|
|
195
|
+
print(validation.loss is loss)
|
|
196
|
+
loss.evaluate(init_params, val_data.get_batch())
|
|
197
|
+
print(loss.norm_key)
|
|
198
|
+
print("Call validation once")
|
|
199
|
+
validation, _, _ = validation(init_params)
|
|
200
|
+
print(validation.loss is loss)
|
|
201
|
+
print(validation.loss.norm_key == loss.norm_key)
|
|
202
|
+
print("Crate new pytree from validation and call it once")
|
|
203
|
+
new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
|
|
204
|
+
print(validation.loss is new_val.loss) # FALSE
|
|
205
|
+
# test if attribute have been modified
|
|
206
|
+
new_val, _, _ = new_val(init_params)
|
|
207
|
+
print(f"{new_val.loss is loss=}")
|
|
208
|
+
print(f"{loss.norm_key=}")
|
|
209
|
+
print(f"{validation.loss.norm_key=}")
|
|
210
|
+
print(f"{new_val.loss.norm_key=}")
|
|
211
|
+
print(f"{new_val.loss.norm_key == loss.norm_key=}")
|
|
212
|
+
print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
|
|
213
|
+
print(new_val.counter)
|
|
214
|
+
print(validation.counter)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.8
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=
|
|
1
|
+
jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
|
|
2
2
|
jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
|
|
3
3
|
jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
|
|
4
|
-
jinns/data/_display.py,sha256=
|
|
5
|
-
jinns/experimental/__init__.py,sha256=
|
|
4
|
+
jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
|
|
5
|
+
jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
|
|
6
6
|
jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
|
|
7
|
+
jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
|
|
8
|
+
jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
|
|
7
9
|
jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
|
|
8
10
|
jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
|
|
9
|
-
jinns/loss/_LossODE.py,sha256=
|
|
10
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
11
|
+
jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
|
|
12
|
+
jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
|
|
11
13
|
jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
|
|
12
14
|
jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
|
|
13
15
|
jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
|
|
@@ -15,8 +17,9 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
|
|
|
15
17
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
18
|
jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
|
|
17
19
|
jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
|
|
18
|
-
jinns/solver/_solve.py,sha256=
|
|
20
|
+
jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
|
|
19
21
|
jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
|
|
22
|
+
jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
|
|
20
23
|
jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
|
|
21
24
|
jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
|
|
22
25
|
jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
|
|
@@ -24,8 +27,10 @@ jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,566
|
|
|
24
27
|
jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
|
|
25
28
|
jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
|
|
26
29
|
jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
|
|
27
|
-
jinns
|
|
28
|
-
jinns
|
|
29
|
-
jinns-0.8.
|
|
30
|
-
jinns-0.8.
|
|
31
|
-
jinns-0.8.
|
|
30
|
+
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
31
|
+
jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
|
|
32
|
+
jinns-0.8.8.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
33
|
+
jinns-0.8.8.dist-info/METADATA,sha256=oTs2EJMu4Bwo2n9DLsAPSU5edpbgPtwhNXBuW8YjpOc,2482
|
|
34
|
+
jinns-0.8.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
35
|
+
jinns-0.8.8.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
36
|
+
jinns-0.8.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|