jinns 1.6.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/solver/_solve.py CHANGED
@@ -8,17 +8,23 @@ from __future__ import (
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
10
  import time
11
- from typing import TYPE_CHECKING, Any, TypeAlias, Callable
12
- from functools import partial
11
+ from typing import TYPE_CHECKING, Any
13
12
  import optax
14
13
  import jax
15
- from jax import jit
16
14
  import jax.numpy as jnp
17
- from jaxtyping import Float, Array, PyTree, PRNGKeyArray
18
- import equinox as eqx
15
+ from jaxtyping import Float, Array, PRNGKeyArray
19
16
  from jinns.solver._rar import init_rar, trigger_rar
20
- from jinns.utils._utils import _check_nan_in_pytree
21
- from jinns.solver._utils import _check_batch_size
17
+ from jinns.solver._utils import (
18
+ _check_batch_size,
19
+ _init_stored_weights_terms,
20
+ _init_stored_params,
21
+ _get_break_fun,
22
+ _loss_evaluate_and_gradient_step,
23
+ _build_get_batch,
24
+ _store_loss_and_params,
25
+ _print_fn,
26
+ )
27
+ from jinns.parameters._params import Params
22
28
  from jinns.utils._containers import (
23
29
  DataGeneratorContainer,
24
30
  OptimizationContainer,
@@ -26,32 +32,18 @@ from jinns.utils._containers import (
26
32
  LossContainer,
27
33
  StoredObjectContainer,
28
34
  )
29
- from jinns.data._utils import append_param_batch, append_obs_batch
30
35
 
31
36
  if TYPE_CHECKING:
32
- from jinns.parameters._params import Params
33
- from jinns.utils._types import AnyBatch
37
+ from jinns.utils._types import AnyLossComponents, SolveCarry
34
38
  from jinns.loss._abstract_loss import AbstractLoss
35
39
  from jinns.validation._validation import AbstractValidationModule
36
40
  from jinns.data._DataGeneratorParameter import DataGeneratorParameter
37
41
  from jinns.data._DataGeneratorObservations import DataGeneratorObservations
38
42
  from jinns.data._AbstractDataGenerator import AbstractDataGenerator
39
43
 
40
- main_carry: TypeAlias = tuple[
41
- int,
42
- AbstractLoss,
43
- OptimizationContainer,
44
- OptimizationExtraContainer,
45
- DataGeneratorContainer,
46
- AbstractValidationModule | None,
47
- LossContainer,
48
- StoredObjectContainer,
49
- Float[Array, " n_iter"] | None,
50
- PRNGKeyArray | None,
51
- ]
52
-
53
44
 
54
45
  def solve(
46
+ *,
55
47
  n_iter: int,
56
48
  init_params: Params[Array],
57
49
  data: AbstractDataGenerator,
@@ -64,24 +56,27 @@ def solve(
64
56
  obs_data: DataGeneratorObservations | None = None,
65
57
  validation: AbstractValidationModule | None = None,
66
58
  obs_batch_sharding: jax.sharding.Sharding | None = None,
59
+ opt_state_field_for_acceleration: str | None = None,
67
60
  verbose: bool = True,
68
61
  ahead_of_time: bool = True,
69
62
  key: PRNGKeyArray | None = None,
70
63
  ) -> tuple[
71
64
  Params[Array],
72
65
  Float[Array, " n_iter"],
73
- PyTree,
66
+ AnyLossComponents[Float[Array, " n_iter"]],
74
67
  AbstractDataGenerator,
75
68
  AbstractLoss,
76
69
  optax.OptState,
77
70
  Params[Array | None],
78
- PyTree,
71
+ AnyLossComponents[Float[Array, " n_iter"]],
72
+ DataGeneratorObservations | None,
73
+ DataGeneratorParameter | None,
79
74
  Float[Array, " n_iter"] | None,
80
- Params[Array],
75
+ Params[Array] | None,
81
76
  ]:
82
77
  """
83
78
  Performs the optimization process via stochastic gradient descent
84
- algorithm. We minimize the function defined `loss.evaluate()` with
79
+ algorithm. We minimize the function defined in `loss.evaluate()` with
85
80
  respect to the learnable parameters of the problem whose initial values
86
81
  are given in `init_params`.
87
82
 
@@ -91,9 +86,9 @@ def solve(
91
86
  n_iter
92
87
  The maximum number of iterations in the optimization.
93
88
  init_params
94
- The initial jinns.parameters.Params object.
89
+ The initial `jinns.parameters.Params` object.
95
90
  data
96
- A DataGenerator object to retrieve batches of collocation points.
91
+ A `jinns.data.AbstractDataGenerator` object to retrieve batches of collocation points.
97
92
  loss
98
93
  The loss function to minimize.
99
94
  optimizer
@@ -102,22 +97,21 @@ def solve(
102
97
  Default 1000. The rate at which we print the loss value in the
103
98
  gradient step loop.
104
99
  opt_state
105
- Provide an optional initial state to the optimizer.
100
+ Default `None`. Provides an optional initial state to the optimizer.
106
101
  tracked_params
107
- Default None. An eqx.Module of type Params with non-None values for
102
+ Default `None`. A `jinns.parameters.Params` object with non-`None` values for
108
103
  parameters that needs to be tracked along the iterations.
109
- None values in tracked_params will not be traversed. Thus
110
- the user can provide something like `tracked_params = jinns.parameters.Params(
111
- nn_params=None, eq_params={"nu": True})` while init_params.nn_params
104
+ The user can provide something like `tracked_params = jinns.parameters.Params(
105
+ nn_params=None, eq_params={"nu": True})` while `init_params.nn_params`
112
106
  being a complex data structure.
113
107
  param_data
114
- Default None. A DataGeneratorParameter object which can be used to
108
+ Default `None`. A `jinns.data.DataGeneratorParameter` object which can be used to
115
109
  sample equation parameters.
116
110
  obs_data
117
- Default None. A DataGeneratorObservations
111
+ Default `None`. A `jinns.data.DataGeneratorObservations`
118
112
  object which can be used to sample minibatches of observations.
119
113
  validation
120
- Default None. Otherwise, a callable ``eqx.Module`` which implements a
114
+ Default `None`. Otherwise, a callable `eqx.Module` which implements a
121
115
  validation strategy. See documentation of `jinns.validation.
122
116
  _validation.AbstractValidationModule` for the general interface, and
123
117
  `jinns.validation._validation.ValidationLoss` for a practical
@@ -131,53 +125,70 @@ def solve(
131
125
  validation strategy of their choice, and to decide on the early
132
126
  stopping criterion.
133
127
  obs_batch_sharding
134
- Default None. An optional sharding object to constraint the obs_batch.
135
- Typically, a SingleDeviceSharding(gpu_device) when obs_data has been
136
- created with sharding_device=SingleDeviceSharding(cpu_device) to avoid
128
+ Default `None`. An optional sharding object to constraint the
129
+ `obs_batch`.
130
+ Typically, a `SingleDeviceSharding(gpu_device)` when `obs_data` has been
131
+ created with `sharding_device=SingleDeviceSharding(cpu_device)` to avoid
137
132
  loading on GPU huge datasets of observations.
133
+ opt_state_field_for_acceleration
134
+ A string. Default `None`, i.e. the optimizer without acceleration.
135
+ Because in some optimization scheme one can have what is called
136
+ acceleration where the loss is computed at some accelerated parameter
137
+ values, different from the actual parameter values. These accelerated
138
+ parameter can be stored in the optimizer state as a field. If this
139
+ field name is passed to `opt_state_field_for_acceleration` then the
140
+ gradient step will be done by evaluate gradients at parameter value
141
+ `opt_state.opt_state_field_for_acceleration`.
138
142
  verbose
139
- Default True. If False, no std output (loss or cause of
143
+ Default `True`. If `False`, no output (loss or cause of
140
144
  exiting the optimization loop) will be produced.
141
145
  ahead_of_time
142
- Default True. Separate the compilation of the main training loop from
146
+ Default `True`. Separate the compilation of the main training loop from
143
147
  the execution to get both timings. You might need to avoid this
144
148
  behaviour if you need to perform JAX transforms over chunks of code
145
149
  containing `jinns.solve()` since AOT-compiled functions cannot be JAX
146
150
  transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
147
- When False, jinns does not provide any timing information (which would
151
+ When `False`, jinns does not provide any timing information (which would
148
152
  be nonsense in a JIT transformed `solve()` function).
149
153
  key
150
- Default None. A JAX random key that can be used for diverse purpose in
154
+ Default `None`. A JAX random key that can be used for diverse purpose in
151
155
  the main iteration loop.
152
156
 
153
157
  Returns
154
158
  -------
155
159
  params
156
- The last non NaN value of the params at then end of the
157
- optimization process
160
+ The last non-NaN value of the params at then end of the
161
+ optimization process.
158
162
  total_loss_values
159
- An array of the total loss term along the gradient steps
163
+ An array of the total loss term along the gradient steps.
160
164
  stored_loss_terms
161
165
  A PyTree with attributes being arrays of all the values for each loss
162
- term
166
+ term.
163
167
  data
164
- The input data object
168
+ The data generator object passed as input, possibly modified.
165
169
  loss
166
- The input loss object
170
+ The loss object passed as input, possibly modified.
167
171
  opt_state
168
- The final optimized state
172
+ The final optimized state.
169
173
  stored_params
170
- A Params objects with the stored values of the desired parameters (as
171
- signified in tracked_params argument)
174
+ A object with the stored values of the desired parameters (as
175
+ signified in `tracked_params` argument).
172
176
  stored_weights_terms
173
- A PyTree with attributes being arrays of all the values for each loss
174
- weight. Note that if Loss.update_weight_method is None, we return None,
177
+ A PyTree with leaves being arrays of all the values for each loss
178
+ weight. Note that if `Loss.update_weight_method is None`, we return
179
+ `None`,
175
180
  because loss weights are never updated and we can then save some
176
- computations
181
+ computations.
182
+ param_data
183
+ The `jinns.data.DataGeneratorParameter` object passed as input or
184
+ `None`.
185
+ obs_data
186
+ The `jinns.data.DataGeneratorObservations` object passed as input or
187
+ `None`.
177
188
  validation_crit_values
178
- An array containing the validation criterion values of the training
189
+ An array containing the validation criterion values of the training.
179
190
  best_val_params
180
- The best parameters according to the validation criterion
191
+ The best parameters according to the validation criterion.
181
192
  """
182
193
  initialization_time = time.time()
183
194
  if n_iter < 1:
@@ -224,24 +235,12 @@ def solve(
224
235
  train_loss_values = jnp.zeros((n_iter))
225
236
  # depending on obs_batch_sharding we will get the simple get_batch or the
226
237
  # get_batch with device_put, the latter is not jittable
227
- get_batch = _get_get_batch(obs_batch_sharding)
238
+ get_batch = _build_get_batch(obs_batch_sharding)
228
239
 
229
240
  # initialize parameter tracking
230
241
  if tracked_params is None:
231
242
  tracked_params = jax.tree.map(lambda p: None, init_params)
232
- stored_params = jax.tree_util.tree_map(
233
- lambda tracked_param, param: (
234
- jnp.zeros((n_iter,) + jnp.asarray(param).shape)
235
- if tracked_param is not None
236
- else None
237
- ),
238
- tracked_params,
239
- init_params,
240
- is_leaf=lambda x: x is None, # None values in tracked_params will not
241
- # be traversed. Thus the user can provide something like `tracked_params = jinns.parameters.Params(
242
- # nn_params=None, eq_params={"nu": True})` while init_params.nn_params
243
- # being a complex data structure
244
- )
243
+ stored_params = _init_stored_params(tracked_params, init_params, n_iter)
245
244
 
246
245
  # initialize the dict for stored parameter values
247
246
  # we need to get a loss_term to init stuff
@@ -257,23 +256,7 @@ def solve(
257
256
 
258
257
  # initialize the PyTree for stored loss weights values
259
258
  if loss.update_weight_method is not None:
260
- stored_weights_terms = eqx.tree_at(
261
- lambda pt: jax.tree.leaves(
262
- pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
263
- ),
264
- loss.loss_weights,
265
- tuple(
266
- jnp.zeros((n_iter))
267
- for n in range(
268
- len(
269
- jax.tree.leaves(
270
- loss.loss_weights,
271
- is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
272
- )
273
- )
274
- )
275
- ),
276
- )
259
+ stored_weights_terms = _init_stored_weights_terms(loss, n_iter)
277
260
  else:
278
261
  stored_weights_terms = None
279
262
  if loss.update_weight_method is not None and key is None:
@@ -326,7 +309,11 @@ def solve(
326
309
  key,
327
310
  )
328
311
 
329
- def _one_iteration(carry: main_carry) -> main_carry:
312
+ def _one_iteration(carry: SolveCarry) -> SolveCarry:
313
+ # Note that optimizer are not part of the carry since
314
+ # the former is not tractable and the latter (while it could be
315
+ # hashable) must be static because of the equinox `filter_spec` (https://github.com/patrick-kidger/equinox/issues/1036)
316
+
330
317
  (
331
318
  i,
332
319
  loss,
@@ -344,43 +331,24 @@ def solve(
344
331
  train_data.data, train_data.param_data, train_data.obs_data
345
332
  )
346
333
 
347
- # ---------------------------------------------------------------------
348
- # The following part is the equivalent of a
349
- # > train_loss_value, grads = jax.values_and_grad(total_loss.evaluate)(params, ...)
350
- # but it is decomposed on individual loss terms so that we can use it
351
- # if needed for updating loss weights.
352
- # Since the total loss is a weighted sum of individual loss terms, so
353
- # are its total gradients.
354
-
355
- # Compute individual losses and individual gradients
356
- loss_terms, grad_terms = loss.evaluate_by_terms(optimization.params, batch)
357
-
358
- if loss.update_weight_method is not None:
359
- key, subkey = jax.random.split(key) # type: ignore because key can
360
- # still be None currently
361
- # avoid computations of tree_at if no updates
362
- loss = loss.update_weights(
363
- i, loss_terms, loss_container.stored_loss_terms, grad_terms, subkey
334
+ if key is not None:
335
+ key, subkey = jax.random.split(key)
336
+ else:
337
+ subkey = None
338
+ (train_loss_value, params, last_non_nan_params, opt_state, loss, loss_terms) = (
339
+ _loss_evaluate_and_gradient_step(
340
+ i,
341
+ batch,
342
+ loss,
343
+ optimization.params,
344
+ optimization.last_non_nan_params,
345
+ optimization.opt_state,
346
+ optimizer,
347
+ loss_container,
348
+ subkey,
349
+ None,
350
+ opt_state_field_for_acceleration,
364
351
  )
365
-
366
- # total grad
367
- grads = loss.ponderate_and_sum_gradient(grad_terms)
368
-
369
- # total loss
370
- train_loss_value = loss.ponderate_and_sum_loss(loss_terms)
371
- # ---------------------------------------------------------------------
372
-
373
- # gradient step
374
- (
375
- params,
376
- opt_state,
377
- last_non_nan_params,
378
- ) = _gradient_step(
379
- grads,
380
- optimizer,
381
- optimization.params,
382
- optimization.opt_state,
383
- optimization.last_non_nan_params,
384
352
  )
385
353
 
386
354
  # Print train loss value during optimization
@@ -552,249 +520,13 @@ def solve(
552
520
  optimization.last_non_nan_params,
553
521
  loss_container.train_loss_values,
554
522
  loss_container.stored_loss_terms,
555
- train_data.data, # return the DataGenerator if needed (no in-place modif)
556
- loss, # return the Loss if needed (no-inplace modif)
523
+ train_data.data,
524
+ loss,
557
525
  optimization.opt_state,
558
526
  stored_objects.stored_params,
559
527
  loss_container.stored_weights_terms,
528
+ train_data.obs_data,
529
+ train_data.param_data,
560
530
  validation_crit_values if validation is not None else None,
561
531
  validation_parameters,
562
532
  )
563
-
564
-
565
- @partial(jit, static_argnames=["optimizer"])
566
- def _gradient_step(
567
- grads: Params[Array],
568
- optimizer: optax.GradientTransformation,
569
- params: Params[Array],
570
- opt_state: optax.OptState,
571
- last_non_nan_params: Params[Array],
572
- ) -> tuple[
573
- Params[Array],
574
- optax.OptState,
575
- Params[Array],
576
- ]:
577
- """
578
- optimizer cannot be jit-ted.
579
- """
580
-
581
- updates, opt_state = optimizer.update(
582
- grads, # type: ignore
583
- opt_state,
584
- params, # type: ignore
585
- ) # see optimizer.init for explaination for the ignore(s) here
586
- params = optax.apply_updates(params, updates) # type: ignore
587
-
588
- # check if any of the parameters is NaN
589
- last_non_nan_params = jax.lax.cond(
590
- _check_nan_in_pytree(params),
591
- lambda _: last_non_nan_params,
592
- lambda _: params,
593
- None,
594
- )
595
-
596
- return (
597
- params,
598
- opt_state,
599
- last_non_nan_params,
600
- )
601
-
602
-
603
- @partial(jit, static_argnames=["prefix"])
604
- def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
605
- # note that if the following is not jitted in the main lor loop, it is
606
- # super slow
607
- _ = jax.lax.cond(
608
- i % print_loss_every == 0,
609
- lambda _: jax.debug.print(
610
- prefix + "Iteration {i}: loss value = {loss_val}",
611
- i=i,
612
- loss_val=loss_val,
613
- ),
614
- lambda _: None,
615
- (None,),
616
- )
617
-
618
-
619
- @jit
620
- def _store_loss_and_params(
621
- i: int,
622
- params: Params[Array],
623
- stored_params: Params[Array | None],
624
- loss_container: LossContainer,
625
- train_loss_val: float,
626
- loss_terms: PyTree[Array],
627
- weight_terms: PyTree[Array],
628
- tracked_params: Params,
629
- ) -> tuple[StoredObjectContainer, LossContainer]:
630
- stored_params = jax.tree_util.tree_map(
631
- lambda stored_value, param, tracked_param: (
632
- None
633
- if stored_value is None
634
- else jax.lax.cond(
635
- tracked_param,
636
- lambda ope: ope[0].at[i].set(ope[1]),
637
- lambda ope: ope[0],
638
- (stored_value, param),
639
- )
640
- ),
641
- stored_params,
642
- params,
643
- tracked_params,
644
- is_leaf=lambda x: x is None,
645
- )
646
- stored_loss_terms = jax.tree_util.tree_map(
647
- lambda stored_term, loss_term: stored_term.at[i].set(loss_term),
648
- loss_container.stored_loss_terms,
649
- loss_terms,
650
- )
651
-
652
- if loss_container.stored_weights_terms is not None:
653
- stored_weights_terms = jax.tree_util.tree_map(
654
- lambda stored_term, weight_term: stored_term.at[i].set(weight_term),
655
- jax.tree.leaves(
656
- loss_container.stored_weights_terms,
657
- is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
658
- ),
659
- jax.tree.leaves(
660
- weight_terms,
661
- is_leaf=lambda x: x is not None and eqx.is_inexact_array(x),
662
- ),
663
- )
664
- stored_weights_terms = eqx.tree_at(
665
- lambda pt: jax.tree.leaves(
666
- pt, is_leaf=lambda x: x is not None and eqx.is_inexact_array(x)
667
- ),
668
- loss_container.stored_weights_terms,
669
- stored_weights_terms,
670
- )
671
- else:
672
- stored_weights_terms = None
673
-
674
- train_loss_values = loss_container.train_loss_values.at[i].set(train_loss_val)
675
- loss_container = LossContainer(
676
- stored_loss_terms, stored_weights_terms, train_loss_values
677
- )
678
- stored_objects = StoredObjectContainer(stored_params)
679
- return stored_objects, loss_container
680
-
681
-
682
- def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
683
- """
684
- Wrapper to get the break_fun with appropriate `n_iter`.
685
- The verbose argument is here to control printing (or not) when exiting
686
- the optimisation loop. It can be convenient is jinns.solve is itself
687
- called in a loop and user want to avoid std output.
688
- """
689
-
690
- @jit
691
- def break_fun(carry: tuple):
692
- """
693
- Function to break from the main optimization loop whe the following
694
- conditions are met : maximum number of iterations, NaN
695
- appearing in the parameters, and early stopping criterion.
696
- """
697
-
698
- def stop_while_loop(msg):
699
- """
700
- Note that the message is wrapped in the jax.lax.cond because a
701
- string is not a valid JAX type that can be fed into the operands
702
- """
703
- if verbose:
704
- jax.debug.print(f"\nStopping main optimization loop, cause: {msg}")
705
- return False
706
-
707
- def continue_while_loop(_):
708
- return True
709
-
710
- (i, _, optimization, optimization_extra, _, _, _, _, _, _) = carry
711
-
712
- # Condition 1
713
- bool_max_iter = jax.lax.cond(
714
- i >= n_iter,
715
- lambda _: stop_while_loop("max iteration is reached"),
716
- continue_while_loop,
717
- None,
718
- )
719
- # Condition 2
720
- bool_nan_in_params = jax.lax.cond(
721
- _check_nan_in_pytree(optimization.params),
722
- lambda _: stop_while_loop(
723
- "NaN values in parameters (returning last non NaN values)"
724
- ),
725
- continue_while_loop,
726
- None,
727
- )
728
- # Condition 3
729
- bool_early_stopping = jax.lax.cond(
730
- optimization_extra.early_stopping,
731
- lambda _: stop_while_loop("early stopping"),
732
- continue_while_loop,
733
- _,
734
- )
735
-
736
- # stop when one of the cond to continue is False
737
- return jax.tree_util.tree_reduce(
738
- lambda x, y: jnp.logical_and(jnp.array(x), jnp.array(y)),
739
- (bool_max_iter, bool_nan_in_params, bool_early_stopping),
740
- )
741
-
742
- return break_fun
743
-
744
-
745
- def _get_get_batch(
746
- obs_batch_sharding: jax.sharding.Sharding | None,
747
- ) -> Callable[
748
- [
749
- AbstractDataGenerator,
750
- DataGeneratorParameter | None,
751
- DataGeneratorObservations | None,
752
- ],
753
- tuple[
754
- AnyBatch,
755
- AbstractDataGenerator,
756
- DataGeneratorParameter | None,
757
- DataGeneratorObservations | None,
758
- ],
759
- ]:
760
- """
761
- Return the get_batch function that will be used either the jittable one or
762
- the non-jittable one with sharding using jax.device.put()
763
- """
764
-
765
- def get_batch_sharding(data, param_data, obs_data):
766
- """
767
- This function is used at each loop but it cannot be jitted because of
768
- device_put
769
- """
770
- data, batch = data.get_batch()
771
- if param_data is not None:
772
- param_data, param_batch = param_data.get_batch()
773
- batch = append_param_batch(batch, param_batch)
774
- if obs_data is not None:
775
- # This is the part that motivated the transition from scan to for loop
776
- # Indeed we need to be transit obs_batch from CPU to GPU when we have
777
- # huge observations that cannot fit on GPU. Such transfer wasn't meant
778
- # to be jitted, i.e. in a scan loop
779
- obs_data, obs_batch = obs_data.get_batch()
780
- obs_batch = jax.device_put(obs_batch, obs_batch_sharding)
781
- batch = append_obs_batch(batch, obs_batch)
782
- return batch, data, param_data, obs_data
783
-
784
- @jit
785
- def get_batch(data, param_data, obs_data):
786
- """
787
- Original get_batch with no sharding
788
- """
789
- data, batch = data.get_batch()
790
- if param_data is not None:
791
- param_data, param_batch = param_data.get_batch()
792
- batch = append_param_batch(batch, param_batch)
793
- if obs_data is not None:
794
- obs_data, obs_batch = obs_data.get_batch()
795
- batch = append_obs_batch(batch, obs_batch)
796
- return batch, data, param_data, obs_data
797
-
798
- if obs_batch_sharding is not None:
799
- return get_batch_sharding
800
- return get_batch