jinns 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -73,9 +73,11 @@ def _decorator_heteregeneous_params(evaluate, eq_type):
73
73
  class DynamicLoss(eqx.Module):
74
74
  r"""
75
75
  Abstract base class for dynamic losses. Implements the physical term:
76
+
76
77
  $$
77
78
  \mathcal{N}[u](t, x) = 0
78
79
  $$
80
+
79
81
  for **one** point $t$, $x$ or $(t, x)$, depending on the context.
80
82
 
81
83
  Parameters
jinns/loss/_LossODE.py CHANGED
@@ -56,6 +56,9 @@ class _LossODEAbstract(eqx.Module):
56
56
  slice of u output(s) that is observed. This is useful for
57
57
  multidimensional PINN, with partially observed outputs.
58
58
  Default is None (whole output is observed).
59
+ params : InitVar[Params], default=None
60
+ The main Params object of the problem needed to instanciate the
61
+ DerivativeKeysODE if the latter is not specified.
59
62
  """
60
63
 
61
64
  # NOTE static=True only for leaf attributes that are not valid JAX types
@@ -66,13 +69,21 @@ class _LossODEAbstract(eqx.Module):
66
69
  initial_condition: tuple | None = eqx.field(kw_only=True, default=None)
67
70
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
68
71
 
69
- def __post_init__(self):
72
+ params: InitVar[Params] = eqx.field(default=None, kw_only=True)
73
+
74
+ def __post_init__(self, params=None):
70
75
  if self.loss_weights is None:
71
76
  self.loss_weights = LossWeightsODE()
72
77
 
73
78
  if self.derivative_keys is None:
74
- # be default we only take gradient wrt nn_params
75
- self.derivative_keys = DerivativeKeysODE()
79
+ try:
80
+ # be default we only take gradient wrt nn_params
81
+ self.derivative_keys = DerivativeKeysODE(params=params)
82
+ except ValueError as exc:
83
+ raise ValueError(
84
+ "Problem at self.derivative_keys initialization "
85
+ f"received {self.derivative_keys=} and {params=}"
86
+ ) from exc
76
87
  if self.initial_condition is None:
77
88
  warnings.warn(
78
89
  "Initial condition wasn't provided. Be sure to cover for that"
@@ -131,6 +142,9 @@ class LossODE(_LossODEAbstract):
131
142
  slice of u output(s) that is observed. This is useful for
132
143
  multidimensional PINN, with partially observed outputs.
133
144
  Default is None (whole output is observed).
145
+ params : InitVar[Params], default=None
146
+ The main Params object of the problem needed to instanciate the
147
+ DerivativeKeysODE if the latter is not specified.
134
148
  u : eqx.Module
135
149
  the PINN
136
150
  dynamic_loss : DynamicLoss
@@ -152,8 +166,10 @@ class LossODE(_LossODEAbstract):
152
166
 
153
167
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
154
168
 
155
- def __post_init__(self):
156
- super().__post_init__() # because __init__ or __post_init__ of Base
169
+ def __post_init__(self, params=None):
170
+ super().__post_init__(
171
+ params=params
172
+ ) # because __init__ or __post_init__ of Base
157
173
  # class is not automatically called
158
174
 
159
175
  self.vmap_in_axes = (0,)
@@ -300,6 +316,9 @@ class SystemLossODE(eqx.Module):
300
316
  PINNs. Default is None. But if a value is given, all the entries of
301
317
  `u_dict` must be represented here with default value `jnp.s_[...]`
302
318
  if no particular slice is to be given.
319
+ params_dict : InitVar[ParamsDict], default=None
320
+ The main Params object of the problem needed to instanciate the
321
+ DerivativeKeysODE if the latter is not specified.
303
322
 
304
323
  Raises
305
324
  ------
@@ -332,14 +351,16 @@ class SystemLossODE(eqx.Module):
332
351
  loss_weights: InitVar[LossWeightsODEDict | None] = eqx.field(
333
352
  kw_only=True, default=None
334
353
  )
354
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
355
+
335
356
  u_constraints_dict: Dict[str, LossODE] = eqx.field(init=False)
336
- derivative_keys_dyn_loss_dict: Dict[str, DerivativeKeysODE] = eqx.field(init=False)
357
+ derivative_keys_dyn_loss: DerivativeKeysODE = eqx.field(init=False)
337
358
 
338
359
  u_dict_with_none: Dict[str, None] = eqx.field(init=False)
339
360
  # internally the loss weights are handled with a dictionary
340
361
  _loss_weights: Dict[str, dict] = eqx.field(init=False)
341
362
 
342
- def __post_init__(self, loss_weights):
363
+ def __post_init__(self, loss_weights=None, params_dict=None):
343
364
  # a dictionary that will be useful at different places
344
365
  self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
345
366
  if self.initial_condition_dict is None:
@@ -369,14 +390,14 @@ class SystemLossODE(eqx.Module):
369
390
  # iterating on dynamic_loss_dict. So each time we will require dome
370
391
  # derivative_keys_dict
371
392
 
372
- # but then if the user did not provide anything, we must at least have
373
- # a default value for the dynamic_loss_dict keys entries in
374
- # self.derivative_keys_dict since the computation of dynamic losses is
375
- # made without create a lossODE object that would provide the
376
- # default values
377
- for k in self.dynamic_loss_dict.keys():
393
+ # derivative keys for the u_constraints. Note that we create missing
394
+ # DerivativeKeysODE around a Params object and not ParamsDict
395
+ # this works because u_dict.keys == params_dict.nn_params.keys()
396
+ for k in self.u_dict.keys():
378
397
  if self.derivative_keys_dict[k] is None:
379
- self.derivative_keys_dict[k] = DerivativeKeysODE()
398
+ self.derivative_keys_dict[k] = DerivativeKeysODE(
399
+ params=params_dict.extract_params(k)
400
+ )
380
401
 
381
402
  self._loss_weights = self.set_loss_weights(loss_weights)
382
403
 
@@ -397,12 +418,11 @@ class SystemLossODE(eqx.Module):
397
418
  obs_slice=self.obs_slice_dict[i],
398
419
  )
399
420
 
400
- # for convenience in the tree_map of evaluate
401
- self.derivative_keys_dyn_loss_dict = {
402
- k: self.derivative_keys_dict[k]
403
- for k in self.dynamic_loss_dict.keys() # & self.derivative_keys_dict.keys()
404
- # comment because intersection is neceserily fulfilled right?
405
- }
421
+ # derivative keys for the dynamic loss. Note that we create a
422
+ # DerivativeKeysODE around a ParamsDict object because a whole
423
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
424
+ # happen inside it)
425
+ self.derivative_keys_dyn_loss = DerivativeKeysODE(params=params_dict)
406
426
 
407
427
  def set_loss_weights(self, loss_weights_init):
408
428
  """
@@ -497,13 +517,13 @@ class SystemLossODE(eqx.Module):
497
517
  batch.param_batch_dict, params_dict
498
518
  )
499
519
 
500
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
520
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
501
521
  """This function is used in tree_map"""
502
522
  return dynamic_loss_apply(
503
523
  dyn_loss.evaluate,
504
524
  self.u_dict,
505
525
  (temporal_batch,),
506
- _set_derivatives(params_dict, derivative_key.dyn_loss),
526
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
507
527
  vmap_in_axes_t + vmap_in_axes_params,
508
528
  loss_weight,
509
529
  u_type=PINN,
@@ -512,7 +532,6 @@ class SystemLossODE(eqx.Module):
512
532
  dyn_loss_mse_dict = jax.tree_util.tree_map(
513
533
  dyn_loss_for_one_key,
514
534
  self.dynamic_loss_dict,
515
- self.derivative_keys_dyn_loss_dict,
516
535
  self._loss_weights["dyn_loss"],
517
536
  is_leaf=lambda x: isinstance(x, ODE), # before when dynamic losses
518
537
  # where plain (unregister pytree) node classes, we could not traverse
jinns/loss/_LossPDE.py CHANGED
@@ -103,6 +103,9 @@ class _LossPDEAbstract(eqx.Module):
103
103
  obs_slice : slice, default=None
104
104
  slice object specifying the begininning/ending of the PINN output
105
105
  that is observed (this is then useful for multidim PINN). Default is None.
106
+ params : InitVar[Params], default=None
107
+ The main Params object of the problem needed to instanciate the
108
+ DerivativeKeysODE if the latter is not specified.
106
109
  """
107
110
 
108
111
  # NOTE static=True only for leaf attributes that are not valid JAX types
@@ -129,18 +132,26 @@ class _LossPDEAbstract(eqx.Module):
129
132
  norm_int_length: float | None = eqx.field(kw_only=True, default=None)
130
133
  obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
131
134
 
132
- def __post_init__(self):
135
+ params: InitVar[Params] = eqx.field(kw_only=True, default=None)
136
+
137
+ def __post_init__(self, params=None):
133
138
  """
134
139
  Note that neither __init__ or __post_init__ are called when udating a
135
140
  Module with eqx.tree_at
136
141
  """
137
142
  if self.derivative_keys is None:
138
143
  # be default we only take gradient wrt nn_params
139
- self.derivative_keys = (
140
- DerivativeKeysPDENonStatio()
141
- if isinstance(self, LossPDENonStatio)
142
- else DerivativeKeysPDEStatio()
143
- )
144
+ try:
145
+ self.derivative_keys = (
146
+ DerivativeKeysPDENonStatio(params=params)
147
+ if isinstance(self, LossPDENonStatio)
148
+ else DerivativeKeysPDEStatio(params=params)
149
+ )
150
+ except ValueError as exc:
151
+ raise ValueError(
152
+ "Problem at self.derivative_keys initialization "
153
+ f"received {self.derivative_keys=} and {params=}"
154
+ ) from exc
144
155
 
145
156
  if self.loss_weights is None:
146
157
  self.loss_weights = (
@@ -324,6 +335,9 @@ class LossPDEStatio(_LossPDEAbstract):
324
335
  obs_slice : slice, default=None
325
336
  slice object specifying the begininning/ending of the PINN output
326
337
  that is observed (this is then useful for multidim PINN). Default is None.
338
+ params : InitVar[Params], default=None
339
+ The main Params object of the problem needed to instanciate the
340
+ DerivativeKeysODE if the latter is not specified.
327
341
 
328
342
 
329
343
  Raises
@@ -342,12 +356,14 @@ class LossPDEStatio(_LossPDEAbstract):
342
356
 
343
357
  vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
344
358
 
345
- def __post_init__(self):
359
+ def __post_init__(self, params=None):
346
360
  """
347
361
  Note that neither __init__ or __post_init__ are called when udating a
348
362
  Module with eqx.tree_at!
349
363
  """
350
- super().__post_init__() # because __init__ or __post_init__ of Base
364
+ super().__post_init__(
365
+ params=params
366
+ ) # because __init__ or __post_init__ of Base
351
367
  # class is not automatically called
352
368
 
353
369
  self.vmap_in_axes = (0,) # for x only here
@@ -547,6 +563,9 @@ class LossPDENonStatio(LossPDEStatio):
547
563
  initial_condition_fun : Callable, default=None
548
564
  A function representing the temporal initial condition. If None
549
565
  (default) then no initial condition is applied
566
+ params : InitVar[Params], default=None
567
+ The main Params object of the problem needed to instanciate the
568
+ DerivativeKeysODE if the latter is not specified.
550
569
 
551
570
  """
552
571
 
@@ -556,12 +575,14 @@ class LossPDENonStatio(LossPDEStatio):
556
575
  kw_only=True, default=None, static=True
557
576
  )
558
577
 
559
- def __post_init__(self):
578
+ def __post_init__(self, params=None):
560
579
  """
561
580
  Note that neither __init__ or __post_init__ are called when udating a
562
581
  Module with eqx.tree_at!
563
582
  """
564
- super().__post_init__() # because __init__ or __post_init__ of Base
583
+ super().__post_init__(
584
+ params=params
585
+ ) # because __init__ or __post_init__ of Base
565
586
  # class is not automatically called
566
587
 
567
588
  self.vmap_in_axes = (0, 0) # for t and x
@@ -616,7 +637,6 @@ class LossPDENonStatio(LossPDEStatio):
616
637
  of parameters (eg. for metamodeling) and an optional additional batch of observed
617
638
  inputs/outputs/parameters
618
639
  """
619
-
620
640
  omega_batch = batch.times_x_inside_batch[:, 1:]
621
641
 
622
642
  # Retrieve the optional eq_params_batch
@@ -725,6 +745,9 @@ class SystemLossPDE(eqx.Module):
725
745
  PINNs. Default is None. But if a value is given, all the entries of
726
746
  `u_dict` must be represented here with default value `jnp.s_[...]`
727
747
  if no particular slice is to be given
748
+ params : InitVar[ParamsDict], default=None
749
+ The main Params object of the problem needed to instanciate the
750
+ DerivativeKeysODE if the latter is not specified.
728
751
 
729
752
  """
730
753
 
@@ -763,22 +786,20 @@ class SystemLossPDE(eqx.Module):
763
786
  loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
764
787
  kw_only=True, default=None
765
788
  )
789
+ params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
766
790
 
767
791
  # following have init=False and are set in the __post_init__
768
792
  u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
769
793
  init=False
770
794
  )
771
- derivative_keys_u_dict: Dict[
772
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
773
- ] = eqx.field(init=False)
774
- derivative_keys_dyn_loss_dict: Dict[
775
- str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
776
- ] = eqx.field(init=False)
795
+ derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
796
+ eqx.field(init=False)
797
+ )
777
798
  u_dict_with_none: Dict[str, None] = eqx.field(init=False)
778
799
  # internally the loss weights are handled with a dictionary
779
800
  _loss_weights: Dict[str, dict] = eqx.field(init=False)
780
801
 
781
- def __post_init__(self, loss_weights):
802
+ def __post_init__(self, loss_weights=None, params_dict=None):
782
803
  # a dictionary that will be useful at different places
783
804
  self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
784
805
  # First, for all the optional dict,
@@ -818,24 +839,19 @@ class SystemLossPDE(eqx.Module):
818
839
  # iterating on dynamic_loss_dict. So each time we will require dome
819
840
  # derivative_keys_dict
820
841
 
821
- # but then if the user did not provide anything, we must at least have
822
- # a default value for the dynamic_loss_dict keys entries in
823
- # self.derivative_keys_dict since the computation of dynamic losses is
824
- # made without create a loss object that would provide the
825
- # default values
826
- for k in self.dynamic_loss_dict.keys():
842
+ # derivative keys for the u_constraints. Note that we create missing
843
+ # DerivativeKeysODE around a Params object and not ParamsDict
844
+ # this works because u_dict.keys == params_dict.nn_params.keys()
845
+ for k in self.u_dict.keys():
827
846
  if self.derivative_keys_dict[k] is None:
828
- try:
829
- if self.u_dict[k].eq_type == "statio_PDE":
830
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
831
- else:
832
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
833
- except KeyError: # We are in a key that is not in u_dict but in
834
- # dynamic_loss_dict
835
- if isinstance(self.dynamic_loss_dict[k], PDEStatio):
836
- self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
837
- else:
838
- self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
847
+ if self.u_dict[k].eq_type == "statio_PDE":
848
+ self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
849
+ params=params_dict.extract_params(k)
850
+ )
851
+ else:
852
+ self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
853
+ params=params_dict.extract_params(k)
854
+ )
839
855
 
840
856
  # Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
841
857
  if (
@@ -904,16 +920,11 @@ class SystemLossPDE(eqx.Module):
904
920
  f"got {self.u_dict[i].eq_type[i]}"
905
921
  )
906
922
 
907
- # for convenience in the tree_map of evaluate,
908
- # we separate the two derivative keys dict
909
- self.derivative_keys_dyn_loss_dict = {
910
- k: self.derivative_keys_dict[k]
911
- for k in self.dynamic_loss_dict.keys() & self.derivative_keys_dict.keys()
912
- }
913
- self.derivative_keys_u_dict = {
914
- k: self.derivative_keys_dict[k]
915
- for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
916
- }
923
+ # derivative keys for the dynamic loss. Note that we create a
924
+ # DerivativeKeysODE around a ParamsDict object because a whole
925
+ # params_dict is feed to DynamicLoss.evaluate functions (extract_params
926
+ # happen inside it)
927
+ self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
917
928
 
918
929
  # also make sure we only have PINNs or SPINNs
919
930
  if not (
@@ -1034,13 +1045,13 @@ class SystemLossPDE(eqx.Module):
1034
1045
  batch.param_batch_dict, params_dict
1035
1046
  )
1036
1047
 
1037
- def dyn_loss_for_one_key(dyn_loss, derivative_key, loss_weight):
1048
+ def dyn_loss_for_one_key(dyn_loss, loss_weight):
1038
1049
  """The function used in tree_map"""
1039
1050
  return dynamic_loss_apply(
1040
1051
  dyn_loss.evaluate,
1041
1052
  self.u_dict,
1042
1053
  batches,
1043
- _set_derivatives(params_dict, derivative_key.dyn_loss),
1054
+ _set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
1044
1055
  vmap_in_axes_x_or_x_t + vmap_in_axes_params,
1045
1056
  loss_weight,
1046
1057
  u_type=type(list(self.u_dict.values())[0]),
@@ -1049,7 +1060,6 @@ class SystemLossPDE(eqx.Module):
1049
1060
  dyn_loss_mse_dict = jax.tree_util.tree_map(
1050
1061
  dyn_loss_for_one_key,
1051
1062
  self.dynamic_loss_dict,
1052
- self.derivative_keys_dyn_loss_dict,
1053
1063
  self._loss_weights["dyn_loss"],
1054
1064
  is_leaf=lambda x: isinstance(
1055
1065
  x, (PDEStatio, PDENonStatio)
jinns/loss/_loss_utils.py CHANGED
@@ -297,12 +297,12 @@ def constraints_system_loss_apply(
297
297
  if isinstance(params_dict.nn_params, dict):
298
298
 
299
299
  def apply_u_constraint(
300
- u_constraint, nn_params, loss_weights_for_u, obs_batch_u
300
+ u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
301
301
  ):
302
302
  res_dict_for_u = u_constraint.evaluate(
303
303
  Params(
304
304
  nn_params=nn_params,
305
- eq_params=params_dict.eq_params,
305
+ eq_params=eq_params,
306
306
  ),
307
307
  append_obs_batch(batch, obs_batch_u),
308
308
  )[1]
@@ -319,6 +319,11 @@ def constraints_system_loss_apply(
319
319
  apply_u_constraint,
320
320
  u_constraints_dict,
321
321
  params_dict.nn_params,
322
+ (
323
+ params_dict.eq_params
324
+ if params_dict.eq_params.keys() == params_dict.nn_params.keys()
325
+ else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
326
+ ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
322
327
  loss_weights_T,
323
328
  batch.obs_batch_dict,
324
329
  is_leaf=lambda x: (
@@ -2,50 +2,468 @@
2
2
  Formalize the data structure for the derivative keys
3
3
  """
4
4
 
5
- from dataclasses import fields
5
+ from functools import partial
6
+ from dataclasses import fields, InitVar
6
7
  from typing import Literal
7
8
  import jax
8
9
  import equinox as eqx
9
10
 
10
- from jinns.parameters._params import Params
11
+ from jinns.parameters._params import Params, ParamsDict
12
+
13
+
14
+ def _get_masked_parameters(
15
+ derivative_mask_str: str, params: Params | ParamsDict
16
+ ) -> Params | ParamsDict:
17
+ """
18
+ Creates the Params object with True values where we want to differentiate
19
+ """
20
+ if isinstance(params, Params):
21
+ # start with a params object with True everywhere. We will update to False
22
+ # for parameters wrt which we do want not to differentiate the loss
23
+ diff_params = jax.tree.map(
24
+ lambda x: True,
25
+ params,
26
+ is_leaf=lambda x: isinstance(x, eqx.Module)
27
+ and not isinstance(x, Params), # do not travers nn_params, more
28
+ # granularity could be imagined here, in the future
29
+ )
30
+ if derivative_mask_str == "both":
31
+ return diff_params
32
+ if derivative_mask_str == "eq_params":
33
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
34
+ if derivative_mask_str == "nn_params":
35
+ return eqx.tree_at(
36
+ lambda p: p.eq_params,
37
+ diff_params,
38
+ jax.tree.map(lambda x: False, params.eq_params),
39
+ )
40
+ raise ValueError(
41
+ "Bad value for DerivativeKeys. Got "
42
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
43
+ ' "eq_params"'
44
+ )
45
+ elif isinstance(params, ParamsDict):
46
+ # do not travers nn_params, more
47
+ # granularity could be imagined here, in the future
48
+ diff_params = ParamsDict(
49
+ nn_params=True, eq_params=jax.tree.map(lambda x: True, params.eq_params)
50
+ )
51
+ if derivative_mask_str == "both":
52
+ return diff_params
53
+ if derivative_mask_str == "eq_params":
54
+ return eqx.tree_at(lambda p: p.nn_params, diff_params, False)
55
+ if derivative_mask_str == "nn_params":
56
+ return eqx.tree_at(
57
+ lambda p: p.eq_params,
58
+ diff_params,
59
+ jax.tree.map(lambda x: False, params.eq_params),
60
+ )
61
+ raise ValueError(
62
+ "Bad value for DerivativeKeys. Got "
63
+ f'{derivative_mask_str}, expected "both", "nn_params" or '
64
+ ' "eq_params"'
65
+ )
66
+
67
+ else:
68
+ raise ValueError(
69
+ f"Bad value for params. Got {type(params)}, expected Params "
70
+ " or ParamsDict"
71
+ )
11
72
 
12
73
 
13
74
  class DerivativeKeysODE(eqx.Module):
14
- # we use static = True because all fields are string, hence should be
15
- # invisible by JAX transforms (JIT, etc.)
16
- dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
17
- kw_only=True, default="nn_params", static=True
18
- )
19
- observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
20
- kw_only=True, default="nn_params", static=True
75
+ """
76
+ A class that specifies with repect to which parameter(s) each term of the
77
+ loss is differentiated. For example, you can specify that the
78
+ [`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
79
+ respect to the neural network parameters *and* the equation parameters, or only some of them.
80
+
81
+ To do so, user can either use strings or a `Params` object
82
+ with PyTree structure matching the parameters of the problem at
83
+ hand, and booleans indicating if gradient is to be taken or not. Internally,
84
+ a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
85
+ computing each loss term.
86
+
87
+ !!! note
88
+
89
+ 1. For unspecified loss term, the default is to differentiate with
90
+ respect to `"nn_params"` only.
91
+ 2. No granularity inside `Params.nn_params` is currently supported.
92
+ 3. Note that the main Params or ParamsDict object of the problem is mandatory if initialization via `from_str()`.
93
+
94
+ A typical specification is of the form:
95
+ ```python
96
+ Params(
97
+ nn_params=True | False,
98
+ eq_params={
99
+ "alpha":True | False,
100
+ "beta":True | False,
101
+ ...
102
+ }
21
103
  )
22
- initial_condition: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
23
- kw_only=True, default="nn_params", static=True
104
+ ```
105
+
106
+ Parameters
107
+ ----------
108
+ dyn_loss : Params | ParamsDict | None, default=None
109
+ Tell wrt which node of `Params` we will differentiate the
110
+ dynamic loss. To do so, the fields of `Params` contain True (if
111
+ differentiation) or False (if no differentiation).
112
+ observations : Params | ParamsDict | None, default=None
113
+ Tell wrt which parameters among Params we will differentiate the
114
+ observation loss. To do so, the fields of Params contain True (if
115
+ differentiation) or False (if no differentiation).
116
+ initial_condition : Params | ParamsDict | None, default=None
117
+ Tell wrt which parameters among Params we will differentiate the
118
+ initial condition loss. To do so, the fields of Params contain True (if
119
+ differentiation) or False (if no differentiation).
120
+ params : InitVar[Params | ParamsDict], default=None
121
+ The main Params object of the problem. It is required
122
+ if some terms are unspecified (None). This is because, jinns cannot
123
+ infer the content of `Params.eq_params`.
124
+ """
125
+
126
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
127
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
128
+ initial_condition: Params | ParamsDict | None = eqx.field(
129
+ kw_only=True, default=None
24
130
  )
25
131
 
132
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
133
+
134
+ def __post_init__(self, params=None):
135
+ if self.dyn_loss is None:
136
+ try:
137
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
138
+ except AttributeError:
139
+ raise ValueError(
140
+ "self.dyn_loss is None, hence params should be " "passed"
141
+ )
142
+ if self.observations is None:
143
+ try:
144
+ self.observations = _get_masked_parameters("nn_params", params)
145
+ except AttributeError:
146
+ raise ValueError(
147
+ "self.observations is None, hence params should be " "passed"
148
+ )
149
+ if self.initial_condition is None:
150
+ try:
151
+ self.initial_condition = _get_masked_parameters("nn_params", params)
152
+ except AttributeError:
153
+ raise ValueError(
154
+ "self.initial_condition is None, hence params should be " "passed"
155
+ )
156
+
157
+ @classmethod
158
+ def from_str(
159
+ cls,
160
+ params: Params | ParamsDict,
161
+ dyn_loss: (
162
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
163
+ ) = "nn_params",
164
+ observations: (
165
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
166
+ ) = "nn_params",
167
+ initial_condition: (
168
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
169
+ ) = "nn_params",
170
+ ):
171
+ """
172
+ Construct the DerivativeKeysODE from strings. For each term of the
173
+ loss, specify whether to differentiate wrt the neural network
174
+ parameters, the equation parameters or both. The `Params` object, which
175
+ contains the actual array of parameters must be passed to
176
+ construct the fields with the appropriate PyTree structure.
177
+
178
+ !!! note
179
+ You can mix strings and `Params` if you need granularity.
180
+
181
+ Parameters
182
+ ----------
183
+ params
184
+ The actual Params or ParamsDict object of the problem.
185
+ dyn_loss
186
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
187
+ `"both"` we will differentiate the dynamic loss. Default is
188
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
189
+ observations
190
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
191
+ `"both"` we will differentiate the observations. Default is
192
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
193
+ initial_condition
194
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
195
+ `"both"` we will differentiate the initial condition. Default is
196
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
197
+ """
198
+ return DerivativeKeysODE(
199
+ dyn_loss=(
200
+ _get_masked_parameters(dyn_loss, params)
201
+ if isinstance(dyn_loss, str)
202
+ else dyn_loss
203
+ ),
204
+ observations=(
205
+ _get_masked_parameters(observations, params)
206
+ if isinstance(observations, str)
207
+ else observations
208
+ ),
209
+ initial_condition=(
210
+ _get_masked_parameters(initial_condition, params)
211
+ if isinstance(initial_condition, str)
212
+ else initial_condition
213
+ ),
214
+ )
215
+
26
216
 
27
217
  class DerivativeKeysPDEStatio(eqx.Module):
218
+ """
219
+ See [jinns.parameters.DerivativeKeysODE][].
28
220
 
29
- dyn_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
30
- kw_only=True, default="nn_params", static=True
31
- )
32
- observations: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
33
- kw_only=True, default="nn_params", static=True
34
- )
35
- boundary_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
36
- kw_only=True, default="nn_params", static=True
37
- )
38
- norm_loss: Literal["nn_params", "eq_params", "both"] | None = eqx.field(
39
- kw_only=True, default="nn_params", static=True
40
- )
221
+ Parameters
222
+ ----------
223
+ dyn_loss : Params | ParamsDict | None, default=None
224
+ Tell wrt which parameters among Params we will differentiate the
225
+ dynamic loss. To do so, the fields of Params contain True (if
226
+ differentiation) or False (if no differentiation).
227
+ observations : Params | ParamsDict | None, default=None
228
+ Tell wrt which parameters among Params we will differentiate the
229
+ observation loss. To do so, the fields of Params contain True (if
230
+ differentiation) or False (if no differentiation).
231
+ boundary_loss : Params | ParamsDict | None, default=None
232
+ Tell wrt which parameters among Params we will differentiate the
233
+ boundary loss. To do so, the fields of Params contain True (if
234
+ differentiation) or False (if no differentiation).
235
+ norm_loss : Params | ParamsDict | None, default=None
236
+ Tell wrt which parameters among Params we will differentiate the
237
+ normalization loss. To do so, the fields of Params contain True (if
238
+ differentiation) or False (if no differentiation).
239
+ params : InitVar[Params | ParamsDict], default=None
240
+ The main Params object of the problem. It is required
241
+ if some terms are unspecified (None). This is because, jinns cannot infer the
242
+ content of `Params.eq_params`.
243
+ """
244
+
245
+ dyn_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
246
+ observations: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
247
+ boundary_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
248
+ norm_loss: Params | ParamsDict | None = eqx.field(kw_only=True, default=None)
249
+
250
+ params: InitVar[Params | ParamsDict] = eqx.field(kw_only=True, default=None)
251
+
252
+ def __post_init__(self, params=None):
253
+ if self.dyn_loss is None:
254
+ try:
255
+ self.dyn_loss = _get_masked_parameters("nn_params", params)
256
+ except AttributeError:
257
+ raise ValueError("self.dyn_loss is None, hence params should be passed")
258
+ if self.observations is None:
259
+ try:
260
+ self.observations = _get_masked_parameters("nn_params", params)
261
+ except AttributeError:
262
+ raise ValueError(
263
+ "self.observations is None, hence params should be passed"
264
+ )
265
+ if self.boundary_loss is None:
266
+ try:
267
+ self.boundary_loss = _get_masked_parameters("nn_params", params)
268
+ except AttributeError:
269
+ raise ValueError(
270
+ "self.boundary_loss is None, hence params should be passed"
271
+ )
272
+ if self.norm_loss is None:
273
+ try:
274
+ self.norm_loss = _get_masked_parameters("nn_params", params)
275
+ except AttributeError:
276
+ raise ValueError(
277
+ "self.norm_loss is None, hence params should be passed"
278
+ )
279
+
280
+ @classmethod
281
+ def from_str(
282
+ cls,
283
+ params: Params | ParamsDict,
284
+ dyn_loss: (
285
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
286
+ ) = "nn_params",
287
+ observations: (
288
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
289
+ ) = "nn_params",
290
+ boundary_loss: (
291
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
292
+ ) = "nn_params",
293
+ norm_loss: (
294
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
295
+ ) = "nn_params",
296
+ ):
297
+ """
298
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
299
+
300
+ Parameters
301
+ ----------
302
+ params
303
+ The actual Param or ParamsDict object of the problem.
304
+ dyn_loss
305
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
306
+ `"both"` we will differentiate the dynamic loss. Default is
307
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
308
+ observations
309
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
310
+ `"both"` we will differentiate the observations. Default is
311
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
312
+ boundary_loss
313
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
314
+ `"both"` we will differentiate the boundary loss. Default is
315
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
316
+ norm_loss
317
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
318
+ `"both"` we will differentiate the normalization loss. Default is
319
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
320
+ """
321
+ return DerivativeKeysPDEStatio(
322
+ dyn_loss=(
323
+ _get_masked_parameters(dyn_loss, params)
324
+ if isinstance(dyn_loss, str)
325
+ else dyn_loss
326
+ ),
327
+ observations=(
328
+ _get_masked_parameters(observations, params)
329
+ if isinstance(observations, str)
330
+ else observations
331
+ ),
332
+ boundary_loss=(
333
+ _get_masked_parameters(boundary_loss, params)
334
+ if isinstance(boundary_loss, str)
335
+ else boundary_loss
336
+ ),
337
+ norm_loss=(
338
+ _get_masked_parameters(norm_loss, params)
339
+ if isinstance(norm_loss, str)
340
+ else norm_loss
341
+ ),
342
+ )
41
343
 
42
344
 
43
345
  class DerivativeKeysPDENonStatio(DerivativeKeysPDEStatio):
346
+ """
347
+ See [jinns.parameters.DerivativeKeysODE][].
44
348
 
45
- initial_condition: Literal["nn_params", "eq_params", "both"] = eqx.field(
46
- kw_only=True, default="nn_params", static=True
349
+ Parameters
350
+ ----------
351
+ dyn_loss : Params | ParamsDict | None, default=None
352
+ Tell wrt which parameters among Params we will differentiate the
353
+ dynamic loss. To do so, the fields of Params contain True (if
354
+ differentiation) or False (if no differentiation).
355
+ observations : Params | ParamsDict | None, default=None
356
+ Tell wrt which parameters among Params we will differentiate the
357
+ observation loss. To do so, the fields of Params contain True (if
358
+ differentiation) or False (if no differentiation).
359
+ boundary_loss : Params | ParamsDict | None, default=None
360
+ Tell wrt which parameters among Params we will differentiate the
361
+ boundary loss. To do so, the fields of Params contain True (if
362
+ differentiation) or False (if no differentiation).
363
+ norm_loss : Params | ParamsDict | None, default=None
364
+ Tell wrt which parameters among Params we will differentiate the
365
+ normalization loss. To do so, the fields of Params contain True (if
366
+ differentiation) or False (if no differentiation).
367
+ initial_condition : Params | ParamsDict | None, default=None
368
+ Tell wrt which parameters among Params we will differentiate the
369
+ initial_condition loss. To do so, the fields of Params contain True (if
370
+ differentiation) or False (if no differentiation).
371
+ params : InitVar[Params | ParamsDict], default=None
372
+ The main Params object of the problem. It is required
373
+ if some terms are unspecified (None). This is because, jinns cannot infer the
374
+ content of `Params.eq_params`.
375
+ """
376
+
377
+ initial_condition: Params | ParamsDict | None = eqx.field(
378
+ kw_only=True, default=None
47
379
  )
48
380
 
381
+ def __post_init__(self, params=None):
382
+ super().__post_init__(params=params)
383
+ if self.initial_condition is None:
384
+ try:
385
+ self.initial_condition = _get_masked_parameters("nn_params", params)
386
+ except AttributeError:
387
+ raise ValueError(
388
+ "self.initial_condition is None, hence params should be passed"
389
+ )
390
+
391
+ @classmethod
392
+ def from_str(
393
+ cls,
394
+ params: Params | ParamsDict,
395
+ dyn_loss: (
396
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
397
+ ) = "nn_params",
398
+ observations: (
399
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
400
+ ) = "nn_params",
401
+ boundary_loss: (
402
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
403
+ ) = "nn_params",
404
+ norm_loss: (
405
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
406
+ ) = "nn_params",
407
+ initial_condition: (
408
+ Literal["nn_params", "eq_params", "both"] | Params | ParamsDict
409
+ ) = "nn_params",
410
+ ):
411
+ """
412
+ See [jinns.parameters.DerivativeKeysODE.from_str][].
413
+
414
+ Parameters
415
+ ----------
416
+ params
417
+ The actual Params | ParamsDict object of the problem.
418
+ dyn_loss
419
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
420
+ `"both"` we will differentiate the dynamic loss. Default is
421
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
422
+ observations
423
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
424
+ `"both"` we will differentiate the observations. Default is
425
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
426
+ boundary_loss
427
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
428
+ `"both"` we will differentiate the boundary loss. Default is
429
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
430
+ norm_loss
431
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
432
+ `"both"` we will differentiate the normalization loss. Default is
433
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
434
+ initial_condition
435
+ Tell wrt which parameters among `"nn_params"`, `"eq_params"` or
436
+ `"both"` we will differentiate the initial_condition loss. Default is
437
+ `"nn_params"`. Specifying a Params or ParamsDict is also possible.
438
+ """
439
+ return DerivativeKeysPDENonStatio(
440
+ dyn_loss=(
441
+ _get_masked_parameters(dyn_loss, params)
442
+ if isinstance(dyn_loss, str)
443
+ else dyn_loss
444
+ ),
445
+ observations=(
446
+ _get_masked_parameters(observations, params)
447
+ if isinstance(observations, str)
448
+ else observations
449
+ ),
450
+ boundary_loss=(
451
+ _get_masked_parameters(boundary_loss, params)
452
+ if isinstance(boundary_loss, str)
453
+ else boundary_loss
454
+ ),
455
+ norm_loss=(
456
+ _get_masked_parameters(norm_loss, params)
457
+ if isinstance(norm_loss, str)
458
+ else norm_loss
459
+ ),
460
+ initial_condition=(
461
+ _get_masked_parameters(initial_condition, params)
462
+ if isinstance(initial_condition, str)
463
+ else initial_condition
464
+ ),
465
+ )
466
+
49
467
 
50
468
  def _set_derivatives(params, derivative_keys):
51
469
  """
@@ -53,42 +471,51 @@ def _set_derivatives(params, derivative_keys):
53
471
  has a copy of the params with appropriate derivatives set
54
472
  """
55
473
 
56
- def _set_derivatives_(loss_term_derivative):
57
- if loss_term_derivative == "both":
58
- return params
59
- # the next line put a stop_gradient around the fields that do not
60
- # appear in loss_term_derivative. Currently there are only two possible
61
- # values nn_params and eq_params but there might be more in the future
62
- return eqx.tree_at(
63
- lambda p: tuple(
64
- getattr(p, f.name)
65
- for f in fields(Params)
66
- if f.name != loss_term_derivative
474
+ def _set_derivatives_ParamsDict(params_, derivative_mask):
475
+ """
476
+ The next lines put a stop_gradient around the fields that do not
477
+ differentiate the loss term
478
+ **Note:** **No granularity inside `ParamsDict.nn_params` is currently
479
+ supported.**
480
+ This means a typical Params specification is of the form:
481
+ `ParamsDict(nn_params=True | False, eq_params={"0":{"alpha":True | False,
482
+ "beta":True | False}}, "1":{"alpha":True | False, "beta":True | False}})`.
483
+ """
484
+ # a ParamsDict object is reconstructed by hand since we do not want to
485
+ # traverse nn_params, for now...
486
+ return ParamsDict(
487
+ nn_params=jax.lax.cond(
488
+ derivative_mask.nn_params,
489
+ lambda p: p,
490
+ jax.lax.stop_gradient,
491
+ params_.nn_params,
492
+ ),
493
+ eq_params=jax.tree.map(
494
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
495
+ params_.eq_params,
496
+ derivative_mask.eq_params,
67
497
  ),
68
- params,
69
- replace_fn=jax.lax.stop_gradient,
70
498
  )
71
499
 
72
- def _set_derivatives_dict(loss_term_derivative):
73
- if loss_term_derivative == "both":
74
- return params
75
- # the next line put a stop_gradient around the fields that do not
76
- # appear in loss_term_derivative. Currently there are only two possible
77
- # values nn_params and eq_params but there might be more in the future
78
- return {
79
- k: eqx.tree_at(
80
- lambda p: tuple(
81
- getattr(p, f.name)
82
- for f in fields(Params)
83
- if f.name != loss_term_derivative
84
- ),
85
- params_,
86
- replace_fn=jax.lax.stop_gradient,
87
- )
88
- for k, params_ in params
89
- }
500
+ def _set_derivatives_(params_, derivative_mask):
501
+ """
502
+ The next lines put a stop_gradient around the fields that do not
503
+ differentiate the loss term
504
+ **Note:** **No granularity inside `Params.nn_params` is currently
505
+ supported.**
506
+ This means a typical Params specification is of the form:
507
+ `Params(nn_params=True | False, eq_params={"alpha":True | False,
508
+ "beta":True | False})`.
509
+ """
510
+ return jax.tree.map(
511
+ lambda p, d: jax.lax.cond(d, lambda p: p, jax.lax.stop_gradient, p),
512
+ params_,
513
+ derivative_mask,
514
+ is_leaf=lambda x: isinstance(x, eqx.Module)
515
+ and not isinstance(x, Params), # do not travers nn_params, more
516
+ # granularity could be imagined here, in the future
517
+ )
90
518
 
91
- if not isinstance(params, dict):
92
- return _set_derivatives_(derivative_keys)
93
- else:
94
- return _set_derivatives_dict(derivative_keys)
519
+ if isinstance(params, ParamsDict):
520
+ return _set_derivatives_ParamsDict(params, derivative_keys)
521
+ return _set_derivatives_(params, derivative_keys)
@@ -0,0 +1,2 @@
1
+ Hugo Gangloff <hugo.gangloff@inrae.fr>
2
+ Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python
13
13
  Requires-Python: >=3.10
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
+ License-File: AUTHORS
16
17
  Requires-Dist: numpy
17
18
  Requires-Dist: jax
18
19
  Requires-Dist: jaxopt
@@ -5,16 +5,16 @@ jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
5
5
  jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
6
6
  jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
7
7
  jinns/loss/_DynamicLoss.py,sha256=WGbAuWnNfsbzUlWEiW_ARd4kI3jmHwdqPjxLC-wCA6s,25753
8
- jinns/loss/_DynamicLossAbstract.py,sha256=UGFAZWCxbk7yTLdngNasaVawz51xZhu7LfCg4MTqj74,13765
9
- jinns/loss/_LossODE.py,sha256=4Tcl_Fpn7htoJvggreL3vv5xfQw2VR1tshLGVX7BUx4,22480
10
- jinns/loss/_LossPDE.py,sha256=zmnLblA6tK-KUKHFbwoo_N01CU1KwCdZt5lUCo-_ecM,48528
8
+ jinns/loss/_DynamicLossAbstract.py,sha256=Xyt28Oej_zlhcV3f6cw2vnAKyRJhXBiA63CsdL3PihU,13767
9
+ jinns/loss/_LossODE.py,sha256=ThWPse6Gn5crM3_tzwZCBx-usoD0xWu6y1n0GVl2dpI,23422
10
+ jinns/loss/_LossPDE.py,sha256=R9kNQiaFbFx2eCMdjB7ie7UJ9pJW7PmvHijioNgu-bs,49117
11
11
  jinns/loss/__init__.py,sha256=Fm4QAHaVmp0CA7HSwb7KUctwdXnNZ9v5KmTqpeoYPaE,669
12
12
  jinns/loss/_boundary_conditions.py,sha256=O0D8eWsFfvNNeO20PQ0rUKBI_MDqaBvqChfXaztZoL4,16679
13
- jinns/loss/_loss_utils.py,sha256=DI9KifH1aObrFau86BgAjtBNS8ecSRYB-_H6fduzGNg,13339
13
+ jinns/loss/_loss_utils.py,sha256=44J-VF6dxT_o5BcNWFOiLpY40c35YnAxxZkoNtdtcZc,13689
14
14
  jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
15
15
  jinns/loss/_operators.py,sha256=o-Ljp_9_HXB9Mhm-ANh6ouNw4_PsqLJAha7dFDGl_nQ,10781
16
16
  jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
17
- jinns/parameters/_derivative_keys.py,sha256=MbbyywuemjvgX-QVG2SSqlEreB7bMDCLnXmRZuGFkjE,3294
17
+ jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
18
18
  jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
19
19
  jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
20
20
  jinns/plot/_plot.py,sha256=ZGIJdGwEd3NlHRTq_2sOfEH_CtOkvPwdgCMct-nQlJE,11691
@@ -31,8 +31,9 @@ jinns/utils/_types.py,sha256=P_dS0odrHbyalYJ0FjS6q0tkXAGr-4GArsiyJYrB1ho,1878
31
31
  jinns/utils/_utils.py,sha256=Ow8xB516E7yHDZatokVJHHFNPDu6fXr9-NmraUXjjyw,1819
32
32
  jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
33
33
  jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
34
- jinns-1.0.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
35
- jinns-1.0.0.dist-info/METADATA,sha256=XD9Az0b4tUbBNS1xReORk0j1HMkonZGs9_B0FfHFq34,2514
36
- jinns-1.0.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
37
- jinns-1.0.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
38
- jinns-1.0.0.dist-info/RECORD,,
34
+ jinns-1.1.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
35
+ jinns-1.1.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
36
+ jinns-1.1.0.dist-info/METADATA,sha256=3Qk885oguf6S_WPHd9KCFVIWP21nJqX9zFWoS9ZI-T0,2536
37
+ jinns-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
38
+ jinns-1.1.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
39
+ jinns-1.1.0.dist-info/RECORD,,
File without changes
File without changes