jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,50 +2,468 @@
2
2
  Formalize the data structure for the derivative keys
3
3
  """
4
4
 
5
- from dataclasses import fields
5
+ from functools import partial
6
+ from dataclasses import fields, InitVar
6
7
  from typing import Literal
7
8
  import jax
8
9
  import equinox as eqx
9
10
 
10
- from jinns.parameters._params import Params
11
+ from jinns.parameters._params import Params, ParamsDict
12
+
13
+
14
+ def _get_masked_parameters(
15
+ derivative_mask_str: str, params: Params | ParamsDict
16
+ ) -> Params | ParamsDict:
17
+ """
18
+ Creates the Params object with True values where we want to differentiate
19
+ """
20
+ if isinstance(params, Params):
21
+ # start with a params object with True everywhere. We will update to False
22
+ # for parameters wrt which we do want not to differentiate the loss
23
+ diff_params = jax.tree.map(
24
+ lambda x: True,
25
+ params,
26
+ is_leaf=lambda x: isinstance(x, eqx.Module)
27
+ and not isinstance(x, Params), # do not travers nn_params, more
28
+ # granularity could be imagined here, in the future
29
+ )
30
+ if derivative_mask_str == "both":
31
+ return diff_params
32
+ if derivative_mask_str == "eq_params":
33
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
34
+ if derivative_mask_str == "nn_params":
35
+ return eqx.tree_at(
36
+ lambda p: p.eq_params,
37
+ diff_params,
38
+ jax.tree.map(lambda x: False, params.eq_params),
39
+ )
40
+ raise ValueError(
41
+ "Bad value for DerivativeKeys. Got "
42
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
43
+ ' "eq_params"'
44
+ )
45
+ elif isinstance(params, ParamsDict):
46
+ # do not travers nn_params, more
47
+ # granularity could be imagined here, in the future
48
+ diff_params = ParamsDict(
49
+ nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
50
+ )
51
+ if derivative_mask_str == "both":
52
+ return diff_params
53
+ if derivative_mask_str == "eq_params":
54
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
55
+ if derivative_mask_str == "nn_params":
56
+ return eqx.tree_at(
57
+ lambda p: p.eq_params,
58
+ diff_params,
59
+ jax.tree.map(lambda x: False, params.eq_params),
60
+ )
61
+ raise ValueError(
62
+ "Bad value for DerivativeKeys. Got "
63
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
64
+ ' "eq_params"'
65
+ )
66
+
67
+ else:
68
+ raise ValueError(
69
+ f"Bad value for params. Got {type(params)}, expected Params "
70
+ " or ParamsDict"
71
+ )
11
72
 
12
73
 
13
74
  class DerivativeKeysODE(eqx.Module):
14
- # we use static = True because all fields are string, hence should be
15
- # invisible by JAX transforms (JIT, etc.)
16
- dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
17
- kw_only=True, default="nn_params", static=True
18
- )
19
- observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
20
- kw_only=True, default="nn_params", static=True
75
+ """
76
+ A class that specifies with repect to which parameter(s) each term of the
77
+ loss is differentiated. For example, you can specify that the
78
+ [`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
79
+ respect to the neural network parameters *and* the equation parameters, or only some of them.
80
+
81
+ To do so, user can either use strings or a `Params` object
82
+ with PyTree structure matching the parameters of the problem at
83
+ hand, and booleans indicating if gradient is to be taken or not. Internally,
84
+ a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
85
+ computing each loss term.
86
+
87
+ !!! note
88
+
89
+ 1. For unspecified loss term, the default is to differentiate with
90
+ respect to `"nn_params"` only.
91
+ 2. No granularity inside `Params.nn_params` is currently supported.
92
+ 3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
93
+
94
+ A typical specification is of the form:
95
+ ```python
96
+ Params(
97
+ nn_params=True | False,
98
+ eq_params={
99
+ "alpha":True | False,
100
+ "beta":True | False,
101
+ ...
102
+ }
21
103
  )
22
- initial_condition: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
23
- kw_only=True, default="nn_params", static=True
104
+ ```
105
+
106
+ Parameters
107
+ ----------
108
+ dyn_loss : Params | ParamsDict | None, default=None
109
+ Tell wrt which node of `Params` we will differentiate the
110
+ dynamic loss. To do so, the fields of `Params` contain True (if
111
+ differentiation) or False (if no differentiation).
112
+ observations : Params | ParamsDict | None, default=None
113
+ Tell wrt which parameters among Params we will differentiate the
114
+ observation loss. To do so, the fields of Params contain True (if
115
+ differentiation) or False (if no differentiation).
116
+ initial_condition : Params | ParamsDict | None, default=None
117
+ Tell wrt which parameters among Params we will differentiate the
118
+ initial condition loss. To do so, the fields of Params contain True (if
119
+ differentiation) or False (if no differentiation).
120
+ params : InitVar[Params | ParamsDict], default=None
121
+ The main Params object of the problem. It is required
122
+ if some terms are unspecified (None). This is because, jinns cannot
123
+ infer the content of `Params.eq_params`.
124
+ """
125
+
126
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
127
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
128
+ initial_condition: Params | ParamsDict | None = eqx.field(
129
+ kw_only=True, default=None
24
130
  )
25
131
 
132
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
133
+
134
+ def __post_init__(self, params=None):
135
+ if self.dyn_loss is None:
136
+ try:
137
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
138
+ except AttributeError:
139
+ raise ValueError(
140
+ "self.dyn_loss is None, hence params should be " "passed"
141
+ )
142
+ if self.observations is None:
143
+ try:
144
+ self.observations = _get_masked_parameters("nn_params", params)
145
+ except AttributeError:
146
+ raise ValueError(
147
+ "self.observations is None, hence params should be " "passed"
148
+ )
149
+ if self.initial_condition is None:
150
+ try:
151
+ self.initial_condition = _get_masked_parameters("nn_params", params)
152
+ except AttributeError:
153
+ raise ValueError(
154
+ "self.initial_condition is None, hence params should be " "passed"
155
+ )
156
+
157
+ @classmethod
158
+ def from_str(
159
+ cls,
160
+ params: Params | ParamsDict,
161
+ dyn_loss: (
162
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
163
+ ) = "nn_params",
164
+ observations: (
165
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
166
+ ) = "nn_params",
167
+ initial_condition: (
168
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
169
+ ) = "nn_params",
170
+ ):
171
+ """
172
+ Construct the DerivativeKeysODE from strings. For each term of the
173
+ loss, specify whether to differentiate wrt the neural network
174
+ parameters, the equation parameters or both. The `Params` object, which
175
+ contains the actual array of parameters must be passed to
176
+ construct the fields with the appropriate PyTree structure.
177
+
178
+ !!! note
179
+ You can mix strings and `Params` if you need granularity.
180
+
181
+ Parameters
182
+ ----------
183
+ params
184
+ The actual Params or ParamsDict object of the problem.
185
+ dyn_loss
186
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
187
+ `"both"` we will differentiate the dynamic loss. Default is
188
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
189
+ observations
190
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
191
+ `"both"` we will differentiate the observations. Default is
192
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
193
+ initial_condition
194
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
195
+ `"both"` we will differentiate the initial condition. Default is
196
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
197
+ """
198
+ return DerivativeKeysODE(
199
+ dyn_loss=(
200
+ _get_masked_parameters(dyn_loss, params)
201
+ if isinstance(dyn_loss, str)
202
+ else dyn_loss
203
+ ),
204
+ observations=(
205
+ _get_masked_parameters(observations, params)
206
+ if isinstance(observations, str)
207
+ else observations
208
+ ),
209
+ initial_condition=(
210
+ _get_masked_parameters(initial_condition, params)
211
+ if isinstance(initial_condition, str)
212
+ else initial_condition
213
+ ),
214
+ )
215
+
26
216
 
27
217
  class DerivativeKeysPDEStatio(eqx.Module):
218
+ """
219
+ See [jinns.parameters.DerivativeKeysODE][].
28
220
 
29
- dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
30
- kw_only=True, default="nn_params", static=True
31
- )
32
- observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
33
- kw_only=True, default="nn_params", static=True
34
- )
35
- boundary_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
36
- kw_only=True, default="nn_params", static=True
37
- )
38
- norm_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
39
- kw_only=True, default="nn_params", static=True
40
- )
221
+ Parameters
222
+ ----------
223
+ dyn_loss : Params | ParamsDict | None, default=None
224
+ Tell wrt which parameters among Params we will differentiate the
225
+ dynamic loss. To do so, the fields of Params contain True (if
226
+ differentiation) or False (if no differentiation).
227
+ observations : Params | ParamsDict | None, default=None
228
+ Tell wrt which parameters among Params we will differentiate the
229
+ observation loss. To do so, the fields of Params contain True (if
230
+ differentiation) or False (if no differentiation).
231
+ boundary_loss : Params | ParamsDict | None, default=None
232
+ Tell wrt which parameters among Params we will differentiate the
233
+ boundary loss. To do so, the fields of Params contain True (if
234
+ differentiation) or False (if no differentiation).
235
+ norm_loss : Params | ParamsDict | None, default=None
236
+ Tell wrt which parameters among Params we will differentiate the
237
+ normalization loss. To do so, the fields of Params contain True (if
238
+ differentiation) or False (if no differentiation).
239
+ params : InitVar[Params | ParamsDict], default=None
240
+ The main Params object of the problem. It is required
241
+ if some terms are unspecified (None). This is because, jinns cannot infer the
242
+ content of `Params.eq_params`.
243
+ """
244
+
245
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
246
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
247
+ boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
248
+ norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
249
+
250
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
251
+
252
+ def __post_init__(self, params=None):
253
+ if self.dyn_loss is None:
254
+ try:
255
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
256
+ except AttributeError:
257
+ raise ValueError("self.dyn_loss is None, hence params should be passed")
258
+ if self.observations is None:
259
+ try:
260
+ self.observations = _get_masked_parameters("nn_params", params)
261
+ except AttributeError:
262
+ raise ValueError(
263
+ "self.observations is None, hence params should be passed"
264
+ )
265
+ if self.boundary_loss is None:
266
+ try:
267
+ self.boundary_loss = _get_masked_parameters("nn_params", params)
268
+ except AttributeError:
269
+ raise ValueError(
270
+ "self.boundary_loss is None, hence params should be passed"
271
+ )
272
+ if self.norm_loss is None:
273
+ try:
274
+ self.norm_loss = _get_masked_parameters("nn_params", params)
275
+ except AttributeError:
276
+ raise ValueError(
277
+ "self.norm_loss is None, hence params should be passed"
278
+ )
279
+
280
+ @classmethod
281
+ def from_str(
282
+ cls,
283
+ params: Params | ParamsDict,
284
+ dyn_loss: (
285
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
286
+ ) = "nn_params",
287
+ observations: (
288
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
289
+ ) = "nn_params",
290
+ boundary_loss: (
291
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
292
+ ) = "nn_params",
293
+ norm_loss: (
294
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
295
+ ) = "nn_params",
296
+ ):
297
+ """
298
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
299
+
300
+ Parameters
301
+ ----------
302
+ params
303
+ The actual Param or ParamsDict object of the problem.
304
+ dyn_loss
305
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
306
+ `"both"` we will differentiate the dynamic loss. Default is
307
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
308
+ observations
309
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
310
+ `"both"` we will differentiate the observations. Default is
311
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
312
+ boundary_loss
313
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
314
+ `"both"` we will differentiate the boundary loss. Default is
315
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
316
+ norm_loss
317
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
318
+ `"both"` we will differentiate the normalization loss. Default is
319
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
320
+ """
321
+ return DerivativeKeysPDEStatio(
322
+ dyn_loss=(
323
+ _get_masked_parameters(dyn_loss, params)
324
+ if isinstance(dyn_loss, str)
325
+ else dyn_loss
326
+ ),
327
+ observations=(
328
+ _get_masked_parameters(observations, params)
329
+ if isinstance(observations, str)
330
+ else observations
331
+ ),
332
+ boundary_loss=(
333
+ _get_masked_parameters(boundary_loss, params)
334
+ if isinstance(boundary_loss, str)
335
+ else boundary_loss
336
+ ),
337
+ norm_loss=(
338
+ _get_masked_parameters(norm_loss, params)
339
+ if isinstance(norm_loss, str)
340
+ else norm_loss
341
+ ),
342
+ )
41
343
 
42
344
 
43
345
  class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
346
+ """
347
+ See [jinns.parameters.DerivativeKeysODE][].
44
348
 
45
- initial_condition: Literal["nn_params", "eq_params", "both"] = eqx.field(
46
- kw_only=True, default="nn_params", static=True
349
+ Parameters
350
+ ----------
351
+ dyn_loss : Params | ParamsDict | None, default=None
352
+ Tell wrt which parameters among Params we will differentiate the
353
+ dynamic loss. To do so, the fields of Params contain True (if
354
+ differentiation) or False (if no differentiation).
355
+ observations : Params | ParamsDict | None, default=None
356
+ Tell wrt which parameters among Params we will differentiate the
357
+ observation loss. To do so, the fields of Params contain True (if
358
+ differentiation) or False (if no differentiation).
359
+ boundary_loss : Params | ParamsDict | None, default=None
360
+ Tell wrt which parameters among Params we will differentiate the
361
+ boundary loss. To do so, the fields of Params contain True (if
362
+ differentiation) or False (if no differentiation).
363
+ norm_loss : Params | ParamsDict | None, default=None
364
+ Tell wrt which parameters among Params we will differentiate the
365
+ normalization loss. To do so, the fields of Params contain True (if
366
+ differentiation) or False (if no differentiation).
367
+ initial_condition : Params | ParamsDict | None, default=None
368
+ Tell wrt which parameters among Params we will differentiate the
369
+ initial_condition loss. To do so, the fields of Params contain True (if
370
+ differentiation) or False (if no differentiation).
371
+ params : InitVar[Params | ParamsDict], default=None
372
+ The main Params object of the problem. It is required
373
+ if some terms are unspecified (None). This is because, jinns cannot infer the
374
+ content of `Params.eq_params`.
375
+ """
376
+
377
+ initial_condition: Params | ParamsDict | None = eqx.field(
378
+ kw_only=True, default=None
47
379
  )
48
380
 
381
+ def __post_init__(self, params=None):
382
+ super().__post_init__(params=params)
383
+ if self.initial_condition is None:
384
+ try:
385
+ self.initial_condition = _get_masked_parameters("nn_params", params)
386
+ except AttributeError:
387
+ raise ValueError(
388
+ "self.initial_condition is None, hence params should be passed"
389
+ )
390
+
391
+ @classmethod
392
+ def from_str(
393
+ cls,
394
+ params: Params | ParamsDict,
395
+ dyn_loss: (
396
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
397
+ ) = "nn_params",
398
+ observations: (
399
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
400
+ ) = "nn_params",
401
+ boundary_loss: (
402
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
403
+ ) = "nn_params",
404
+ norm_loss: (
405
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
406
+ ) = "nn_params",
407
+ initial_condition: (
408
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
409
+ ) = "nn_params",
410
+ ):
411
+ """
412
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
413
+
414
+ Parameters
415
+ ----------
416
+ params
417
+ The actual Params | ParamsDict object of the problem.
418
+ dyn_loss
419
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
420
+ `"both"` we will differentiate the dynamic loss. Default is
421
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
422
+ observations
423
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
424
+ `"both"` we will differentiate the observations. Default is
425
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
426
+ boundary_loss
427
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
428
+ `"both"` we will differentiate the boundary loss. Default is
429
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
430
+ norm_loss
431
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
432
+ `"both"` we will differentiate the normalization loss. Default is
433
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
434
+ initial_condition
435
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
436
+ `"both"` we will differentiate the initial_condition loss. Default is
437
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
438
+ """
439
+ return DerivativeKeysPDENonStatio(
440
+ dyn_loss=(
441
+ _get_masked_parameters(dyn_loss, params)
442
+ if isinstance(dyn_loss, str)
443
+ else dyn_loss
444
+ ),
445
+ observations=(
446
+ _get_masked_parameters(observations, params)
447
+ if isinstance(observations, str)
448
+ else observations
449
+ ),
450
+ boundary_loss=(
451
+ _get_masked_parameters(boundary_loss, params)
452
+ if isinstance(boundary_loss, str)
453
+ else boundary_loss
454
+ ),
455
+ norm_loss=(
456
+ _get_masked_parameters(norm_loss, params)
457
+ if isinstance(norm_loss, str)
458
+ else norm_loss
459
+ ),
460
+ initial_condition=(
461
+ _get_masked_parameters(initial_condition, params)
462
+ if isinstance(initial_condition, str)
463
+ else initial_condition
464
+ ),
465
+ )
466
+
49
467
 
50
468
  def _set_derivatives(params, derivative_keys):
51
469
  """
@@ -53,42 +471,51 @@ def _set_derivatives(params, derivative_keys):
53
471
  has a copy of the params with appropriate derivatives set
54
472
  """
55
473
 
56
- def _set_derivatives_(loss_term_derivative):
57
- if loss_term_derivative == "both":
58
- return params
59
- # the next line put a stop_gradient around the fields that do not
60
- # appear in loss_term_derivative. Currently there are only two possible
61
- # values nn_params and eq_params but there might be more in the future
62
- return eqx.tree_at(
63
- lambda p: tuple(
64
- getattr(p, f.name)
65
- for f in fields(Params)
66
- if f.name != loss_term_derivative
474
+ def _set_derivatives_ParamsDict(params_, derivative_mask):
475
+ """
476
+ The next lines put a stop_gradient around the fields that do not
477
+ differentiate the loss term
478
+ **Note:** **No granularity inside `ParamsDict.nn_params` is currently
479
+ supported.**
480
+ This means a typical Params specification is of the form:
481
+ `ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
482
+ "beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
483
+ """
484
+ # a ParamsDict object is reconstructed by hand since we do not want to
485
+ # traverse nn_params, for now...
486
+ return ParamsDict(
487
+ nn_params=jax.lax.cond(
488
+ derivative_mask.nn_params,
489
+ lambda p: p,
490
+ jax.lax.stop_gradient,
491
+ params_.nn_params,
492
+ ),
493
+ eq_params=jax.tree.map(
494
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
495
+ params_.eq_params,
496
+ derivative_mask.eq_params,
67
497
  ),
68
- params,
69
- replace_fn=jax.lax.stop_gradient,
70
498
  )
71
499
 
72
- def _set_derivatives_dict(loss_term_derivative):
73
- if loss_term_derivative == "both":
74
- return params
75
- # the next line put a stop_gradient around the fields that do not
76
- # appear in loss_term_derivative. Currently there are only two possible
77
- # values nn_params and eq_params but there might be more in the future
78
- return {
79
- k: eqx.tree_at(
80
- lambda p: tuple(
81
- getattr(p, f.name)
82
- for f in fields(Params)
83
- if f.name != loss_term_derivative
84
- ),
85
- params_,
86
- replace_fn=jax.lax.stop_gradient,
87
- )
88
- for k, params_ in params
89
- }
500
+ def _set_derivatives_(params_, derivative_mask):
501
+ """
502
+ The next lines put a stop_gradient around the fields that do not
503
+ differentiate the loss term
504
+ **Note:** **No granularity inside `Params.nn_params` is currently
505
+ supported.**
506
+ This means a typical Params specification is of the form:
507
+ `Params(nn_params=True | False, eq_params={"alpha":True | False,
508
+ "beta":True | False})`.
509
+ """
510
+ return jax.tree.map(
511
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
512
+ params_,
513
+ derivative_mask,
514
+ is_leaf=lambda x: isinstance(x, eqx.Module)
515
+ and not isinstance(x, Params), # do not travers nn_params, more
516
+ # granularity could be imagined here, in the future
517
+ )
90
518
 
91
- if not isinstance(params, dict):
92
- return _set_derivatives_(derivative_keys)
93
- else:
94
- return _set_derivatives_dict(derivative_keys)
519
+ if isinstance(params, ParamsDict):
520
+ return _set_derivatives_ParamsDict(params, derivative_keys)
521
+ return _set_derivatives_(params, derivative_keys)