jinns 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/loss/_DynamicLossAbstract.py +9 -40
- jinns/solver/_solve.py +2 -3
- {jinns-0.4.0.dist-info → jinns-0.4.2.dist-info}/METADATA +1 -1
- {jinns-0.4.0.dist-info → jinns-0.4.2.dist-info}/RECORD +7 -7
- {jinns-0.4.0.dist-info → jinns-0.4.2.dist-info}/WHEEL +1 -1
- {jinns-0.4.0.dist-info → jinns-0.4.2.dist-info}/LICENSE +0 -0
- {jinns-0.4.0.dist-info → jinns-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -29,10 +29,9 @@ class DynamicLoss:
|
|
|
29
29
|
equation solution with as PINN.
|
|
30
30
|
eq_params_heterogeneity
|
|
31
31
|
Default None. A dict with the keys being the same as in eq_params
|
|
32
|
-
and the value being `
|
|
33
|
-
|
|
34
|
-
this
|
|
35
|
-
eq_params_heterogeneity is None this means there is no
|
|
32
|
+
and the value being either None (no heterogeneity) or a function which encodes for the spatio-temporal heterogeneity of the parameter. Such a function must be jittable and take three arguments `t`, `x` and `params["eq_params"]` even if one is not used. Therefore, one can introduce spatio-temporal covariates upon which a particular parameter can depend, e.g. in a GLM fashion. The effect of these covariables can themselves be estimated by being in `eq_params` too.
|
|
33
|
+
A value can be missing, in this case there is no heterogeneity (=None).
|
|
34
|
+
If eq_params_heterogeneity is None this means there is no
|
|
36
35
|
heterogeneity for no parameters.
|
|
37
36
|
"""
|
|
38
37
|
self.Tmax = Tmax
|
|
@@ -49,48 +48,18 @@ class DynamicLoss:
|
|
|
49
48
|
return eq_params
|
|
50
49
|
for k, p in eq_params.items():
|
|
51
50
|
try:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
51
|
+
if eq_params_heterogeneity[k] is None:
|
|
52
|
+
eq_params_[k] = p
|
|
53
|
+
else:
|
|
54
|
+
eq_params_[k] = eq_params_heterogeneity[k](
|
|
55
|
+
t, x, eq_params # heterogeneity encoded through a function
|
|
56
|
+
)
|
|
55
57
|
except KeyError:
|
|
56
58
|
# we authorize missing eq_params_heterogeneity key
|
|
57
59
|
# is its heterogeneity is None anyway
|
|
58
60
|
eq_params_[k] = p
|
|
59
61
|
return eq_params_
|
|
60
62
|
|
|
61
|
-
def _eval_heterogeneous_array_parameter(self, p, t, x, heterogeneity=None):
|
|
62
|
-
"""
|
|
63
|
-
For time and/or space heterogeneous params defined by an n-dimensional
|
|
64
|
-
array `p` we return the value `p[t, x]` with discretization of the
|
|
65
|
-
collocation point
|
|
66
|
-
|
|
67
|
-
Parameters
|
|
68
|
-
----------
|
|
69
|
-
p
|
|
70
|
-
The parameter
|
|
71
|
-
heterogeneity
|
|
72
|
-
A string. Either `time`, `space`, `both` or None to specify which
|
|
73
|
-
kind of heterogeneity we have. Default is None, is this case we do
|
|
74
|
-
not have heterogeneity.
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
**Note** t is assumed to be normalized in [0, 1] as well as x!
|
|
78
|
-
"""
|
|
79
|
-
if heterogeneity is None:
|
|
80
|
-
return p
|
|
81
|
-
elif heterogeneity == "time":
|
|
82
|
-
return p[(t * len(p)).astype(int)]
|
|
83
|
-
elif heterogeneity == "space":
|
|
84
|
-
coords = (x * jnp.array(p.shape)).astype(int)
|
|
85
|
-
return jnp.take(p, jnp.ravel_multi_index(coords, p.shape, mode="clip"))
|
|
86
|
-
elif heterogeneity == "both":
|
|
87
|
-
coords = jnp.concatenate(
|
|
88
|
-
[(t * len(p))[:, None], x * jnp.array(p.shape)], axis=1
|
|
89
|
-
).astype(int)
|
|
90
|
-
return jnp.take(p, jnp.ravel_multi_index(coords, p.shape, mode="clip"))
|
|
91
|
-
else:
|
|
92
|
-
raise ValueError("Wrong paramater value for parameter `heterogeneity`")
|
|
93
|
-
|
|
94
63
|
def set_stop_gradient(self, params_dict):
|
|
95
64
|
"""
|
|
96
65
|
Set the stop gradient operators in the dynamic loss `evaluate`
|
jinns/solver/_solve.py
CHANGED
|
@@ -196,7 +196,7 @@ def solve(
|
|
|
196
196
|
if carry["param_data"] is not None:
|
|
197
197
|
batch = append_param_batch(batch, carry["param_data"].get_batch())
|
|
198
198
|
carry["params"], carry["opt_state"] = optimizer.update(
|
|
199
|
-
params=carry["params"], state=carry["
|
|
199
|
+
params=carry["params"], state=carry["opt_state"], batch=batch
|
|
200
200
|
)
|
|
201
201
|
|
|
202
202
|
# check if any of the parameters is NaN
|
|
@@ -260,7 +260,6 @@ def solve(
|
|
|
260
260
|
{
|
|
261
261
|
"params": init_params,
|
|
262
262
|
"last_non_nan_params": init_params.copy(),
|
|
263
|
-
"state": opt_state,
|
|
264
263
|
"data": data,
|
|
265
264
|
"curr_seq": curr_seq,
|
|
266
265
|
"seq2seq": seq2seq,
|
|
@@ -281,7 +280,7 @@ def solve(
|
|
|
281
280
|
|
|
282
281
|
params = res["params"]
|
|
283
282
|
last_non_nan_params = res["last_non_nan_params"]
|
|
284
|
-
opt_state = res["
|
|
283
|
+
opt_state = res["opt_state"]
|
|
285
284
|
data = res["data"]
|
|
286
285
|
loss = res["loss"]
|
|
287
286
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -3,7 +3,7 @@ jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA
|
|
|
3
3
|
jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
|
|
4
4
|
jinns/data/_display.py,sha256=Xnfo6_PH1g-ZFpWJcbF6CF6Pp12wJtNQb1W1bADuQrA,6134
|
|
5
5
|
jinns/loss/_DynamicLoss.py,sha256=VyoyWdkoxRPeP2vs4ZZBK_T9xWgwkcDuaFrjUSid3Zo,52975
|
|
6
|
-
jinns/loss/_DynamicLossAbstract.py,sha256=
|
|
6
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=V9NJvMmqnSC06yccu9bkFCWLF3cjyr8ze8qf0LRykjo,7718
|
|
7
7
|
jinns/loss/_LossODE.py,sha256=FHTKQPqLSoMh18j_RYUoR7tUGg_ljd0JKL8xf2HiF5M,17541
|
|
8
8
|
jinns/loss/_LossPDE.py,sha256=8wEEPnkibNOdqCXtcuU5nieRsPoffSyV8wcUMjcUkvg,57145
|
|
9
9
|
jinns/loss/__init__.py,sha256=4JxMHHVMxTMsVZmV8mRSIyMVAEp3QIR8QtZvnwvj96Q,560
|
|
@@ -12,11 +12,11 @@ jinns/loss/_operators.py,sha256=HYGDq3K_K7lztDp6al88Q78F6o-hDRo3ODhw8UMRdC8,5690
|
|
|
12
12
|
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
|
|
14
14
|
jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
|
|
15
|
-
jinns/solver/_solve.py,sha256=
|
|
15
|
+
jinns/solver/_solve.py,sha256=Yz_asD0ZuYN923E6ysxtPemN-36Zt-BtlSPkTX-BfA8,10047
|
|
16
16
|
jinns/utils/__init__.py,sha256=-jDlwCjyEzWweswKdwLal3OhaUU3FVzK_Ge2S-7KHXs,149
|
|
17
17
|
jinns/utils/_utils.py,sha256=bQm6z_xPKJj9BMCr2tXc44IA8JyGNN-PR5LNRhZ1fD8,20085
|
|
18
|
-
jinns-0.4.
|
|
19
|
-
jinns-0.4.
|
|
20
|
-
jinns-0.4.
|
|
21
|
-
jinns-0.4.
|
|
22
|
-
jinns-0.4.
|
|
18
|
+
jinns-0.4.2.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
19
|
+
jinns-0.4.2.dist-info/METADATA,sha256=maaTIojnCdHhTIPd2kCJa8QGfX36mWAt1DN7q8Pd_3o,1821
|
|
20
|
+
jinns-0.4.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
21
|
+
jinns-0.4.2.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
22
|
+
jinns-0.4.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|