jinns 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/loss/_LossODE.py +39 -12
- jinns/loss/_LossPDE.py +67 -31
- jinns/loss/_abstract_loss.py +33 -8
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_solve.py +98 -366
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +503 -0
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/METADATA +16 -14
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/RECORD +18 -17
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/WHEEL +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
|
@@ -7,8 +7,9 @@ from jinns import parameters as parameters
|
|
|
7
7
|
from jinns import plot as plot
|
|
8
8
|
from jinns import nn as nn
|
|
9
9
|
from jinns.solver._solve import solve
|
|
10
|
+
from jinns.solver._solve_alternate import solve_alternate
|
|
10
11
|
|
|
11
|
-
__all__ = ["nn", "solve"]
|
|
12
|
+
__all__ = ["nn", "solve", "solve_alternate"]
|
|
12
13
|
|
|
13
14
|
import warnings
|
|
14
15
|
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -47,7 +47,11 @@ if TYPE_CHECKING:
|
|
|
47
47
|
)
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
class LossODE(
|
|
50
|
+
class LossODE(
|
|
51
|
+
AbstractLoss[
|
|
52
|
+
LossWeightsODE, ODEBatch, ODEComponents[Array | None], DerivativeKeysODE
|
|
53
|
+
]
|
|
54
|
+
):
|
|
51
55
|
r"""Loss object for an ordinary differential equation
|
|
52
56
|
|
|
53
57
|
$$
|
|
@@ -57,7 +61,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
57
61
|
where $\mathcal{N}[\cdot]$ is a differential operator and the
|
|
58
62
|
initial condition is $u(t_0)=u_0$.
|
|
59
63
|
|
|
60
|
-
|
|
61
64
|
Parameters
|
|
62
65
|
----------
|
|
63
66
|
u : eqx.Module
|
|
@@ -107,7 +110,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
107
110
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
108
111
|
u: AbstractPINN
|
|
109
112
|
dynamic_loss: ODE | None
|
|
110
|
-
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
111
113
|
derivative_keys: DerivativeKeysODE
|
|
112
114
|
loss_weights: LossWeightsODE
|
|
113
115
|
initial_condition: InitialCondition | None
|
|
@@ -131,10 +133,6 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
131
133
|
else:
|
|
132
134
|
self.loss_weights = loss_weights
|
|
133
135
|
|
|
134
|
-
super().__init__(loss_weights=self.loss_weights, **kwargs)
|
|
135
|
-
self.u = u
|
|
136
|
-
self.dynamic_loss = dynamic_loss
|
|
137
|
-
self.vmap_in_axes = (0,)
|
|
138
136
|
if derivative_keys is None:
|
|
139
137
|
# by default we only take gradient wrt nn_params
|
|
140
138
|
if params is None:
|
|
@@ -142,9 +140,27 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
142
140
|
"Problem at derivative_keys initialization "
|
|
143
141
|
f"received {derivative_keys=} and {params=}"
|
|
144
142
|
)
|
|
145
|
-
|
|
143
|
+
derivative_keys = DerivativeKeysODE(params=params)
|
|
146
144
|
else:
|
|
147
|
-
|
|
145
|
+
derivative_keys = derivative_keys
|
|
146
|
+
|
|
147
|
+
super().__init__(
|
|
148
|
+
loss_weights=self.loss_weights,
|
|
149
|
+
derivative_keys=derivative_keys,
|
|
150
|
+
vmap_in_axes=(0,),
|
|
151
|
+
**kwargs,
|
|
152
|
+
)
|
|
153
|
+
self.u = u
|
|
154
|
+
self.dynamic_loss = dynamic_loss
|
|
155
|
+
if self.update_weight_method is not None and jnp.any(
|
|
156
|
+
jnp.array(jax.tree.leaves(self.loss_weights)) == 0
|
|
157
|
+
):
|
|
158
|
+
warnings.warn(
|
|
159
|
+
"self.update_weight_method is activated while some loss "
|
|
160
|
+
"weights are zero. The update weight method will likely "
|
|
161
|
+
"update the zero weight to some non-zero value. Check that "
|
|
162
|
+
"this is the desired behaviour."
|
|
163
|
+
)
|
|
148
164
|
|
|
149
165
|
if initial_condition is None:
|
|
150
166
|
warnings.warn(
|
|
@@ -220,7 +236,11 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
220
236
|
self.obs_slice = obs_slice
|
|
221
237
|
|
|
222
238
|
def evaluate_by_terms(
|
|
223
|
-
self,
|
|
239
|
+
self,
|
|
240
|
+
opt_params: Params[Array],
|
|
241
|
+
batch: ODEBatch,
|
|
242
|
+
*,
|
|
243
|
+
non_opt_params: Params[Array] | None = None,
|
|
224
244
|
) -> tuple[
|
|
225
245
|
ODEComponents[Float[Array, " "] | None], ODEComponents[Float[Array, " "] | None]
|
|
226
246
|
]:
|
|
@@ -231,15 +251,22 @@ class LossODE(AbstractLoss[LossWeightsODE, ODEBatch, ODEComponents[Array | None]
|
|
|
231
251
|
|
|
232
252
|
Parameters
|
|
233
253
|
---------
|
|
234
|
-
|
|
235
|
-
Parameters at which the loss is evaluated
|
|
254
|
+
opt_params
|
|
255
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
236
256
|
batch
|
|
237
257
|
Composed of a batch of points in the
|
|
238
258
|
domain, a batch of points in the domain
|
|
239
259
|
border and an optional additional batch of parameters (eg. for
|
|
240
260
|
metamodeling) and an optional additional batch of observed
|
|
241
261
|
inputs/outputs/parameters
|
|
262
|
+
non_opt_params
|
|
263
|
+
Parameters, which are not optimized, at which the loss is evaluated
|
|
242
264
|
"""
|
|
265
|
+
if non_opt_params is not None:
|
|
266
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
267
|
+
else:
|
|
268
|
+
params = opt_params
|
|
269
|
+
|
|
243
270
|
temporal_batch = batch.temporal_batch
|
|
244
271
|
|
|
245
272
|
# Retrieve the optional eq_params_batch
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -68,11 +68,14 @@ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
|
|
|
68
68
|
C = TypeVar(
|
|
69
69
|
"C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
|
|
70
70
|
)
|
|
71
|
-
|
|
71
|
+
DKPDE = TypeVar("DKPDE", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
|
|
72
72
|
Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
class _LossPDEAbstract(
|
|
75
|
+
class _LossPDEAbstract(
|
|
76
|
+
AbstractLoss[L, B, C, DKPDE],
|
|
77
|
+
Generic[L, B, C, DKPDE, Y],
|
|
78
|
+
):
|
|
76
79
|
r"""
|
|
77
80
|
Parameters
|
|
78
81
|
----------
|
|
@@ -158,9 +161,24 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
158
161
|
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
|
|
159
162
|
obs_slice: EllipsisType | slice | None = None,
|
|
160
163
|
key: PRNGKeyArray | None = None,
|
|
164
|
+
derivative_keys: DKPDE,
|
|
161
165
|
**kwargs: Any, # for arguments for super()
|
|
162
166
|
):
|
|
163
|
-
super().__init__(
|
|
167
|
+
super().__init__(
|
|
168
|
+
loss_weights=self.loss_weights,
|
|
169
|
+
derivative_keys=derivative_keys,
|
|
170
|
+
**kwargs,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if self.update_weight_method is not None and jnp.any(
|
|
174
|
+
jnp.array(jax.tree.leaves(self.loss_weights)) == 0
|
|
175
|
+
):
|
|
176
|
+
warnings.warn(
|
|
177
|
+
"self.update_weight_method is activated while some loss "
|
|
178
|
+
"weights are zero. The update weight method will likely "
|
|
179
|
+
"update the zero weight to some non-zero value. Check that "
|
|
180
|
+
"this is the desired behaviour."
|
|
181
|
+
)
|
|
164
182
|
|
|
165
183
|
if obs_slice is None:
|
|
166
184
|
self.obs_slice = jnp.s_[...]
|
|
@@ -497,7 +515,6 @@ class LossPDEStatio(
|
|
|
497
515
|
dynamic_loss: PDEStatio | None
|
|
498
516
|
loss_weights: LossWeightsPDEStatio
|
|
499
517
|
derivative_keys: DerivativeKeysPDEStatio
|
|
500
|
-
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
501
518
|
|
|
502
519
|
params: InitVar[Params[Array] | None]
|
|
503
520
|
|
|
@@ -516,25 +533,25 @@ class LossPDEStatio(
|
|
|
516
533
|
self.loss_weights = LossWeightsPDEStatio()
|
|
517
534
|
else:
|
|
518
535
|
self.loss_weights = loss_weights
|
|
519
|
-
self.dynamic_loss = dynamic_loss
|
|
520
|
-
|
|
521
|
-
super().__init__(
|
|
522
|
-
**kwargs,
|
|
523
|
-
)
|
|
524
536
|
|
|
525
537
|
if derivative_keys is None:
|
|
526
538
|
# be default we only take gradient wrt nn_params
|
|
527
539
|
try:
|
|
528
|
-
|
|
540
|
+
derivative_keys = DerivativeKeysPDEStatio(params=params)
|
|
529
541
|
except ValueError as exc:
|
|
530
542
|
raise ValueError(
|
|
531
543
|
"Problem at derivative_keys initialization "
|
|
532
544
|
f"received {derivative_keys=} and {params=}"
|
|
533
545
|
) from exc
|
|
534
546
|
else:
|
|
535
|
-
|
|
547
|
+
derivative_keys = derivative_keys
|
|
536
548
|
|
|
537
|
-
|
|
549
|
+
super().__init__(
|
|
550
|
+
derivative_keys=derivative_keys,
|
|
551
|
+
vmap_in_axes=(0,),
|
|
552
|
+
**kwargs,
|
|
553
|
+
)
|
|
554
|
+
self.dynamic_loss = dynamic_loss
|
|
538
555
|
|
|
539
556
|
def _get_dynamic_loss_batch(
|
|
540
557
|
self, batch: PDEStatioBatch
|
|
@@ -549,11 +566,12 @@ class LossPDEStatio(
|
|
|
549
566
|
# we could have used typing.cast though
|
|
550
567
|
|
|
551
568
|
def evaluate_by_terms(
|
|
552
|
-
self,
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
569
|
+
self,
|
|
570
|
+
opt_params: Params[Array],
|
|
571
|
+
batch: PDEStatioBatch,
|
|
572
|
+
*,
|
|
573
|
+
non_opt_params: Params[Array] | None = None,
|
|
574
|
+
) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
|
|
557
575
|
"""
|
|
558
576
|
Evaluate the loss function at a batch of points for given parameters.
|
|
559
577
|
|
|
@@ -561,15 +579,22 @@ class LossPDEStatio(
|
|
|
561
579
|
|
|
562
580
|
Parameters
|
|
563
581
|
---------
|
|
564
|
-
|
|
565
|
-
Parameters at which the loss is evaluated
|
|
582
|
+
opt_params
|
|
583
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
566
584
|
batch
|
|
567
585
|
Composed of a batch of points in the
|
|
568
586
|
domain, a batch of points in the domain
|
|
569
587
|
border and an optional additional batch of parameters (eg. for
|
|
570
588
|
metamodeling) and an optional additional batch of observed
|
|
571
589
|
inputs/outputs/parameters
|
|
590
|
+
non_opt_params
|
|
591
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
572
592
|
"""
|
|
593
|
+
if non_opt_params is not None:
|
|
594
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
595
|
+
else:
|
|
596
|
+
params = opt_params
|
|
597
|
+
|
|
573
598
|
# Retrieve the optional eq_params_batch
|
|
574
599
|
# and update eq_params with the latter
|
|
575
600
|
# and update vmap_in_axes
|
|
@@ -740,7 +765,6 @@ class LossPDENonStatio(
|
|
|
740
765
|
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
|
|
741
766
|
eqx.field(static=True)
|
|
742
767
|
)
|
|
743
|
-
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
744
768
|
max_norm_samples_omega: int = eqx.field(static=True)
|
|
745
769
|
max_norm_time_slices: int = eqx.field(static=True)
|
|
746
770
|
|
|
@@ -766,25 +790,26 @@ class LossPDENonStatio(
|
|
|
766
790
|
self.loss_weights = LossWeightsPDENonStatio()
|
|
767
791
|
else:
|
|
768
792
|
self.loss_weights = loss_weights
|
|
769
|
-
self.dynamic_loss = dynamic_loss
|
|
770
|
-
|
|
771
|
-
super().__init__(
|
|
772
|
-
**kwargs,
|
|
773
|
-
)
|
|
774
793
|
|
|
775
794
|
if derivative_keys is None:
|
|
776
795
|
# be default we only take gradient wrt nn_params
|
|
777
796
|
try:
|
|
778
|
-
|
|
797
|
+
derivative_keys = DerivativeKeysPDENonStatio(params=params)
|
|
779
798
|
except ValueError as exc:
|
|
780
799
|
raise ValueError(
|
|
781
800
|
"Problem at derivative_keys initialization "
|
|
782
801
|
f"received {derivative_keys=} and {params=}"
|
|
783
802
|
) from exc
|
|
784
803
|
else:
|
|
785
|
-
|
|
804
|
+
derivative_keys = derivative_keys
|
|
786
805
|
|
|
787
|
-
|
|
806
|
+
super().__init__(
|
|
807
|
+
derivative_keys=derivative_keys,
|
|
808
|
+
vmap_in_axes=(0,), # for t_x
|
|
809
|
+
**kwargs,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
self.dynamic_loss = dynamic_loss
|
|
788
813
|
|
|
789
814
|
if initial_condition_fun is None:
|
|
790
815
|
warnings.warn(
|
|
@@ -820,7 +845,11 @@ class LossPDENonStatio(
|
|
|
820
845
|
)
|
|
821
846
|
|
|
822
847
|
def evaluate_by_terms(
|
|
823
|
-
self,
|
|
848
|
+
self,
|
|
849
|
+
opt_params: Params[Array],
|
|
850
|
+
batch: PDENonStatioBatch,
|
|
851
|
+
*,
|
|
852
|
+
non_opt_params: Params[Array] | None = None,
|
|
824
853
|
) -> tuple[
|
|
825
854
|
PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
|
|
826
855
|
]:
|
|
@@ -831,15 +860,22 @@ class LossPDENonStatio(
|
|
|
831
860
|
|
|
832
861
|
Parameters
|
|
833
862
|
---------
|
|
834
|
-
|
|
835
|
-
Parameters at which the loss is evaluated
|
|
863
|
+
opt_params
|
|
864
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
836
865
|
batch
|
|
837
866
|
Composed of a batch of points in the
|
|
838
867
|
domain, a batch of points in the domain
|
|
839
868
|
border and an optional additional batch of parameters (eg. for
|
|
840
869
|
metamodeling) and an optional additional batch of observed
|
|
841
870
|
inputs/outputs/parameters
|
|
871
|
+
non_opt_params
|
|
872
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
842
873
|
"""
|
|
874
|
+
if non_opt_params is not None:
|
|
875
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
876
|
+
else:
|
|
877
|
+
params = opt_params
|
|
878
|
+
|
|
843
879
|
omega_initial_batch = batch.initial_batch
|
|
844
880
|
assert omega_initial_batch is not None
|
|
845
881
|
|
jinns/loss/_abstract_loss.py
CHANGED
|
@@ -9,7 +9,12 @@ import jax.numpy as jnp
|
|
|
9
9
|
import optax
|
|
10
10
|
from jinns.parameters._params import Params
|
|
11
11
|
from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
12
|
-
from jinns.utils._types import
|
|
12
|
+
from jinns.utils._types import (
|
|
13
|
+
AnyLossComponents,
|
|
14
|
+
AnyBatch,
|
|
15
|
+
AnyLossWeights,
|
|
16
|
+
AnyDerivativeKeys,
|
|
17
|
+
)
|
|
13
18
|
|
|
14
19
|
L = TypeVar(
|
|
15
20
|
"L", bound=AnyLossWeights
|
|
@@ -25,31 +30,47 @@ C = TypeVar(
|
|
|
25
30
|
"C", bound=AnyLossComponents[Array | None]
|
|
26
31
|
) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
|
|
27
32
|
|
|
33
|
+
DK = TypeVar("DK", bound=AnyDerivativeKeys)
|
|
34
|
+
|
|
28
35
|
# In the cases above, without the bound, we could not have covariance on
|
|
29
36
|
# the type because it would break LSP. Note that covariance on the return type
|
|
30
37
|
# is authorized in LSP hence we do not need the same TypeVar instruction for
|
|
31
38
|
# the return types of evaluate_by_terms for example!
|
|
32
39
|
|
|
33
40
|
|
|
34
|
-
class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
41
|
+
class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
35
42
|
"""
|
|
36
43
|
About the call:
|
|
37
44
|
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
38
45
|
"""
|
|
39
46
|
|
|
47
|
+
derivative_keys: eqx.AbstractVar[DK]
|
|
40
48
|
loss_weights: eqx.AbstractVar[L]
|
|
41
49
|
update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
|
|
42
50
|
eqx.field(kw_only=True, default=None, static=True)
|
|
43
51
|
)
|
|
52
|
+
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
44
53
|
|
|
45
54
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
46
55
|
return self.evaluate(*args, **kwargs)
|
|
47
56
|
|
|
48
57
|
@abc.abstractmethod
|
|
49
|
-
def evaluate_by_terms(
|
|
58
|
+
def evaluate_by_terms(
|
|
59
|
+
self,
|
|
60
|
+
opt_params: Params[Array],
|
|
61
|
+
batch: B,
|
|
62
|
+
*,
|
|
63
|
+
non_opt_params: Params[Array] | None = None,
|
|
64
|
+
) -> tuple[C, C]:
|
|
50
65
|
pass
|
|
51
66
|
|
|
52
|
-
def evaluate(
|
|
67
|
+
def evaluate(
|
|
68
|
+
self,
|
|
69
|
+
opt_params: Params[Array],
|
|
70
|
+
batch: B,
|
|
71
|
+
*,
|
|
72
|
+
non_opt_params: Params[Array] | None = None,
|
|
73
|
+
) -> tuple[Float[Array, " "], C]:
|
|
53
74
|
"""
|
|
54
75
|
Evaluate the loss function at a batch of points for given parameters.
|
|
55
76
|
|
|
@@ -57,16 +78,20 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
57
78
|
|
|
58
79
|
Parameters
|
|
59
80
|
---------
|
|
60
|
-
|
|
61
|
-
Parameters at which the loss is evaluated
|
|
81
|
+
opt_params
|
|
82
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
62
83
|
batch
|
|
63
84
|
Composed of a batch of points in the
|
|
64
85
|
domain, a batch of points in the domain
|
|
65
86
|
border and an optional additional batch of parameters (eg. for
|
|
66
87
|
metamodeling) and an optional additional batch of observed
|
|
67
88
|
inputs/outputs/parameters
|
|
89
|
+
non_opt_params
|
|
90
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
68
91
|
"""
|
|
69
|
-
loss_terms, _ = self.evaluate_by_terms(
|
|
92
|
+
loss_terms, _ = self.evaluate_by_terms(
|
|
93
|
+
opt_params, batch, non_opt_params=non_opt_params
|
|
94
|
+
)
|
|
70
95
|
|
|
71
96
|
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
72
97
|
|
|
@@ -105,7 +130,7 @@ class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
|
105
130
|
f" got {len(weights)} and {len(terms_list)}"
|
|
106
131
|
)
|
|
107
132
|
|
|
108
|
-
def ponderate_and_sum_gradient(self, terms: C) ->
|
|
133
|
+
def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
|
|
109
134
|
"""
|
|
110
135
|
Get total gradients from individual loss gradients and weights
|
|
111
136
|
for each parameter
|
|
@@ -47,9 +47,9 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
47
47
|
[`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
|
|
48
48
|
respect to the neural network parameters *and* the equation parameters, or only some of them.
|
|
49
49
|
|
|
50
|
-
To do so, user can either use strings or a `Params` object
|
|
51
|
-
with PyTree structure matching the parameters of the problem at
|
|
52
|
-
hand, and booleans indicating if gradient is to be taken or not. Internally,
|
|
50
|
+
To do so, user can either use strings or a `Params[bool]` object
|
|
51
|
+
with PyTree structure matching the parameters of the problem (`Params[Array]`) at
|
|
52
|
+
hand, and leaves being booleans indicating if gradient is to be taken or not. Internally,
|
|
53
53
|
a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
|
|
54
54
|
computing each loss term.
|
|
55
55
|
|
|
@@ -156,12 +156,12 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
156
156
|
"""
|
|
157
157
|
Construct the DerivativeKeysODE from strings. For each term of the
|
|
158
158
|
loss, specify whether to differentiate wrt the neural network
|
|
159
|
-
parameters, the equation parameters or both. The `Params` object, which
|
|
159
|
+
parameters, the equation parameters or both. The `Params[Array]` object, which
|
|
160
160
|
contains the actual array of parameters must be passed to
|
|
161
161
|
construct the fields with the appropriate PyTree structure.
|
|
162
162
|
|
|
163
163
|
!!! note
|
|
164
|
-
You can mix strings and `Params` if you need granularity.
|
|
164
|
+
You can mix strings and `Params[bool]` if you need granularity.
|
|
165
165
|
|
|
166
166
|
Parameters
|
|
167
167
|
----------
|
|
@@ -498,7 +498,14 @@ def _set_derivatives(
|
|
|
498
498
|
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
499
499
|
"beta":True | False})`.
|
|
500
500
|
"""
|
|
501
|
-
|
|
501
|
+
assert jax.tree.structure(params_.eq_params) == jax.tree.structure(
|
|
502
|
+
derivative_mask.eq_params
|
|
503
|
+
), (
|
|
504
|
+
"The derivative "
|
|
505
|
+
"mask for eq_params does not have the same tree structure as "
|
|
506
|
+
"Params.eq_params. This is often due to a wrong Params[bool] "
|
|
507
|
+
"passed when initializing the derivative key object."
|
|
508
|
+
)
|
|
502
509
|
return Params(
|
|
503
510
|
nn_params=jax.lax.cond(
|
|
504
511
|
derivative_mask.nn_params,
|
jinns/parameters/_params.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Formalize the data structure for the parameters
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
5
6
|
from dataclasses import fields
|
|
6
7
|
from typing import Generic, TypeVar
|
|
7
8
|
import equinox as eqx
|
|
@@ -60,6 +61,15 @@ class Params(eqx.Module, Generic[T]):
|
|
|
60
61
|
else:
|
|
61
62
|
self.eq_params = eq_params
|
|
62
63
|
|
|
64
|
+
def partition(self, mask: Params[bool] | None):
|
|
65
|
+
"""
|
|
66
|
+
following the boolean mask, partition into two Params
|
|
67
|
+
"""
|
|
68
|
+
if mask is not None:
|
|
69
|
+
return eqx.partition(self, mask)
|
|
70
|
+
else:
|
|
71
|
+
return self, None
|
|
72
|
+
|
|
63
73
|
|
|
64
74
|
def update_eq_params(
|
|
65
75
|
params: Params[Array],
|