jinns 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,214 @@
1
+ """
2
+ Implements some validation functions and their associated hyperparameter
3
+ """
4
+
5
+ import copy
6
+ import abc
7
+ from typing import Union
8
+ import equinox as eqx
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jaxtyping import Array, Bool, PyTree, Int
12
+ import jinns
13
+ import jinns.data
14
+ from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
15
+ from jinns.data._DataGenerators import (
16
+ DataGeneratorODE,
17
+ CubicMeshPDEStatio,
18
+ CubicMeshPDENonStatio,
19
+ DataGeneratorParameter,
20
+ DataGeneratorObservations,
21
+ DataGeneratorObservationsMultiPINNs,
22
+ append_obs_batch,
23
+ append_param_batch,
24
+ )
25
+ import jinns.loss
26
+
27
+ # Using eqx Module for the DataClass + Pytree inheritance
28
+ # Abstract class and abstract/final pattern is used
29
+ # see : https://docs.kidger.site/equinox/pattern/
30
+
31
+
32
+ class AbstractValidationModule(eqx.Module):
33
+ """Abstract class representing interface for any validation module. It must
34
+ 1. have a ``call_every`` attribute.
35
+ 2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
36
+ """
37
+
38
+ call_every: eqx.AbstractVar[Int] # Mandatory for all validation step,
39
+ # it tells that the validation step is performed every call_every
40
+ # iterations.
41
+
42
+ @abc.abstractmethod
43
+ def __call__(
44
+ self, params: PyTree
45
+ ) -> tuple["AbstractValidationModule", Bool, Array]:
46
+ raise NotImplementedError
47
+
48
+
49
+ class ValidationLoss(AbstractValidationModule):
50
+ """
51
+ Implementation of a vanilla validation module returning the PINN loss
52
+ on a validation set of collocation points. This can be used as a baseline
53
+ for more complicated validation strategy.
54
+ """
55
+
56
+ loss: Union[callable, LossODE, LossPDEStatio, LossPDENonStatio] = eqx.field(
57
+ converter=copy.deepcopy
58
+ )
59
+ validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
60
+ validation_param_data: Union[DataGeneratorParameter, None] = None
61
+ validation_obs_data: Union[
62
+ DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
63
+ ] = None
64
+ call_every: Int = 250 # concrete typing
65
+ early_stopping: Bool = True # globally control if early stopping happens
66
+
67
+ patience: Union[Int] = 10
68
+ best_val_loss: Array = eqx.field(
69
+ converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
70
+ )
71
+
72
+ counter: Array = eqx.field(
73
+ converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
74
+ )
75
+
76
+ def __call__(self, params) -> tuple["ValidationLoss", Bool, Array]:
77
+ # do in-place mutation
78
+ val_batch = self.validation_data.get_batch()
79
+ if self.validation_param_data is not None:
80
+ val_batch = append_param_batch(
81
+ val_batch, self.validation_param_data.get_batch()
82
+ )
83
+ if self.validation_obs_data is not None:
84
+ val_batch = append_obs_batch(
85
+ val_batch, self.validation_obs_data.get_batch()
86
+ )
87
+
88
+ validation_loss_value, _ = self.loss(params, val_batch)
89
+ (counter, best_val_loss) = jax.lax.cond(
90
+ validation_loss_value < self.best_val_loss,
91
+ lambda _: (jnp.array(0.0), validation_loss_value), # reset
92
+ lambda operands: (operands[0] + 1, operands[1]), # increment
93
+ (self.counter, self.best_val_loss),
94
+ )
95
+
96
+ # use eqx.tree_at to update attributes
97
+ # (https://github.com/patrick-kidger/equinox/issues/396)
98
+ new = eqx.tree_at(lambda t: t.counter, self, counter)
99
+ new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
100
+
101
+ bool_early_stopping = jax.lax.cond(
102
+ jnp.logical_and(
103
+ jnp.array(self.counter == self.patience),
104
+ jnp.array(self.early_stopping),
105
+ ),
106
+ lambda _: True,
107
+ lambda _: False,
108
+ None,
109
+ )
110
+ # return `new` cause no in-place modification of the eqx.Module
111
+ return (new, bool_early_stopping, validation_loss_value)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ import jax
116
+ import jax.numpy as jnp
117
+ import jax.random as random
118
+ from jinns.loss import BurgerEquation
119
+
120
+ key = random.PRNGKey(1)
121
+ key, subkey = random.split(key)
122
+
123
+ n = 50
124
+ nb = 2 * 2 * 10
125
+ nt = 10
126
+ omega_batch_size = 10
127
+ omega_border_batch_size = 10
128
+ temporal_batch_size = 4
129
+ dim = 1
130
+ xmin = 0
131
+ xmax = 1
132
+ tmin, tmax = 0, 1
133
+ method = "uniform"
134
+
135
+ val_data = jinns.data.CubicMeshPDENonStatio(
136
+ subkey,
137
+ n,
138
+ nb,
139
+ nt,
140
+ omega_batch_size,
141
+ omega_border_batch_size,
142
+ temporal_batch_size,
143
+ dim,
144
+ (xmin,),
145
+ (xmax,),
146
+ tmin,
147
+ tmax,
148
+ method,
149
+ )
150
+
151
+ eqx_list = [
152
+ [eqx.nn.Linear, 2, 50],
153
+ [jax.nn.tanh],
154
+ [eqx.nn.Linear, 50, 50],
155
+ [jax.nn.tanh],
156
+ [eqx.nn.Linear, 50, 50],
157
+ [jax.nn.tanh],
158
+ [eqx.nn.Linear, 50, 50],
159
+ [jax.nn.tanh],
160
+ [eqx.nn.Linear, 50, 50],
161
+ [jax.nn.tanh],
162
+ [eqx.nn.Linear, 50, 2],
163
+ ]
164
+
165
+ key, subkey = random.split(key)
166
+ u = jinns.utils.create_PINN(
167
+ subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
168
+ )
169
+ init_nn_params = u.init_params()
170
+
171
+ dyn_loss = BurgerEquation()
172
+ loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
173
+
174
+ key, subkey = random.split(key)
175
+ loss = jinns.loss.LossPDENonStatio(
176
+ u=u,
177
+ loss_weights=loss_weights,
178
+ dynamic_loss=dyn_loss,
179
+ norm_key=subkey,
180
+ norm_borders=(-1, 1),
181
+ )
182
+ print(id(loss))
183
+ validation = ValidationLoss(
184
+ call_every=250,
185
+ early_stopping=True,
186
+ patience=1000,
187
+ loss=loss,
188
+ validation_data=val_data,
189
+ validation_param_data=None,
190
+ )
191
+ print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
192
+
193
+ init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
194
+
195
+ print(validation.loss is loss)
196
+ loss.evaluate(init_params, val_data.get_batch())
197
+ print(loss.norm_key)
198
+ print("Call validation once")
199
+ validation, _, _ = validation(init_params)
200
+ print(validation.loss is loss)
201
+ print(validation.loss.norm_key == loss.norm_key)
202
+ print("Crate new pytree from validation and call it once")
203
+ new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
204
+ print(validation.loss is new_val.loss) # FALSE
205
+ # test if attribute have been modified
206
+ new_val, _, _ = new_val(init_params)
207
+ print(f"{new_val.loss is loss=}")
208
+ print(f"{loss.norm_key=}")
209
+ print(f"{validation.loss.norm_key=}")
210
+ print(f"{new_val.loss.norm_key=}")
211
+ print(f"{new_val.loss.norm_key == loss.norm_key=}")
212
+ print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
213
+ print(new_val.counter)
214
+ print(validation.counter)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.6
3
+ Version: 0.8.8
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,13 +1,15 @@
1
- jinns/__init__.py,sha256=Nw5pdlmDhJwco3bXX3YttkeCF8czX_6m0poh8vu0lDQ,113
1
+ jinns/__init__.py,sha256=T2XlmLbYqcXTumPJL00cJ80W98We5LH8Yg_Lss_exl4,139
2
2
  jinns/data/_DataGenerators.py,sha256=N4-U4z3MG46UIzHCbKScv9Z7AN40w1wlLY_VsVNj2sI,62293
3
3
  jinns/data/__init__.py,sha256=yBOmoavSD-cABp4XcjQY1zsEVO0mDyIhi2MJ5WNp0l8,326
4
- jinns/data/_display.py,sha256=6renz4H7kHktutmLY7HM6PmxYH7cBfGHpC7GQa1Fnlk,7778
5
- jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
4
+ jinns/data/_display.py,sha256=vlqggDCgVMEwdGBtjVmZaTQORU6imSfDkssn2XCtITI,10392
5
+ jinns/experimental/__init__.py,sha256=qWbhC7Z8UgLWy0t-zU7RYze6v13-FngiCYXu-2bRVFQ,296
6
6
  jinns/experimental/_diffrax_solver.py,sha256=sLT22byqh-6015_fhe1xtMWlFOYcCjzYKET4sLhA9R4,6818
7
+ jinns/experimental/_sinuspinn.py,sha256=hxSzscwMV2LayWOqenIlT1zqEVVrE5Y8CKf7bHX5XFQ,5016
8
+ jinns/experimental/_spectralpinn.py,sha256=-4795pa7AYtRNSE-ugan3gHh64mtu2VdrRG5AS_J9Eg,2654
7
9
  jinns/loss/_DynamicLoss.py,sha256=L4CVmmF0rTPbHntgqsLLHlnrlQgLHsetUocpJm7ZYag,27461
8
10
  jinns/loss/_DynamicLossAbstract.py,sha256=kTQlhLx7SBuH5dIDmYaE79sVHUZt1nUFa8LxPU5IHhM,8504
9
- jinns/loss/_LossODE.py,sha256=sxpgiDR6mfoREuc-qe0AkirOe5K_5oblaYCnodTNxoI,21912
10
- jinns/loss/_LossPDE.py,sha256=_yX3R-FrAScTn9_QfVC8PfDYRE4UQ5lnzITUYgNFitA,61766
11
+ jinns/loss/_LossODE.py,sha256=b9doBHoQwYvlgpqzrNO4dOaTN87LRvjHtHbz9bMoH7E,22119
12
+ jinns/loss/_LossPDE.py,sha256=purAEtc0e71kv9XnZUT-a7MrkDAkM_3tTI4xJPu6fH4,61629
11
13
  jinns/loss/_Losses.py,sha256=XOL3MFiKEd3ndsc78Qnpi1vbgR0B2HaAWOGGW2meDM8,11190
12
14
  jinns/loss/__init__.py,sha256=pFNYUxns-NPXBFdqrEVSiXkQLfCtKw-t2trlhvLzpYE,355
13
15
  jinns/loss/_boundary_conditions.py,sha256=YfSnLZ25hXqQ5KWAuxOrWSKkf_oBqAc9GQV4z7MjWyQ,17434
@@ -15,8 +17,9 @@ jinns/loss/_operators.py,sha256=zDGJqYqeYH7xd-4dtGX9PS-pf0uSOpUUXGo5SVjIJ4o,1106
15
17
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
18
  jinns/solver/_rar.py,sha256=K-0y1-ofOAo1n_Ea3QShSGCGKVYTwiaE_Bz9-DZMJm8,14525
17
19
  jinns/solver/_seq2seq.py,sha256=FL-42hTgmVl7O3hHh1ccFVw2bT8bW82hvlDRz971Chk,5620
18
- jinns/solver/_solve.py,sha256=r4jn6hx7_t-Y2rBWA2npUmWWnDg4iRbgYBHZDNn9tmY,13745
20
+ jinns/solver/_solve.py,sha256=mGi0zaT_fK_QpBjTxof5Ix4mmfmnPi66CNJ3GQFZuo4,19099
19
21
  jinns/utils/__init__.py,sha256=44ms5UR6vMw3Nf6u4RCAzPFs4fom_YbBnH9mfne8m6k,313
22
+ jinns/utils/_containers.py,sha256=eYD277fO7X4EfX7PUFCCl69r3JBfh1sCfq8LkL5gd6o,1495
20
23
  jinns/utils/_hyperpinn.py,sha256=93hbiATdp5W4l1cu9Oe6O2c45o-ZF_z2u6FzNLyjnm4,10878
21
24
  jinns/utils/_optim.py,sha256=550kxH75TL30o1iKx1swJyP0KqyUPsJ7-imL1w65Qd0,4444
22
25
  jinns/utils/_pinn.py,sha256=mhA4-3PazyQTbWIx9oLaNwL0QDe8ZIBhbiy5J3kwa4I,9471
@@ -24,8 +27,10 @@ jinns/utils/_save_load.py,sha256=qgZ23nUcB8-B5IZ2guuUWC4M7r5Lxd_Ms3staScdyJo,566
24
27
  jinns/utils/_spinn.py,sha256=SzOUt1KHtB9QOpghpvitnXN-KEqXUXbvabC5k0TnKEo,7793
25
28
  jinns/utils/_utils.py,sha256=8dgvWXX9NT7_7-zltWp0C9tG45ZFNwXxueyxPBb4hjo,6740
26
29
  jinns/utils/_utils_uspinn.py,sha256=qcKcOw3zrwWSQyGVj6fD8c9GinHt_U6JWN_k0auTtXM,26039
27
- jinns-0.8.6.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
28
- jinns-0.8.6.dist-info/METADATA,sha256=3Ml6PCA-569v9-1FgyPDySX09RQas0zPOVEV_gqy9lk,2482
29
- jinns-0.8.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
30
- jinns-0.8.6.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
31
- jinns-0.8.6.dist-info/RECORD,,
30
+ jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
31
+ jinns/validation/_validation.py,sha256=KfetbzB0xTNdBcYLwFWjEtP63Tf9wJirlhgqLTJDyy4,6761
32
+ jinns-0.8.8.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
33
+ jinns-0.8.8.dist-info/METADATA,sha256=oTs2EJMu4Bwo2n9DLsAPSU5edpbgPtwhNXBuW8YjpOc,2482
34
+ jinns-0.8.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
35
+ jinns-0.8.8.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
36
+ jinns-0.8.8.dist-info/RECORD,,
File without changes
File without changes