jinns 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,10 +29,9 @@ class DynamicLoss:
29
29
  equation solution with as PINN.
30
30
  eq_params_heterogeneity
31
31
  Default None. A dict with the keys being the same as in eq_params
32
- and the value being `time`, `space`, `both` or None which corresponds to
33
- the heterogeneity of a given parameter. A value can be missing, in
34
- this case there is no heterogeneity (=None). If
35
- eq_params_heterogeneity is None this means there is no
32
+ and the value being either None (no heterogeneity) or a function which encodes for the spatio-temporal heterogeneity of the parameter. Such a function must be jittable and take three arguments `t`, `x` and `params["eq_params"]` even if one is not used. Therefore, one can introduce spatio-temporal covariates upon which a particular parameter can depend, e.g. in a GLM fashion. The effect of these covariables can themselves be estimated by being in `eq_params` too.
33
+ A value can be missing, in this case there is no heterogeneity (=None).
34
+ If eq_params_heterogeneity is None this means there is no
36
35
  heterogeneity for no parameters.
37
36
  """
38
37
  self.Tmax = Tmax
@@ -49,48 +48,18 @@ class DynamicLoss:
49
48
  return eq_params
50
49
  for k, p in eq_params.items():
51
50
  try:
52
- eq_params_[k] = self._eval_heterogeneous_array_parameter(
53
- p, t, x, heterogeneity=eq_params_heterogeneity[k]
54
- )
51
+ if eq_params_heterogeneity[k] is None:
52
+ eq_params_[k] = p
53
+ else:
54
+ eq_params_[k] = eq_params_heterogeneity[k](
55
+ t, x, eq_params # heterogeneity encoded through a function
56
+ )
55
57
  except KeyError:
56
58
  # we authorize missing eq_params_heterogeneity key
57
59
  # is its heterogeneity is None anyway
58
60
  eq_params_[k] = p
59
61
  return eq_params_
60
62
 
61
- def _eval_heterogeneous_array_parameter(self, p, t, x, heterogeneity=None):
62
- """
63
- For time and/or space heterogeneous params defined by an n-dimensional
64
- array `p` we return the value `p[t, x]` with discretization of the
65
- collocation point
66
-
67
- Parameters
68
- ----------
69
- p
70
- The parameter
71
- heterogeneity
72
- A string. Either `time`, `space`, `both` or None to specify which
73
- kind of heterogeneity we have. Default is None, is this case we do
74
- not have heterogeneity.
75
-
76
-
77
- **Note** t is assumed to be normalized in [0, 1] as well as x!
78
- """
79
- if heterogeneity is None:
80
- return p
81
- elif heterogeneity == "time":
82
- return p[(t * len(p)).astype(int)]
83
- elif heterogeneity == "space":
84
- coords = (x * jnp.array(p.shape)).astype(int)
85
- return jnp.take(p, jnp.ravel_multi_index(coords, p.shape, mode="clip"))
86
- elif heterogeneity == "both":
87
- coords = jnp.concatenate(
88
- [(t * len(p))[:, None], x * jnp.array(p.shape)], axis=1
89
- ).astype(int)
90
- return jnp.take(p, jnp.ravel_multi_index(coords, p.shape, mode="clip"))
91
- else:
92
- raise ValueError("Wrong paramater value for parameter `heterogeneity`")
93
-
94
63
  def set_stop_gradient(self, params_dict):
95
64
  """
96
65
  Set the stop gradient operators in the dynamic loss `evaluate`
jinns/solver/_solve.py CHANGED
@@ -196,7 +196,7 @@ def solve(
196
196
  if carry["param_data"] is not None:
197
197
  batch = append_param_batch(batch, carry["param_data"].get_batch())
198
198
  carry["params"], carry["opt_state"] = optimizer.update(
199
- params=carry["params"], state=carry["state"], batch=batch
199
+ params=carry["params"], state=carry["opt_state"], batch=batch
200
200
  )
201
201
 
202
202
  # check if any of the parameters is NaN
@@ -260,7 +260,6 @@ def solve(
260
260
  {
261
261
  "params": init_params,
262
262
  "last_non_nan_params": init_params.copy(),
263
- "state": opt_state,
264
263
  "data": data,
265
264
  "curr_seq": curr_seq,
266
265
  "seq2seq": seq2seq,
@@ -281,7 +280,7 @@ def solve(
281
280
 
282
281
  params = res["params"]
283
282
  last_non_nan_params = res["last_non_nan_params"]
284
- opt_state = res["state"]
283
+ opt_state = res["opt_state"]
285
284
  data = res["data"]
286
285
  loss = res["loss"]
287
286
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -3,7 +3,7 @@ jinns/data/_DataGenerators.py,sha256=nIuKtkX4V4ckfT4-g0bjlY7BLkgcok5JbI9OzJn73mA
3
3
  jinns/data/__init__.py,sha256=S13J59Fxuph4uNJ542fP_Mj8U72ilhb5t_UQ-c1k3nY,232
4
4
  jinns/data/_display.py,sha256=Xnfo6_PH1g-ZFpWJcbF6CF6Pp12wJtNQb1W1bADuQrA,6134
5
5
  jinns/loss/_DynamicLoss.py,sha256=VyoyWdkoxRPeP2vs4ZZBK_T9xWgwkcDuaFrjUSid3Zo,52975
6
- jinns/loss/_DynamicLossAbstract.py,sha256=O_dxb7_0yBCxsoa7ornU5PvqOqDPlkOwK55CGgdlGSs,8578
6
+ jinns/loss/_DynamicLossAbstract.py,sha256=V9NJvMmqnSC06yccu9bkFCWLF3cjyr8ze8qf0LRykjo,7718
7
7
  jinns/loss/_LossODE.py,sha256=FHTKQPqLSoMh18j_RYUoR7tUGg_ljd0JKL8xf2HiF5M,17541
8
8
  jinns/loss/_LossPDE.py,sha256=8wEEPnkibNOdqCXtcuU5nieRsPoffSyV8wcUMjcUkvg,57145
9
9
  jinns/loss/__init__.py,sha256=4JxMHHVMxTMsVZmV8mRSIyMVAEp3QIR8QtZvnwvj96Q,560
@@ -12,11 +12,11 @@ jinns/loss/_operators.py,sha256=HYGDq3K_K7lztDp6al88Q78F6o-hDRo3ODhw8UMRdC8,5690
12
12
  jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  jinns/solver/_rar.py,sha256=V9y07F6objmP6rPA305dIJ82h7kwP4AWBAktZ68b-38,13894
14
14
  jinns/solver/_seq2seq.py,sha256=XNL9e0fBj85Q86XfGDzq9dzxIkPPMwoJF38C8doNYtM,6032
15
- jinns/solver/_solve.py,sha256=pCctVnrkfOptGd1wH2w9jDdFiOMmNqgatqaoYtyE_LM,10071
15
+ jinns/solver/_solve.py,sha256=Yz_asD0ZuYN923E6ysxtPemN-36Zt-BtlSPkTX-BfA8,10047
16
16
  jinns/utils/__init__.py,sha256=-jDlwCjyEzWweswKdwLal3OhaUU3FVzK_Ge2S-7KHXs,149
17
17
  jinns/utils/_utils.py,sha256=bQm6z_xPKJj9BMCr2tXc44IA8JyGNN-PR5LNRhZ1fD8,20085
18
- jinns-0.4.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
19
- jinns-0.4.0.dist-info/METADATA,sha256=iJyoQcOqKE4tgof5Xr_gfOgIKWkhD_VmfWGfHj9GJ14,1821
20
- jinns-0.4.0.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
21
- jinns-0.4.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
22
- jinns-0.4.0.dist-info/RECORD,,
18
+ jinns-0.4.2.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
19
+ jinns-0.4.2.dist-info/METADATA,sha256=maaTIojnCdHhTIPd2kCJa8QGfX36mWAt1DN7q8Pd_3o,1821
20
+ jinns-0.4.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
21
+ jinns-0.4.2.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
22
+ jinns-0.4.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.3)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
File without changes