jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
@@ -2,27 +2,24 @@
2
2
  Implements some validation functions and their associated hyperparameter
3
3
  """
4
4
 
5
- import copy
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
6
9
  import abc
7
- from typing import Union
10
+ from typing import TYPE_CHECKING, Union
8
11
  import equinox as eqx
9
12
  import jax
10
13
  import jax.numpy as jnp
11
- from jaxtyping import Array, Bool, PyTree, Int
12
- import jinns
13
- import jinns.data
14
- from jinns.loss import LossODE, LossPDENonStatio, LossPDEStatio
14
+ from jaxtyping import Array
15
+
15
16
  from jinns.data._DataGenerators import (
16
- DataGeneratorODE,
17
- CubicMeshPDEStatio,
18
- CubicMeshPDENonStatio,
19
- DataGeneratorParameter,
20
- DataGeneratorObservations,
21
- DataGeneratorObservationsMultiPINNs,
22
17
  append_obs_batch,
23
18
  append_param_batch,
24
19
  )
25
- import jinns.loss
20
+
21
+ if TYPE_CHECKING:
22
+ from jinns.utils._types import *
26
23
 
27
24
  # Using eqx Module for the DataClass + Pytree inheritance
28
25
  # Abstract class and abstract/final pattern is used
@@ -35,14 +32,16 @@ class AbstractValidationModule(eqx.Module):
35
32
  2. implement a ``__call__`` returning ``(AbstractValidationModule, Bool, Array)``
36
33
  """
37
34
 
38
- call_every: eqx.AbstractVar[Int] # Mandatory for all validation step,
35
+ call_every: eqx.AbstractVar[int] = eqx.field(
36
+ kw_only=True
37
+ ) # Mandatory for all validation step,
39
38
  # it tells that the validation step is performed every call_every
40
39
  # iterations.
41
40
 
42
41
  @abc.abstractmethod
43
42
  def __call__(
44
- self, params: PyTree
45
- ) -> tuple["AbstractValidationModule", Bool, Array, Bool]:
43
+ self, params: Params | ParamsDict
44
+ ) -> tuple["AbstractValidationModule", bool, Array, bool]:
46
45
  raise NotImplementedError
47
46
 
48
47
 
@@ -53,37 +52,44 @@ class ValidationLoss(AbstractValidationModule):
53
52
  for more complicated validation strategy.
54
53
  """
55
54
 
56
- loss: Union[callable, LossODE, LossPDEStatio, LossPDENonStatio] = eqx.field(
57
- converter=copy.deepcopy
55
+ loss: AnyLoss = eqx.field(kw_only=True) # NOTE that
56
+ # there used to be a deepcopy here which has been suppressed. 1) No need
57
+ # because loss are now eqx.Module (immutable) so no risk of in-place
58
+ # modification. 2) deepcopy is buggy with equinox, InitVar etc. (see issue
59
+ # #857 on equinox github)
60
+ validation_data: Union[AnyDataGenerator] = eqx.field(kw_only=True)
61
+ validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
62
+ kw_only=True, default=None
58
63
  )
59
- validation_data: Union[DataGeneratorODE, CubicMeshPDEStatio, CubicMeshPDENonStatio]
60
- validation_param_data: Union[DataGeneratorParameter, None] = None
61
64
  validation_obs_data: Union[
62
65
  DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
63
- ] = None
64
- call_every: Int = 250 # concrete typing
65
- early_stopping: Bool = True # globally control if early stopping happens
66
+ ] = eqx.field(kw_only=True, default=None)
67
+ call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
68
+ early_stopping: bool = eqx.field(
69
+ kw_only=True, default=True
70
+ ) # globally control if early stopping happens
66
71
 
67
- patience: Union[Int] = 10
72
+ patience: Union[int] = eqx.field(kw_only=True, default=10)
68
73
  best_val_loss: Array = eqx.field(
69
- converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf)
74
+ converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
70
75
  )
71
76
 
72
77
  counter: Array = eqx.field(
73
- converter=jnp.asarray, default_factory=lambda: jnp.array(0.0)
78
+ converter=jnp.asarray, default_factory=lambda: jnp.array(0.0), kw_only=True
74
79
  )
75
80
 
76
- def __call__(self, params) -> tuple["ValidationLoss", Bool, Array]:
81
+ def __call__(
82
+ self, params: AnyParams
83
+ ) -> tuple["ValidationLoss", bool, float, AnyParams]:
77
84
  # do in-place mutation
78
- val_batch = self.validation_data.get_batch()
85
+
86
+ validation_data, val_batch = self.validation_data.get_batch()
79
87
  if self.validation_param_data is not None:
80
- val_batch = append_param_batch(
81
- val_batch, self.validation_param_data.get_batch()
82
- )
88
+ validation_param_data, param_batch = self.validation_param_data.get_batch()
89
+ val_batch = append_param_batch(val_batch, param_batch)
83
90
  if self.validation_obs_data is not None:
84
- val_batch = append_obs_batch(
85
- val_batch, self.validation_obs_data.get_batch()
86
- )
91
+ validation_obs_data, obs_batch = self.validation_obs_data.get_batch()
92
+ val_batch = append_obs_batch(val_batch, obs_batch)
87
93
 
88
94
  validation_loss_value, _ = self.loss(params, val_batch)
89
95
  (counter, best_val_loss, update_best_params) = jax.lax.cond(
@@ -93,9 +99,14 @@ class ValidationLoss(AbstractValidationModule):
93
99
  (self.counter, self.best_val_loss),
94
100
  )
95
101
 
96
- # use eqx.tree_at to update attributes
97
- # (https://github.com/patrick-kidger/equinox/issues/396)
98
- new = eqx.tree_at(lambda t: t.counter, self, counter)
102
+ new = eqx.tree_at(lambda t: t.validation_data, self, validation_data)
103
+ if self.validation_param_data is not None:
104
+ new = eqx.tree_at(
105
+ lambda t: t.validation_param_data, new, validation_param_data
106
+ )
107
+ if self.validation_obs_data is not None:
108
+ new = eqx.tree_at(lambda t: t.validation_obs_data, new, validation_obs_data)
109
+ new = eqx.tree_at(lambda t: t.counter, new, counter)
99
110
  new = eqx.tree_at(lambda t: t.best_val_loss, new, best_val_loss)
100
111
 
101
112
  bool_early_stopping = jax.lax.cond(
@@ -109,106 +120,3 @@ class ValidationLoss(AbstractValidationModule):
109
120
  )
110
121
  # return `new` cause no in-place modification of the eqx.Module
111
122
  return (new, bool_early_stopping, validation_loss_value, update_best_params)
112
-
113
-
114
- if __name__ == "__main__":
115
- import jax
116
- import jax.numpy as jnp
117
- import jax.random as random
118
- from jinns.loss import BurgerEquation
119
-
120
- key = random.PRNGKey(1)
121
- key, subkey = random.split(key)
122
-
123
- n = 50
124
- nb = 2 * 2 * 10
125
- nt = 10
126
- omega_batch_size = 10
127
- omega_border_batch_size = 10
128
- temporal_batch_size = 4
129
- dim = 1
130
- xmin = 0
131
- xmax = 1
132
- tmin, tmax = 0, 1
133
- method = "uniform"
134
-
135
- val_data = jinns.data.CubicMeshPDENonStatio(
136
- subkey,
137
- n,
138
- nb,
139
- nt,
140
- omega_batch_size,
141
- omega_border_batch_size,
142
- temporal_batch_size,
143
- dim,
144
- (xmin,),
145
- (xmax,),
146
- tmin,
147
- tmax,
148
- method,
149
- )
150
-
151
- eqx_list = [
152
- [eqx.nn.Linear, 2, 50],
153
- [jax.nn.tanh],
154
- [eqx.nn.Linear, 50, 50],
155
- [jax.nn.tanh],
156
- [eqx.nn.Linear, 50, 50],
157
- [jax.nn.tanh],
158
- [eqx.nn.Linear, 50, 50],
159
- [jax.nn.tanh],
160
- [eqx.nn.Linear, 50, 50],
161
- [jax.nn.tanh],
162
- [eqx.nn.Linear, 50, 2],
163
- ]
164
-
165
- key, subkey = random.split(key)
166
- u = jinns.utils.create_PINN(
167
- subkey, eqx_list, "nonstatio_PDE", 2, slice_solution=jnp.s_[:1]
168
- )
169
- init_nn_params = u.init_params()
170
-
171
- dyn_loss = BurgerEquation()
172
- loss_weights = {"dyn_loss": 1, "boundary_loss": 10, "observations": 10}
173
-
174
- key, subkey = random.split(key)
175
- loss = jinns.loss.LossPDENonStatio(
176
- u=u,
177
- loss_weights=loss_weights,
178
- dynamic_loss=dyn_loss,
179
- norm_key=subkey,
180
- norm_borders=(-1, 1),
181
- )
182
- print(id(loss))
183
- validation = ValidationLoss(
184
- call_every=250,
185
- early_stopping=True,
186
- patience=1000,
187
- loss=loss,
188
- validation_data=val_data,
189
- validation_param_data=None,
190
- )
191
- print(id(validation.loss) is not id(loss)) # should be True (deepcopy)
192
-
193
- init_params = {"nn_params": init_nn_params, "eq_params": {"nu": 1.0}}
194
-
195
- print(validation.loss is loss)
196
- loss.evaluate(init_params, val_data.get_batch())
197
- print(loss.norm_key)
198
- print("Call validation once")
199
- validation, _, _ = validation(init_params)
200
- print(validation.loss is loss)
201
- print(validation.loss.norm_key == loss.norm_key)
202
- print("Crate new pytree from validation and call it once")
203
- new_val = eqx.tree_at(lambda t: t.counter, validation, jnp.array(3.0))
204
- print(validation.loss is new_val.loss) # FALSE
205
- # test if attribute have been modified
206
- new_val, _, _ = new_val(init_params)
207
- print(f"{new_val.loss is loss=}")
208
- print(f"{loss.norm_key=}")
209
- print(f"{validation.loss.norm_key=}")
210
- print(f"{new_val.loss.norm_key=}")
211
- print(f"{new_val.loss.norm_key == loss.norm_key=}")
212
- print(f"{new_val.loss.norm_key == validation.loss.norm_key=}")
213
- print(new_val.counter)
214
- print(validation.counter)
@@ -0,0 +1,2 @@
1
+ Hugo Gangloff <hugo.gangloff@inrae.fr>
2
+ Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.9.0
3
+ Version: 1.1.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -10,20 +10,21 @@ Project-URL: Documentation, https://mia_jinns.gitlab.io/jinns/index.html
10
10
  Classifier: License :: OSI Approved :: Apache Software License
11
11
  Classifier: Development Status :: 4 - Beta
12
12
  Classifier: Programming Language :: Python
13
- Requires-Python: >=3.7
13
+ Requires-Python: >=3.10
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
+ License-File: AUTHORS
16
17
  Requires-Dist: numpy
17
18
  Requires-Dist: jax
18
19
  Requires-Dist: jaxopt
19
20
  Requires-Dist: optax
20
- Requires-Dist: equinox
21
+ Requires-Dist: equinox >0.11.3
21
22
  Requires-Dist: jax-tqdm
22
23
  Requires-Dist: diffrax
23
24
  Requires-Dist: matplotlib
24
25
  Provides-Extra: notebook
25
26
  Requires-Dist: jupyter ; extra == 'notebook'
26
- Requires-Dist: matplotlib ; extra == 'notebook'
27
+ Requires-Dist: seaborn ; extra == 'notebook'
27
28
 
28
29
  jinns
29
30
  =====
@@ -0,0 +1,39 @@
1
+ jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
2
+ jinns/data/_Batchs.py,sha256=BLxTDiFb6o9M6Irc2_HKmpr8IgA159u_kJIbCBZ490E,926
3
+ jinns/data/_DataGenerators.py,sha256=TuQDPI8NGx4WnCGtUp8v7o7GcnDfVZvL5E34JI-9Lmw,58455
4
+ jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
5
+ jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
6
+ jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
7
+ jinns/loss/_DynamicLoss.py,sha256=WGbAuWnNfsbzUlWEiW_ARd4kI3jmHwdqPjxLC-wCA6s,25753
8
+ jinns/loss/_DynamicLossAbstract.py,sha256=Xyt28Oej_zlhcV3f6cw2vnAKyRJhXBiA63CsdL3PihU,13767
9
+ jinns/loss/_LossODE.py,sha256=ThWPse6Gn5crM3_tzwZCBx-usoD0xWu6y1n0GVl2dpI,23422
10
+ jinns/loss/_LossPDE.py,sha256=R9kNQiaFbFx2eCMdjB7ie7UJ9pJW7PmvHijioNgu-bs,49117
11
+ jinns/loss/__init__.py,sha256=Fm4QAHaVmp0CA7HSwb7KUctwdXnNZ9v5KmTqpeoYPaE,669
12
+ jinns/loss/_boundary_conditions.py,sha256=O0D8eWsFfvNNeO20PQ0rUKBI_MDqaBvqChfXaztZoL4,16679
13
+ jinns/loss/_loss_utils.py,sha256=44J-VF6dxT_o5BcNWFOiLpY40c35YnAxxZkoNtdtcZc,13689
14
+ jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
15
+ jinns/loss/_operators.py,sha256=o-Ljp_9_HXB9Mhm-ANh6ouNw4_PsqLJAha7dFDGl_nQ,10781
16
+ jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
17
+ jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
18
+ jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
19
+ jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
20
+ jinns/plot/_plot.py,sha256=ZGIJdGwEd3NlHRTq_2sOfEH_CtOkvPwdgCMct-nQlJE,11691
21
+ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ jinns/solver/_rar.py,sha256=BAr0qHNtCQ11XRhxRDed8JZPSWBm4E5FT3VOGQPrm4g,22728
23
+ jinns/solver/_solve.py,sha256=iUkvZ9f-0bxO2eJlNv7wLk5oES69LNbDxzOgBAjbHTg,21106
24
+ jinns/utils/__init__.py,sha256=CNxcb_AYzA2aeDsYwLfIuT1zy8NU7LElU160eHCj2oA,174
25
+ jinns/utils/_containers.py,sha256=_PLIkJHY-jE3nigjuAiYE5USJr7rpXCDsNULRKZtqnU,1213
26
+ jinns/utils/_hyperpinn.py,sha256=7fAy6Xv7CkIeIYCw7u9RB0gh4aITMfxz6JCmJ685KB8,16351
27
+ jinns/utils/_pinn.py,sha256=eiw_i72D7-CxZXXOBpgOgDndwGd5sSnFZNk-Rg-6xy8,12977
28
+ jinns/utils/_save_load.py,sha256=kVDeQrpPf7j7kYUy1gWGCr4_QyaBplRbPSitkWTQnQA,8574
29
+ jinns/utils/_spinn.py,sha256=lesBjXwoj3UUzfecPdh1wBFCY0BlA-q0Wb1pDv0RVYA,9211
30
+ jinns/utils/_types.py,sha256=P_dS0odrHbyalYJ0FjS6q0tkXAGr-4GArsiyJYrB1ho,1878
31
+ jinns/utils/_utils.py,sha256=Ow8xB516E7yHDZatokVJHHFNPDu6fXr9-NmraUXjjyw,1819
32
+ jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
33
+ jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
34
+ jinns-1.1.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
35
+ jinns-1.1.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
36
+ jinns-1.1.0.dist-info/METADATA,sha256=3Qk885oguf6S_WPHd9KCFVIWP21nJqX9zFWoS9ZI-T0,2536
37
+ jinns-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
38
+ jinns-1.1.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
39
+ jinns-1.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.2.0)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,135 +0,0 @@
1
- import jax
2
- import equinox as eqx
3
- import jax.numpy as jnp
4
- from jinns.utils._pinn import PINN
5
-
6
-
7
- def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
- out, in_ = weight.shape
9
- stddev = 1e-2
10
- return stddev * jax.random.normal(key, shape=(out, in_))
11
-
12
-
13
- class _SinusPINN(eqx.Module):
14
- """
15
- A specific PINN whose layers are x_sin2x functions whose frequencies are
16
- determined by an other network
17
- """
18
-
19
- layers_pinn: list
20
- layers_aux_nn: list
21
-
22
- def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
23
- """
24
- Parameters
25
- ----------
26
- key
27
- A jax random key
28
- list_layers_pinn
29
- A list as eqx_list in jinns' PINN utility for the main PINN
30
- list_layers_aux_nn
31
- A list as eqx_list in jinns' PINN utility for the network which outputs
32
- the PINN's activation frequencies
33
- """
34
- self.layers_pinn = []
35
- for l in list_layers_pinn:
36
- if len(l) == 1:
37
- self.layers_pinn.append(l[0])
38
- else:
39
- key, subkey = jax.random.split(key, 2)
40
- self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
- self.layers_aux_nn = []
42
- for idx, l in enumerate(list_layers_aux_nn):
43
- if len(l) == 1:
44
- self.layers_aux_nn.append(l[0])
45
- else:
46
- key, subkey = jax.random.split(key, 2)
47
- linear_layer = l[0](*l[1:], key=subkey)
48
- key, subkey = jax.random.split(key, 2)
49
- linear_layer = eqx.tree_at(
50
- lambda l: l.weight,
51
- linear_layer,
52
- almost_zero_init(linear_layer.weight, subkey),
53
- )
54
- if (idx == len(list_layers_aux_nn) - 1) or (
55
- idx == len(list_layers_aux_nn) - 2
56
- ):
57
- # for the last layer: almost 0 weights and 0.5 bias
58
- linear_layer = eqx.tree_at(
59
- lambda l: l.bias,
60
- linear_layer,
61
- 0.5 * jnp.ones(linear_layer.bias.shape),
62
- )
63
- else:
64
- # for the other previous layers:
65
- # almost 0 weight and 0 bias
66
- linear_layer = eqx.tree_at(
67
- lambda l: l.bias,
68
- linear_layer,
69
- jnp.zeros(linear_layer.bias.shape),
70
- )
71
- self.layers_aux_nn.append(linear_layer)
72
-
73
- ## init to zero the frequency network except last biases
74
- # key, subkey = jax.random.split(key, 2)
75
- # _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
76
- # key, subkey = jax.random.split(key, 2)
77
- # _pinn = init_linear_bias(_pinn, zero_init, subkey)
78
- # print(_pinn)
79
- # print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
80
- # p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
81
- # _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
82
- # jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
83
- # #, is_leaf=lambda
84
- # #p:not isinstance(p, eqx.nn.Linear))
85
-
86
- def __call__(self, x):
87
- x_ = x.copy()
88
- # forward pass in the network which determines the freq
89
- for layer in self.layers_aux_nn:
90
- x_ = layer(x_)
91
- freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
92
- x_ = x.copy()
93
- # forward pass through the actual PINN
94
- for idx, layer in enumerate(self.layers_pinn):
95
- if idx % 2 == 0:
96
- # Currently: every two layer we have an activation
97
- # requiring a frequency
98
- x_ = layer(x_)
99
- else:
100
- x_ = layer(x_, freq_list[(idx - 1) // 2])
101
- return x_
102
-
103
-
104
- class sinusPINN(PINN):
105
- """
106
- MUST inherit from PINN to pass all the checks
107
-
108
- HOWEVER we dot not bother with reimplementing anything
109
- """
110
-
111
- def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
112
- super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
113
- key, subkey = jax.random.split(key, 2)
114
- _pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
115
-
116
- self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
117
-
118
- def init_params(self):
119
- return self.params
120
-
121
- def __call__(self, x, params):
122
- try:
123
- model = eqx.combine(params["nn_params"], self.static)
124
- except (KeyError, TypeError) as e: # give more flexibility
125
- model = eqx.combine(params, self.static)
126
- res = model(x)
127
- if not res.shape:
128
- return jnp.expand_dims(res, axis=-1)
129
- return res
130
-
131
-
132
- def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
133
- """ """
134
- u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
135
- return u
@@ -1,87 +0,0 @@
1
- import jax
2
- import equinox as eqx
3
- import jax.numpy as jnp
4
- from jinns.utils._pinn import PINN
5
-
6
-
7
- def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
- out, in_ = weight.shape
9
- stddev = 1e-2
10
- return stddev * jax.random.normal(key, shape=(out, in_))
11
-
12
-
13
- class _SpectralPINN(eqx.Module):
14
- """
15
- A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
16
- (Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
17
- """
18
-
19
- layers_pinn: list
20
- nbands: int
21
-
22
- def __init__(self, key, list_layers_pinn, nbands):
23
- """
24
- Parameters
25
- ----------
26
- key
27
- A jax random key
28
- list_layers_pinn
29
- A list as eqx_list in jinns' PINN utility for the main PINN
30
- nbands
31
- Number of spectral bands (i.e., neurones in the single layer of the PINN)
32
- """
33
- self.nbands = nbands
34
- self.layers_pinn = []
35
- for l in list_layers_pinn:
36
- if len(l) == 1:
37
- self.layers_pinn.append(l[0])
38
- else:
39
- key, subkey = jax.random.split(key, 2)
40
- self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
-
42
- def __call__(self, x):
43
- # forward pass through the actual PINN
44
- for layer in self.layers_pinn:
45
- x = layer(x)
46
-
47
- return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
48
-
49
-
50
- class spectralPINN(PINN):
51
- """
52
- MUST inherit from PINN to pass all the checks
53
-
54
- HOWEVER we dot not bother with reimplementing anything
55
- """
56
-
57
- def __init__(self, key, list_layers_pinn, nbands):
58
- super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
59
- key, subkey = jax.random.split(key, 2)
60
- _pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
61
-
62
- self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
63
-
64
- def init_params(self):
65
- return self.params
66
-
67
- def __call__(self, x, params):
68
- try:
69
- model = eqx.combine(params["nn_params"], self.static)
70
- except (KeyError, TypeError) as e: # give more flexibility
71
- model = eqx.combine(params, self.static)
72
- # model = eqx.tree_at(lambda m:
73
- # m.layers_pinn[0].bias,
74
- # model,
75
- # model.layers_pinn[0].bias % (2 *
76
- # jnp.pi)
77
- # )
78
- res = model(x)
79
- if not res.shape:
80
- return jnp.expand_dims(res, axis=-1)
81
- return res
82
-
83
-
84
- def create_spectralPINN(key, list_layers_pinn, nbands):
85
- """ """
86
- u = spectralPINN(key, list_layers_pinn, nbands)
87
- return u