jinns 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +321 -168
  9. jinns/loss/_LossODE.py +292 -309
  10. jinns/loss/_LossPDE.py +625 -1010
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +95 -44
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +94 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/METADATA +4 -4
  33. jinns-1.0.0.dist-info/RECORD +38 -0
  34. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/WHEEL +1 -1
  35. jinns/experimental/_sinuspinn.py +0 -135
  36. jinns/experimental/_spectralpinn.py +0 -87
  37. jinns/solver/_seq2seq.py +0 -157
  38. jinns/utils/_optim.py +0 -147
  39. jinns/utils/_utils_uspinn.py +0 -727
  40. jinns-0.9.0.dist-info/RECORD +0 -36
  41. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/LICENSE +0 -0
  42. {jinns-0.9.0.dist-info → jinns-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,219 +2,372 @@
2
2
  Implements abstract classes for dynamic losses
3
3
  """
4
4
 
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ import equinox as eqx
10
+ from typing import Callable, Dict, TYPE_CHECKING, ClassVar
11
+ from jaxtyping import Float, Array
12
+ from functools import partial
13
+ import abc
14
+
15
+
16
+ # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
17
+ if TYPE_CHECKING:
18
+ from typing import ClassVar as AbstractClassVar
19
+ from jinns.parameters import Params, ParamsDict
20
+ else:
21
+ from equinox import AbstractClassVar
5
22
 
6
- class DynamicLoss:
7
- r"""
8
- Abstract base class for dynamic losses whose aim is to implement the term:
9
23
 
10
- .. math::
24
+ def _decorator_heteregeneous_params(evaluate, eq_type):
25
+
26
+ def wrapper_ode(*args):
27
+ self, t, u, params = args
28
+ _params = eqx.tree_at(
29
+ lambda p: p.eq_params,
30
+ params,
31
+ self._eval_heterogeneous_parameters(
32
+ t, None, u, params, self.eq_params_heterogeneity
33
+ ),
34
+ )
35
+ new_args = args[:-1] + (_params,)
36
+ res = evaluate(*new_args)
37
+ return res
38
+
39
+ def wrapper_pde_statio(*args):
40
+ self, x, u, params = args
41
+ _params = eqx.tree_at(
42
+ lambda p: p.eq_params,
43
+ params,
44
+ self._eval_heterogeneous_parameters(
45
+ None, x, u, params, self.eq_params_heterogeneity
46
+ ),
47
+ )
48
+ new_args = args[:-1] + (_params,)
49
+ res = evaluate(*new_args)
50
+ return res
51
+
52
+ def wrapper_pde_non_statio(*args):
53
+ self, t, x, u, params = args
54
+ _params = eqx.tree_at(
55
+ lambda p: p.eq_params,
56
+ params,
57
+ self._eval_heterogeneous_parameters(
58
+ t, x, u, params, self.eq_params_heterogeneity
59
+ ),
60
+ )
61
+ new_args = args[:-1] + (_params,)
62
+ res = evaluate(*new_args)
63
+ return res
64
+
65
+ if eq_type == "ODE":
66
+ return wrapper_ode
67
+ elif eq_type == "Statio PDE":
68
+ return wrapper_pde_statio
69
+ elif eq_type == "Non-statio PDE":
70
+ return wrapper_pde_non_statio
71
+
72
+
73
+ class DynamicLoss(eqx.Module):
74
+ r"""
75
+ Abstract base class for dynamic losses. Implements the physical term:
76
+ $$
11
77
  \mathcal{N}[u](t, x) = 0
78
+ $$
79
+ for **one** point $t$, $x$ or $(t, x)$, depending on the context.
80
+
81
+ Parameters
82
+ ----------
83
+ Tmax : Float, default=1
84
+ Tmax needs to be given when the PINN time input is normalized in
85
+ [0, 1], ie. we have performed renormalization of the differential
86
+ equation
87
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
88
+ A dict with the same keys as eq_params and the value being either None
89
+ (no heterogeneity) or a function which encodes for the spatio-temporal
90
+ heterogeneity of the parameter.
91
+ Such a function must be jittable and take four arguments `t`, `x`,
92
+ `u` and `params` even if some are not used. Therefore,
93
+ one can introduce spatio-temporal covariates upon which a particular
94
+ parameter can depend, e.g. in a Generalized Linear Model fashion. The
95
+ effect of these covariates can themselves be estimated by being in
96
+ `eq_params` too.
97
+ A value can be missing, in this case there is no heterogeneity (=None).
98
+ Default None, meaning there is no heterogeneity in the equation
99
+ parameters.
12
100
  """
13
101
 
14
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
15
- """
16
- Parameters
17
- ----------
18
- Tmax
19
- Tmax needs to be given when the PINN time input is normalized in
20
- [0, 1], ie. we have performed renormalization of the differential
21
- equation
22
- eq_params_heterogeneity
23
- Default None. A dict with the keys being the same as in eq_params
24
- and the value being either None (no heterogeneity) or a function
25
- which encodes for the spatio-temporal heterogeneity of the parameter.
26
- Such a function must be jittable and take four arguments `t`, `x`,
27
- `u` and `params` even if one is not used. Therefore,
28
- one can introduce spatio-temporal covariates upon which a particular
29
- parameter can depend, e.g. in a GLM fashion. The effect of these
30
- covariables can themselves be estimated by being in `eq_params` too.
31
- A value can be missing, in this case there is no heterogeneity (=None).
32
- If eq_params_heterogeneity is None this means there is no
33
- heterogeneity for no parameters.
34
- """
35
- self.Tmax = Tmax
36
- self.eq_params_heterogeneity = eq_params_heterogeneity
102
+ _eq_type = AbstractClassVar[str] # class variable denoting the type of
103
+ # differential equation
104
+ Tmax: Float = eqx.field(kw_only=True, default=1)
105
+ eq_params_heterogeneity: Dict[str, Callable | None] = eqx.field(
106
+ kw_only=True, default=None, static=True
107
+ )
37
108
 
38
- @staticmethod
39
- def _eval_heterogeneous_parameters(t, x, u, params, eq_params_heterogeneity=None):
109
+ def _eval_heterogeneous_parameters(
110
+ self,
111
+ t: Float[Array, "1"],
112
+ x: Float[Array, "dim"],
113
+ u: eqx.Module,
114
+ params: Params | ParamsDict,
115
+ eq_params_heterogeneity: Dict[str, Callable | None] = None,
116
+ ) -> Dict[str, float | Float[Array, "parameter_dimension"]]:
40
117
  eq_params_ = {}
41
118
  if eq_params_heterogeneity is None:
42
- return params["eq_params"]
43
- for k, p in params["eq_params"].items():
119
+ return params.eq_params
120
+ for k, p in params.eq_params.items():
44
121
  try:
45
122
  if eq_params_heterogeneity[k] is None:
46
123
  eq_params_[k] = p
47
124
  else:
48
- if t is None:
49
- eq_params_[k] = eq_params_heterogeneity[k](
50
- x, u, params # heterogeneity encoded through a function
51
- )
52
- else:
53
- eq_params_[k] = eq_params_heterogeneity[k](
54
- t, x, u, params # heterogeneity encoded through a function
55
- )
125
+ # heterogeneity encoded through a function whose
126
+ # signature will vary according to _eq_type
127
+ if self._eq_type == "ODE":
128
+ eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
129
+ elif self._eq_type == "Statio PDE":
130
+ eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
131
+ elif self._eq_type == "Non-statio PDE":
132
+ eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
56
133
  except KeyError:
57
134
  # we authorize missing eq_params_heterogeneity key
58
- # is its heterogeneity is None anyway
135
+ # if its heterogeneity is None anyway
59
136
  eq_params_[k] = p
60
137
  return eq_params_
61
138
 
139
+ def _evaluate(
140
+ self,
141
+ t: Float[Array, "1"],
142
+ x: Float[Array, "dim"],
143
+ u: eqx.Module,
144
+ params: Params | ParamsDict,
145
+ ) -> float:
146
+ # Here we handle the various possible signature
147
+ if self._eq_type == "ODE":
148
+ ans = self.equation(t, u, params)
149
+ elif self._eq_type == "Statio PDE":
150
+ ans = self.equation(x, u, params)
151
+ elif self._eq_type == "Non-statio PDE":
152
+ ans = self.equation(t, x, u, params)
153
+ else:
154
+ raise NotImplementedError("the equation type is not handled.")
155
+
156
+ return ans
157
+
158
+ @abc.abstractmethod
159
+ def equation(self, *args, **kwargs):
160
+ # TO IMPLEMENT
161
+ # Point-wise evaluation of the differential equation N[u](.)
162
+ raise NotImplementedError("You should implement your equation.")
163
+
62
164
 
63
165
  class ODE(DynamicLoss):
64
166
  r"""
65
- Abstract base class for ODE dynamic losses
167
+ Abstract base class for ODE dynamic losses. All dynamic loss must subclass
168
+ this class and override the abstract method `equation`.
169
+
170
+ Attributes
171
+ ----------
172
+ Tmax : float, default=1
173
+ Tmax needs to be given when the PINN time input is normalized in
174
+ [0, 1], ie. we have performed renormalization of the differential
175
+ equation
176
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
177
+ Default None. A dict with the keys being the same as in eq_params
178
+ and the value being either None (no heterogeneity) or a function
179
+ which encodes for the spatio-temporal heterogeneity of the parameter.
180
+ Such a function must be jittable and take four arguments `t`, `x`,
181
+ `u` and `params` even if one is not used. Therefore,
182
+ one can introduce spatio-temporal covariates upon which a particular
183
+ parameter can depend, e.g. in a GLM fashion. The effect of these
184
+ covariables can themselves be estimated by being in `eq_params` too.
185
+ Some key can be missing, in this case there is no heterogeneity (=None).
186
+ If eq_params_heterogeneity is None this means there is no
187
+ heterogeneity for no parameters.
66
188
  """
67
189
 
68
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
69
- """
190
+ _eq_type: ClassVar[str] = "ODE"
191
+
192
+ @partial(_decorator_heteregeneous_params, eq_type="ODE")
193
+ def evaluate(
194
+ self,
195
+ t: Float[Array, "1"],
196
+ u: eqx.Module | Dict[str, eqx.Module],
197
+ params: Params | ParamsDict,
198
+ ) -> float:
199
+ """Here we call DynamicLoss._evaluate with x=None"""
200
+ return self._evaluate(t, None, u, params)
201
+
202
+ @abc.abstractmethod
203
+ def equation(
204
+ self, t: Float[Array, "1"], u: eqx.Module, params: Params | ParamsDict
205
+ ) -> float:
206
+ r"""
207
+ The differential operator defining the ODE.
208
+
209
+ !!! warning
210
+
211
+ This is an abstract method to be implemented by users.
212
+
70
213
  Parameters
71
214
  ----------
72
- Tmax
73
- Tmax needs to be given when the PINN time input is normalized in
74
- [0, 1], ie. we have performed renormalization of the differential
75
- equation
76
- eq_params_heterogeneity
77
- Default None. A dict with the keys being the same as in eq_params
78
- and the value being `time`, `space`, `both` or None which corresponds to
79
- the heterogeneity of a given parameter. A value can be missing, in
80
- this case there is no heterogeneity (=None). If
81
- eq_params_heterogeneity is None this means there is no
82
- heterogeneity for no parameters.
83
- """
84
- super().__init__(Tmax, eq_params_heterogeneity)
215
+ t : Float[Array, "1"]
216
+ A 1-dimensional jnp.array representing the time point.
217
+ u : eqx.Module
218
+ The network with a call signature `u(t, params)`.
219
+ params : Params | ParamsDict
220
+ The equation and neural network parameters $\theta$ and $\nu$.
85
221
 
86
- def eval_heterogeneous_parameters(self, t, u, params, eq_params_heterogeneity=None):
87
- return super()._eval_heterogeneous_parameters(
88
- t, None, u, params, eq_params_heterogeneity
89
- )
222
+ Returns
223
+ -------
224
+ float
225
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t)$ evaluated at point `t`.
90
226
 
91
- @staticmethod
92
- def evaluate_heterogeneous_parameters(evaluate):
93
- """
94
- Decorator which aims to decorate the evaluate methods of Dynamic losses
95
- in order. It calls _eval_heterogeneous_parameters which applies the
96
- user defined rules to obtain spatially / temporally heterogeneous
97
- parameters
227
+ Raises
228
+ ------
229
+ NotImplementedError
230
+ This is an abstract method to be implemented.
98
231
  """
99
-
100
- def wrapper(*args):
101
- self, t, u, params = args
102
- # avoid side effect with in-place modif of param["eq_params"]
103
- # TODO NamedTuple for params and use _replace() see Issue 1
104
- _params = {
105
- "nn_params": params["nn_params"],
106
- "eq_params": self.eval_heterogeneous_parameters(
107
- t, u, params, self.eq_params_heterogeneity
108
- ),
109
- }
110
- new_args = args[:-1] + (_params,)
111
- res = evaluate(*new_args)
112
- return res
113
-
114
- return wrapper
232
+ raise NotImplementedError
115
233
 
116
234
 
117
235
  class PDEStatio(DynamicLoss):
118
236
  r"""
119
- Abstract base class for PDE statio dynamic losses
237
+ Abstract base class for stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
238
+
239
+ Attributes
240
+ ----------
241
+ Tmax : float, default=1
242
+ Tmax needs to be given when the PINN time input is normalized in
243
+ [0, 1], ie. we have performed renormalization of the differential
244
+ equation
245
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
246
+ Default None. A dict with the keys being the same as in eq_params
247
+ and the value being either None (no heterogeneity) or a function
248
+ which encodes for the spatio-temporal heterogeneity of the parameter.
249
+ Such a function must be jittable and take four arguments `t`, `x`,
250
+ `u` and `params` even if one is not used. Therefore,
251
+ one can introduce spatio-temporal covariates upon which a particular
252
+ parameter can depend, e.g. in a GLM fashion. The effect of these
253
+ covariables can themselves be estimated by being in `eq_params` too.
254
+ Some key can be missing, in this case there is no heterogeneity (=None).
255
+ If eq_params_heterogeneity is None this means there is no
256
+ heterogeneity for no parameters.
120
257
  """
121
258
 
122
- def __init__(self, eq_params_heterogeneity=None):
123
- """
259
+ _eq_type: ClassVar[str] = "Statio PDE"
260
+
261
+ @partial(_decorator_heteregeneous_params, eq_type="Statio PDE")
262
+ def evaluate(
263
+ self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
264
+ ) -> float:
265
+ """Here we call the DynamicLoss._evaluate with t=None"""
266
+ return self._evaluate(None, x, u, params)
267
+
268
+ @abc.abstractmethod
269
+ def equation(
270
+ self, x: Float[Array, "d"], u: eqx.Module, params: Params | ParamsDict
271
+ ) -> float:
272
+ r"""The differential operator defining the stationnary PDE.
273
+
274
+ !!! warning
275
+
276
+ This is an abstract method to be implemented by users.
277
+
124
278
  Parameters
125
279
  ----------
126
- eq_params_heterogeneity
127
- Default None. A dict with the keys being the same as in eq_params
128
- and the value being `time`, `space`, `both` or None which corresponds to
129
- the heterogeneity of a given parameter. A value can be missing, in
130
- this case there is no heterogeneity (=None). If
131
- eq_params_heterogeneity is None this means there is no
132
- heterogeneity for no parameters.
133
- """
134
- super().__init__(eq_params_heterogeneity=eq_params_heterogeneity)
280
+ x : Float[Array, "d"]
281
+ A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
282
+ u : eqx.Module
283
+ The neural network.
284
+ params : Params | ParamsDict
285
+ The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
135
286
 
136
- def eval_heterogeneous_parameters(self, x, u, params, eq_params_heterogeneity=None):
137
- return super()._eval_heterogeneous_parameters(
138
- None, x, u, params, eq_params_heterogeneity
139
- )
287
+ Returns
288
+ -------
289
+ float
290
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](x)$ evaluated at point `x`.
140
291
 
141
- @staticmethod
142
- def evaluate_heterogeneous_parameters(evaluate):
143
- """
144
- Decorator which aims to decorate the evaluate methods of Dynamic losses
145
- in order. It calls _eval_heterogeneous_parameters which applies the
146
- user defined rules to obtain spatially / temporally heterogeneous
147
- parameters
292
+ Raises
293
+ ------
294
+ NotImplementedError
295
+ This is an abstract method to be implemented.
148
296
  """
149
-
150
- def wrapper(*args):
151
- self, x, u, params = args
152
- # avoid side effect with in-place modif of param["eq_params"]
153
- # TODO NamedTuple for params and use _replace() see Issue 1
154
- _params = {
155
- "nn_params": params["nn_params"],
156
- "eq_params": self.eval_heterogeneous_parameters(
157
- x, u, params, self.eq_params_heterogeneity
158
- ),
159
- }
160
- new_args = args[:-1] + (_params,)
161
- res = evaluate(*new_args)
162
- return res
163
-
164
- return wrapper
297
+ raise NotImplementedError
165
298
 
166
299
 
167
300
  class PDENonStatio(DynamicLoss):
168
- r"""
169
- Abstract base class for PDE Non statio dynamic losses
170
301
  """
302
+ Abstract base class for non-stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
303
+
304
+ Attributes
305
+ ----------
306
+ Tmax : float, default=1
307
+ Tmax needs to be given when the PINN time input is normalized in
308
+ [0, 1], ie. we have performed renormalization of the differential
309
+ equation
310
+ eq_params_heterogeneity : Dict[str, Callable | None], default=None
311
+ Default None. A dict with the keys being the same as in eq_params
312
+ and the value being either None (no heterogeneity) or a function
313
+ which encodes for the spatio-temporal heterogeneity of the parameter.
314
+ Such a function must be jittable and take four arguments `t`, `x`,
315
+ `u` and `params` even if one is not used. Therefore,
316
+ one can introduce spatio-temporal covariates upon which a particular
317
+ parameter can depend, e.g. in a GLM fashion. The effect of these
318
+ covariables can themselves be estimated by being in `eq_params` too.
319
+ Some key can be missing, in this case there is no heterogeneity (=None).
320
+ If eq_params_heterogeneity is None this means there is no
321
+ heterogeneity for no parameters.
322
+ """
323
+
324
+ _eq_type: ClassVar[str] = "Non-statio PDE"
325
+
326
+ @partial(_decorator_heteregeneous_params, eq_type="Non-statio PDE")
327
+ def evaluate(
328
+ self,
329
+ t: Float[Array, "1"],
330
+ x: Float[Array, "dim"],
331
+ u: eqx.Module,
332
+ params: Params | ParamsDict,
333
+ ) -> float:
334
+ """Here we call the DynamicLoss._evaluate with full arguments"""
335
+ ans = self._evaluate(t, x, u, params)
336
+ return ans
337
+
338
+ @abc.abstractmethod
339
+ def equation(
340
+ self,
341
+ t: Float[Array, "1"],
342
+ x: Float[Array, "dim"],
343
+ u: eqx.Module,
344
+ params: Params | ParamsDict,
345
+ ) -> float:
346
+ r"""The differential operator defining the non-stationnary PDE.
347
+
348
+ !!! warning
349
+
350
+ This is an abstract method to be implemented by users.
171
351
 
172
- def __init__(self, Tmax=None, eq_params_heterogeneity=None):
173
- """
174
352
  Parameters
175
353
  ----------
176
- Tmax
177
- Tmax needs to be given when the PINN time input is normalized in
178
- [0, 1], ie. we have performed renormalization of the differential
179
- equation
180
- eq_params_heterogeneity
181
- Default None. A dict with the keys being the same as in eq_params
182
- and the value being `time`, `space`, `both` or None which corresponds to
183
- the heterogeneity of a given parameter. A value can be missing, in
184
- this case there is no heterogeneity (=None). If
185
- eq_params_heterogeneity is None this means there is no
186
- heterogeneity for no parameters.
187
- """
188
- super().__init__(Tmax, eq_params_heterogeneity)
354
+ t : Float[Array, "1"]
355
+ A 1-dimensional jnp.array representing the time point.
356
+ x : Float[Array, "d"]
357
+ A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
358
+ u : eqx.Module
359
+ The neural network.
360
+ params : Params | ParamsDict
361
+ The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
362
+ Returns
363
+ -------
364
+ float
365
+ The residual, *i.e.* the differential operator $\mathcal{N}_\theta[u_\nu](t, x)$ evaluated at point `(t, x)`.
189
366
 
190
- def eval_heterogeneous_parameters(
191
- self, t, x, u, params, eq_params_heterogeneity=None
192
- ):
193
- return super()._eval_heterogeneous_parameters(
194
- t, x, u, params, eq_params_heterogeneity
195
- )
196
367
 
197
- @staticmethod
198
- def evaluate_heterogeneous_parameters(evaluate):
199
- """
200
- Decorator which aims to decorate the evaluate methods of Dynamic losses
201
- in order. It calls _eval_heterogeneous_parameters which applies the
202
- user defined rules to obtain spatially / temporally heterogeneous
203
- parameters
368
+ Raises
369
+ ------
370
+ NotImplementedError
371
+ This is an abstract method to be implemented.
204
372
  """
205
-
206
- def wrapper(*args):
207
- self, t, x, u, params = args
208
- # avoid side effect with in-place modif of param["eq_params"]
209
- # TODO NamedTuple for params and use _replace() see Issue 1
210
- _params = {
211
- "nn_params": params["nn_params"],
212
- "eq_params": self.eval_heterogeneous_parameters(
213
- t, x, u, params, self.eq_params_heterogeneity
214
- ),
215
- }
216
- new_args = args[:-1] + (_params,)
217
- res = evaluate(*new_args)
218
- return res
219
-
220
- return wrapper
373
+ raise NotImplementedError