jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@ from jaxtyping import Float, Array, PyTree
|
|
|
16
16
|
import jax
|
|
17
17
|
import jax.numpy as jnp
|
|
18
18
|
from jinns.parameters._params import EqParams
|
|
19
|
+
from jinns.nn import SPINN
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
# See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
|
|
@@ -38,6 +39,7 @@ def _decorator_heteregeneous_params(evaluate):
|
|
|
38
39
|
self._eval_heterogeneous_parameters(
|
|
39
40
|
inputs, u, params, self.eq_params_heterogeneity
|
|
40
41
|
),
|
|
42
|
+
is_leaf=lambda x: x is None,
|
|
41
43
|
)
|
|
42
44
|
new_args = args[:-1] + (_params,)
|
|
43
45
|
res = evaluate(*new_args)
|
|
@@ -152,7 +154,7 @@ class DynamicLoss(eqx.Module, Generic[InputDim]):
|
|
|
152
154
|
"The output of dynamic loss must be vectorial, "
|
|
153
155
|
"i.e. of shape (d,) with d >= 1"
|
|
154
156
|
)
|
|
155
|
-
if len(evaluation.shape) > 1:
|
|
157
|
+
if len(evaluation.shape) > 1 and not isinstance(u, SPINN):
|
|
156
158
|
warnings.warn(
|
|
157
159
|
"Return value from DynamicLoss' equation has more "
|
|
158
160
|
"than one dimension. This is in general a mistake (probably from "
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -28,13 +28,13 @@ from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysOD
|
|
|
28
28
|
from jinns.loss._loss_weights import LossWeightsODE
|
|
29
29
|
from jinns.loss._abstract_loss import AbstractLoss
|
|
30
30
|
from jinns.loss._loss_components import ODEComponents
|
|
31
|
+
from jinns.loss import ODE
|
|
31
32
|
from jinns.parameters._params import Params
|
|
32
33
|
from jinns.data._Batchs import ODEBatch
|
|
33
34
|
|
|
34
35
|
if TYPE_CHECKING:
|
|
35
36
|
# imports only used in type hints
|
|
36
37
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
37
|
-
from jinns.loss import ODE
|
|
38
38
|
|
|
39
39
|
InitialConditionUser = (
|
|
40
40
|
tuple[Float[Array, " n_cond "], Float[Array, " n_cond dim"]]
|
|
@@ -47,7 +47,11 @@ if TYPE_CHECKING:
|
|
|
47
47
|
)
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
class LossODE(
|
|
50
|
+
class LossODE(
|
|
51
|
+
AbstractLoss[
|
|
52
|
+
LossWeightsODE, ODEBatch, ODEComponents[Array | None], DerivativeKeysODE
|
|
53
|
+
]
|
|
54
|
+
):
|
|
51
55
|
r"""Loss object for an ordinary differential equation
|
|
52
56
|
|
|
53
57
|
$$
|
|
@@ -57,44 +61,46 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
57
61
|
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
58
62
|
initial condition is $u(t_0)=u_0$.
|
|
59
63
|
|
|
60
|
-
|
|
61
64
|
Parameters
|
|
62
65
|
----------
|
|
63
66
|
u : eqx.Module
|
|
64
67
|
the PINN
|
|
65
|
-
dynamic_loss : ODE
|
|
68
|
+
dynamic_loss : tuple[ODE, ...] | ODE | None
|
|
66
69
|
the ODE dynamic part of the loss, basically the differential
|
|
67
70
|
operator $\mathcal{N}[u](t)$. Should implement a method
|
|
68
71
|
`dynamic_loss.evaluate(t, u, params)`.
|
|
69
72
|
Can be None in order to access only some part of the evaluate call.
|
|
70
|
-
loss_weights : LossWeightsODE, default=None
|
|
73
|
+
loss_weights : LossWeightsODE | None, default=None
|
|
71
74
|
The loss weights for the differents term : dynamic loss,
|
|
72
75
|
initial condition and eventually observations if any.
|
|
73
76
|
Can be updated according to a specific algorithm. See
|
|
74
77
|
`update_weight_method`
|
|
75
|
-
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'], default=None
|
|
78
|
+
update_weight_method : Literal['soft_adapt', 'lr_annealing', 'ReLoBRaLo'] | None, default=None
|
|
76
79
|
Default is None meaning no update for loss weights. Otherwise a string
|
|
77
|
-
|
|
80
|
+
keep_initial_loss_weight_scales : bool, default=True
|
|
81
|
+
Only used if an update weight method is specified. It decides whether
|
|
82
|
+
the updated loss weights are multiplied by the initial `loss_weights`
|
|
83
|
+
passed by the user at initialization. This is useful to force some
|
|
84
|
+
scale difference between the adaptative loss weights even after the
|
|
85
|
+
update method is applied.
|
|
86
|
+
derivative_keys : DerivativeKeysODE | None, default=None
|
|
78
87
|
Specify which field of `params` should be differentiated for each
|
|
79
88
|
composant of the total loss. Particularily useful for inverse problems.
|
|
80
89
|
Fields can be "nn_params", "eq_params" or "both". Those that should not
|
|
81
90
|
be updated will have a `jax.lax.stop_gradient` called on them. Default
|
|
82
91
|
is `"nn_params"` for each composant of the loss.
|
|
83
|
-
initial_condition :
|
|
84
|
-
Float[Array, "n_cond "],
|
|
85
|
-
Float[Array, "n_cond dim"]
|
|
86
|
-
] |
|
|
87
|
-
tuple[int | float | Float[Array, " "],
|
|
88
|
-
int | float | Float[Array, " dim"]
|
|
89
|
-
], default=None
|
|
92
|
+
initial_condition : InitialConditionUser | None, default=None
|
|
90
93
|
Most of the time, a tuple of length 2 with initial condition $(t_0, u_0)$.
|
|
91
94
|
From jinns v1.5.1 we accept tuples of jnp arrays with shape (n_cond, 1) for t0 and (n_cond, dim) for u0. This is useful to include observed conditions at different time points, such as *e.g* final conditions. It was designed to implement $\mathcal{L}^{aux}$ from _Systems biology informed deep learning for inferring parameters and hidden dynamics_, Alireza Yazdani et al., 2020
|
|
92
|
-
obs_slice : EllipsisType | slice, default=None
|
|
95
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
93
96
|
Slice object specifying the begininning/ending
|
|
94
97
|
slice of u output(s) that is observed. This is useful for
|
|
95
98
|
multidimensional PINN, with partially observed outputs.
|
|
96
99
|
Default is None (whole output is observed).
|
|
97
|
-
|
|
100
|
+
**Note**: If several observation datasets are passed this arguments need to be set as a
|
|
101
|
+
tuple of jnp.slice objects with the same length as the number of
|
|
102
|
+
observation datasets
|
|
103
|
+
params : InitVar[Params[Array]] | None, default=None
|
|
98
104
|
The main Params object of the problem needed to instanciate the
|
|
99
105
|
DerivativeKeysODE if the latter is not specified.
|
|
100
106
|
Raises
|
|
@@ -106,23 +112,26 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
106
112
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
107
113
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
108
114
|
u: AbstractPINN
|
|
109
|
-
dynamic_loss: ODE | None
|
|
115
|
+
dynamic_loss: tuple[ODE | None, ...]
|
|
110
116
|
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
111
117
|
derivative_keys: DerivativeKeysODE
|
|
112
118
|
loss_weights: LossWeightsODE
|
|
113
119
|
initial_condition: InitialCondition | None
|
|
114
|
-
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
120
|
+
obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
|
|
115
121
|
params: InitVar[Params[Array] | None]
|
|
116
122
|
|
|
117
123
|
def __init__(
|
|
118
124
|
self,
|
|
119
125
|
*,
|
|
120
126
|
u: AbstractPINN,
|
|
121
|
-
dynamic_loss: ODE | None,
|
|
127
|
+
dynamic_loss: tuple[ODE, ...] | ODE | None,
|
|
122
128
|
loss_weights: LossWeightsODE | None = None,
|
|
123
129
|
derivative_keys: DerivativeKeysODE | None = None,
|
|
124
130
|
initial_condition: InitialConditionUser | None = None,
|
|
125
|
-
obs_slice: EllipsisType | slice
|
|
131
|
+
obs_slice: tuple[EllipsisType | slice, ...]
|
|
132
|
+
| EllipsisType
|
|
133
|
+
| slice
|
|
134
|
+
| None = None,
|
|
126
135
|
params: Params[Array] | None = None,
|
|
127
136
|
**kwargs: Any, # this is for arguments for super()
|
|
128
137
|
):
|
|
@@ -131,10 +140,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
131
140
|
else:
|
|
132
141
|
self.loss_weights = loss_weights
|
|
133
142
|
|
|
134
|
-
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
135
|
-
self.u = u
|
|
136
|
-
self.dynamic_loss = dynamic_loss
|
|
137
|
-
self.vmap_in_axes = (0,)
|
|
138
143
|
if derivative_keys is None:
|
|
139
144
|
# by default we only take gradient wrt nn_params
|
|
140
145
|
if params is None:
|
|
@@ -142,9 +147,28 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
142
147
|
"Problem at derivative_keys initialization "
|
|
143
148
|
f"received {derivative_keys=} and {params=}"
|
|
144
149
|
)
|
|
145
|
-
|
|
150
|
+
derivative_keys = DerivativeKeysODE(params=params)
|
|
151
|
+
|
|
152
|
+
super().__init__(
|
|
153
|
+
loss_weights=self.loss_weights,
|
|
154
|
+
derivative_keys=derivative_keys,
|
|
155
|
+
vmap_in_axes=(0,),
|
|
156
|
+
**kwargs,
|
|
157
|
+
)
|
|
158
|
+
self.u = u
|
|
159
|
+
if not isinstance(dynamic_loss, tuple):
|
|
160
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
146
161
|
else:
|
|
147
|
-
self.
|
|
162
|
+
self.dynamic_loss = dynamic_loss
|
|
163
|
+
if self.update_weight_method is not None and jnp.any(
|
|
164
|
+
jnp.array(jax.tree.leaves(self.loss_weights)) == 0
|
|
165
|
+
):
|
|
166
|
+
warnings.warn(
|
|
167
|
+
"self.update_weight_method is activated while some loss "
|
|
168
|
+
"weights are zero. The update weight method will likely "
|
|
169
|
+
"update the zero weight to some non-zero value. Check that "
|
|
170
|
+
"this is the desired behaviour."
|
|
171
|
+
)
|
|
148
172
|
|
|
149
173
|
if initial_condition is None:
|
|
150
174
|
warnings.warn(
|
|
@@ -215,12 +239,21 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
215
239
|
self.initial_condition = (t0, u0)
|
|
216
240
|
|
|
217
241
|
if obs_slice is None:
|
|
218
|
-
self.obs_slice = jnp.s_[...]
|
|
242
|
+
self.obs_slice = (jnp.s_[...],)
|
|
243
|
+
elif not isinstance(obs_slice, tuple):
|
|
244
|
+
self.obs_slice = (obs_slice,)
|
|
219
245
|
else:
|
|
220
246
|
self.obs_slice = obs_slice
|
|
221
247
|
|
|
248
|
+
if self.loss_weights is None:
|
|
249
|
+
self.loss_weights = LossWeightsODE()
|
|
250
|
+
|
|
222
251
|
def evaluate_by_terms(
|
|
223
|
-
self,
|
|
252
|
+
self,
|
|
253
|
+
opt_params: Params[Array],
|
|
254
|
+
batch: ODEBatch,
|
|
255
|
+
*,
|
|
256
|
+
non_opt_params: Params[Array] | None = None,
|
|
224
257
|
) -> tuple[
|
|
225
258
|
ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
|
|
226
259
|
]:
|
|
@@ -231,15 +264,22 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
231
264
|
|
|
232
265
|
Parameters
|
|
233
266
|
---------
|
|
234
|
-
|
|
235
|
-
Parameters at which the loss is evaluated
|
|
267
|
+
opt_params
|
|
268
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
236
269
|
batch
|
|
237
270
|
Composed of a batch of points in the
|
|
238
271
|
domain, a batch of points in the domain
|
|
239
272
|
border and an optional additional batch of parameters (eg. for
|
|
240
273
|
metamodeling) and an optional additional batch of observed
|
|
241
274
|
inputs/outputs/parameters
|
|
275
|
+
non_opt_params
|
|
276
|
+
Parameters, which are not optimized, at which the loss is evaluated
|
|
242
277
|
"""
|
|
278
|
+
if non_opt_params is not None:
|
|
279
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
280
|
+
else:
|
|
281
|
+
params = opt_params
|
|
282
|
+
|
|
243
283
|
temporal_batch = batch.temporal_batch
|
|
244
284
|
|
|
245
285
|
# Retrieve the optional eq_params_batch
|
|
@@ -253,23 +293,29 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
253
293
|
cast(eqx.Module, batch.param_batch_dict), params
|
|
254
294
|
)
|
|
255
295
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
296
|
+
if self.dynamic_loss != (None,):
|
|
297
|
+
# Note, for the record, multiple dynamic losses
|
|
298
|
+
# have been introduced in MR 92
|
|
299
|
+
dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
300
|
+
jax.tree.map(
|
|
301
|
+
lambda d: lambda p: dynamic_loss_apply(
|
|
302
|
+
d.evaluate,
|
|
303
|
+
self.u,
|
|
304
|
+
temporal_batch,
|
|
305
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
306
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
307
|
+
),
|
|
308
|
+
self.dynamic_loss,
|
|
309
|
+
is_leaf=lambda x: isinstance(x, ODE), # do not traverse
|
|
310
|
+
# further than first level
|
|
266
311
|
)
|
|
267
312
|
)
|
|
268
313
|
else:
|
|
269
314
|
dyn_loss_fun = None
|
|
270
315
|
|
|
271
316
|
if self.initial_condition is not None:
|
|
272
|
-
# initial
|
|
317
|
+
# Note, for the record, multiple initial conditions for LossODEs
|
|
318
|
+
# have been introduced in MR 77
|
|
273
319
|
t0, u0 = self.initial_condition
|
|
274
320
|
|
|
275
321
|
# first construct the plain init loss no vmaping
|
|
@@ -304,34 +350,50 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
304
350
|
# if there is no parameter batch to vmap over we cannot call
|
|
305
351
|
# vmap because calling vmap must be done with at least one non
|
|
306
352
|
# None in_axes or out_axes
|
|
307
|
-
initial_condition_fun = initial_condition_fun_
|
|
353
|
+
initial_condition_fun = (initial_condition_fun_,)
|
|
308
354
|
else:
|
|
309
|
-
initial_condition_fun:
|
|
355
|
+
initial_condition_fun: (
|
|
356
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
357
|
+
) = (
|
|
310
358
|
lambda p: jnp.mean(
|
|
311
359
|
vmap(initial_condition_fun_, vmap_in_axes_params)(p)
|
|
312
|
-
)
|
|
360
|
+
),
|
|
313
361
|
)
|
|
362
|
+
# Note that since MR 92
|
|
363
|
+
# initial_condition_fun is formed as a tuple for
|
|
364
|
+
# consistency with dynamic and observation
|
|
365
|
+
# losses and more modularity for later
|
|
314
366
|
else:
|
|
315
367
|
initial_condition_fun = None
|
|
316
368
|
|
|
317
369
|
if batch.obs_batch_dict is not None:
|
|
318
|
-
#
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
lambda
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
370
|
+
# Note, for the record, multiple DGObs
|
|
371
|
+
# (leading to batch.obs_batch_dict being tuple | None)
|
|
372
|
+
# have been introduced in MR 92
|
|
373
|
+
if len(batch.obs_batch_dict) != len(self.obs_slice):
|
|
374
|
+
raise ValueError(
|
|
375
|
+
"There must be the same number of "
|
|
376
|
+
"observation datasets as the number of "
|
|
377
|
+
"obs_slice"
|
|
378
|
+
)
|
|
379
|
+
params_obs = jax.tree.map(
|
|
380
|
+
lambda d: update_eq_params(params, d["eq_params"]),
|
|
381
|
+
batch.obs_batch_dict,
|
|
382
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
383
|
+
)
|
|
384
|
+
obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
385
|
+
jax.tree.map(
|
|
386
|
+
lambda d, slice_: lambda po: observations_loss_apply(
|
|
387
|
+
self.u,
|
|
388
|
+
d["pinn_in"],
|
|
389
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
390
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
391
|
+
d["val"],
|
|
392
|
+
slice_,
|
|
393
|
+
),
|
|
394
|
+
batch.obs_batch_dict,
|
|
334
395
|
self.obs_slice,
|
|
396
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
335
397
|
)
|
|
336
398
|
)
|
|
337
399
|
else:
|
|
@@ -339,33 +401,36 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
339
401
|
obs_loss_fun = None
|
|
340
402
|
|
|
341
403
|
# get the unweighted mses for each loss term as well as the gradients
|
|
342
|
-
all_funs: ODEComponents[Callable[[Params[Array]], Array] | None] = (
|
|
404
|
+
all_funs: ODEComponents[tuple[Callable[[Params[Array]], Array], ...] | None] = (
|
|
343
405
|
ODEComponents(dyn_loss_fun, initial_condition_fun, obs_loss_fun)
|
|
344
406
|
)
|
|
345
|
-
all_params: ODEComponents[Params[Array] | None] = ODEComponents(
|
|
346
|
-
params,
|
|
407
|
+
all_params: ODEComponents[tuple[Params[Array], ...] | None] = ODEComponents(
|
|
408
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
409
|
+
jax.tree.map(lambda l: params, initial_condition_fun),
|
|
410
|
+
params_obs,
|
|
347
411
|
)
|
|
348
412
|
|
|
349
|
-
# Note that the lambda functions below are with type: ignore just
|
|
350
|
-
# because the lambda are not type annotated, but there is no proper way
|
|
351
|
-
# to do this and we should assign the lambda to a type hinted variable
|
|
352
|
-
# before hand: this is not practical, let us not get mad at this
|
|
353
413
|
mses_grads = jax.tree.map(
|
|
354
414
|
self.get_gradients,
|
|
355
415
|
all_funs,
|
|
356
416
|
all_params,
|
|
357
417
|
is_leaf=lambda x: x is None,
|
|
358
418
|
)
|
|
359
|
-
|
|
419
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
420
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
421
|
+
# the tuple of dyn_loss is of length 2)
|
|
360
422
|
mses = jax.tree.map(
|
|
361
|
-
lambda leaf: leaf[0],
|
|
423
|
+
lambda leaf: leaf[0],
|
|
362
424
|
mses_grads,
|
|
363
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
425
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
426
|
+
and len(x) == 2
|
|
427
|
+
and isinstance(x[1], Params),
|
|
364
428
|
)
|
|
365
429
|
grads = jax.tree.map(
|
|
366
|
-
lambda leaf: leaf[1],
|
|
430
|
+
lambda leaf: leaf[1],
|
|
367
431
|
mses_grads,
|
|
368
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
432
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
433
|
+
and len(x) == 2
|
|
434
|
+
and isinstance(x[1], Params),
|
|
369
435
|
)
|
|
370
|
-
|
|
371
436
|
return mses, grads
|