jinns 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/__init__.py CHANGED
@@ -7,8 +7,9 @@ from jinns import parameters as parameters
7
7
  from jinns import plot as plot
8
8
  from jinns import nn as nn
9
9
  from jinns.solver._solve import solve
10
+ from jinns.solver._solve_alternate import solve_alternate
10
11
 
11
- __all__ = ["nn", "solve"]
12
+ __all__ = ["nn", "solve", "solve_alternate"]
12
13
 
13
14
  import warnings
14
15
 
jinns/loss/_LossODE.py CHANGED
@@ -47,7 +47,11 @@ if TYPE_CHECKING:
47
47
  )
48
48
 
49
49
 
50
- class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]]):
50
+ class LossODE(
51
+ AbstractLoss[
52
+ LossWeightsODE, ODEBatch, ODEComponents[Array | None], DerivativeKeysODE
53
+ ]
54
+ ):
51
55
  r"""Loss object for an ordinary differential equation
52
56
 
53
57
  $$
@@ -57,7 +61,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
57
61
  where $\mathcal{N}[\cdot]$ is a differential operator and the
58
62
  initial condition is $u(t_0)=u_0$.
59
63
 
60
-
61
64
  Parameters
62
65
  ----------
63
66
  u : eqx.Module
@@ -107,7 +110,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
107
110
  # (ie. jax.Array cannot be static) and that we do not expect to change
108
111
  u: AbstractPINN
109
112
  dynamic_loss: ODE | None
110
- vmap_in_axes: tuple[int] = eqx.field(static=True)
111
113
  derivative_keys: DerivativeKeysODE
112
114
  loss_weights: LossWeightsODE
113
115
  initial_condition: InitialCondition | None
@@ -131,10 +133,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
131
133
  else:
132
134
  self.loss_weights = loss_weights
133
135
 
134
- super().__init__(loss_weights=self.loss_weights, **kwargs)
135
- self.u = u
136
- self.dynamic_loss = dynamic_loss
137
- self.vmap_in_axes = (0,)
138
136
  if derivative_keys is None:
139
137
  # by default we only take gradient wrt nn_params
140
138
  if params is None:
@@ -142,9 +140,27 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
142
140
  "Problem at derivative_keys initialization "
143
141
  f"received {derivative_keys=} and {params=}"
144
142
  )
145
- self.derivative_keys = DerivativeKeysODE(params=params)
143
+ derivative_keys = DerivativeKeysODE(params=params)
146
144
  else:
147
- self.derivative_keys = derivative_keys
145
+ derivative_keys = derivative_keys
146
+
147
+ super().__init__(
148
+ loss_weights=self.loss_weights,
149
+ derivative_keys=derivative_keys,
150
+ vmap_in_axes=(0,),
151
+ **kwargs,
152
+ )
153
+ self.u = u
154
+ self.dynamic_loss = dynamic_loss
155
+ if self.update_weight_method is not None and jnp.any(
156
+ jnp.array(jax.tree.leaves(self.loss_weights)) == 0
157
+ ):
158
+ warnings.warn(
159
+ "self.update_weight_method is activated while some loss "
160
+ "weights are zero. The update weight method will likely "
161
+ "update the zero weight to some non-zero value. Check that "
162
+ "this is the desired behaviour."
163
+ )
148
164
 
149
165
  if initial_condition is None:
150
166
  warnings.warn(
@@ -220,7 +236,11 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
220
236
  self.obs_slice = obs_slice
221
237
 
222
238
  def evaluate_by_terms(
223
- self, params: Params[Array], batch: ODEBatch
239
+ self,
240
+ opt_params: Params[Array],
241
+ batch: ODEBatch,
242
+ *,
243
+ non_opt_params: Params[Array] | None = None,
224
244
  ) -> tuple[
225
245
  ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
226
246
  ]:
@@ -231,15 +251,22 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
231
251
 
232
252
  Parameters
233
253
  ---------
234
- params
235
- Parameters at which the loss is evaluated
254
+ opt_params
255
+ Parameters, which are optimized, at which the loss is evaluated
236
256
  batch
237
257
  Composed of a batch of points in the
238
258
  domain, a batch of points in the domain
239
259
  border and an optional additional batch of parameters (eg. for
240
260
  metamodeling) and an optional additional batch of observed
241
261
  inputs/outputs/parameters
262
+ non_opt_params
263
+ Parameters, which are not optimized, at which the loss is evaluated
242
264
  """
265
+ if non_opt_params is not None:
266
+ params = eqx.combine(opt_params, non_opt_params)
267
+ else:
268
+ params = opt_params
269
+
243
270
  temporal_batch = batch.temporal_batch
244
271
 
245
272
  # Retrieve the optional eq_params_batch
jinns/loss/_LossPDE.py CHANGED
@@ -68,11 +68,14 @@ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
68
68
  C = TypeVar(
69
69
  "C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
70
70
  )
71
- D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
71
+ DKPDE = TypeVar("DKPDE", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
72
72
  Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
73
73
 
74
74
 
75
- class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
75
+ class _LossPDEAbstract(
76
+ AbstractLoss[L, B, C, DKPDE],
77
+ Generic[L, B, C, DKPDE, Y],
78
+ ):
76
79
  r"""
77
80
  Parameters
78
81
  ----------
@@ -158,9 +161,24 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
158
161
  norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
159
162
  obs_slice: EllipsisType | slice | None = None,
160
163
  key: PRNGKeyArray | None = None,
164
+ derivative_keys: DKPDE,
161
165
  **kwargs: Any, # for arguments for super()
162
166
  ):
163
- super().__init__(loss_weights=self.loss_weights, **kwargs)
167
+ super().__init__(
168
+ loss_weights=self.loss_weights,
169
+ derivative_keys=derivative_keys,
170
+ **kwargs,
171
+ )
172
+
173
+ if self.update_weight_method is not None and jnp.any(
174
+ jnp.array(jax.tree.leaves(self.loss_weights)) == 0
175
+ ):
176
+ warnings.warn(
177
+ "self.update_weight_method is activated while some loss "
178
+ "weights are zero. The update weight method will likely "
179
+ "update the zero weight to some non-zero value. Check that "
180
+ "this is the desired behaviour."
181
+ )
164
182
 
165
183
  if obs_slice is None:
166
184
  self.obs_slice = jnp.s_[...]
@@ -497,7 +515,6 @@ class LossPDEStatio(
497
515
  dynamic_loss: PDEStatio | None
498
516
  loss_weights: LossWeightsPDEStatio
499
517
  derivative_keys: DerivativeKeysPDEStatio
500
- vmap_in_axes: tuple[int] = eqx.field(static=True)
501
518
 
502
519
  params: InitVar[Params[Array] | None]
503
520
 
@@ -516,25 +533,25 @@ class LossPDEStatio(
516
533
  self.loss_weights = LossWeightsPDEStatio()
517
534
  else:
518
535
  self.loss_weights = loss_weights
519
- self.dynamic_loss = dynamic_loss
520
-
521
- super().__init__(
522
- **kwargs,
523
- )
524
536
 
525
537
  if derivative_keys is None:
526
538
  # be default we only take gradient wrt nn_params
527
539
  try:
528
- self.derivative_keys = DerivativeKeysPDEStatio(params=params)
540
+ derivative_keys = DerivativeKeysPDEStatio(params=params)
529
541
  except ValueError as exc:
530
542
  raise ValueError(
531
543
  "Problem at derivative_keys initialization "
532
544
  f"received {derivative_keys=} and {params=}"
533
545
  ) from exc
534
546
  else:
535
- self.derivative_keys = derivative_keys
547
+ derivative_keys = derivative_keys
536
548
 
537
- self.vmap_in_axes = (0,)
549
+ super().__init__(
550
+ derivative_keys=derivative_keys,
551
+ vmap_in_axes=(0,),
552
+ **kwargs,
553
+ )
554
+ self.dynamic_loss = dynamic_loss
538
555
 
539
556
  def _get_dynamic_loss_batch(
540
557
  self, batch: PDEStatioBatch
@@ -549,11 +566,12 @@ class LossPDEStatio(
549
566
  # we could have used typing.cast though
550
567
 
551
568
  def evaluate_by_terms(
552
- self, params: Params[Array], batch: PDEStatioBatch
553
- ) -> tuple[
554
- PDEStatioComponents[Float[Array, ""] | None],
555
- PDEStatioComponents[Float[Array, ""] | None],
556
- ]:
569
+ self,
570
+ opt_params: Params[Array],
571
+ batch: PDEStatioBatch,
572
+ *,
573
+ non_opt_params: Params[Array] | None = None,
574
+ ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
557
575
  """
558
576
  Evaluate the loss function at a batch of points for given parameters.
559
577
 
@@ -561,15 +579,22 @@ class LossPDEStatio(
561
579
 
562
580
  Parameters
563
581
  ---------
564
- params
565
- Parameters at which the loss is evaluated
582
+ opt_params
583
+ Parameters, which are optimized, at which the loss is evaluated
566
584
  batch
567
585
  Composed of a batch of points in the
568
586
  domain, a batch of points in the domain
569
587
  border and an optional additional batch of parameters (eg. for
570
588
  metamodeling) and an optional additional batch of observed
571
589
  inputs/outputs/parameters
590
+ non_opt_params
591
+ Parameters, which are non optimized, at which the loss is evaluated
572
592
  """
593
+ if non_opt_params is not None:
594
+ params = eqx.combine(opt_params, non_opt_params)
595
+ else:
596
+ params = opt_params
597
+
573
598
  # Retrieve the optional eq_params_batch
574
599
  # and update eq_params with the latter
575
600
  # and update vmap_in_axes
@@ -740,7 +765,6 @@ class LossPDENonStatio(
740
765
  initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
741
766
  eqx.field(static=True)
742
767
  )
743
- vmap_in_axes: tuple[int] = eqx.field(static=True)
744
768
  max_norm_samples_omega: int = eqx.field(static=True)
745
769
  max_norm_time_slices: int = eqx.field(static=True)
746
770
 
@@ -766,25 +790,26 @@ class LossPDENonStatio(
766
790
  self.loss_weights = LossWeightsPDENonStatio()
767
791
  else:
768
792
  self.loss_weights = loss_weights
769
- self.dynamic_loss = dynamic_loss
770
-
771
- super().__init__(
772
- **kwargs,
773
- )
774
793
 
775
794
  if derivative_keys is None:
776
795
  # be default we only take gradient wrt nn_params
777
796
  try:
778
- self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
797
+ derivative_keys = DerivativeKeysPDENonStatio(params=params)
779
798
  except ValueError as exc:
780
799
  raise ValueError(
781
800
  "Problem at derivative_keys initialization "
782
801
  f"received {derivative_keys=} and {params=}"
783
802
  ) from exc
784
803
  else:
785
- self.derivative_keys = derivative_keys
804
+ derivative_keys = derivative_keys
786
805
 
787
- self.vmap_in_axes = (0,) # for t_x
806
+ super().__init__(
807
+ derivative_keys=derivative_keys,
808
+ vmap_in_axes=(0,), # for t_x
809
+ **kwargs,
810
+ )
811
+
812
+ self.dynamic_loss = dynamic_loss
788
813
 
789
814
  if initial_condition_fun is None:
790
815
  warnings.warn(
@@ -820,7 +845,11 @@ class LossPDENonStatio(
820
845
  )
821
846
 
822
847
  def evaluate_by_terms(
823
- self, params: Params[Array], batch: PDENonStatioBatch
848
+ self,
849
+ opt_params: Params[Array],
850
+ batch: PDENonStatioBatch,
851
+ *,
852
+ non_opt_params: Params[Array] | None = None,
824
853
  ) -> tuple[
825
854
  PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
826
855
  ]:
@@ -831,15 +860,22 @@ class LossPDENonStatio(
831
860
 
832
861
  Parameters
833
862
  ---------
834
- params
835
- Parameters at which the loss is evaluated
863
+ opt_params
864
+ Parameters, which are optimized, at which the loss is evaluated
836
865
  batch
837
866
  Composed of a batch of points in the
838
867
  domain, a batch of points in the domain
839
868
  border and an optional additional batch of parameters (eg. for
840
869
  metamodeling) and an optional additional batch of observed
841
870
  inputs/outputs/parameters
871
+ non_opt_params
872
+ Parameters, which are non optimized, at which the loss is evaluated
842
873
  """
874
+ if non_opt_params is not None:
875
+ params = eqx.combine(opt_params, non_opt_params)
876
+ else:
877
+ params = opt_params
878
+
843
879
  omega_initial_batch = batch.initial_batch
844
880
  assert omega_initial_batch is not None
845
881
 
@@ -9,7 +9,12 @@ import jax.numpy as jnp
9
9
  import optax
10
10
  from jinns.parameters._params import Params
11
11
  from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
12
- from jinns.utils._types import AnyLossComponents, AnyBatch, AnyLossWeights
12
+ from jinns.utils._types import (
13
+ AnyLossComponents,
14
+ AnyBatch,
15
+ AnyLossWeights,
16
+ AnyDerivativeKeys,
17
+ )
13
18
 
14
19
  L = TypeVar(
15
20
  "L", bound=AnyLossWeights
@@ -25,31 +30,47 @@ C = TypeVar(
25
30
  "C", bound=AnyLossComponents[Array | None]
26
31
  ) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
27
32
 
33
+ DK = TypeVar("DK", bound=AnyDerivativeKeys)
34
+
28
35
  # In the cases above, without the bound, we could not have covariance on
29
36
  # the type because it would break LSP. Note that covariance on the return type
30
37
  # is authorized in LSP hence we do not need the same TypeVar instruction for
31
38
  # the return types of evaluate_by_terms for example!
32
39
 
33
40
 
34
- class AbstractLoss(eqx.Module, Generic[L, B, C]):
41
+ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
35
42
  """
36
43
  About the call:
37
44
  https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
38
45
  """
39
46
 
47
+ derivative_keys: eqx.AbstractVar[DK]
40
48
  loss_weights: eqx.AbstractVar[L]
41
49
  update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
42
50
  eqx.field(kw_only=True, default=None, static=True)
43
51
  )
52
+ vmap_in_axes: tuple[int] = eqx.field(static=True)
44
53
 
45
54
  def __call__(self, *args: Any, **kwargs: Any) -> Any:
46
55
  return self.evaluate(*args, **kwargs)
47
56
 
48
57
  @abc.abstractmethod
49
- def evaluate_by_terms(self, params: Params[Array], batch: B) -> tuple[C, C]:
58
+ def evaluate_by_terms(
59
+ self,
60
+ opt_params: Params[Array],
61
+ batch: B,
62
+ *,
63
+ non_opt_params: Params[Array] | None = None,
64
+ ) -> tuple[C, C]:
50
65
  pass
51
66
 
52
- def evaluate(self, params: Params[Array], batch: B) -> tuple[Float[Array, " "], C]:
67
+ def evaluate(
68
+ self,
69
+ opt_params: Params[Array],
70
+ batch: B,
71
+ *,
72
+ non_opt_params: Params[Array] | None = None,
73
+ ) -> tuple[Float[Array, " "], C]:
53
74
  """
54
75
  Evaluate the loss function at a batch of points for given parameters.
55
76
 
@@ -57,16 +78,20 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
57
78
 
58
79
  Parameters
59
80
  ---------
60
- params
61
- Parameters at which the loss is evaluated
81
+ opt_params
82
+ Parameters, which are optimized, at which the loss is evaluated
62
83
  batch
63
84
  Composed of a batch of points in the
64
85
  domain, a batch of points in the domain
65
86
  border and an optional additional batch of parameters (eg. for
66
87
  metamodeling) and an optional additional batch of observed
67
88
  inputs/outputs/parameters
89
+ non_opt_params
90
+ Parameters, which are non optimized, at which the loss is evaluated
68
91
  """
69
- loss_terms, _ = self.evaluate_by_terms(params, batch)
92
+ loss_terms, _ = self.evaluate_by_terms(
93
+ opt_params, batch, non_opt_params=non_opt_params
94
+ )
70
95
 
71
96
  loss_val = self.ponderate_and_sum_loss(loss_terms)
72
97
 
@@ -105,7 +130,7 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
105
130
  f" got {len(weights)} and {len(terms_list)}"
106
131
  )
107
132
 
108
- def ponderate_and_sum_gradient(self, terms: C) -> C:
133
+ def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
109
134
  """
110
135
  Get total gradients from individual loss gradients and weights
111
136
  for each parameter
@@ -47,9 +47,9 @@ class DerivativeKeysODE(eqx.Module):
47
47
  [`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
48
48
  respect to the neural network parameters *and* the equation parameters, or only some of them.
49
49
 
50
- To do so, user can either use strings or a `Params` object
51
- with PyTree structure matching the parameters of the problem at
52
- hand, and booleans indicating if gradient is to be taken or not. Internally,
50
+ To do so, user can either use strings or a `Params[bool]` object
51
+ with PyTree structure matching the parameters of the problem (`Params[Array]`) at
52
+ hand, and leaves being booleans indicating if gradient is to be taken or not. Internally,
53
53
  a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
54
54
  computing each loss term.
55
55
 
@@ -156,12 +156,12 @@ class DerivativeKeysODE(eqx.Module):
156
156
  """
157
157
  Construct the DerivativeKeysODE from strings. For each term of the
158
158
  loss, specify whether to differentiate wrt the neural network
159
- parameters, the equation parameters or both. The `Params` object, which
159
+ parameters, the equation parameters or both. The `Params[Array]` object, which
160
160
  contains the actual array of parameters must be passed to
161
161
  construct the fields with the appropriate PyTree structure.
162
162
 
163
163
  !!! note
164
- You can mix strings and `Params` if you need granularity.
164
+ You can mix strings and `Params[bool]` if you need granularity.
165
165
 
166
166
  Parameters
167
167
  ----------
@@ -498,7 +498,14 @@ def _set_derivatives(
498
498
  `Params(nn_params=True | False, eq_params={"alpha":True | False,
499
499
  "beta":True | False})`.
500
500
  """
501
-
501
+ assert jax.tree.structure(params_.eq_params) == jax.tree.structure(
502
+ derivative_mask.eq_params
503
+ ), (
504
+ "The derivative "
505
+ "mask for eq_params does not have the same tree structure as "
506
+ "Params.eq_params. This is often due to a wrong Params[bool] "
507
+ "passed when initializing the derivative key object."
508
+ )
502
509
  return Params(
503
510
  nn_params=jax.lax.cond(
504
511
  derivative_mask.nn_params,
@@ -2,6 +2,7 @@
2
2
  Formalize the data structure for the parameters
3
3
  """
4
4
 
5
+ from __future__ import annotations
5
6
  from dataclasses import fields
6
7
  from typing import Generic, TypeVar
7
8
  import equinox as eqx
@@ -60,6 +61,15 @@ class Params(eqx.Module, Generic[T]):
60
61
  else:
61
62
  self.eq_params = eq_params
62
63
 
64
+ def partition(self, mask: Params[bool] | None):
65
+ """
66
+ following the boolean mask, partition into two Params
67
+ """
68
+ if mask is not None:
69
+ return eqx.partition(self, mask)
70
+ else:
71
+ return self, None
72
+
63
73
 
64
74
  def update_eq_params(
65
75
  params: Params[Array],