jinns 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,885 @@
1
+ """
2
+ `jinns.solve_alternate()` to efficiently resolve inverse problems
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import time
8
+ import operator
9
+ from dataclasses import fields
10
+ from typing import TYPE_CHECKING
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import optax
14
+ from jaxtyping import Array, PRNGKeyArray, Float
15
+ import equinox as eqx
16
+
17
+ from jinns.parameters._params import Params
18
+ from jinns.solver._utils import (
19
+ _init_stored_weights_terms,
20
+ _init_stored_params,
21
+ _get_break_fun,
22
+ _loss_evaluate_and_gradient_step,
23
+ _build_get_batch,
24
+ _store_loss_and_params,
25
+ _print_fn,
26
+ )
27
+ from jinns.utils._containers import (
28
+ DataGeneratorContainer,
29
+ OptimizationContainer,
30
+ OptimizationExtraContainer,
31
+ LossContainer,
32
+ StoredObjectContainer,
33
+ )
34
+
35
+ if TYPE_CHECKING:
36
+ from typing import Any
37
+ from jinns.utils._types import AnyLossComponents
38
+ from jinns.loss._abstract_loss import AbstractLoss
39
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
40
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
41
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
42
+
43
+
44
+ def solve_alternate(
45
+ *,
46
+ n_iter: int,
47
+ optimizers: Params[optax.GradientTransformation],
48
+ n_iter_by_solver: Params[int],
49
+ init_params: Params[Array],
50
+ data: AbstractDataGenerator,
51
+ loss: AbstractLoss,
52
+ print_loss_every: int = 10,
53
+ tracked_params: Params[Any | None] | None = None,
54
+ verbose: bool = True,
55
+ obs_data: DataGeneratorObservations | None = None,
56
+ param_data: DataGeneratorParameter | None = None,
57
+ opt_state_fields_for_acceleration: Params[str] | None = None,
58
+ key: PRNGKeyArray | None = None,
59
+ ) -> tuple[
60
+ Params[Array],
61
+ Float[Array, " n_iter_total"],
62
+ AnyLossComponents[Float[Array, " n_iter_total"]],
63
+ AbstractDataGenerator,
64
+ AbstractLoss,
65
+ optax.OptState,
66
+ Params[Array | None],
67
+ AnyLossComponents[Float[Array, " n_iter_total"]],
68
+ DataGeneratorObservations | None,
69
+ DataGeneratorParameter | None,
70
+ ]:
71
+ """
72
+ Efficient implementation of the alternate minimization scheme between
73
+ `Params.nn_params` and `Params.eq_params`. This function is recommended for inverse problems where `Params.nn_params` is arbitrarily big, but
74
+ `Params.eq_params` prepresents only a few physical parameters.
75
+
76
+
77
+ In this functions both type of parameters (`eq` and `nn`) are handled
78
+ separately, as well as all related quantities such as gradient updates,
79
+ opt_states, etc. This approach becomes more efficient than solely
80
+ relying on optax masked transforms and `jinns.parameters.DerivativeKeys`
81
+ when `Params.nn_params` is big while `Params.eq_params` is much smaller,
82
+ which is often the case. Indeed, `DerivativeKeys` only prevents some
83
+ gradients computations but a major computational bottleneck comes from
84
+ passing huge optax states filled with dummy zeros udpdates (for frozen
85
+ parameters) at each iteration, [see the `optax` issue that we raised](https://www.github.com/google-deepmind/optax/issues/993)).
86
+
87
+ Using `solve_alternate` improves this situation by handling Optax
88
+ optimization states separately for `nn` and `eq` params. This allows to
89
+ pass `None` instead of huge dummy zero updates for "frozen" parameters in
90
+ the optimization states. Internally, this is done thanks to the
91
+ `params_mask` PyTree of booleans used for `eqx.partition` and `eqx.combine`.
92
+
93
+
94
+ Parameters
95
+ ----------
96
+ n_iter
97
+ The maximum number of cyles of alternate iterations.
98
+ optimizers
99
+ A `jinns.parameters.Params` object, where each leave is an optax
100
+ optimizer. Note that when using an `optax.chain` with a schedular for a
101
+ certain parameter, the iteration count considered is the one of this
102
+ precise parameter. That is, for parameter `theta`, the scheduler is
103
+ spread over `n_iter_by_solver.eq_params.theta * n_iter` steps.
104
+ n_iter_by_optimizer
105
+ A Params object, where each leaves gives the number of iteration of the
106
+ corresponding optimizer, within one alternate cycle.
107
+ init_params
108
+ The initial `jinns.parameters.Params` object.
109
+ data
110
+ A `jinns.data.AbstractDataGenerator` object to retrieve batches of collocation points.
111
+ loss
112
+ The loss function to minimize.
113
+ print_loss_every
114
+ Default 10. The rate at which we print the loss value in the
115
+ gradient step loop.
116
+ tracked_params
117
+ Default `None`. A `jinns.parameters.Params` object with non-`None` values for
118
+ parameters that needs to be tracked along the iterations.
119
+ The user can provide something like `tracked_params = jinns.parameters.Params(
120
+ nn_params=None, eq_params={"nu": True})` while `init_params.nn_params`
121
+ being a complex data structure.
122
+ verbose
123
+ Default `True`. If `False`, no output (loss or cause of
124
+ exiting the optimization loop) will be produced.
125
+ obs_data
126
+ Default `None`. A `jinns.data.DataGeneratorObservations`
127
+ object which can be used to sample minibatches of observations.
128
+ param_data
129
+ Default `None`. A `jinns.data.DataGeneratorParameter` object which can be used to
130
+ sample equation parameters.
131
+ opt_state_fields_for_acceleration
132
+ A `jinns.parameters.Params` object, where leave
133
+ is an `opt_state_field_for_acceleration` as
134
+ described in `jinns.solve`.
135
+ key
136
+ Default `None`. A JAX random key that can be used for diverse purpose in
137
+ the main iteration loop.
138
+
139
+ Returns
140
+ -------
141
+
142
+ params
143
+ The last non-NaN value of the params at then end of the
144
+ optimization process.
145
+ total_loss_values
146
+ An array of the total loss term along the gradient steps.
147
+ stored_loss_terms
148
+ A PyTree with attributes being arrays of all the values for each loss
149
+ term.
150
+ data
151
+ The data generator object passed as input, possibly modified.
152
+ loss
153
+ The loss object passed as input, possibly modified.
154
+ opt_state
155
+ The final `jinns.parameters.Params` PyTree with opt_state as leaves.
156
+ stored_params
157
+ A object with the stored values of the desired parameters (as
158
+ signified in `tracked_params` argument).
159
+ stored_weights_terms
160
+ A PyTree with leaves being arrays of all the values for each loss
161
+ weight. Note that if `Loss.update_weight_method is None`, we return
162
+ `None`,
163
+ because loss weights are never updated and we can then save some
164
+ computations.
165
+ obs_data
166
+ The `jinns.data.DataGeneratorObservations` object passed as input or
167
+ `None`.
168
+ param_data
169
+ The `jinns.data.DataGeneratorParameter` object passed as input or
170
+ `None`.
171
+ """
172
+ # The key functions that perform the partitions are
173
+ # `_get_masked_optimization_stuff` and `_get_unmasked_optimization_stuff` in
174
+ # `jinns/solver/_utils.py`.
175
+
176
+ # The `solve_alternate()` main loop efficiently alternates between a local
177
+ # optimization on `nn_params` and local optimizations on all `eq_params`.
178
+ # There is then a main `jax.while_loop` with a main carry, and several
179
+ # local `jax.while_loop` for each local optimizations, with local carry
180
+ # structures. Local optimizations (local loops and carrys) are defined
181
+ # in AOT jitted functions
182
+ # (`nn_params_train_fun_compiled` and the elements of the dict
183
+ # `eq_params_train_fun_compiled`). Those AOT jitted functions comprise the
184
+ # body of the local loop (`_nn_params_one_iteration` and
185
+ # `_eq_params_one_iteration`) as well as 3 steps:
186
+
187
+ # 1) Step 1. Prepare the local carry. Make the junction with the main carry
188
+ # and make the appropriate initializations. See the function
189
+ # `_init_before_local_optimization`.
190
+ # 2) Step 2. Perfom the local gradient steps (local `jax.while_loop`)
191
+ # 3) Step 3. Extract the needed elements from the local carry at the end of
192
+ # the local loop to the main carry. See the function
193
+ # `_get_loss_and_objects_container`.
194
+
195
+ initialization_time = time.time()
196
+ if n_iter < 1:
197
+ raise ValueError("Cannot run jinns.solve for n_iter<1")
198
+
199
+ main_break_fun = _get_break_fun(
200
+ n_iter, verbose, conditions_str=("bool_max_iter", "bool_nan_in_params")
201
+ )
202
+ get_batch = _build_get_batch(None)
203
+
204
+ nn_n_iter = n_iter_by_solver.nn_params
205
+ eq_n_iters = n_iter_by_solver.eq_params
206
+
207
+ nn_optimizer = optimizers.nn_params
208
+ eq_optimizers = optimizers.eq_params
209
+
210
+ # NOTE below we have opt_states that are shaped as Params
211
+ # but this seems OK since the real gain is to not differentiate
212
+ # wrt to unwanted params
213
+ nn_opt_state = nn_optimizer.init(init_params)
214
+
215
+ if opt_state_fields_for_acceleration is None:
216
+ nn_opt_state_field_for_acceleration = None
217
+ eq_params_opt_state_field_for_accel = jax.tree.map(
218
+ lambda l: None,
219
+ eq_optimizers,
220
+ is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
221
+ )
222
+ else:
223
+ nn_opt_state_field_for_acceleration = (
224
+ opt_state_fields_for_acceleration.nn_params
225
+ )
226
+ eq_params_opt_state_field_for_accel = (
227
+ opt_state_fields_for_acceleration.eq_params
228
+ )
229
+
230
+ eq_opt_states = jax.tree.map(
231
+ lambda opt_: opt_.init(init_params),
232
+ eq_optimizers,
233
+ is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
234
+ # do not traverse further
235
+ )
236
+
237
+ # params mask to be able to optimize only on nn_params
238
+ # NOTE we can imagine that later on, params mask is given as user input and
239
+ # we could then have more refined scheme than just nn_params and eq_params.
240
+ nn_params_mask = Params(
241
+ nn_params=True, eq_params=jax.tree.map(lambda ll: False, init_params.eq_params)
242
+ )
243
+ # derivative keys with only nn_params updates for the gradient steps over nn_params
244
+ # this is a standard derivative key, with True for nn_params and False to
245
+ # all leaves of eq_params
246
+ nn_gd_steps_derivative_keys = jax.tree.map(
247
+ lambda l: nn_params_mask,
248
+ loss.derivative_keys,
249
+ is_leaf=lambda x: isinstance(x, Params),
250
+ )
251
+
252
+ # and get the negative to optimize only on eq_params FOR EACH EQ_PARAMS
253
+ # Hence the PyTree we need to construct to tree.map over is a little more
254
+ # complex since we need to keep the overall dict structure
255
+
256
+ eq_params_masks, eq_gd_steps_derivative_keys = (
257
+ _get_eq_param_masks_and_derivative_keys(eq_optimizers, init_params, loss)
258
+ )
259
+
260
+ #######################################
261
+ # SOME INITIALIZATIONS FOR CONTAINERS #
262
+ #######################################
263
+
264
+ # initialize the PyTree for stored loss values
265
+ total_iter_all_solvers = jax.tree.reduce(operator.add, n_iter_by_solver, 0)
266
+
267
+ # initialize parameter tracking
268
+ if tracked_params is None:
269
+ tracked_params = jax.tree.map(lambda p: None, init_params)
270
+ stored_params = _init_stored_params(
271
+ tracked_params, init_params, n_iter * total_iter_all_solvers
272
+ )
273
+
274
+ # initialize the dict for stored parameter values
275
+ # we need to get a loss_term to init stuff
276
+ # NOTE: we use jax.eval_shape to avoid FLOPS since we only need the tree
277
+ # structure
278
+ batch_ini, data, param_data, obs_data = get_batch(data, param_data, obs_data)
279
+ _, loss_terms = jax.eval_shape(loss, init_params, batch_ini)
280
+
281
+ stored_loss_terms = jax.tree_util.tree_map(
282
+ lambda _: jnp.zeros((n_iter * total_iter_all_solvers)), loss_terms
283
+ )
284
+ n_iter_list_eq_params = jax.tree.leaves(n_iter_by_solver.eq_params)
285
+ train_loss_values = jnp.zeros((n_iter * total_iter_all_solvers))
286
+
287
+ # initialize the PyTree for stored loss weights values
288
+ if loss.update_weight_method is not None:
289
+ stored_weights_terms = _init_stored_weights_terms(
290
+ loss, n_iter * total_iter_all_solvers
291
+ )
292
+ else:
293
+ stored_weights_terms = None
294
+
295
+ train_data = DataGeneratorContainer(
296
+ data=data, param_data=param_data, obs_data=obs_data
297
+ )
298
+ optimization = OptimizationContainer(
299
+ params=init_params,
300
+ last_non_nan_params=init_params,
301
+ opt_state=(nn_opt_state, eq_opt_states), # NOTE that this field changes
302
+ # between the outer while loop and inner loops
303
+ )
304
+ optimization_extra = OptimizationExtraContainer(
305
+ curr_seq=None,
306
+ best_iter_id=None,
307
+ best_val_criterion=None,
308
+ best_val_params=None,
309
+ )
310
+ loss_container = LossContainer(
311
+ stored_loss_terms=stored_loss_terms,
312
+ train_loss_values=train_loss_values,
313
+ stored_weights_terms=stored_weights_terms,
314
+ )
315
+ stored_objects = StoredObjectContainer(
316
+ stored_params=stored_params,
317
+ )
318
+
319
+ # Main carry defined here
320
+ carry = (
321
+ 0,
322
+ loss,
323
+ optimization,
324
+ optimization_extra,
325
+ train_data,
326
+ loss_container,
327
+ stored_objects,
328
+ key,
329
+ )
330
+ ###
331
+
332
+ # NOTE we precompile the eq_n_iters[eq_params]-iterations over eq_params
333
+ # that we will repeat many times. This gets the compilation cost out of the
334
+ # loop. This is done for each equation parameters, those functions are
335
+ # stored in a dictionary.
336
+
337
+ eq_param_eq_optim = tuple(
338
+ (f.name, getattr(eq_optimizers, f.name)) for f in fields(eq_optimizers)
339
+ )
340
+
341
+ eq_params_train_fun_compiled = {}
342
+ for idx_params, (eq_param, eq_optim) in enumerate(eq_param_eq_optim):
343
+ n_iter_for_params = getattr(eq_n_iters, eq_param)
344
+
345
+ def eq_train_fun(_, carry):
346
+ i = carry[0]
347
+ loss_container = carry[5]
348
+ stored_objects = carry[6]
349
+
350
+ def _eq_params_one_iteration(carry):
351
+ (
352
+ i,
353
+ loss,
354
+ optimization,
355
+ _,
356
+ train_data,
357
+ loss_container,
358
+ stored_objects,
359
+ key,
360
+ ) = carry
361
+
362
+ (nn_opt_state, eq_opt_states) = optimization.opt_state
363
+
364
+ batch, data, param_data, obs_data = get_batch(
365
+ train_data.data, train_data.param_data, train_data.obs_data
366
+ )
367
+
368
+ if key is not None:
369
+ key, subkey = jax.random.split(key)
370
+ else:
371
+ subkey = None
372
+ # Gradient step
373
+ (
374
+ train_loss_value,
375
+ params,
376
+ last_non_nan_params,
377
+ eq_opt_state,
378
+ loss,
379
+ loss_terms,
380
+ ) = _loss_evaluate_and_gradient_step(
381
+ i,
382
+ batch,
383
+ loss,
384
+ optimization.params,
385
+ optimization.last_non_nan_params,
386
+ getattr(eq_opt_states, eq_param),
387
+ eq_optim,
388
+ loss_container,
389
+ subkey,
390
+ getattr(eq_params_masks, eq_param),
391
+ getattr(eq_params_opt_state_field_for_accel, eq_param),
392
+ with_loss_weight_update=True,
393
+ )
394
+
395
+ # save loss value and selected parameters
396
+ stored_objects_, loss_container_ = _store_loss_and_params(
397
+ i,
398
+ params,
399
+ stored_objects.stored_params,
400
+ loss_container,
401
+ train_loss_value,
402
+ loss_terms,
403
+ loss.loss_weights,
404
+ tracked_params,
405
+ )
406
+
407
+ carry = (
408
+ i + 1,
409
+ loss,
410
+ OptimizationContainer(
411
+ params,
412
+ last_non_nan_params,
413
+ (
414
+ nn_opt_state,
415
+ eqx.tree_at(
416
+ lambda pt: (getattr(pt, eq_param),),
417
+ eq_opt_states,
418
+ (eq_opt_state,),
419
+ ),
420
+ ),
421
+ ),
422
+ carry[3],
423
+ DataGeneratorContainer(
424
+ data=data, param_data=param_data, obs_data=obs_data
425
+ ),
426
+ loss_container_,
427
+ stored_objects_,
428
+ carry[7],
429
+ )
430
+
431
+ return carry
432
+
433
+ break_fun_ = _get_break_fun(
434
+ n_iter_for_params,
435
+ verbose=False,
436
+ conditions_str=("bool_max_iter", "bool_nan_in_params"),
437
+ )
438
+
439
+ # STEP 1 (see main docstring)
440
+ start_idx = i * (sum(n_iter_list_eq_params) + nn_n_iter) + sum(
441
+ n_iter_list_eq_params[:idx_params]
442
+ )
443
+
444
+ loss_, loss_container_, stored_objects_ = _init_before_local_optimization(
445
+ eq_gd_steps_derivative_keys[eq_param],
446
+ n_iter_for_params,
447
+ loss_terms,
448
+ carry[1],
449
+ loss_container,
450
+ start_idx,
451
+ tracked_params,
452
+ init_params,
453
+ )
454
+
455
+ carry_ = (
456
+ 0,
457
+ loss_,
458
+ carry[2],
459
+ carry[3],
460
+ carry[4],
461
+ loss_container_,
462
+ stored_objects_,
463
+ carry[7],
464
+ )
465
+ # STEP 2 (see main docstring)
466
+ carry_ = jax.lax.while_loop(break_fun_, _eq_params_one_iteration, carry_)
467
+
468
+ # STEP 3 (see main docstring)
469
+ loss_container, stored_objects = _get_loss_and_objects_container(
470
+ loss_container, carry_[5], stored_objects, carry_[6], start_idx
471
+ )
472
+
473
+ carry = (
474
+ i,
475
+ carry_[1],
476
+ carry_[2],
477
+ carry_[3],
478
+ carry_[4],
479
+ loss_container,
480
+ stored_objects,
481
+ carry_[7],
482
+ )
483
+ return carry
484
+
485
+ eq_params_train_fun_compiled[eq_param] = (
486
+ jax.jit(eq_train_fun, static_argnums=0)
487
+ .trace(n_iter_for_params, jax.eval_shape(lambda _: carry, (None,)))
488
+ .lower()
489
+ .compile()
490
+ )
491
+
492
+ # NOTE we precompile the local optimization loop on the nn params
493
+ # In the plain while loop, the compilation is costly each time
494
+ # In the jax lax while loop, the compilation is better but AOT is
495
+ # disallowed there
496
+ nn_break_fun_ = _get_break_fun(
497
+ nn_n_iter, verbose=False, conditions_str=("bool_max_iter", "bool_nan_in_params")
498
+ )
499
+
500
+ def nn_train_fun(carry):
501
+ i = carry[0]
502
+ loss_container = carry[5]
503
+ stored_objects = carry[6]
504
+
505
+ def _nn_params_one_iteration(carry):
506
+ (
507
+ i,
508
+ loss,
509
+ optimization,
510
+ _,
511
+ train_data,
512
+ loss_container,
513
+ stored_objects,
514
+ key,
515
+ ) = carry
516
+
517
+ #
518
+ (nn_opt_state, eq_opt_states) = optimization.opt_state
519
+
520
+ batch, data, param_data, obs_data = get_batch(
521
+ train_data.data, train_data.param_data, train_data.obs_data
522
+ )
523
+
524
+ # Gradient step
525
+ if key is not None:
526
+ key, subkey = jax.random.split(key)
527
+ else:
528
+ subkey = None
529
+ (
530
+ train_loss_value,
531
+ params,
532
+ last_non_nan_params,
533
+ nn_opt_state,
534
+ loss,
535
+ loss_terms,
536
+ ) = _loss_evaluate_and_gradient_step(
537
+ i,
538
+ batch,
539
+ loss,
540
+ optimization.params,
541
+ optimization.last_non_nan_params,
542
+ nn_opt_state,
543
+ nn_optimizer,
544
+ loss_container,
545
+ subkey,
546
+ nn_params_mask,
547
+ nn_opt_state_field_for_acceleration,
548
+ with_loss_weight_update=True,
549
+ )
550
+
551
+ # save loss value and selected parameters
552
+ stored_objects_, loss_container_ = _store_loss_and_params(
553
+ i,
554
+ params,
555
+ stored_objects.stored_params,
556
+ loss_container,
557
+ train_loss_value,
558
+ loss_terms,
559
+ loss.loss_weights,
560
+ tracked_params,
561
+ )
562
+
563
+ carry = (
564
+ i + 1,
565
+ loss,
566
+ OptimizationContainer(
567
+ params, last_non_nan_params, (nn_opt_state, eq_opt_states)
568
+ ),
569
+ carry[3],
570
+ DataGeneratorContainer(
571
+ data=data, param_data=param_data, obs_data=obs_data
572
+ ),
573
+ loss_container_,
574
+ stored_objects_,
575
+ carry[7],
576
+ )
577
+
578
+ return carry
579
+
580
+ # STEP 1 (see main docstring)
581
+ start_idx = i * (sum(n_iter_list_eq_params) + nn_n_iter) + sum(
582
+ n_iter_list_eq_params
583
+ )
584
+ loss_, loss_container_, stored_objects_ = _init_before_local_optimization(
585
+ nn_gd_steps_derivative_keys,
586
+ nn_n_iter,
587
+ loss_terms,
588
+ carry[1],
589
+ loss_container,
590
+ start_idx,
591
+ tracked_params,
592
+ init_params,
593
+ )
594
+ carry_ = (
595
+ 0,
596
+ loss_,
597
+ carry[2],
598
+ carry[3],
599
+ carry[4],
600
+ loss_container_,
601
+ stored_objects_,
602
+ carry[7],
603
+ )
604
+ # STEP 2 (see main docstring)
605
+ carry_ = jax.lax.while_loop(nn_break_fun_, _nn_params_one_iteration, carry_)
606
+
607
+ # Now we prepare back the main carry
608
+ # STEP 3 (see main docstring)
609
+ loss_container, stored_objects = _get_loss_and_objects_container(
610
+ loss_container, carry_[5], stored_objects, carry_[6], start_idx
611
+ )
612
+
613
+ carry = (
614
+ i,
615
+ carry_[1],
616
+ carry_[2],
617
+ carry_[3],
618
+ carry_[4],
619
+ loss_container,
620
+ stored_objects,
621
+ carry_[7],
622
+ )
623
+ return carry
624
+
625
+ nn_params_train_fun_compiled = (
626
+ jax.jit(nn_train_fun)
627
+ .trace(jax.eval_shape(lambda _: carry, (None,)))
628
+ .lower()
629
+ .compile()
630
+ )
631
+
632
+ if verbose:
633
+ print("Initialization time:", time.time() - initialization_time)
634
+
635
+ def _one_alternate_iteration(carry):
636
+ (
637
+ i,
638
+ loss,
639
+ optimization,
640
+ optimization_extra,
641
+ train_data,
642
+ loss_container,
643
+ stored_objects,
644
+ key,
645
+ ) = carry
646
+
647
+ ###### OPTIMIZATION ON EQ_PARAMS ###########
648
+
649
+ for eq_param, _ in eq_param_eq_optim:
650
+ carry = eq_params_train_fun_compiled[eq_param](carry)
651
+
652
+ ###### OPTIMIZATION ON NN_PARAMS ###########
653
+
654
+ carry = nn_params_train_fun_compiled(carry)
655
+
656
+ ############################################
657
+
658
+ if verbose:
659
+ n_iter_total = (
660
+ i * (sum(n_iter_list_eq_params) + nn_n_iter)
661
+ + sum(n_iter_list_eq_params)
662
+ + nn_n_iter
663
+ )
664
+ _print_fn(
665
+ i,
666
+ carry[5].train_loss_values[n_iter_total - 1],
667
+ print_loss_every,
668
+ prefix="[train alternate]",
669
+ )
670
+
671
+ i += 1
672
+ return (i, carry[1], carry[2], carry[3], carry[4], carry[5], carry[6], carry[7])
673
+
674
+ start = time.time()
675
+ # jax.lax.while_loop jits its content so cannot be used when we try to
676
+ # precompile what is inside. JAX tranformations are not compatible with AOT
677
+ while main_break_fun(carry):
678
+ carry = _one_alternate_iteration(carry)
679
+ jax.block_until_ready(carry)
680
+ end = time.time()
681
+
682
+ if verbose:
683
+ n_iter_total = (carry[0]) * (sum(n_iter_list_eq_params) + nn_n_iter)
684
+ jax.debug.print(
685
+ "\nFinal alternate iteration {i}: loss value = {train_loss_val}",
686
+ i=carry[0],
687
+ train_loss_val=carry[5].train_loss_values[n_iter_total - 1],
688
+ )
689
+
690
+ if verbose:
691
+ print("\nTraining took\n", end - start, "\n")
692
+
693
+ return (
694
+ carry[2].params,
695
+ carry[5].train_loss_values,
696
+ carry[5].stored_loss_terms,
697
+ carry[4].data,
698
+ carry[1], # loss
699
+ carry[2].opt_state,
700
+ carry[6].stored_params,
701
+ carry[5].stored_weights_terms,
702
+ carry[4].obs_data,
703
+ carry[4].param_data,
704
+ )
705
+
706
+
707
+ def _get_loss_and_objects_container(
708
+ loss_container, loss_container_, stored_objects, stored_objects_, start_idx
709
+ ):
710
+ """
711
+ This functions contains what needs to be done at the end of a local
712
+ optimization on `nn_params` or on one of the `eq_params`. This mainly
713
+ consists in extracting from the local carry what needs to be transferred to
714
+ the global carry:
715
+
716
+ - loss_container content (to get the continuity of loss values, etc.)
717
+ - stored_objects content (to get the continuity of stored params etc.)
718
+ """
719
+ loss_container = LossContainer(
720
+ stored_loss_terms=jax.tree.map(
721
+ lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,)),
722
+ loss_container.stored_loss_terms,
723
+ loss_container_.stored_loss_terms,
724
+ ),
725
+ train_loss_values=jax.lax.dynamic_update_slice(
726
+ loss_container.train_loss_values,
727
+ loss_container_.train_loss_values,
728
+ (start_idx,),
729
+ ),
730
+ stored_weights_terms=jax.tree.map(
731
+ lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,)),
732
+ loss_container.stored_weights_terms,
733
+ loss_container_.stored_weights_terms,
734
+ ),
735
+ )
736
+ stored_objects = StoredObjectContainer(
737
+ stored_params=jax.tree.map(
738
+ lambda s, l: jax.lax.dynamic_update_slice(s, l, (start_idx,) + s[0].shape),
739
+ stored_objects.stored_params,
740
+ stored_objects_.stored_params,
741
+ )
742
+ )
743
+ return loss_container, stored_objects
744
+
745
+
746
+ def _init_before_local_optimization(
747
+ derivative_keys,
748
+ n_iter_local,
749
+ loss_terms,
750
+ loss,
751
+ loss_container,
752
+ start_idx,
753
+ tracked_params,
754
+ init_params,
755
+ ):
756
+ """
757
+ This functions contains what needs to be done at the beginning of a local
758
+ optimization on `nn_params` or on one of the `eq_params`. This maily
759
+ consists in initializating the local carry with the object having the
760
+ correct shape for the incoming local while loop.
761
+ This also
762
+ consists in extracting from the global carry what needs to be transferred to
763
+ the local carry:
764
+
765
+ - loss weight values to get the continuity of loss_weight updates methods
766
+ """
767
+ loss_ = eqx.tree_at(
768
+ lambda pt: (pt.derivative_keys,),
769
+ loss,
770
+ (derivative_keys,),
771
+ )
772
+ # Reinit a loss container for this inner loop
773
+ stored_loss_terms_ = jax.tree_util.tree_map(
774
+ lambda _: jnp.zeros((n_iter_local)), loss_terms
775
+ )
776
+ train_loss_values_ = jnp.zeros((n_iter_local,))
777
+ if loss_.update_weight_method is not None:
778
+ stored_weights_terms_ = _init_stored_weights_terms(loss_, n_iter_local)
779
+ # ensure continuity between steps for loss weights
780
+ # this is important for update weight methods which requires
781
+ # previous weight values
782
+ stored_weights_terms_ = jax.tree_util.tree_map(
783
+ lambda st_, st: st_.at[-1].set(st[start_idx - 1]),
784
+ stored_weights_terms_,
785
+ loss_container.stored_weights_terms,
786
+ )
787
+ else:
788
+ stored_weights_terms_ = None
789
+ loss_container_ = LossContainer(
790
+ stored_loss_terms=stored_loss_terms_,
791
+ train_loss_values=train_loss_values_,
792
+ stored_weights_terms=stored_weights_terms_,
793
+ )
794
+
795
+ # Reinit a stored_objects for this inner loop
796
+ stored_params_ = _init_stored_params(tracked_params, init_params, n_iter_local)
797
+ stored_objects_ = StoredObjectContainer(stored_params=stored_params_)
798
+ return loss_, loss_container_, stored_objects_
799
+
800
+
801
+ def _get_eq_param_masks_and_derivative_keys(eq_optimizers, init_params, loss):
802
+ nb_eq_params = len(
803
+ jax.tree.leaves(
804
+ eq_optimizers, is_leaf=lambda x: isinstance(x, optax.GradientTransformation)
805
+ )
806
+ )
807
+ # masks_ is a sort of one hot encoding for each eq_param
808
+ masks_ = tuple(jnp.eye(nb_eq_params)[i] for i in range(nb_eq_params))
809
+ # eq_params_masks_ is a EqParams with each leaf getting its one hot
810
+ # encoding of the eq_param it represents
811
+ eq_params_masks_ = jax.tree.unflatten(
812
+ jax.tree.structure(
813
+ eq_optimizers, is_leaf=lambda x: isinstance(x, optax.GradientTransformation)
814
+ ),
815
+ masks_,
816
+ )
817
+ # if you forget about the broadcast below
818
+ # eq_params_masks is a EqParams where each leaf is a Params
819
+ # where we have a 1 where the subleaf of Params is the same as the upper
820
+ # leaf of the EqParams
821
+ # now add the broadcast: it is needed because eg ll=[0, 0, 1] has just been
822
+ # unflattened into 3 eq_params (from eq_optimizers structure). The problem
823
+ # is that here, a float (0 or 0 or 1) has been assigned, all with struct
824
+ # (). This is problematic since it will not match struct of
825
+ # Params.eq_params that are tuple for eg. Then if
826
+ # Params.eq_params=(alpha=(0., 0.), beta=(1.,), gamma=(4., 4.,
827
+ # jnp.array([4., 4.]))) then the result of the unflatten will be
828
+ # modified into the correct structures ie,
829
+ # (alpha=(0, 0), beta=(0,), gamma=(1, 1, 1)) instead of
830
+ # (alpha=0, beta=0, gamma=1)
831
+ # the tree.broadcast has been added to prevent a bug in the tree.map of
832
+ # `_set_derivatives` of jinns DerivativeKeys
833
+
834
+ eq_params_masks = jax.tree.map(
835
+ lambda l, ll, p: Params(
836
+ nn_params=False,
837
+ eq_params=jax.tree.broadcast(
838
+ jax.tree.unflatten(
839
+ jax.tree.structure(
840
+ eq_optimizers,
841
+ is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
842
+ ),
843
+ ll,
844
+ ),
845
+ init_params.eq_params,
846
+ ),
847
+ ),
848
+ eq_optimizers,
849
+ eq_params_masks_,
850
+ init_params.eq_params,
851
+ is_leaf=lambda x: isinstance(x, optax.GradientTransformation),
852
+ )
853
+
854
+ def replace_float(leaf):
855
+ if isinstance(leaf, bool):
856
+ return leaf
857
+ elif leaf == 1:
858
+ return True
859
+ elif leaf == 0:
860
+ return False
861
+ else:
862
+ raise ValueError
863
+
864
+ # Note that we need to replace with plain bool:
865
+ # 1. filter_spec does not even accept onp.array
866
+ # 2. filter_spec does not accept non static arguments. So any jnp array is
867
+ # non hashable and we will not be able to make it static
868
+ # params_mask cannot be inside the carry of course, just like the
869
+ # optimizer
870
+ eq_params_masks = jax.tree.map(lambda l: replace_float(l), eq_params_masks)
871
+
872
+ # derivative keys with only eq_params updates for the gradient steps over eq_params
873
+ # Here we make a dict for simplicity
874
+ # A key=a eq_param=the content to form the jinns DerivativeKeys for each eq_param
875
+ # There is then True for where needed
876
+ eq_gd_steps_derivative_keys = {
877
+ f.name: jax.tree.map(
878
+ lambda l: getattr(eq_params_masks, f.name),
879
+ loss.derivative_keys,
880
+ is_leaf=lambda x: isinstance(x, Params),
881
+ )
882
+ for f in fields(eq_params_masks)
883
+ }
884
+
885
+ return eq_params_masks, eq_gd_steps_derivative_keys