jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
@@ -2,219 +2,374 @@
2
2
  Implements abstract classes for dynamic losses
3
3
  """
4
4
 
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
5
8
 
6
- class DynamicLoss:
9
+ import equinox as eqx
10
+ from typing import Callable, Dict, TYPE_CHECKING, ClassVar
11
+ from jaxtyping import Float, Array
12
+ from functools import partial
13
+ import abc
14
+
15
+
16
+ # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
17
+ if TYPE_CHECKING:
18
+ from typing import ClassVar as AbstractClassVar
19
+ from jinns.parameters import Params, ParamsDict
20
+ else:
21
+ from equinox import AbstractClassVar
22
+
23
+
24
+ def _decorator_heteregeneous_params(evaluate, eq_type):
25
+
26
+ def wrapper_ode(*args):
27
+ self, t, u, params = args
28
+ _params = eqx.tree_at(
29
+ lambda p: p.eq_params,
30
+ params,
31
+ self._eval_heterogeneous_parameters(
32
+ t, None, u, params, self.eq_params_heterogeneity
33
+ ),
34
+ )
35
+ new_args = args[:-1] + (_params,)
36
+ res = evaluate(*new_args)
37
+ return res
38
+
39
+ def wrapper_pde_statio(*args):
40
+ self, x, u, params = args
41
+ _params = eqx.tree_at(
42
+ lambda p: p.eq_params,
43
+ params,
44
+ self._eval_heterogeneous_parameters(
45
+ None, x, u, params, self.eq_params_heterogeneity
46
+ ),
47
+ )
48
+ new_args = args[:-1] + (_params,)
49
+ res = evaluate(*new_args)
50
+ return res
51
+
52
+ def wrapper_pde_non_statio(*args):
53
+ self, t, x, u, params = args
54
+ _params = eqx.tree_at(
55
+ lambda p: p.eq_params,
56
+ params,
57
+ self._eval_heterogeneous_parameters(
58
+ t, x, u, params, self.eq_params_heterogeneity
59
+ ),
60
+ )
61
+ new_args = args[:-1] + (_params,)
62
+ res = evaluate(*new_args)
63
+ return res
64
+
65
+ if eq_type == "ODE":
66
+ return wrapper_ode
67
+ elif eq_type == "Statio PDE":
68
+ return wrapper_pde_statio
69
+ elif eq_type == "Non-statio PDE":
70
+ return wrapper_pde_non_statio
71
+
72
+
73
+ class DynamicLoss(eqx.Module):
7
74
  r"""
8
- Abstract base class for dynamic losses whose aim is to implement the term:
75
+ Abstract base class for dynamic losses. Implements the physical term:
9
76
 
10
- .. math::
77
+ $$
11
78
  \mathcal{N}[u](t, x) = 0
79
+ $$
80
+
81
+ for **one** point $t$, $x$ or $(t, x)$, depending on the context.
82
+
83
+ Parameters
84
+ ----------
85
+ Tmax : Float, default=1
86
+ Tmax needs to be given when the PINN time input is normalized in
87
+ [0, 1], ie. we have performed renormalization of the differential
88
+ equation
89
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
90
+ A dict with the same keys as eq_params and the value being either None
91
+ (no heterogeneity) or a function which encodes for the spatio-temporal
92
+ heterogeneity of the parameter.
93
+ Such a function must be jittable and take four arguments `t`, `x`,
94
+ `u` and `params` even if some are not used. Therefore,
95
+ one can introduce spatio-temporal covariates upon which a particular
96
+ parameter can depend, e.g. in a Generalized Linear Model fashion. The
97
+ effect of these covariates can themselves be estimated by being in
98
+ `eq_params` too.
99
+ A value can be missing, in this case there is no heterogeneity (=None).
100
+ Default None, meaning there is no heterogeneity in the equation
101
+ parameters.
12
102
  """
13
103
 
14
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
15
- """
16
- Parameters
17
- ----------
18
- Tmax
19
- Tmax needs to be given when the PINN time input is normalized in
20
- [0, 1], ie. we have performed renormalization of the differential
21
- equation
22
- eq_params_heterogeneity
23
- Default None. A dict with the keys being the same as in eq_params
24
- and the value being either None (no heterogeneity) or a function
25
- which encodes for the spatio-temporal heterogeneity of the parameter.
26
- Such a function must be jittable and take four arguments `t`, `x`,
27
- `u` and `params` even if one is not used. Therefore,
28
- one can introduce spatio-temporal covariates upon which a particular
29
- parameter can depend, e.g. in a GLM fashion. The effect of these
30
- covariables can themselves be estimated by being in `eq_params` too.
31
- A value can be missing, in this case there is no heterogeneity (=None).
32
- If eq_params_heterogeneity is None this means there is no
33
- heterogeneity for no parameters.
34
- """
35
- self.Tmax = Tmax
36
- self.eq_params_heterogeneity = eq_params_heterogeneity
104
+ _eq_type = AbstractClassVar[str] # class variable denoting the type of
105
+ # differential equation
106
+ Tmax: Float = eqx.field(kw_only=True, default=1)
107
+ eq_params_heterogeneity: Dict[str, Callable | None] = eqx.field(
108
+ kw_only=True, default=None, static=True
109
+ )
37
110
 
38
- @staticmethod
39
- def _eval_heterogeneous_parameters(t, x, u, params, eq_params_heterogeneity=None):
111
+ def _eval_heterogeneous_parameters(
112
+ self,
113
+ t: Float[Array, "1"],
114
+ x: Float[Array, "dim"],
115
+ u: eqx.Module,
116
+ params: Params | ParamsDict,
117
+ eq_params_heterogeneity: Dict[str, Callable | None] = None,
118
+ ) -> Dict[str, float | Float[Array, "parameter_dimension"]]:
40
119
  eq_params_ = {}
41
120
  if eq_params_heterogeneity is None:
42
- return params["eq_params"]
43
- for k, p in params["eq_params"].items():
121
+ return params.eq_params
122
+ for k, p in params.eq_params.items():
44
123
  try:
45
124
  if eq_params_heterogeneity[k] is None:
46
125
  eq_params_[k] = p
47
126
  else:
48
- if t is None:
49
- eq_params_[k] = eq_params_heterogeneity[k](
50
- x, u, params # heterogeneity encoded through a function
51
- )
52
- else:
53
- eq_params_[k] = eq_params_heterogeneity[k](
54
- t, x, u, params # heterogeneity encoded through a function
55
- )
127
+ # heterogeneity encoded through a function whose
128
+ # signature will vary according to _eq_type
129
+ if self._eq_type == "ODE":
130
+ eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
131
+ elif self._eq_type == "Statio PDE":
132
+ eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
133
+ elif self._eq_type == "Non-statio PDE":
134
+ eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
56
135
  except KeyError:
57
136
  # we authorize missing eq_params_heterogeneity key
58
- # is its heterogeneity is None anyway
137
+ # if its heterogeneity is None anyway
59
138
  eq_params_[k] = p
60
139
  return eq_params_
61
140
 
141
+ def _evaluate(
142
+ self,
143
+ t: Float[Array, "1"],
144
+ x: Float[Array, "dim"],
145
+ u: eqx.Module,
146
+ params: Params | ParamsDict,
147
+ ) -> float:
148
+ # Here we handle the various possible signature
149
+ if self._eq_type == "ODE":
150
+ ans = self.equation(t, u, params)
151
+ elif self._eq_type == "Statio PDE":
152
+ ans = self.equation(x, u, params)
153
+ elif self._eq_type == "Non-statio PDE":
154
+ ans = self.equation(t, x, u, params)
155
+ else:
156
+ raise NotImplementedError("the equation type is not handled.")
157
+
158
+ return ans
159
+
160
+ @abc.abstractmethod
161
+ def equation(self, *args, **kwargs):
162
+ # TO IMPLEMENT
163
+ # Point-wise evaluation of the differential equation N[u](.)
164
+ raise NotImplementedError("You should implement your equation.")
165
+
62
166
 
63
167
  class ODE(DynamicLoss):
64
168
  r"""
65
- Abstract base class for ODE dynamic losses
169
+ Abstract base class for ODE dynamic losses. All dynamic loss must subclass
170
+ this class and override the abstract method `equation`.
171
+
172
+ Attributes
173
+ ----------
174
+ Tmax : float, default=1
175
+ Tmax needs to be given when the PINN time input is normalized in
176
+ [0, 1], ie. we have performed renormalization of the differential
177
+ equation
178
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
179
+ Default None. A dict with the keys being the same as in eq_params
180
+ and the value being either None (no heterogeneity) or a function
181
+ which encodes for the spatio-temporal heterogeneity of the parameter.
182
+ Such a function must be jittable and take four arguments `t`, `x`,
183
+ `u` and `params` even if one is not used. Therefore,
184
+ one can introduce spatio-temporal covariates upon which a particular
185
+ parameter can depend, e.g. in a GLM fashion. The effect of these
186
+ covariables can themselves be estimated by being in `eq_params` too.
187
+ Some key can be missing, in this case there is no heterogeneity (=None).
188
+ If eq_params_heterogeneity is None this means there is no
189
+ heterogeneity for no parameters.
66
190
  """
67
191
 
68
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
69
- """
192
+ _eq_type: ClassVar[str] = "ODE"
193
+
194
+ @partial(_decorator_heteregeneous_params, eq_type="ODE")
195
+ def evaluate(
196
+ self,
197
+ t: Float[Array, "1"],
198
+ u: eqx.Module | Dict[str, eqx.Module],
199
+ params: Params | ParamsDict,
200
+ ) -> float:
201
+ """Here we call DynamicLoss._evaluate with x=None"""
202
+ return self._evaluate(t, None, u, params)
203
+
204
+ @abc.abstractmethod
205
+ def equation(
206
+ self, t: Float[Array, "1"], u: eqx.Module, params: Params | ParamsDict
207
+ ) -> float:
208
+ r"""
209
+ The differential operator defining the ODE.
210
+
211
+ !!! warning
212
+
213
+ This is an abstract method to be implemented by users.
214
+
70
215
  Parameters
71
216
  ----------
72
- Tmax
73
- Tmax needs to be given when the PINN time input is normalized in
74
- [0, 1], ie. we have performed renormalization of the differential
75
- equation
76
- eq_params_heterogeneity
77
- Default None. A dict with the keys being the same as in eq_params
78
- and the value being `time`, `space`, `both` or None which corresponds to
79
- the heterogeneity of a given parameter. A value can be missing, in
80
- this case there is no heterogeneity (=None). If
81
- eq_params_heterogeneity is None this means there is no
82
- heterogeneity for no parameters.
83
- """
84
- super().__init__(Tmax, eq_params_heterogeneity)
217
+ t : Float[Array, "1"]
218
+ A 1-dimensional jnp.array representing the time point.
219
+ u : eqx.Module
220
+ The network with a call signature `u(t, params)`.
221
+ params : Params | ParamsDict
222
+ The equation and neural network parameters $\theta$ and $\nu$.
85
223
 
86
- def eval_heterogeneous_parameters(self, t, u, params, eq_params_heterogeneity=None):
87
- return super()._eval_heterogeneous_parameters(
88
- t, None, u, params, eq_params_heterogeneity
89
- )
224
+ Returns
225
+ -------
226
+ float
227
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t)$ evaluated at point `t`.
90
228
 
91
- @staticmethod
92
- def evaluate_heterogeneous_parameters(evaluate):
93
- """
94
- Decorator which aims to decorate the evaluate methods of Dynamic losses
95
- in order. It calls _eval_heterogeneous_parameters which applies the
96
- user defined rules to obtain spatially / temporally heterogeneous
97
- parameters
229
+ Raises
230
+ ------
231
+ NotImplementedError
232
+ This is an abstract method to be implemented.
98
233
  """
99
-
100
- def wrapper(*args):
101
- self, t, u, params = args
102
- # avoid side effect with in-place modif of param["eq_params"]
103
- # TODO NamedTuple for params and use _replace() see Issue 1
104
- _params = {
105
- "nn_params": params["nn_params"],
106
- "eq_params": self.eval_heterogeneous_parameters(
107
- t, u, params, self.eq_params_heterogeneity
108
- ),
109
- }
110
- new_args = args[:-1] + (_params,)
111
- res = evaluate(*new_args)
112
- return res
113
-
114
- return wrapper
234
+ raise NotImplementedError
115
235
 
116
236
 
117
237
  class PDEStatio(DynamicLoss):
118
238
  r"""
119
- Abstract base class for PDE statio dynamic losses
239
+ Abstract base class for stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
240
+
241
+ Attributes
242
+ ----------
243
+ Tmax : float, default=1
244
+ Tmax needs to be given when the PINN time input is normalized in
245
+ [0, 1], ie. we have performed renormalization of the differential
246
+ equation
247
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
248
+ Default None. A dict with the keys being the same as in eq_params
249
+ and the value being either None (no heterogeneity) or a function
250
+ which encodes for the spatio-temporal heterogeneity of the parameter.
251
+ Such a function must be jittable and take four arguments `t`, `x`,
252
+ `u` and `params` even if one is not used. Therefore,
253
+ one can introduce spatio-temporal covariates upon which a particular
254
+ parameter can depend, e.g. in a GLM fashion. The effect of these
255
+ covariables can themselves be estimated by being in `eq_params` too.
256
+ Some key can be missing, in this case there is no heterogeneity (=None).
257
+ If eq_params_heterogeneity is None this means there is no
258
+ heterogeneity for no parameters.
120
259
  """
121
260
 
122
- def __init__(self, eq_params_heterogeneity=None):
123
- """
261
+ _eq_type: ClassVar[str] = "Statio PDE"
262
+
263
+ @partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
264
+ def evaluate(
265
+ self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
266
+ ) -> float:
267
+ """Here we call the DynamicLoss._evaluate with t=None"""
268
+ return self._evaluate(None, x, u, params)
269
+
270
+ @abc.abstractmethod
271
+ def equation(
272
+ self, x: Float[Array, "d"], u: eqx.Module, params: Params | ParamsDict
273
+ ) -> float:
274
+ r"""The differential operator defining the stationnary PDE.
275
+
276
+ !!! warning
277
+
278
+ This is an abstract method to be implemented by users.
279
+
124
280
  Parameters
125
281
  ----------
126
- eq_params_heterogeneity
127
- Default None. A dict with the keys being the same as in eq_params
128
- and the value being `time`, `space`, `both` or None which corresponds to
129
- the heterogeneity of a given parameter. A value can be missing, in
130
- this case there is no heterogeneity (=None). If
131
- eq_params_heterogeneity is None this means there is no
132
- heterogeneity for no parameters.
133
- """
134
- super().__init__(eq_params_heterogeneity=eq_params_heterogeneity)
282
+ x : Float[Array, "d"]
283
+ A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
284
+ u : eqx.Module
285
+ The neural network.
286
+ params : Params | ParamsDict
287
+ The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
135
288
 
136
- def eval_heterogeneous_parameters(self, x, u, params, eq_params_heterogeneity=None):
137
- return super()._eval_heterogeneous_parameters(
138
- None, x, u, params, eq_params_heterogeneity
139
- )
289
+ Returns
290
+ -------
291
+ float
292
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](x)$ evaluated at point `x`.
140
293
 
141
- @staticmethod
142
- def evaluate_heterogeneous_parameters(evaluate):
143
- """
144
- Decorator which aims to decorate the evaluate methods of Dynamic losses
145
- in order. It calls _eval_heterogeneous_parameters which applies the
146
- user defined rules to obtain spatially / temporally heterogeneous
147
- parameters
294
+ Raises
295
+ ------
296
+ NotImplementedError
297
+ This is an abstract method to be implemented.
148
298
  """
149
-
150
- def wrapper(*args):
151
- self, x, u, params = args
152
- # avoid side effect with in-place modif of param["eq_params"]
153
- # TODO NamedTuple for params and use _replace() see Issue 1
154
- _params = {
155
- "nn_params": params["nn_params"],
156
- "eq_params": self.eval_heterogeneous_parameters(
157
- x, u, params, self.eq_params_heterogeneity
158
- ),
159
- }
160
- new_args = args[:-1] + (_params,)
161
- res = evaluate(*new_args)
162
- return res
163
-
164
- return wrapper
299
+ raise NotImplementedError
165
300
 
166
301
 
167
302
  class PDENonStatio(DynamicLoss):
168
- r"""
169
- Abstract base class for PDE Non statio dynamic losses
170
303
  """
304
+ Abstract base class for non-stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
305
+
306
+ Attributes
307
+ ----------
308
+ Tmax : float, default=1
309
+ Tmax needs to be given when the PINN time input is normalized in
310
+ [0, 1], ie. we have performed renormalization of the differential
311
+ equation
312
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
313
+ Default None. A dict with the keys being the same as in eq_params
314
+ and the value being either None (no heterogeneity) or a function
315
+ which encodes for the spatio-temporal heterogeneity of the parameter.
316
+ Such a function must be jittable and take four arguments `t`, `x`,
317
+ `u` and `params` even if one is not used. Therefore,
318
+ one can introduce spatio-temporal covariates upon which a particular
319
+ parameter can depend, e.g. in a GLM fashion. The effect of these
320
+ covariables can themselves be estimated by being in `eq_params` too.
321
+ Some key can be missing, in this case there is no heterogeneity (=None).
322
+ If eq_params_heterogeneity is None this means there is no
323
+ heterogeneity for no parameters.
324
+ """
325
+
326
+ _eq_type: ClassVar[str] = "Non-statio PDE"
327
+
328
+ @partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
329
+ def evaluate(
330
+ self,
331
+ t: Float[Array, "1"],
332
+ x: Float[Array, "dim"],
333
+ u: eqx.Module,
334
+ params: Params | ParamsDict,
335
+ ) -> float:
336
+ """Here we call the DynamicLoss._evaluate with full arguments"""
337
+ ans = self._evaluate(t, x, u, params)
338
+ return ans
339
+
340
+ @abc.abstractmethod
341
+ def equation(
342
+ self,
343
+ t: Float[Array, "1"],
344
+ x: Float[Array, "dim"],
345
+ u: eqx.Module,
346
+ params: Params | ParamsDict,
347
+ ) -> float:
348
+ r"""The differential operator defining the non-stationnary PDE.
349
+
350
+ !!! warning
351
+
352
+ This is an abstract method to be implemented by users.
171
353
 
172
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
173
- """
174
354
  Parameters
175
355
  ----------
176
- Tmax
177
- Tmax needs to be given when the PINN time input is normalized in
178
- [0, 1], ie. we have performed renormalization of the differential
179
- equation
180
- eq_params_heterogeneity
181
- Default None. A dict with the keys being the same as in eq_params
182
- and the value being `time`, `space`, `both` or None which corresponds to
183
- the heterogeneity of a given parameter. A value can be missing, in
184
- this case there is no heterogeneity (=None). If
185
- eq_params_heterogeneity is None this means there is no
186
- heterogeneity for no parameters.
187
- """
188
- super().__init__(Tmax, eq_params_heterogeneity)
356
+ t : Float[Array, "1"]
357
+ A 1-dimensional jnp.array representing the time point.
358
+ x : Float[Array, "d"]
359
+ A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
360
+ u : eqx.Module
361
+ The neural network.
362
+ params : Params | ParamsDict
363
+ The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
364
+ Returns
365
+ -------
366
+ float
367
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t, x)$ evaluated at point `(t, x)`.
189
368
 
190
- def eval_heterogeneous_parameters(
191
- self, t, x, u, params, eq_params_heterogeneity=None
192
- ):
193
- return super()._eval_heterogeneous_parameters(
194
- t, x, u, params, eq_params_heterogeneity
195
- )
196
369
 
197
- @staticmethod
198
- def evaluate_heterogeneous_parameters(evaluate):
199
- """
200
- Decorator which aims to decorate the evaluate methods of Dynamic losses
201
- in order. It calls _eval_heterogeneous_parameters which applies the
202
- user defined rules to obtain spatially / temporally heterogeneous
203
- parameters
370
+ Raises
371
+ ------
372
+ NotImplementedError
373
+ This is an abstract method to be implemented.
204
374
  """
205
-
206
- def wrapper(*args):
207
- self, t, x, u, params = args
208
- # avoid side effect with in-place modif of param["eq_params"]
209
- # TODO NamedTuple for params and use _replace() see Issue 1
210
- _params = {
211
- "nn_params": params["nn_params"],
212
- "eq_params": self.eval_heterogeneous_parameters(
213
- t, x, u, params, self.eq_params_heterogeneity
214
- ),
215
- }
216
- new_args = args[:-1] + (_params,)
217
- res = evaluate(*new_args)
218
- return res
219
-
220
- return wrapper
375
+ raise NotImplementedError