jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,885 @@
|
|
|
1
|
+
"""
|
|
2
|
+
`jinns.solve_alternate()` to efficiently resolve inverse problems
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import operator
|
|
9
|
+
from dataclasses import fields
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import optax
|
|
14
|
+
from jaxtyping import Array, PRNGKeyArray, Float
|
|
15
|
+
import equinox as eqx
|
|
16
|
+
|
|
17
|
+
from jinns.parameters._params import Params
|
|
18
|
+
from jinns.solver._utils import (
|
|
19
|
+
_init_stored_weights_terms,
|
|
20
|
+
_init_stored_params,
|
|
21
|
+
_get_break_fun,
|
|
22
|
+
_loss_evaluate_and_gradient_step,
|
|
23
|
+
_build_get_batch,
|
|
24
|
+
_store_loss_and_params,
|
|
25
|
+
_print_fn,
|
|
26
|
+
)
|
|
27
|
+
from jinns.utils._containers import (
|
|
28
|
+
DataGeneratorContainer,
|
|
29
|
+
OptimizationContainer,
|
|
30
|
+
OptimizationExtraContainer,
|
|
31
|
+
LossContainer,
|
|
32
|
+
StoredObjectContainer,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from typing import Any
|
|
37
|
+
from jinns.utils._types import AnyLossComponents
|
|
38
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
39
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
40
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
41
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def solve_alternate(
|
|
45
|
+
*,
|
|
46
|
+
n_iter: int,
|
|
47
|
+
optimizers: Params[optax.GradientTransformation],
|
|
48
|
+
n_iter_by_solver: Params[int],
|
|
49
|
+
init_params: Params[Array],
|
|
50
|
+
data: AbstractDataGenerator,
|
|
51
|
+
loss: AbstractLoss,
|
|
52
|
+
print_loss_every: int = 10,
|
|
53
|
+
tracked_params: Params[Any | None] | None = None,
|
|
54
|
+
verbose: bool = True,
|
|
55
|
+
obs_data: DataGeneratorObservations | None = None,
|
|
56
|
+
param_data: DataGeneratorParameter | None = None,
|
|
57
|
+
opt_state_fields_for_acceleration: Params[str] | None = None,
|
|
58
|
+
key: PRNGKeyArray | None = None,
|
|
59
|
+
) -> tuple[
|
|
60
|
+
Params[Array],
|
|
61
|
+
Float[Array, " n_iter_total"],
|
|
62
|
+
AnyLossComponents[Float[Array, " n_iter_total"]],
|
|
63
|
+
AbstractDataGenerator,
|
|
64
|
+
AbstractLoss,
|
|
65
|
+
optax.OptState,
|
|
66
|
+
Params[Array | None],
|
|
67
|
+
AnyLossComponents[Float[Array, " n_iter_total"]],
|
|
68
|
+
DataGeneratorObservations | None,
|
|
69
|
+
DataGeneratorParameter | None,
|
|
70
|
+
]:
|
|
71
|
+
"""
|
|
72
|
+
Efficient implementation of the alternate minimization scheme between
|
|
73
|
+
`Params.nn_params` and `Params.eq_params`. This function is recommended for inverse problems where `Params.nn_params` is arbitrarily big, but
|
|
74
|
+
`Params.eq_params` prepresents only a few physical parameters.
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
In this functions both type of parameters (`eq` and `nn`) are handled
|
|
78
|
+
separately, as well as all related quantities such as gradient updates,
|
|
79
|
+
opt_states, etc. This approach becomes more efficient than solely
|
|
80
|
+
relying on optax masked transforms and `jinns.parameters.DerivativeKeys`
|
|
81
|
+
when `Params.nn_params` is big while `Params.eq_params` is much smaller,
|
|
82
|
+
which is often the case. Indeed, `DerivativeKeys` only prevents some
|
|
83
|
+
gradients computations but a major computational bottleneck comes from
|
|
84
|
+
passing huge optax states filled with dummy zeros udpdates (for frozen
|
|
85
|
+
parameters) at each iteration, [see the `optax` issue that we raised](https://www.github.com/google-deepmind/optax/issues/993)).
|
|
86
|
+
|
|
87
|
+
Using `solve_alternate` improves this situation by handling Optax
|
|
88
|
+
optimization states separately for `nn` and `eq` params. This allows to
|
|
89
|
+
pass `None` instead of huge dummy zero updates for "frozen" parameters in
|
|
90
|
+
the optimization states. Internally, this is done thanks to the
|
|
91
|
+
`params_mask` PyTree of booleans used for `eqx.partition` and `eqx.combine`.
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
n_iter
|
|
97
|
+
The maximum number of cyles of alternate iterations.
|
|
98
|
+
optimizers
|
|
99
|
+
A `jinns.parameters.Params` object, where each leave is an optax
|
|
100
|
+
optimizer. Note that when using an `optax.chain` with a schedular for a
|
|
101
|
+
certain parameter, the iteration count considered is the one of this
|
|
102
|
+
precise parameter. That is, for parameter `theta`, the scheduler is
|
|
103
|
+
spread over `n_iter_by_solver.eq_params.theta * n_iter` steps.
|
|
104
|
+
n_iter_by_optimizer
|
|
105
|
+
A Params object, where each leaves gives the number of iteration of the
|
|
106
|
+
corresponding optimizer, within one alternate cycle.
|
|
107
|
+
init_params
|
|
108
|
+
The initial `jinns.parameters.Params` object.
|
|
109
|
+
data
|
|
110
|
+
A `jinns.data.AbstractDataGenerator` object to retrieve batches of collocation points.
|
|
111
|
+
loss
|
|
112
|
+
The loss function to minimize.
|
|
113
|
+
print_loss_every
|
|
114
|
+
Default 10. The rate at which we print the loss value in the
|
|
115
|
+
gradient step loop.
|
|
116
|
+
tracked_params
|
|
117
|
+
Default `None`. A `jinns.parameters.Params` object with non-`None` values for
|
|
118
|
+
parameters that needs to be tracked along the iterations.
|
|
119
|
+
The user can provide something like `tracked_params = jinns.parameters.Params(
|
|
120
|
+
nn_params=None, eq_params={"nu": True})` while `init_params.nn_params`
|
|
121
|
+
being a complex data structure.
|
|
122
|
+
verbose
|
|
123
|
+
Default `True`. If `False`, no output (loss or cause of
|
|
124
|
+
exiting the optimization loop) will be produced.
|
|
125
|
+
obs_data
|
|
126
|
+
Default `None`. A `jinns.data.DataGeneratorObservations`
|
|
127
|
+
object which can be used to sample minibatches of observations.
|
|
128
|
+
param_data
|
|
129
|
+
Default `None`. A `jinns.data.DataGeneratorParameter` object which can be used to
|
|
130
|
+
sample equation parameters.
|
|
131
|
+
opt_state_fields_for_acceleration
|
|
132
|
+
A `jinns.parameters.Params` object, where leave
|
|
133
|
+
is an `opt_state_field_for_acceleration` as
|
|
134
|
+
described in `jinns.solve`.
|
|
135
|
+
key
|
|
136
|
+
Default `None`. A JAX random key that can be used for diverse purpose in
|
|
137
|
+
the main iteration loop.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
|
|
142
|
+
params
|
|
143
|
+
The last non-NaN value of the params at then end of the
|
|
144
|
+
optimization process.
|
|
145
|
+
total_loss_values
|
|
146
|
+
An array of the total loss term along the gradient steps.
|
|
147
|
+
stored_loss_terms
|
|
148
|
+
A PyTree with attributes being arrays of all the values for each loss
|
|
149
|
+
term.
|
|
150
|
+
data
|
|
151
|
+
The data generator object passed as input, possibly modified.
|
|
152
|
+
loss
|
|
153
|
+
The loss object passed as input, possibly modified.
|
|
154
|
+
opt_state
|
|
155
|
+
The final `jinns.parameters.Params` PyTree with opt_state as leaves.
|
|
156
|
+
stored_params
|
|
157
|
+
A object with the stored values of the desired parameters (as
|
|
158
|
+
signified in `tracked_params` argument).
|
|
159
|
+
stored_weights_terms
|
|
160
|
+
A PyTree with leaves being arrays of all the values for each loss
|
|
161
|
+
weight. Note that if `Loss.update_weight_method is None`, we return
|
|
162
|
+
`None`,
|
|
163
|
+
because loss weights are never updated and we can then save some
|
|
164
|
+
computations.
|
|
165
|
+
obs_data
|
|
166
|
+
The `jinns.data.DataGeneratorObservations` object passed as input or
|
|
167
|
+
`None`.
|
|
168
|
+
param_data
|
|
169
|
+
The `jinns.data.DataGeneratorParameter` object passed as input or
|
|
170
|
+
`None`.
|
|
171
|
+
"""
|
|
172
|
+
# The key functions that perform the partitions are
|
|
173
|
+
# `_get_masked_optimization_stuff` and `_get_unmasked_optimization_stuff` in
|
|
174
|
+
# `jinns/solver/_utils.py`.
|
|
175
|
+
|
|
176
|
+
# The `solve_alternate()` main loop efficiently alternates between a local
|
|
177
|
+
# optimization on `nn_params` and local optimizations on all `eq_params`.
|
|
178
|
+
# There is then a main `jax.while_loop` with a main carry, and several
|
|
179
|
+
# local `jax.while_loop` for each local optimizations, with local carry
|
|
180
|
+
# structures. Local optimizations (local loops and carrys) are defined
|
|
181
|
+
# in AOT jitted functions
|
|
182
|
+
# (`nn_params_train_fun_compiled` and the elements of the dict
|
|
183
|
+
# `eq_params_train_fun_compiled`). Those AOT jitted functions comprise the
|
|
184
|
+
# body of the local loop (`_nn_params_one_iteration` and
|
|
185
|
+
# `_eq_params_one_iteration`) as well as 3 steps:
|
|
186
|
+
|
|
187
|
+
# 1) Step 1. Prepare the local carry. Make the junction with the main carry
|
|
188
|
+
# and make the appropriate initializations. See the function
|
|
189
|
+
# `_init_before_local_optimization`.
|
|
190
|
+
# 2) Step 2. Perfom the local gradient steps (local `jax.while_loop`)
|
|
191
|
+
# 3) Step 3. Extract the needed elements from the local carry at the end of
|
|
192
|
+
# the local loop to the main carry. See the function
|
|
193
|
+
# `_get_loss_and_objects_container`.
|
|
194
|
+
|
|
195
|
+
initialization_time = time.time()
|
|
196
|
+
if n_iter < 1:
|
|
197
|
+
raise ValueError("Cannot run jinns.solve for n_iter<1")
|
|
198
|
+
|
|
199
|
+
main_break_fun = _get_break_fun(
|
|
200
|
+
n_iter, verbose, conditions_str=("bool_max_iter", "bool_nan_in_params")
|
|
201
|
+
)
|
|
202
|
+
get_batch = _build_get_batch(None)
|
|
203
|
+
|
|
204
|
+
nn_n_iter = n_iter_by_solver.nn_params
|
|
205
|
+
eq_n_iters = n_iter_by_solver.eq_params
|
|
206
|
+
|
|
207
|
+
nn_optimizer = optimizers.nn_params
|
|
208
|
+
eq_optimizers = optimizers.eq_params
|
|
209
|
+
|
|
210
|
+
# NOTE below we have opt_states that are shaped as Params
|
|
211
|
+
# but this seems OK since the real gain is to not differentiate
|
|
212
|
+
# wrt to unwanted params
|
|
213
|
+
nn_opt_state = nn_optimizer.init(init_params)
|
|
214
|
+
|
|
215
|
+
if opt_state_fields_for_acceleration is None:
|
|
216
|
+
nn_opt_state_field_for_acceleration = None
|
|
217
|
+
eq_params_opt_state_field_for_accel = jax.tree.map(
|
|
218
|
+
lambda l: None,
|
|
219
|
+
eq_optimizers,
|
|
220
|
+
is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
nn_opt_state_field_for_acceleration = (
|
|
224
|
+
opt_state_fields_for_acceleration.nn_params
|
|
225
|
+
)
|
|
226
|
+
eq_params_opt_state_field_for_accel = (
|
|
227
|
+
opt_state_fields_for_acceleration.eq_params
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
eq_opt_states = jax.tree.map(
|
|
231
|
+
lambda opt_: opt_.init(init_params),
|
|
232
|
+
eq_optimizers,
|
|
233
|
+
is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
|
|
234
|
+
# do not traverse further
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# params mask to be able to optimize only on nn_params
|
|
238
|
+
# NOTE we can imagine that later on, params mask is given as user input and
|
|
239
|
+
# we could then have more refined scheme than just nn_params and eq_params.
|
|
240
|
+
nn_params_mask = Params(
|
|
241
|
+
nn_params=True, eq_params=jax.tree.map(lambda ll: False, init_params.eq_params)
|
|
242
|
+
)
|
|
243
|
+
# derivative keys with only nn_params updates for the gradient steps over nn_params
|
|
244
|
+
# this is a standard derivative key, with True for nn_params and False to
|
|
245
|
+
# all leaves of eq_params
|
|
246
|
+
nn_gd_steps_derivative_keys = jax.tree.map(
|
|
247
|
+
lambda l: nn_params_mask,
|
|
248
|
+
loss.derivative_keys,
|
|
249
|
+
is_leaf=lambda x: isinstance(x, Params),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# and get the negative to optimize only on eq_params FOR EACH EQ_PARAMS
|
|
253
|
+
# Hence the PyTree we need to construct to tree.map over is a little more
|
|
254
|
+
# complex since we need to keep the overall dict structure
|
|
255
|
+
|
|
256
|
+
eq_params_masks, eq_gd_steps_derivative_keys = (
|
|
257
|
+
_get_eq_param_masks_and_derivative_keys(eq_optimizers, init_params, loss)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
#######################################
|
|
261
|
+
# SOME INITIALIZATIONS FOR CONTAINERS #
|
|
262
|
+
#######################################
|
|
263
|
+
|
|
264
|
+
# initialize the PyTree for stored loss values
|
|
265
|
+
total_iter_all_solvers = jax.tree.reduce(operator.add, n_iter_by_solver, 0)
|
|
266
|
+
|
|
267
|
+
# initialize parameter tracking
|
|
268
|
+
if tracked_params is None:
|
|
269
|
+
tracked_params = jax.tree.map(lambda p: None, init_params)
|
|
270
|
+
stored_params = _init_stored_params(
|
|
271
|
+
tracked_params, init_params, n_iter * total_iter_all_solvers
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# initialize the dict for stored parameter values
|
|
275
|
+
# we need to get a loss_term to init stuff
|
|
276
|
+
# NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
|
|
277
|
+
# structure
|
|
278
|
+
batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
279
|
+
_, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
|
|
280
|
+
|
|
281
|
+
stored_loss_terms = jax.tree_util.tree_map(
|
|
282
|
+
lambda _: jnp.zeros((n_iter * total_iter_all_solvers)), loss_terms
|
|
283
|
+
)
|
|
284
|
+
n_iter_list_eq_params = jax.tree.leaves(n_iter_by_solver.eq_params)
|
|
285
|
+
train_loss_values = jnp.zeros((n_iter * total_iter_all_solvers))
|
|
286
|
+
|
|
287
|
+
# initialize the PyTree for stored loss weights values
|
|
288
|
+
if loss.update_weight_method is not None:
|
|
289
|
+
stored_weights_terms = _init_stored_weights_terms(
|
|
290
|
+
loss, n_iter * total_iter_all_solvers
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
stored_weights_terms = None
|
|
294
|
+
|
|
295
|
+
train_data = DataGeneratorContainer(
|
|
296
|
+
data=data, param_data=param_data, obs_data=obs_data
|
|
297
|
+
)
|
|
298
|
+
optimization = OptimizationContainer(
|
|
299
|
+
params=init_params,
|
|
300
|
+
last_non_nan_params=init_params,
|
|
301
|
+
opt_state=(nn_opt_state, eq_opt_states), # NOTE that this field changes
|
|
302
|
+
# between the outer while loop and inner loops
|
|
303
|
+
)
|
|
304
|
+
optimization_extra = OptimizationExtraContainer(
|
|
305
|
+
curr_seq=None,
|
|
306
|
+
best_iter_id=None,
|
|
307
|
+
best_val_criterion=None,
|
|
308
|
+
best_val_params=None,
|
|
309
|
+
)
|
|
310
|
+
loss_container = LossContainer(
|
|
311
|
+
stored_loss_terms=stored_loss_terms,
|
|
312
|
+
train_loss_values=train_loss_values,
|
|
313
|
+
stored_weights_terms=stored_weights_terms,
|
|
314
|
+
)
|
|
315
|
+
stored_objects = StoredObjectContainer(
|
|
316
|
+
stored_params=stored_params,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Main carry defined here
|
|
320
|
+
carry = (
|
|
321
|
+
0,
|
|
322
|
+
loss,
|
|
323
|
+
optimization,
|
|
324
|
+
optimization_extra,
|
|
325
|
+
train_data,
|
|
326
|
+
loss_container,
|
|
327
|
+
stored_objects,
|
|
328
|
+
key,
|
|
329
|
+
)
|
|
330
|
+
###
|
|
331
|
+
|
|
332
|
+
# NOTE we precompile the eq_n_iters[eq_params]-iterations over eq_params
|
|
333
|
+
# that we will repeat many times. This gets the compilation cost out of the
|
|
334
|
+
# loop. This is done for each equation parameters, those functions are
|
|
335
|
+
# stored in a dictionary.
|
|
336
|
+
|
|
337
|
+
eq_param_eq_optim = tuple(
|
|
338
|
+
(f.name, getattr(eq_optimizers, f.name)) for f in fields(eq_optimizers)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
eq_params_train_fun_compiled = {}
|
|
342
|
+
for idx_params, (eq_param, eq_optim) in enumerate(eq_param_eq_optim):
|
|
343
|
+
n_iter_for_params = getattr(eq_n_iters, eq_param)
|
|
344
|
+
|
|
345
|
+
def eq_train_fun(_, carry):
|
|
346
|
+
i = carry[0]
|
|
347
|
+
loss_container = carry[5]
|
|
348
|
+
stored_objects = carry[6]
|
|
349
|
+
|
|
350
|
+
def _eq_params_one_iteration(carry):
|
|
351
|
+
(
|
|
352
|
+
i,
|
|
353
|
+
loss,
|
|
354
|
+
optimization,
|
|
355
|
+
_,
|
|
356
|
+
train_data,
|
|
357
|
+
loss_container,
|
|
358
|
+
stored_objects,
|
|
359
|
+
key,
|
|
360
|
+
) = carry
|
|
361
|
+
|
|
362
|
+
(nn_opt_state, eq_opt_states) = optimization.opt_state
|
|
363
|
+
|
|
364
|
+
batch, data, param_data, obs_data = get_batch(
|
|
365
|
+
train_data.data, train_data.param_data, train_data.obs_data
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
if key is not None:
|
|
369
|
+
key, subkey = jax.random.split(key)
|
|
370
|
+
else:
|
|
371
|
+
subkey = None
|
|
372
|
+
# Gradient step
|
|
373
|
+
(
|
|
374
|
+
train_loss_value,
|
|
375
|
+
params,
|
|
376
|
+
last_non_nan_params,
|
|
377
|
+
eq_opt_state,
|
|
378
|
+
loss,
|
|
379
|
+
loss_terms,
|
|
380
|
+
) = _loss_evaluate_and_gradient_step(
|
|
381
|
+
i,
|
|
382
|
+
batch,
|
|
383
|
+
loss,
|
|
384
|
+
optimization.params,
|
|
385
|
+
optimization.last_non_nan_params,
|
|
386
|
+
getattr(eq_opt_states, eq_param),
|
|
387
|
+
eq_optim,
|
|
388
|
+
loss_container,
|
|
389
|
+
subkey,
|
|
390
|
+
getattr(eq_params_masks, eq_param),
|
|
391
|
+
getattr(eq_params_opt_state_field_for_accel, eq_param),
|
|
392
|
+
with_loss_weight_update=True,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# save loss value and selected parameters
|
|
396
|
+
stored_objects_, loss_container_ = _store_loss_and_params(
|
|
397
|
+
i,
|
|
398
|
+
params,
|
|
399
|
+
stored_objects.stored_params,
|
|
400
|
+
loss_container,
|
|
401
|
+
train_loss_value,
|
|
402
|
+
loss_terms,
|
|
403
|
+
loss.loss_weights,
|
|
404
|
+
tracked_params,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
carry = (
|
|
408
|
+
i + 1,
|
|
409
|
+
loss,
|
|
410
|
+
OptimizationContainer(
|
|
411
|
+
params,
|
|
412
|
+
last_non_nan_params,
|
|
413
|
+
(
|
|
414
|
+
nn_opt_state,
|
|
415
|
+
eqx.tree_at(
|
|
416
|
+
lambda pt: (getattr(pt, eq_param),),
|
|
417
|
+
eq_opt_states,
|
|
418
|
+
(eq_opt_state,),
|
|
419
|
+
),
|
|
420
|
+
),
|
|
421
|
+
),
|
|
422
|
+
carry[3],
|
|
423
|
+
DataGeneratorContainer(
|
|
424
|
+
data=data, param_data=param_data, obs_data=obs_data
|
|
425
|
+
),
|
|
426
|
+
loss_container_,
|
|
427
|
+
stored_objects_,
|
|
428
|
+
carry[7],
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
return carry
|
|
432
|
+
|
|
433
|
+
break_fun_ = _get_break_fun(
|
|
434
|
+
n_iter_for_params,
|
|
435
|
+
verbose=False,
|
|
436
|
+
conditions_str=("bool_max_iter", "bool_nan_in_params"),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# STEP 1 (see main docstring)
|
|
440
|
+
start_idx = i * (sum(n_iter_list_eq_params) + nn_n_iter) + sum(
|
|
441
|
+
n_iter_list_eq_params[:idx_params]
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
loss_, loss_container_, stored_objects_ = _init_before_local_optimization(
|
|
445
|
+
eq_gd_steps_derivative_keys[eq_param],
|
|
446
|
+
n_iter_for_params,
|
|
447
|
+
loss_terms,
|
|
448
|
+
carry[1],
|
|
449
|
+
loss_container,
|
|
450
|
+
start_idx,
|
|
451
|
+
tracked_params,
|
|
452
|
+
init_params,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
carry_ = (
|
|
456
|
+
0,
|
|
457
|
+
loss_,
|
|
458
|
+
carry[2],
|
|
459
|
+
carry[3],
|
|
460
|
+
carry[4],
|
|
461
|
+
loss_container_,
|
|
462
|
+
stored_objects_,
|
|
463
|
+
carry[7],
|
|
464
|
+
)
|
|
465
|
+
# STEP 2 (see main docstring)
|
|
466
|
+
carry_ = jax.lax.while_loop(break_fun_, _eq_params_one_iteration, carry_)
|
|
467
|
+
|
|
468
|
+
# STEP 3 (see main docstring)
|
|
469
|
+
loss_container, stored_objects = _get_loss_and_objects_container(
|
|
470
|
+
loss_container, carry_[5], stored_objects, carry_[6], start_idx
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
carry = (
|
|
474
|
+
i,
|
|
475
|
+
carry_[1],
|
|
476
|
+
carry_[2],
|
|
477
|
+
carry_[3],
|
|
478
|
+
carry_[4],
|
|
479
|
+
loss_container,
|
|
480
|
+
stored_objects,
|
|
481
|
+
carry_[7],
|
|
482
|
+
)
|
|
483
|
+
return carry
|
|
484
|
+
|
|
485
|
+
eq_params_train_fun_compiled[eq_param] = (
|
|
486
|
+
jax.jit(eq_train_fun, static_argnums=0)
|
|
487
|
+
.trace(n_iter_for_params, jax.eval_shape(lambda _: carry, (None,)))
|
|
488
|
+
.lower()
|
|
489
|
+
.compile()
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# NOTE we precompile the local optimization loop on the nn params
|
|
493
|
+
# In the plain while loop, the compilation is costly each time
|
|
494
|
+
# In the jax lax while loop, the compilation is better but AOT is
|
|
495
|
+
# disallowed there
|
|
496
|
+
nn_break_fun_ = _get_break_fun(
|
|
497
|
+
nn_n_iter, verbose=False, conditions_str=("bool_max_iter", "bool_nan_in_params")
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
def nn_train_fun(carry):
|
|
501
|
+
i = carry[0]
|
|
502
|
+
loss_container = carry[5]
|
|
503
|
+
stored_objects = carry[6]
|
|
504
|
+
|
|
505
|
+
def _nn_params_one_iteration(carry):
|
|
506
|
+
(
|
|
507
|
+
i,
|
|
508
|
+
loss,
|
|
509
|
+
optimization,
|
|
510
|
+
_,
|
|
511
|
+
train_data,
|
|
512
|
+
loss_container,
|
|
513
|
+
stored_objects,
|
|
514
|
+
key,
|
|
515
|
+
) = carry
|
|
516
|
+
|
|
517
|
+
#
|
|
518
|
+
(nn_opt_state, eq_opt_states) = optimization.opt_state
|
|
519
|
+
|
|
520
|
+
batch, data, param_data, obs_data = get_batch(
|
|
521
|
+
train_data.data, train_data.param_data, train_data.obs_data
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Gradient step
|
|
525
|
+
if key is not None:
|
|
526
|
+
key, subkey = jax.random.split(key)
|
|
527
|
+
else:
|
|
528
|
+
subkey = None
|
|
529
|
+
(
|
|
530
|
+
train_loss_value,
|
|
531
|
+
params,
|
|
532
|
+
last_non_nan_params,
|
|
533
|
+
nn_opt_state,
|
|
534
|
+
loss,
|
|
535
|
+
loss_terms,
|
|
536
|
+
) = _loss_evaluate_and_gradient_step(
|
|
537
|
+
i,
|
|
538
|
+
batch,
|
|
539
|
+
loss,
|
|
540
|
+
optimization.params,
|
|
541
|
+
optimization.last_non_nan_params,
|
|
542
|
+
nn_opt_state,
|
|
543
|
+
nn_optimizer,
|
|
544
|
+
loss_container,
|
|
545
|
+
subkey,
|
|
546
|
+
nn_params_mask,
|
|
547
|
+
nn_opt_state_field_for_acceleration,
|
|
548
|
+
with_loss_weight_update=True,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# save loss value and selected parameters
|
|
552
|
+
stored_objects_, loss_container_ = _store_loss_and_params(
|
|
553
|
+
i,
|
|
554
|
+
params,
|
|
555
|
+
stored_objects.stored_params,
|
|
556
|
+
loss_container,
|
|
557
|
+
train_loss_value,
|
|
558
|
+
loss_terms,
|
|
559
|
+
loss.loss_weights,
|
|
560
|
+
tracked_params,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
carry = (
|
|
564
|
+
i + 1,
|
|
565
|
+
loss,
|
|
566
|
+
OptimizationContainer(
|
|
567
|
+
params, last_non_nan_params, (nn_opt_state, eq_opt_states)
|
|
568
|
+
),
|
|
569
|
+
carry[3],
|
|
570
|
+
DataGeneratorContainer(
|
|
571
|
+
data=data, param_data=param_data, obs_data=obs_data
|
|
572
|
+
),
|
|
573
|
+
loss_container_,
|
|
574
|
+
stored_objects_,
|
|
575
|
+
carry[7],
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
return carry
|
|
579
|
+
|
|
580
|
+
# STEP 1 (see main docstring)
|
|
581
|
+
start_idx = i * (sum(n_iter_list_eq_params) + nn_n_iter) + sum(
|
|
582
|
+
n_iter_list_eq_params
|
|
583
|
+
)
|
|
584
|
+
loss_, loss_container_, stored_objects_ = _init_before_local_optimization(
|
|
585
|
+
nn_gd_steps_derivative_keys,
|
|
586
|
+
nn_n_iter,
|
|
587
|
+
loss_terms,
|
|
588
|
+
carry[1],
|
|
589
|
+
loss_container,
|
|
590
|
+
start_idx,
|
|
591
|
+
tracked_params,
|
|
592
|
+
init_params,
|
|
593
|
+
)
|
|
594
|
+
carry_ = (
|
|
595
|
+
0,
|
|
596
|
+
loss_,
|
|
597
|
+
carry[2],
|
|
598
|
+
carry[3],
|
|
599
|
+
carry[4],
|
|
600
|
+
loss_container_,
|
|
601
|
+
stored_objects_,
|
|
602
|
+
carry[7],
|
|
603
|
+
)
|
|
604
|
+
# STEP 2 (see main docstring)
|
|
605
|
+
carry_ = jax.lax.while_loop(nn_break_fun_, _nn_params_one_iteration, carry_)
|
|
606
|
+
|
|
607
|
+
# Now we prepare back the main carry
|
|
608
|
+
# STEP 3 (see main docstring)
|
|
609
|
+
loss_container, stored_objects = _get_loss_and_objects_container(
|
|
610
|
+
loss_container, carry_[5], stored_objects, carry_[6], start_idx
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
carry = (
|
|
614
|
+
i,
|
|
615
|
+
carry_[1],
|
|
616
|
+
carry_[2],
|
|
617
|
+
carry_[3],
|
|
618
|
+
carry_[4],
|
|
619
|
+
loss_container,
|
|
620
|
+
stored_objects,
|
|
621
|
+
carry_[7],
|
|
622
|
+
)
|
|
623
|
+
return carry
|
|
624
|
+
|
|
625
|
+
nn_params_train_fun_compiled = (
|
|
626
|
+
jax.jit(nn_train_fun)
|
|
627
|
+
.trace(jax.eval_shape(lambda _: carry, (None,)))
|
|
628
|
+
.lower()
|
|
629
|
+
.compile()
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
if verbose:
|
|
633
|
+
print("Initialization time:", time.time() - initialization_time)
|
|
634
|
+
|
|
635
|
+
def _one_alternate_iteration(carry):
|
|
636
|
+
(
|
|
637
|
+
i,
|
|
638
|
+
loss,
|
|
639
|
+
optimization,
|
|
640
|
+
optimization_extra,
|
|
641
|
+
train_data,
|
|
642
|
+
loss_container,
|
|
643
|
+
stored_objects,
|
|
644
|
+
key,
|
|
645
|
+
) = carry
|
|
646
|
+
|
|
647
|
+
###### OPTIMIZATION ON EQ_PARAMS ###########
|
|
648
|
+
|
|
649
|
+
for eq_param, _ in eq_param_eq_optim:
|
|
650
|
+
carry = eq_params_train_fun_compiled[eq_param](carry)
|
|
651
|
+
|
|
652
|
+
###### OPTIMIZATION ON NN_PARAMS ###########
|
|
653
|
+
|
|
654
|
+
carry = nn_params_train_fun_compiled(carry)
|
|
655
|
+
|
|
656
|
+
############################################
|
|
657
|
+
|
|
658
|
+
if verbose:
|
|
659
|
+
n_iter_total = (
|
|
660
|
+
i * (sum(n_iter_list_eq_params) + nn_n_iter)
|
|
661
|
+
+ sum(n_iter_list_eq_params)
|
|
662
|
+
+ nn_n_iter
|
|
663
|
+
)
|
|
664
|
+
_print_fn(
|
|
665
|
+
i,
|
|
666
|
+
carry[5].train_loss_values[n_iter_total - 1],
|
|
667
|
+
print_loss_every,
|
|
668
|
+
prefix="[train alternate]",
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
i += 1
|
|
672
|
+
return (i, carry[1], carry[2], carry[3], carry[4], carry[5], carry[6], carry[7])
|
|
673
|
+
|
|
674
|
+
start = time.time()
|
|
675
|
+
# jax.lax.while_loop jits its content so cannot be used when we try to
|
|
676
|
+
# precompile what is inside. JAX tranformations are not compatible with AOT
|
|
677
|
+
while main_break_fun(carry):
|
|
678
|
+
carry = _one_alternate_iteration(carry)
|
|
679
|
+
jax.block_until_ready(carry)
|
|
680
|
+
end = time.time()
|
|
681
|
+
|
|
682
|
+
if verbose:
|
|
683
|
+
n_iter_total = (carry[0]) * (sum(n_iter_list_eq_params) + nn_n_iter)
|
|
684
|
+
jax.debug.print(
|
|
685
|
+
"\nFinal alternate iteration {i}: loss value = {train_loss_val}",
|
|
686
|
+
i=carry[0],
|
|
687
|
+
train_loss_val=carry[5].train_loss_values[n_iter_total - 1],
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
if verbose:
|
|
691
|
+
print("\nTraining took\n", end - start, "\n")
|
|
692
|
+
|
|
693
|
+
return (
|
|
694
|
+
carry[2].params,
|
|
695
|
+
carry[5].train_loss_values,
|
|
696
|
+
carry[5].stored_loss_terms,
|
|
697
|
+
carry[4].data,
|
|
698
|
+
carry[1], # loss
|
|
699
|
+
carry[2].opt_state,
|
|
700
|
+
carry[6].stored_params,
|
|
701
|
+
carry[5].stored_weights_terms,
|
|
702
|
+
carry[4].obs_data,
|
|
703
|
+
carry[4].param_data,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def _get_loss_and_objects_container(
|
|
708
|
+
loss_container, loss_container_, stored_objects, stored_objects_, start_idx
|
|
709
|
+
):
|
|
710
|
+
"""
|
|
711
|
+
This functions contains what needs to be done at the end of a local
|
|
712
|
+
optimization on `nn_params` or on one of the `eq_params`. This mainly
|
|
713
|
+
consists in extracting from the local carry what needs to be transferred to
|
|
714
|
+
the global carry:
|
|
715
|
+
|
|
716
|
+
- loss_container content (to get the continuity of loss values, etc.)
|
|
717
|
+
- stored_objects content (to get the continuity of stored params etc.)
|
|
718
|
+
"""
|
|
719
|
+
loss_container = LossContainer(
|
|
720
|
+
stored_loss_terms=jax.tree.map(
|
|
721
|
+
lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,)),
|
|
722
|
+
loss_container.stored_loss_terms,
|
|
723
|
+
loss_container_.stored_loss_terms,
|
|
724
|
+
),
|
|
725
|
+
train_loss_values=jax.lax.dynamic_update_slice(
|
|
726
|
+
loss_container.train_loss_values,
|
|
727
|
+
loss_container_.train_loss_values,
|
|
728
|
+
(start_idx,),
|
|
729
|
+
),
|
|
730
|
+
stored_weights_terms=jax.tree.map(
|
|
731
|
+
lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,)),
|
|
732
|
+
loss_container.stored_weights_terms,
|
|
733
|
+
loss_container_.stored_weights_terms,
|
|
734
|
+
),
|
|
735
|
+
)
|
|
736
|
+
stored_objects = StoredObjectContainer(
|
|
737
|
+
stored_params=jax.tree.map(
|
|
738
|
+
lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,) + s[0].shape),
|
|
739
|
+
stored_objects.stored_params,
|
|
740
|
+
stored_objects_.stored_params,
|
|
741
|
+
)
|
|
742
|
+
)
|
|
743
|
+
return loss_container, stored_objects
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def _init_before_local_optimization(
|
|
747
|
+
derivative_keys,
|
|
748
|
+
n_iter_local,
|
|
749
|
+
loss_terms,
|
|
750
|
+
loss,
|
|
751
|
+
loss_container,
|
|
752
|
+
start_idx,
|
|
753
|
+
tracked_params,
|
|
754
|
+
init_params,
|
|
755
|
+
):
|
|
756
|
+
"""
|
|
757
|
+
This functions contains what needs to be done at the beginning of a local
|
|
758
|
+
optimization on `nn_params` or on one of the `eq_params`. This maily
|
|
759
|
+
consists in initializating the local carry with the object having the
|
|
760
|
+
correct shape for the incoming local while loop.
|
|
761
|
+
This also
|
|
762
|
+
consists in extracting from the global carry what needs to be transferred to
|
|
763
|
+
the local carry:
|
|
764
|
+
|
|
765
|
+
- loss weight values to get the continuity of loss_weight updates methods
|
|
766
|
+
"""
|
|
767
|
+
loss_ = eqx.tree_at(
|
|
768
|
+
lambda pt: (pt.derivative_keys,),
|
|
769
|
+
loss,
|
|
770
|
+
(derivative_keys,),
|
|
771
|
+
)
|
|
772
|
+
# Reinit a loss container for this inner loop
|
|
773
|
+
stored_loss_terms_ = jax.tree_util.tree_map(
|
|
774
|
+
lambda _: jnp.zeros((n_iter_local)), loss_terms
|
|
775
|
+
)
|
|
776
|
+
train_loss_values_ = jnp.zeros((n_iter_local,))
|
|
777
|
+
if loss_.update_weight_method is not None:
|
|
778
|
+
stored_weights_terms_ = _init_stored_weights_terms(loss_, n_iter_local)
|
|
779
|
+
# ensure continuity between steps for loss weights
|
|
780
|
+
# this is important for update weight methods which requires
|
|
781
|
+
# previous weight values
|
|
782
|
+
stored_weights_terms_ = jax.tree_util.tree_map(
|
|
783
|
+
lambda st_, st: st_.at[-1].set(st[start_idx - 1]),
|
|
784
|
+
stored_weights_terms_,
|
|
785
|
+
loss_container.stored_weights_terms,
|
|
786
|
+
)
|
|
787
|
+
else:
|
|
788
|
+
stored_weights_terms_ = None
|
|
789
|
+
loss_container_ = LossContainer(
|
|
790
|
+
stored_loss_terms=stored_loss_terms_,
|
|
791
|
+
train_loss_values=train_loss_values_,
|
|
792
|
+
stored_weights_terms=stored_weights_terms_,
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
# Reinit a stored_objects for this inner loop
|
|
796
|
+
stored_params_ = _init_stored_params(tracked_params, init_params, n_iter_local)
|
|
797
|
+
stored_objects_ = StoredObjectContainer(stored_params=stored_params_)
|
|
798
|
+
return loss_, loss_container_, stored_objects_
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
def _get_eq_param_masks_and_derivative_keys(eq_optimizers, init_params, loss):
|
|
802
|
+
nb_eq_params = len(
|
|
803
|
+
jax.tree.leaves(
|
|
804
|
+
eq_optimizers, is_leaf=lambda x: isinstance(x, optax.GradientTransformation)
|
|
805
|
+
)
|
|
806
|
+
)
|
|
807
|
+
# masks_ is a sort of one hot encoding for each eq_param
|
|
808
|
+
masks_ = tuple(jnp.eye(nb_eq_params)[i] for i in range(nb_eq_params))
|
|
809
|
+
# eq_params_masks_ is a EqParams with each leaf getting its one hot
|
|
810
|
+
# encoding of the eq_param it represents
|
|
811
|
+
eq_params_masks_ = jax.tree.unflatten(
|
|
812
|
+
jax.tree.structure(
|
|
813
|
+
eq_optimizers, is_leaf=lambda x: isinstance(x, optax.GradientTransformation)
|
|
814
|
+
),
|
|
815
|
+
masks_,
|
|
816
|
+
)
|
|
817
|
+
# if you forget about the broadcast below
|
|
818
|
+
# eq_params_masks is a EqParams where each leaf is a Params
|
|
819
|
+
# where we have a 1 where the subleaf of Params is the same as the upper
|
|
820
|
+
# leaf of the EqParams
|
|
821
|
+
# now add the broadcast: it is needed because eg ll=[0, 0, 1] has just been
|
|
822
|
+
# unflattened into 3 eq_params (from eq_optimizers structure). The problem
|
|
823
|
+
# is that here, a float (0 or 0 or 1) has been assigned, all with struct
|
|
824
|
+
# (). This is problematic since it will not match struct of
|
|
825
|
+
# Params.eq_params that are tuple for eg. Then if
|
|
826
|
+
# Params.eq_params=(alpha=(0., 0.), beta=(1.,), gamma=(4., 4.,
|
|
827
|
+
# jnp.array([4., 4.]))) then the result of the unflatten will be
|
|
828
|
+
# modified into the correct structures ie,
|
|
829
|
+
# (alpha=(0, 0), beta=(0,), gamma=(1, 1, 1)) instead of
|
|
830
|
+
# (alpha=0, beta=0, gamma=1)
|
|
831
|
+
# the tree.broadcast has been added to prevent a bug in the tree.map of
|
|
832
|
+
# `_set_derivatives` of jinns DerivativeKeys
|
|
833
|
+
|
|
834
|
+
eq_params_masks = jax.tree.map(
|
|
835
|
+
lambda l, ll, p: Params(
|
|
836
|
+
nn_params=False,
|
|
837
|
+
eq_params=jax.tree.broadcast(
|
|
838
|
+
jax.tree.unflatten(
|
|
839
|
+
jax.tree.structure(
|
|
840
|
+
eq_optimizers,
|
|
841
|
+
is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
|
|
842
|
+
),
|
|
843
|
+
ll,
|
|
844
|
+
),
|
|
845
|
+
init_params.eq_params,
|
|
846
|
+
),
|
|
847
|
+
),
|
|
848
|
+
eq_optimizers,
|
|
849
|
+
eq_params_masks_,
|
|
850
|
+
init_params.eq_params,
|
|
851
|
+
is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
def replace_float(leaf):
|
|
855
|
+
if isinstance(leaf, bool):
|
|
856
|
+
return leaf
|
|
857
|
+
elif leaf == 1:
|
|
858
|
+
return True
|
|
859
|
+
elif leaf == 0:
|
|
860
|
+
return False
|
|
861
|
+
else:
|
|
862
|
+
raise ValueError
|
|
863
|
+
|
|
864
|
+
# Note that we need to replace with plain bool:
|
|
865
|
+
# 1. filter_spec does not even accept onp.array
|
|
866
|
+
# 2. filter_spec does not accept non static arguments. So any jnp array is
|
|
867
|
+
# non hashable and we will not be able to make it static
|
|
868
|
+
# params_mask cannot be inside the carry of course, just like the
|
|
869
|
+
# optimizer
|
|
870
|
+
eq_params_masks = jax.tree.map(lambda l: replace_float(l), eq_params_masks)
|
|
871
|
+
|
|
872
|
+
# derivative keys with only eq_params updates for the gradient steps over eq_params
|
|
873
|
+
# Here we make a dict for simplicity
|
|
874
|
+
# A key=a eq_param=the content to form the jinns DerivativeKeys for each eq_param
|
|
875
|
+
# There is then True for where needed
|
|
876
|
+
eq_gd_steps_derivative_keys = {
|
|
877
|
+
f.name: jax.tree.map(
|
|
878
|
+
lambda l: getattr(eq_params_masks, f.name),
|
|
879
|
+
loss.derivative_keys,
|
|
880
|
+
is_leaf=lambda x: isinstance(x, Params),
|
|
881
|
+
)
|
|
882
|
+
for f in fields(eq_params_masks)
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
return eq_params_masks, eq_gd_steps_derivative_keys
|