jinns 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/loss/_DynamicLossAbstract.py +2 -0
- jinns/loss/_LossODE.py +42 -23
- jinns/loss/_LossPDE.py +58 -48
- jinns/loss/_loss_utils.py +7 -2
- jinns/parameters/_derivative_keys.py +487 -60
- jinns-1.1.0.dist-info/AUTHORS +2 -0
- {jinns-1.0.0.dist-info → jinns-1.1.0.dist-info}/METADATA +2 -1
- {jinns-1.0.0.dist-info → jinns-1.1.0.dist-info}/RECORD +11 -10
- {jinns-1.0.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +0 -0
- {jinns-1.0.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -73,9 +73,11 @@ def _decorator_heteregeneous_params(evaluate, eq_type):
|
|
|
73
73
|
class DynamicLoss(eqx.Module):
|
|
74
74
|
r"""
|
|
75
75
|
Abstract base class for dynamic losses. Implements the physical term:
|
|
76
|
+
|
|
76
77
|
$$
|
|
77
78
|
\mathcal{N}[u](t, x) = 0
|
|
78
79
|
$$
|
|
80
|
+
|
|
79
81
|
for **one** point $t$, $x$ or $(t, x)$, depending on the context.
|
|
80
82
|
|
|
81
83
|
Parameters
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -56,6 +56,9 @@ class _LossODEAbstract(eqx.Module):
|
|
|
56
56
|
slice of u output(s) that is observed. This is useful for
|
|
57
57
|
multidimensional PINN, with partially observed outputs.
|
|
58
58
|
Default is None (whole output is observed).
|
|
59
|
+
params : InitVar[Params], default=None
|
|
60
|
+
The main Params object of the problem needed to instanciate the
|
|
61
|
+
DerivativeKeysODE if the latter is not specified.
|
|
59
62
|
"""
|
|
60
63
|
|
|
61
64
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
@@ -66,13 +69,21 @@ class _LossODEAbstract(eqx.Module):
|
|
|
66
69
|
initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
|
|
67
70
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
68
71
|
|
|
69
|
-
|
|
72
|
+
params: InitVar[Params] = eqx.field(default=None, kw_only=True)
|
|
73
|
+
|
|
74
|
+
def __post_init__(self, params=None):
|
|
70
75
|
if self.loss_weights is None:
|
|
71
76
|
self.loss_weights = LossWeightsODE()
|
|
72
77
|
|
|
73
78
|
if self.derivative_keys is None:
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
try:
|
|
80
|
+
# be default we only take gradient wrt nn_params
|
|
81
|
+
self.derivative_keys = DerivativeKeysODE(params=params)
|
|
82
|
+
except ValueError as exc:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Problem at self.derivative_keys initialization "
|
|
85
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
86
|
+
) from exc
|
|
76
87
|
if self.initial_condition is None:
|
|
77
88
|
warnings.warn(
|
|
78
89
|
"Initial condition wasn't provided. Be sure to cover for that"
|
|
@@ -131,6 +142,9 @@ class LossODE(_LossODEAbstract):
|
|
|
131
142
|
slice of u output(s) that is observed. This is useful for
|
|
132
143
|
multidimensional PINN, with partially observed outputs.
|
|
133
144
|
Default is None (whole output is observed).
|
|
145
|
+
params : InitVar[Params], default=None
|
|
146
|
+
The main Params object of the problem needed to instanciate the
|
|
147
|
+
DerivativeKeysODE if the latter is not specified.
|
|
134
148
|
u : eqx.Module
|
|
135
149
|
the PINN
|
|
136
150
|
dynamic_loss : DynamicLoss
|
|
@@ -152,8 +166,10 @@ class LossODE(_LossODEAbstract):
|
|
|
152
166
|
|
|
153
167
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
154
168
|
|
|
155
|
-
def __post_init__(self):
|
|
156
|
-
super().__post_init__(
|
|
169
|
+
def __post_init__(self, params=None):
|
|
170
|
+
super().__post_init__(
|
|
171
|
+
params=params
|
|
172
|
+
) # because __init__ or __post_init__ of Base
|
|
157
173
|
# class is not automatically called
|
|
158
174
|
|
|
159
175
|
self.vmap_in_axes = (0,)
|
|
@@ -300,6 +316,9 @@ class SystemLossODE(eqx.Module):
|
|
|
300
316
|
PINNs. Default is None. But if a value is given, all the entries of
|
|
301
317
|
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
302
318
|
if no particular slice is to be given.
|
|
319
|
+
params_dict : InitVar[ParamsDict], default=None
|
|
320
|
+
The main Params object of the problem needed to instanciate the
|
|
321
|
+
DerivativeKeysODE if the latter is not specified.
|
|
303
322
|
|
|
304
323
|
Raises
|
|
305
324
|
------
|
|
@@ -332,14 +351,16 @@ class SystemLossODE(eqx.Module):
|
|
|
332
351
|
loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
|
|
333
352
|
kw_only=True, default=None
|
|
334
353
|
)
|
|
354
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
355
|
+
|
|
335
356
|
u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
|
|
336
|
-
|
|
357
|
+
derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
|
|
337
358
|
|
|
338
359
|
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
339
360
|
# internally the loss weights are handled with a dictionary
|
|
340
361
|
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
341
362
|
|
|
342
|
-
def __post_init__(self, loss_weights):
|
|
363
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
343
364
|
# a dictionary that will be useful at different places
|
|
344
365
|
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
345
366
|
if self.initial_condition_dict is None:
|
|
@@ -369,14 +390,14 @@ class SystemLossODE(eqx.Module):
|
|
|
369
390
|
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
370
391
|
# derivative_keys_dict
|
|
371
392
|
|
|
372
|
-
#
|
|
373
|
-
#
|
|
374
|
-
#
|
|
375
|
-
|
|
376
|
-
# default values
|
|
377
|
-
for k in self.dynamic_loss_dict.keys():
|
|
393
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
394
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
395
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
396
|
+
for k in self.u_dict.keys():
|
|
378
397
|
if self.derivative_keys_dict[k] is None:
|
|
379
|
-
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
398
|
+
self.derivative_keys_dict[k] = DerivativeKeysODE(
|
|
399
|
+
params=params_dict.extract_params(k)
|
|
400
|
+
)
|
|
380
401
|
|
|
381
402
|
self._loss_weights = self.set_loss_weights(loss_weights)
|
|
382
403
|
|
|
@@ -397,12 +418,11 @@ class SystemLossODE(eqx.Module):
|
|
|
397
418
|
obs_slice=self.obs_slice_dict[i],
|
|
398
419
|
)
|
|
399
420
|
|
|
400
|
-
#
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
}
|
|
421
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
422
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
423
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
424
|
+
# happen inside it)
|
|
425
|
+
self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
|
|
406
426
|
|
|
407
427
|
def set_loss_weights(self, loss_weights_init):
|
|
408
428
|
"""
|
|
@@ -497,13 +517,13 @@ class SystemLossODE(eqx.Module):
|
|
|
497
517
|
batch.param_batch_dict, params_dict
|
|
498
518
|
)
|
|
499
519
|
|
|
500
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
520
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
501
521
|
"""This function is used in tree_map"""
|
|
502
522
|
return dynamic_loss_apply(
|
|
503
523
|
dyn_loss.evaluate,
|
|
504
524
|
self.u_dict,
|
|
505
525
|
(temporal_batch,),
|
|
506
|
-
_set_derivatives(params_dict,
|
|
526
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
507
527
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
508
528
|
loss_weight,
|
|
509
529
|
u_type=PINN,
|
|
@@ -512,7 +532,6 @@ class SystemLossODE(eqx.Module):
|
|
|
512
532
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
513
533
|
dyn_loss_for_one_key,
|
|
514
534
|
self.dynamic_loss_dict,
|
|
515
|
-
self.derivative_keys_dyn_loss_dict,
|
|
516
535
|
self._loss_weights["dyn_loss"],
|
|
517
536
|
is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
|
|
518
537
|
# where plain (unregister pytree) node classes, we could not traverse
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -103,6 +103,9 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
103
103
|
obs_slice : slice, default=None
|
|
104
104
|
slice object specifying the begininning/ending of the PINN output
|
|
105
105
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
106
|
+
params : InitVar[Params], default=None
|
|
107
|
+
The main Params object of the problem needed to instanciate the
|
|
108
|
+
DerivativeKeysODE if the latter is not specified.
|
|
106
109
|
"""
|
|
107
110
|
|
|
108
111
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
@@ -129,18 +132,26 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
129
132
|
norm_int_length: float | None = eqx.field(kw_only=True, default=None)
|
|
130
133
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
131
134
|
|
|
132
|
-
|
|
135
|
+
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
136
|
+
|
|
137
|
+
def __post_init__(self, params=None):
|
|
133
138
|
"""
|
|
134
139
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
135
140
|
Module with eqx.tree_at
|
|
136
141
|
"""
|
|
137
142
|
if self.derivative_keys is None:
|
|
138
143
|
# be default we only take gradient wrt nn_params
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
+
try:
|
|
145
|
+
self.derivative_keys = (
|
|
146
|
+
DerivativeKeysPDENonStatio(params=params)
|
|
147
|
+
if isinstance(self, LossPDENonStatio)
|
|
148
|
+
else DerivativeKeysPDEStatio(params=params)
|
|
149
|
+
)
|
|
150
|
+
except ValueError as exc:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
"Problem at self.derivative_keys initialization "
|
|
153
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
154
|
+
) from exc
|
|
144
155
|
|
|
145
156
|
if self.loss_weights is None:
|
|
146
157
|
self.loss_weights = (
|
|
@@ -324,6 +335,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
324
335
|
obs_slice : slice, default=None
|
|
325
336
|
slice object specifying the begininning/ending of the PINN output
|
|
326
337
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
338
|
+
params : InitVar[Params], default=None
|
|
339
|
+
The main Params object of the problem needed to instanciate the
|
|
340
|
+
DerivativeKeysODE if the latter is not specified.
|
|
327
341
|
|
|
328
342
|
|
|
329
343
|
Raises
|
|
@@ -342,12 +356,14 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
342
356
|
|
|
343
357
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
344
358
|
|
|
345
|
-
def __post_init__(self):
|
|
359
|
+
def __post_init__(self, params=None):
|
|
346
360
|
"""
|
|
347
361
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
348
362
|
Module with eqx.tree_at!
|
|
349
363
|
"""
|
|
350
|
-
super().__post_init__(
|
|
364
|
+
super().__post_init__(
|
|
365
|
+
params=params
|
|
366
|
+
) # because __init__ or __post_init__ of Base
|
|
351
367
|
# class is not automatically called
|
|
352
368
|
|
|
353
369
|
self.vmap_in_axes = (0,) # for x only here
|
|
@@ -547,6 +563,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
547
563
|
initial_condition_fun : Callable, default=None
|
|
548
564
|
A function representing the temporal initial condition. If None
|
|
549
565
|
(default) then no initial condition is applied
|
|
566
|
+
params : InitVar[Params], default=None
|
|
567
|
+
The main Params object of the problem needed to instanciate the
|
|
568
|
+
DerivativeKeysODE if the latter is not specified.
|
|
550
569
|
|
|
551
570
|
"""
|
|
552
571
|
|
|
@@ -556,12 +575,14 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
556
575
|
kw_only=True, default=None, static=True
|
|
557
576
|
)
|
|
558
577
|
|
|
559
|
-
def __post_init__(self):
|
|
578
|
+
def __post_init__(self, params=None):
|
|
560
579
|
"""
|
|
561
580
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
562
581
|
Module with eqx.tree_at!
|
|
563
582
|
"""
|
|
564
|
-
super().__post_init__(
|
|
583
|
+
super().__post_init__(
|
|
584
|
+
params=params
|
|
585
|
+
) # because __init__ or __post_init__ of Base
|
|
565
586
|
# class is not automatically called
|
|
566
587
|
|
|
567
588
|
self.vmap_in_axes = (0, 0) # for t and x
|
|
@@ -616,7 +637,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
616
637
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
617
638
|
inputs/outputs/parameters
|
|
618
639
|
"""
|
|
619
|
-
|
|
620
640
|
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
621
641
|
|
|
622
642
|
# Retrieve the optional eq_params_batch
|
|
@@ -725,6 +745,9 @@ class SystemLossPDE(eqx.Module):
|
|
|
725
745
|
PINNs. Default is None. But if a value is given, all the entries of
|
|
726
746
|
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
727
747
|
if no particular slice is to be given
|
|
748
|
+
params : InitVar[ParamsDict], default=None
|
|
749
|
+
The main Params object of the problem needed to instanciate the
|
|
750
|
+
DerivativeKeysODE if the latter is not specified.
|
|
728
751
|
|
|
729
752
|
"""
|
|
730
753
|
|
|
@@ -763,22 +786,20 @@ class SystemLossPDE(eqx.Module):
|
|
|
763
786
|
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
764
787
|
kw_only=True, default=None
|
|
765
788
|
)
|
|
789
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
766
790
|
|
|
767
791
|
# following have init=False and are set in the __post_init__
|
|
768
792
|
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
769
793
|
init=False
|
|
770
794
|
)
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
derivative_keys_dyn_loss_dict: Dict[
|
|
775
|
-
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
776
|
-
] = eqx.field(init=False)
|
|
795
|
+
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
796
|
+
eqx.field(init=False)
|
|
797
|
+
)
|
|
777
798
|
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
778
799
|
# internally the loss weights are handled with a dictionary
|
|
779
800
|
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
780
801
|
|
|
781
|
-
def __post_init__(self, loss_weights):
|
|
802
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
782
803
|
# a dictionary that will be useful at different places
|
|
783
804
|
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
784
805
|
# First, for all the optional dict,
|
|
@@ -818,24 +839,19 @@ class SystemLossPDE(eqx.Module):
|
|
|
818
839
|
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
819
840
|
# derivative_keys_dict
|
|
820
841
|
|
|
821
|
-
#
|
|
822
|
-
#
|
|
823
|
-
#
|
|
824
|
-
|
|
825
|
-
# default values
|
|
826
|
-
for k in self.dynamic_loss_dict.keys():
|
|
842
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
843
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
844
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
845
|
+
for k in self.u_dict.keys():
|
|
827
846
|
if self.derivative_keys_dict[k] is None:
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
837
|
-
else:
|
|
838
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
847
|
+
if self.u_dict[k].eq_type == "statio_PDE":
|
|
848
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
849
|
+
params=params_dict.extract_params(k)
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
853
|
+
params=params_dict.extract_params(k)
|
|
854
|
+
)
|
|
839
855
|
|
|
840
856
|
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
841
857
|
if (
|
|
@@ -904,16 +920,11 @@ class SystemLossPDE(eqx.Module):
|
|
|
904
920
|
f"got {self.u_dict[i].eq_type[i]}"
|
|
905
921
|
)
|
|
906
922
|
|
|
907
|
-
#
|
|
908
|
-
#
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
}
|
|
913
|
-
self.derivative_keys_u_dict = {
|
|
914
|
-
k: self.derivative_keys_dict[k]
|
|
915
|
-
for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
|
|
916
|
-
}
|
|
923
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
924
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
925
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
926
|
+
# happen inside it)
|
|
927
|
+
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
917
928
|
|
|
918
929
|
# also make sure we only have PINNs or SPINNs
|
|
919
930
|
if not (
|
|
@@ -1034,13 +1045,13 @@ class SystemLossPDE(eqx.Module):
|
|
|
1034
1045
|
batch.param_batch_dict, params_dict
|
|
1035
1046
|
)
|
|
1036
1047
|
|
|
1037
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
1048
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1038
1049
|
"""The function used in tree_map"""
|
|
1039
1050
|
return dynamic_loss_apply(
|
|
1040
1051
|
dyn_loss.evaluate,
|
|
1041
1052
|
self.u_dict,
|
|
1042
1053
|
batches,
|
|
1043
|
-
_set_derivatives(params_dict,
|
|
1054
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1044
1055
|
vmap_in_axes_x_or_x_t + vmap_in_axes_params,
|
|
1045
1056
|
loss_weight,
|
|
1046
1057
|
u_type=type(list(self.u_dict.values())[0]),
|
|
@@ -1049,7 +1060,6 @@ class SystemLossPDE(eqx.Module):
|
|
|
1049
1060
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1050
1061
|
dyn_loss_for_one_key,
|
|
1051
1062
|
self.dynamic_loss_dict,
|
|
1052
|
-
self.derivative_keys_dyn_loss_dict,
|
|
1053
1063
|
self._loss_weights["dyn_loss"],
|
|
1054
1064
|
is_leaf=lambda x: isinstance(
|
|
1055
1065
|
x, (PDEStatio, PDENonStatio)
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -297,12 +297,12 @@ def constraints_system_loss_apply(
|
|
|
297
297
|
if isinstance(params_dict.nn_params, dict):
|
|
298
298
|
|
|
299
299
|
def apply_u_constraint(
|
|
300
|
-
u_constraint, nn_params, loss_weights_for_u, obs_batch_u
|
|
300
|
+
u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
|
|
301
301
|
):
|
|
302
302
|
res_dict_for_u = u_constraint.evaluate(
|
|
303
303
|
Params(
|
|
304
304
|
nn_params=nn_params,
|
|
305
|
-
eq_params=
|
|
305
|
+
eq_params=eq_params,
|
|
306
306
|
),
|
|
307
307
|
append_obs_batch(batch, obs_batch_u),
|
|
308
308
|
)[1]
|
|
@@ -319,6 +319,11 @@ def constraints_system_loss_apply(
|
|
|
319
319
|
apply_u_constraint,
|
|
320
320
|
u_constraints_dict,
|
|
321
321
|
params_dict.nn_params,
|
|
322
|
+
(
|
|
323
|
+
params_dict.eq_params
|
|
324
|
+
if params_dict.eq_params.keys() == params_dict.nn_params.keys()
|
|
325
|
+
else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
|
|
326
|
+
), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
|
|
322
327
|
loss_weights_T,
|
|
323
328
|
batch.obs_batch_dict,
|
|
324
329
|
is_leaf=lambda x: (
|
|
@@ -2,50 +2,468 @@
|
|
|
2
2
|
Formalize the data structure for the derivative keys
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from functools import partial
|
|
6
|
+
from dataclasses import fields, InitVar
|
|
6
7
|
from typing import Literal
|
|
7
8
|
import jax
|
|
8
9
|
import equinox as eqx
|
|
9
10
|
|
|
10
|
-
from jinns.parameters._params import Params
|
|
11
|
+
from jinns.parameters._params import Params, ParamsDict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_masked_parameters(
|
|
15
|
+
derivative_mask_str: str, params: Params | ParamsDict
|
|
16
|
+
) -> Params | ParamsDict:
|
|
17
|
+
"""
|
|
18
|
+
Creates the Params object with True values where we want to differentiate
|
|
19
|
+
"""
|
|
20
|
+
if isinstance(params, Params):
|
|
21
|
+
# start with a params object with True everywhere. We will update to False
|
|
22
|
+
# for parameters wrt which we do want not to differentiate the loss
|
|
23
|
+
diff_params = jax.tree.map(
|
|
24
|
+
lambda x: True,
|
|
25
|
+
params,
|
|
26
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
27
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
28
|
+
# granularity could be imagined here, in the future
|
|
29
|
+
)
|
|
30
|
+
if derivative_mask_str == "both":
|
|
31
|
+
return diff_params
|
|
32
|
+
if derivative_mask_str == "eq_params":
|
|
33
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
34
|
+
if derivative_mask_str == "nn_params":
|
|
35
|
+
return eqx.tree_at(
|
|
36
|
+
lambda p: p.eq_params,
|
|
37
|
+
diff_params,
|
|
38
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
39
|
+
)
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"Bad value for DerivativeKeys. Got "
|
|
42
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
43
|
+
' "eq_params"'
|
|
44
|
+
)
|
|
45
|
+
elif isinstance(params, ParamsDict):
|
|
46
|
+
# do not travers nn_params, more
|
|
47
|
+
# granularity could be imagined here, in the future
|
|
48
|
+
diff_params = ParamsDict(
|
|
49
|
+
nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
|
|
50
|
+
)
|
|
51
|
+
if derivative_mask_str == "both":
|
|
52
|
+
return diff_params
|
|
53
|
+
if derivative_mask_str == "eq_params":
|
|
54
|
+
return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
|
|
55
|
+
if derivative_mask_str == "nn_params":
|
|
56
|
+
return eqx.tree_at(
|
|
57
|
+
lambda p: p.eq_params,
|
|
58
|
+
diff_params,
|
|
59
|
+
jax.tree.map(lambda x: False, params.eq_params),
|
|
60
|
+
)
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Bad value for DerivativeKeys. Got "
|
|
63
|
+
f'{derivative_mask_str}, expected "both", "nn_params" or '
|
|
64
|
+
' "eq_params"'
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Bad value for params. Got {type(params)}, expected Params "
|
|
70
|
+
" or ParamsDict"
|
|
71
|
+
)
|
|
11
72
|
|
|
12
73
|
|
|
13
74
|
class DerivativeKeysODE(eqx.Module):
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
75
|
+
"""
|
|
76
|
+
A class that specifies with repect to which parameter(s) each term of the
|
|
77
|
+
loss is differentiated. For example, you can specify that the
|
|
78
|
+
[`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
|
|
79
|
+
respect to the neural network parameters *and* the equation parameters, or only some of them.
|
|
80
|
+
|
|
81
|
+
To do so, user can either use strings or a `Params` object
|
|
82
|
+
with PyTree structure matching the parameters of the problem at
|
|
83
|
+
hand, and booleans indicating if gradient is to be taken or not. Internally,
|
|
84
|
+
a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
|
|
85
|
+
computing each loss term.
|
|
86
|
+
|
|
87
|
+
!!! note
|
|
88
|
+
|
|
89
|
+
1. For unspecified loss term, the default is to differentiate with
|
|
90
|
+
respect to `"nn_params"` only.
|
|
91
|
+
2. No granularity inside `Params.nn_params` is currently supported.
|
|
92
|
+
3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
|
|
93
|
+
|
|
94
|
+
A typical specification is of the form:
|
|
95
|
+
```python
|
|
96
|
+
Params(
|
|
97
|
+
nn_params=True | False,
|
|
98
|
+
eq_params={
|
|
99
|
+
"alpha":True | False,
|
|
100
|
+
"beta":True | False,
|
|
101
|
+
...
|
|
102
|
+
}
|
|
21
103
|
)
|
|
22
|
-
|
|
23
|
-
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
109
|
+
Tell wrt which node of `Params` we will differentiate the
|
|
110
|
+
dynamic loss. To do so, the fields of `Params` contain True (if
|
|
111
|
+
differentiation) or False (if no differentiation).
|
|
112
|
+
observations : Params | ParamsDict | None, default=None
|
|
113
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
114
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
115
|
+
differentiation) or False (if no differentiation).
|
|
116
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
117
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
118
|
+
initial condition loss. To do so, the fields of Params contain True (if
|
|
119
|
+
differentiation) or False (if no differentiation).
|
|
120
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
121
|
+
The main Params object of the problem. It is required
|
|
122
|
+
if some terms are unspecified (None). This is because, jinns cannot
|
|
123
|
+
infer the content of `Params.eq_params`.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
127
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
128
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
129
|
+
kw_only=True, default=None
|
|
24
130
|
)
|
|
25
131
|
|
|
132
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
133
|
+
|
|
134
|
+
def __post_init__(self, params=None):
|
|
135
|
+
if self.dyn_loss is None:
|
|
136
|
+
try:
|
|
137
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
138
|
+
except AttributeError:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"self.dyn_loss is None, hence params should be " "passed"
|
|
141
|
+
)
|
|
142
|
+
if self.observations is None:
|
|
143
|
+
try:
|
|
144
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
145
|
+
except AttributeError:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"self.observations is None, hence params should be " "passed"
|
|
148
|
+
)
|
|
149
|
+
if self.initial_condition is None:
|
|
150
|
+
try:
|
|
151
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
152
|
+
except AttributeError:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"self.initial_condition is None, hence params should be " "passed"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def from_str(
|
|
159
|
+
cls,
|
|
160
|
+
params: Params | ParamsDict,
|
|
161
|
+
dyn_loss: (
|
|
162
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
163
|
+
) = "nn_params",
|
|
164
|
+
observations: (
|
|
165
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
166
|
+
) = "nn_params",
|
|
167
|
+
initial_condition: (
|
|
168
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
169
|
+
) = "nn_params",
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Construct the DerivativeKeysODE from strings. For each term of the
|
|
173
|
+
loss, specify whether to differentiate wrt the neural network
|
|
174
|
+
parameters, the equation parameters or both. The `Params` object, which
|
|
175
|
+
contains the actual array of parameters must be passed to
|
|
176
|
+
construct the fields with the appropriate PyTree structure.
|
|
177
|
+
|
|
178
|
+
!!! note
|
|
179
|
+
You can mix strings and `Params` if you need granularity.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
params
|
|
184
|
+
The actual Params or ParamsDict object of the problem.
|
|
185
|
+
dyn_loss
|
|
186
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
187
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
188
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
189
|
+
observations
|
|
190
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
191
|
+
`"both"` we will differentiate the observations. Default is
|
|
192
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
193
|
+
initial_condition
|
|
194
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
195
|
+
`"both"` we will differentiate the initial condition. Default is
|
|
196
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
197
|
+
"""
|
|
198
|
+
return DerivativeKeysODE(
|
|
199
|
+
dyn_loss=(
|
|
200
|
+
_get_masked_parameters(dyn_loss, params)
|
|
201
|
+
if isinstance(dyn_loss, str)
|
|
202
|
+
else dyn_loss
|
|
203
|
+
),
|
|
204
|
+
observations=(
|
|
205
|
+
_get_masked_parameters(observations, params)
|
|
206
|
+
if isinstance(observations, str)
|
|
207
|
+
else observations
|
|
208
|
+
),
|
|
209
|
+
initial_condition=(
|
|
210
|
+
_get_masked_parameters(initial_condition, params)
|
|
211
|
+
if isinstance(initial_condition, str)
|
|
212
|
+
else initial_condition
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
26
216
|
|
|
27
217
|
class DerivativeKeysPDEStatio(eqx.Module):
|
|
218
|
+
"""
|
|
219
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
28
220
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
224
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
225
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
226
|
+
differentiation) or False (if no differentiation).
|
|
227
|
+
observations : Params | ParamsDict | None, default=None
|
|
228
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
229
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
230
|
+
differentiation) or False (if no differentiation).
|
|
231
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
232
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
233
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
234
|
+
differentiation) or False (if no differentiation).
|
|
235
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
236
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
237
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
238
|
+
differentiation) or False (if no differentiation).
|
|
239
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
240
|
+
The main Params object of the problem. It is required
|
|
241
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
242
|
+
content of `Params.eq_params`.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
246
|
+
observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
247
|
+
boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
248
|
+
norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
|
|
249
|
+
|
|
250
|
+
params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
251
|
+
|
|
252
|
+
def __post_init__(self, params=None):
|
|
253
|
+
if self.dyn_loss is None:
|
|
254
|
+
try:
|
|
255
|
+
self.dyn_loss = _get_masked_parameters("nn_params", params)
|
|
256
|
+
except AttributeError:
|
|
257
|
+
raise ValueError("self.dyn_loss is None, hence params should be passed")
|
|
258
|
+
if self.observations is None:
|
|
259
|
+
try:
|
|
260
|
+
self.observations = _get_masked_parameters("nn_params", params)
|
|
261
|
+
except AttributeError:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"self.observations is None, hence params should be passed"
|
|
264
|
+
)
|
|
265
|
+
if self.boundary_loss is None:
|
|
266
|
+
try:
|
|
267
|
+
self.boundary_loss = _get_masked_parameters("nn_params", params)
|
|
268
|
+
except AttributeError:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"self.boundary_loss is None, hence params should be passed"
|
|
271
|
+
)
|
|
272
|
+
if self.norm_loss is None:
|
|
273
|
+
try:
|
|
274
|
+
self.norm_loss = _get_masked_parameters("nn_params", params)
|
|
275
|
+
except AttributeError:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"self.norm_loss is None, hence params should be passed"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def from_str(
|
|
282
|
+
cls,
|
|
283
|
+
params: Params | ParamsDict,
|
|
284
|
+
dyn_loss: (
|
|
285
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
286
|
+
) = "nn_params",
|
|
287
|
+
observations: (
|
|
288
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
289
|
+
) = "nn_params",
|
|
290
|
+
boundary_loss: (
|
|
291
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
292
|
+
) = "nn_params",
|
|
293
|
+
norm_loss: (
|
|
294
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
295
|
+
) = "nn_params",
|
|
296
|
+
):
|
|
297
|
+
"""
|
|
298
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
params
|
|
303
|
+
The actual Param or ParamsDict object of the problem.
|
|
304
|
+
dyn_loss
|
|
305
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
306
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
307
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
308
|
+
observations
|
|
309
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
310
|
+
`"both"` we will differentiate the observations. Default is
|
|
311
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
312
|
+
boundary_loss
|
|
313
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
314
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
315
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
316
|
+
norm_loss
|
|
317
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
318
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
319
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
320
|
+
"""
|
|
321
|
+
return DerivativeKeysPDEStatio(
|
|
322
|
+
dyn_loss=(
|
|
323
|
+
_get_masked_parameters(dyn_loss, params)
|
|
324
|
+
if isinstance(dyn_loss, str)
|
|
325
|
+
else dyn_loss
|
|
326
|
+
),
|
|
327
|
+
observations=(
|
|
328
|
+
_get_masked_parameters(observations, params)
|
|
329
|
+
if isinstance(observations, str)
|
|
330
|
+
else observations
|
|
331
|
+
),
|
|
332
|
+
boundary_loss=(
|
|
333
|
+
_get_masked_parameters(boundary_loss, params)
|
|
334
|
+
if isinstance(boundary_loss, str)
|
|
335
|
+
else boundary_loss
|
|
336
|
+
),
|
|
337
|
+
norm_loss=(
|
|
338
|
+
_get_masked_parameters(norm_loss, params)
|
|
339
|
+
if isinstance(norm_loss, str)
|
|
340
|
+
else norm_loss
|
|
341
|
+
),
|
|
342
|
+
)
|
|
41
343
|
|
|
42
344
|
|
|
43
345
|
class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
|
|
346
|
+
"""
|
|
347
|
+
See [jinns.parameters.DerivativeKeysODE][].
|
|
44
348
|
|
|
45
|
-
|
|
46
|
-
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
dyn_loss : Params | ParamsDict | None, default=None
|
|
352
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
353
|
+
dynamic loss. To do so, the fields of Params contain True (if
|
|
354
|
+
differentiation) or False (if no differentiation).
|
|
355
|
+
observations : Params | ParamsDict | None, default=None
|
|
356
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
357
|
+
observation loss. To do so, the fields of Params contain True (if
|
|
358
|
+
differentiation) or False (if no differentiation).
|
|
359
|
+
boundary_loss : Params | ParamsDict | None, default=None
|
|
360
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
361
|
+
boundary loss. To do so, the fields of Params contain True (if
|
|
362
|
+
differentiation) or False (if no differentiation).
|
|
363
|
+
norm_loss : Params | ParamsDict | None, default=None
|
|
364
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
365
|
+
normalization loss. To do so, the fields of Params contain True (if
|
|
366
|
+
differentiation) or False (if no differentiation).
|
|
367
|
+
initial_condition : Params | ParamsDict | None, default=None
|
|
368
|
+
Tell wrt which parameters among Params we will differentiate the
|
|
369
|
+
initial_condition loss. To do so, the fields of Params contain True (if
|
|
370
|
+
differentiation) or False (if no differentiation).
|
|
371
|
+
params : InitVar[Params | ParamsDict], default=None
|
|
372
|
+
The main Params object of the problem. It is required
|
|
373
|
+
if some terms are unspecified (None). This is because, jinns cannot infer the
|
|
374
|
+
content of `Params.eq_params`.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
initial_condition: Params | ParamsDict | None = eqx.field(
|
|
378
|
+
kw_only=True, default=None
|
|
47
379
|
)
|
|
48
380
|
|
|
381
|
+
def __post_init__(self, params=None):
|
|
382
|
+
super().__post_init__(params=params)
|
|
383
|
+
if self.initial_condition is None:
|
|
384
|
+
try:
|
|
385
|
+
self.initial_condition = _get_masked_parameters("nn_params", params)
|
|
386
|
+
except AttributeError:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"self.initial_condition is None, hence params should be passed"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def from_str(
|
|
393
|
+
cls,
|
|
394
|
+
params: Params | ParamsDict,
|
|
395
|
+
dyn_loss: (
|
|
396
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
397
|
+
) = "nn_params",
|
|
398
|
+
observations: (
|
|
399
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
400
|
+
) = "nn_params",
|
|
401
|
+
boundary_loss: (
|
|
402
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
403
|
+
) = "nn_params",
|
|
404
|
+
norm_loss: (
|
|
405
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
406
|
+
) = "nn_params",
|
|
407
|
+
initial_condition: (
|
|
408
|
+
Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
|
|
409
|
+
) = "nn_params",
|
|
410
|
+
):
|
|
411
|
+
"""
|
|
412
|
+
See [jinns.parameters.DerivativeKeysODE.from_str][].
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
params
|
|
417
|
+
The actual Params | ParamsDict object of the problem.
|
|
418
|
+
dyn_loss
|
|
419
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
420
|
+
`"both"` we will differentiate the dynamic loss. Default is
|
|
421
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
422
|
+
observations
|
|
423
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
424
|
+
`"both"` we will differentiate the observations. Default is
|
|
425
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
426
|
+
boundary_loss
|
|
427
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
428
|
+
`"both"` we will differentiate the boundary loss. Default is
|
|
429
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
430
|
+
norm_loss
|
|
431
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
432
|
+
`"both"` we will differentiate the normalization loss. Default is
|
|
433
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
434
|
+
initial_condition
|
|
435
|
+
Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
|
|
436
|
+
`"both"` we will differentiate the initial_condition loss. Default is
|
|
437
|
+
`"nn_params"`. Specifying a Params or ParamsDict is also possible.
|
|
438
|
+
"""
|
|
439
|
+
return DerivativeKeysPDENonStatio(
|
|
440
|
+
dyn_loss=(
|
|
441
|
+
_get_masked_parameters(dyn_loss, params)
|
|
442
|
+
if isinstance(dyn_loss, str)
|
|
443
|
+
else dyn_loss
|
|
444
|
+
),
|
|
445
|
+
observations=(
|
|
446
|
+
_get_masked_parameters(observations, params)
|
|
447
|
+
if isinstance(observations, str)
|
|
448
|
+
else observations
|
|
449
|
+
),
|
|
450
|
+
boundary_loss=(
|
|
451
|
+
_get_masked_parameters(boundary_loss, params)
|
|
452
|
+
if isinstance(boundary_loss, str)
|
|
453
|
+
else boundary_loss
|
|
454
|
+
),
|
|
455
|
+
norm_loss=(
|
|
456
|
+
_get_masked_parameters(norm_loss, params)
|
|
457
|
+
if isinstance(norm_loss, str)
|
|
458
|
+
else norm_loss
|
|
459
|
+
),
|
|
460
|
+
initial_condition=(
|
|
461
|
+
_get_masked_parameters(initial_condition, params)
|
|
462
|
+
if isinstance(initial_condition, str)
|
|
463
|
+
else initial_condition
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
|
|
49
467
|
|
|
50
468
|
def _set_derivatives(params, derivative_keys):
|
|
51
469
|
"""
|
|
@@ -53,42 +471,51 @@ def _set_derivatives(params, derivative_keys):
|
|
|
53
471
|
has a copy of the params with appropriate derivatives set
|
|
54
472
|
"""
|
|
55
473
|
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
474
|
+
def _set_derivatives_ParamsDict(params_, derivative_mask):
|
|
475
|
+
"""
|
|
476
|
+
The next lines put a stop_gradient around the fields that do not
|
|
477
|
+
differentiate the loss term
|
|
478
|
+
**Note:** **No granularity inside `ParamsDict.nn_params` is currently
|
|
479
|
+
supported.**
|
|
480
|
+
This means a typical Params specification is of the form:
|
|
481
|
+
`ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
|
|
482
|
+
"beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
|
|
483
|
+
"""
|
|
484
|
+
# a ParamsDict object is reconstructed by hand since we do not want to
|
|
485
|
+
# traverse nn_params, for now...
|
|
486
|
+
return ParamsDict(
|
|
487
|
+
nn_params=jax.lax.cond(
|
|
488
|
+
derivative_mask.nn_params,
|
|
489
|
+
lambda p: p,
|
|
490
|
+
jax.lax.stop_gradient,
|
|
491
|
+
params_.nn_params,
|
|
492
|
+
),
|
|
493
|
+
eq_params=jax.tree.map(
|
|
494
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
495
|
+
params_.eq_params,
|
|
496
|
+
derivative_mask.eq_params,
|
|
67
497
|
),
|
|
68
|
-
params,
|
|
69
|
-
replace_fn=jax.lax.stop_gradient,
|
|
70
498
|
)
|
|
71
499
|
|
|
72
|
-
def
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
|
|
500
|
+
def _set_derivatives_(params_, derivative_mask):
|
|
501
|
+
"""
|
|
502
|
+
The next lines put a stop_gradient around the fields that do not
|
|
503
|
+
differentiate the loss term
|
|
504
|
+
**Note:** **No granularity inside `Params.nn_params` is currently
|
|
505
|
+
supported.**
|
|
506
|
+
This means a typical Params specification is of the form:
|
|
507
|
+
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
508
|
+
"beta":True | False})`.
|
|
509
|
+
"""
|
|
510
|
+
return jax.tree.map(
|
|
511
|
+
lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
|
|
512
|
+
params_,
|
|
513
|
+
derivative_mask,
|
|
514
|
+
is_leaf=lambda x: isinstance(x, eqx.Module)
|
|
515
|
+
and not isinstance(x, Params), # do not travers nn_params, more
|
|
516
|
+
# granularity could be imagined here, in the future
|
|
517
|
+
)
|
|
90
518
|
|
|
91
|
-
if
|
|
92
|
-
return
|
|
93
|
-
|
|
94
|
-
return _set_derivatives_dict(derivative_keys)
|
|
519
|
+
if isinstance(params, ParamsDict):
|
|
520
|
+
return _set_derivatives_ParamsDict(params, derivative_keys)
|
|
521
|
+
return _set_derivatives_(params, derivative_keys)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python
|
|
|
13
13
|
Requires-Python: >=3.10
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
+
License-File: AUTHORS
|
|
16
17
|
Requires-Dist: numpy
|
|
17
18
|
Requires-Dist: jax
|
|
18
19
|
Requires-Dist: jaxopt
|
|
@@ -5,16 +5,16 @@ jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
|
|
|
5
5
|
jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
|
|
6
6
|
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
7
7
|
jinns/loss/_DynamicLoss.py,sha256=WGbAuWnNfsbzUlWEiW_ARd4kI3jmHwdqPjxLC-wCA6s,25753
|
|
8
|
-
jinns/loss/_DynamicLossAbstract.py,sha256=
|
|
9
|
-
jinns/loss/_LossODE.py,sha256=
|
|
10
|
-
jinns/loss/_LossPDE.py,sha256=
|
|
8
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=Xyt28Oej_zlhcV3f6cw2vnAKyRJhXBiA63CsdL3PihU,13767
|
|
9
|
+
jinns/loss/_LossODE.py,sha256=ThWPse6Gn5crM3_tzwZCBx-usoD0xWu6y1n0GVl2dpI,23422
|
|
10
|
+
jinns/loss/_LossPDE.py,sha256=R9kNQiaFbFx2eCMdjB7ie7UJ9pJW7PmvHijioNgu-bs,49117
|
|
11
11
|
jinns/loss/__init__.py,sha256=Fm4QAHaVmp0CA7HSwb7KUctwdXnNZ9v5KmTqpeoYPaE,669
|
|
12
12
|
jinns/loss/_boundary_conditions.py,sha256=O0D8eWsFfvNNeO20PQ0rUKBI_MDqaBvqChfXaztZoL4,16679
|
|
13
|
-
jinns/loss/_loss_utils.py,sha256=
|
|
13
|
+
jinns/loss/_loss_utils.py,sha256=44J-VF6dxT_o5BcNWFOiLpY40c35YnAxxZkoNtdtcZc,13689
|
|
14
14
|
jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
|
|
15
15
|
jinns/loss/_operators.py,sha256=o-Ljp_9_HXB9Mhm-ANh6ouNw4_PsqLJAha7dFDGl_nQ,10781
|
|
16
16
|
jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
|
|
17
|
-
jinns/parameters/_derivative_keys.py,sha256=
|
|
17
|
+
jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
|
|
18
18
|
jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
|
|
19
19
|
jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
|
|
20
20
|
jinns/plot/_plot.py,sha256=ZGIJdGwEd3NlHRTq_2sOfEH_CtOkvPwdgCMct-nQlJE,11691
|
|
@@ -31,8 +31,9 @@ jinns/utils/_types.py,sha256=P_dS0odrHbyalYJ0FjS6q0tkXAGr-4GArsiyJYrB1ho,1878
|
|
|
31
31
|
jinns/utils/_utils.py,sha256=Ow8xB516E7yHDZatokVJHHFNPDu6fXr9-NmraUXjjyw,1819
|
|
32
32
|
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
33
33
|
jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
|
|
34
|
-
jinns-1.
|
|
35
|
-
jinns-1.
|
|
36
|
-
jinns-1.
|
|
37
|
-
jinns-1.
|
|
38
|
-
jinns-1.
|
|
34
|
+
jinns-1.1.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
35
|
+
jinns-1.1.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
36
|
+
jinns-1.1.0.dist-info/METADATA,sha256=3Qk885oguf6S_WPHd9KCFVIWP21nJqX9zFWoS9ZI-T0,2536
|
|
37
|
+
jinns-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
38
|
+
jinns-1.1.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
39
|
+
jinns-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|