jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossPDE.py CHANGED
@@ -68,11 +68,14 @@ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
68
68
  C = TypeVar(
69
69
  "C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
70
70
  )
71
- D = TypeVar("D", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
71
+ DKPDE = TypeVar("DKPDE", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
72
72
  Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
73
73
 
74
74
 
75
- class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
75
+ class _LossPDEAbstract(
76
+ AbstractLoss[L, B, C, DKPDE],
77
+ Generic[L, B, C, DKPDE, Y],
78
+ ):
76
79
  r"""
77
80
  Parameters
78
81
  ----------
@@ -119,7 +122,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
119
122
  made broadcastable to `norm_samples`.
120
123
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
121
124
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
122
- obs_slice : EllipsisType | slice, default=None
125
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
123
126
  slice object specifying the begininning/ending of the PINN output
124
127
  that is observed (this is then useful for multidim PINN). Default is None.
125
128
  key : Key | None
@@ -133,7 +136,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
133
136
  # NOTE static=True only for leaf attributes that are not valid JAX types
134
137
  # (ie. jax.Array cannot be static) and that we do not expect to change
135
138
  u: eqx.AbstractVar[AbstractPINN]
136
- dynamic_loss: eqx.AbstractVar[Y]
139
+ dynamic_loss: tuple[eqx.AbstractVar[Y] | None, ...]
137
140
  omega_boundary_fun: (
138
141
  BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
139
142
  ) = eqx.field(static=True)
@@ -143,7 +146,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
143
146
  omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
144
147
  norm_samples: Float[Array, " nb_norm_samples dimension"] | None
145
148
  norm_weights: Float[Array, " nb_norm_samples"] | None
146
- obs_slice: EllipsisType | slice = eqx.field(static=True)
149
+ obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
147
150
  key: PRNGKeyArray | None
148
151
 
149
152
  def __init__(
@@ -156,14 +159,34 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
156
159
  omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
157
160
  norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
158
161
  norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
159
- obs_slice: EllipsisType | slice | None = None,
162
+ obs_slice: tuple[EllipsisType | slice, ...]
163
+ | EllipsisType
164
+ | slice
165
+ | None = None,
160
166
  key: PRNGKeyArray | None = None,
167
+ derivative_keys: DKPDE,
161
168
  **kwargs: Any, # for arguments for super()
162
169
  ):
163
- super().__init__(loss_weights=self.loss_weights, **kwargs)
170
+ super().__init__(
171
+ loss_weights=self.loss_weights,
172
+ derivative_keys=derivative_keys,
173
+ **kwargs,
174
+ )
175
+
176
+ if self.update_weight_method is not None and jnp.any(
177
+ jnp.array(jax.tree.leaves(self.loss_weights)) == 0
178
+ ):
179
+ warnings.warn(
180
+ "self.update_weight_method is activated while some loss "
181
+ "weights are zero. The update weight method will likely "
182
+ "update the zero weight to some non-zero value. Check that "
183
+ "this is the desired behaviour."
184
+ )
164
185
 
165
186
  if obs_slice is None:
166
- self.obs_slice = jnp.s_[...]
187
+ self.obs_slice = (jnp.s_[...],)
188
+ elif not isinstance(obs_slice, tuple):
189
+ self.obs_slice = (obs_slice,)
167
190
  else:
168
191
  self.obs_slice = obs_slice
169
192
 
@@ -303,16 +326,23 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
303
326
 
304
327
  def _get_dyn_loss_fun(
305
328
  self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
306
- ) -> Callable[[Params[Array]], Array] | None:
307
- if self.dynamic_loss is not None:
308
- dyn_loss_eval = self.dynamic_loss.evaluate
309
- dyn_loss_fun: Callable[[Params[Array]], Array] | None = (
310
- lambda p: dynamic_loss_apply(
311
- dyn_loss_eval,
312
- self.u,
313
- self._get_dynamic_loss_batch(batch),
314
- _set_derivatives(p, self.derivative_keys.dyn_loss),
315
- self.vmap_in_axes + vmap_in_axes_params,
329
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
330
+ # Note, for the record, multiple dynamic losses
331
+ # have been introduced in MR 92
332
+ if self.dynamic_loss != (None,):
333
+ dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
334
+ jax.tree.map(
335
+ lambda d: lambda p: dynamic_loss_apply(
336
+ d.evaluate,
337
+ self.u,
338
+ self._get_dynamic_loss_batch(batch),
339
+ _set_derivatives(p, self.derivative_keys.dyn_loss),
340
+ self.vmap_in_axes + vmap_in_axes_params,
341
+ ),
342
+ self.dynamic_loss,
343
+ is_leaf=lambda x: isinstance(
344
+ x, (PDEStatio, PDENonStatio)
345
+ ), # do not traverse further than first level
316
346
  )
317
347
  )
318
348
  else:
@@ -322,9 +352,13 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
322
352
 
323
353
  def _get_norm_loss_fun(
324
354
  self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
325
- ) -> Callable[[Params[Array]], Array] | None:
355
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
356
+ # Note that since MR 92
357
+ # norm_loss_fun is formed as a tuple for
358
+ # consistency with dynamic and observation
359
+ # losses and more modularity for later
326
360
  if self.norm_samples is not None:
327
- norm_loss_fun: Callable[[Params[Array]], Array] | None = (
361
+ norm_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
328
362
  lambda p: normalization_loss_apply(
329
363
  self.u,
330
364
  cast(
@@ -333,7 +367,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
333
367
  _set_derivatives(p, self.derivative_keys.norm_loss),
334
368
  vmap_in_axes_params,
335
369
  self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
336
- )
370
+ ),
337
371
  )
338
372
  else:
339
373
  norm_loss_fun = None
@@ -341,20 +375,24 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
341
375
 
342
376
  def _get_boundary_loss_fun(
343
377
  self, batch: B
344
- ) -> Callable[[Params[Array]], Array] | None:
378
+ ) -> tuple[Callable[[Params[Array]], Array], ...] | None:
379
+ # Note that since MR 92
380
+ # boundary_loss_fun is formed as a tuple for
381
+ # consistency with dynamic and observation
382
+ # losses and more modularity for later
345
383
  if (
346
384
  self.omega_boundary_condition is not None
347
385
  and self.omega_boundary_fun is not None
348
386
  ):
349
- boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
387
+ boundary_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
350
388
  lambda p: boundary_condition_apply(
351
389
  self.u,
352
390
  batch,
353
391
  _set_derivatives(p, self.derivative_keys.boundary_loss),
354
392
  self.omega_boundary_fun, # type: ignore (we are in lambda)
355
393
  self.omega_boundary_condition, # type: ignore
356
- self.omega_boundary_dim, # type: ignore
357
- )
394
+ self.omega_boundary_dim,
395
+ ),
358
396
  )
359
397
  else:
360
398
  boundary_loss_fun = None
@@ -366,24 +404,37 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
366
404
  batch: B,
367
405
  vmap_in_axes_params: tuple[Params[int | None] | None],
368
406
  params: Params[Array],
369
- ) -> tuple[Params[Array] | None, Callable[[Params[Array]], Array] | None]:
407
+ ) -> tuple[
408
+ Params[Array] | None, tuple[Callable[[Params[Array]], Array], ...] | None
409
+ ]:
410
+ # Note, for the record, multiple DGObs
411
+ # (leading to batch.obs_batch_dict being tuple | None)
412
+ # have been introduced in MR 92
370
413
  if batch.obs_batch_dict is not None:
371
- # update params with the batches of observed params
372
- params_obs = update_eq_params(params, batch.obs_batch_dict["eq_params"])
373
-
374
- pinn_in, val = (
375
- batch.obs_batch_dict["pinn_in"],
376
- batch.obs_batch_dict["val"],
414
+ if len(batch.obs_batch_dict) != len(self.obs_slice):
415
+ raise ValueError(
416
+ "There must be the same number of "
417
+ "observation datasets as the number of "
418
+ "obs_slice"
419
+ )
420
+ params_obs = jax.tree.map(
421
+ lambda d: update_eq_params(params, d["eq_params"]),
422
+ batch.obs_batch_dict,
423
+ is_leaf=lambda x: isinstance(x, dict),
377
424
  )
378
-
379
- obs_loss_fun: Callable[[Params[Array]], Array] | None = (
380
- lambda po: observations_loss_apply(
381
- self.u,
382
- pinn_in,
383
- _set_derivatives(po, self.derivative_keys.observations),
384
- self.vmap_in_axes + vmap_in_axes_params,
385
- val,
425
+ obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
426
+ jax.tree.map(
427
+ lambda d, slice_: lambda po: observations_loss_apply(
428
+ self.u,
429
+ d["pinn_in"],
430
+ _set_derivatives(po, self.derivative_keys.observations),
431
+ self.vmap_in_axes + vmap_in_axes_params,
432
+ d["val"],
433
+ slice_,
434
+ ),
435
+ batch.obs_batch_dict,
386
436
  self.obs_slice,
437
+ is_leaf=lambda x: isinstance(x, dict),
387
438
  )
388
439
  )
389
440
  else:
@@ -416,7 +467,7 @@ class LossPDEStatio(
416
467
  ----------
417
468
  u : AbstractPINN
418
469
  the PINN
419
- dynamic_loss : PDEStatio | None
470
+ dynamic_loss : tuple[PDEStatio, ...] | PDEStatio | None
420
471
  the stationary PDE dynamic part of the loss, basically the differential
421
472
  operator $\mathcal{N}[u](x)$. Should implement a method
422
473
  `dynamic_loss.evaluate(x, u, params)`.
@@ -479,7 +530,7 @@ class LossPDEStatio(
479
530
  Alternatively, the user can pass a float or an integer.
480
531
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
481
532
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
482
- obs_slice : slice, default=None
533
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
483
534
  slice object specifying the begininning/ending of the PINN output
484
535
  that is observed (this is then useful for multidim PINN). Default is None.
485
536
 
@@ -494,10 +545,9 @@ class LossPDEStatio(
494
545
  # (ie. jax.Array cannot be static) and that we do not expect to change
495
546
 
496
547
  u: AbstractPINN
497
- dynamic_loss: PDEStatio | None
548
+ dynamic_loss: tuple[PDEStatio | None, ...]
498
549
  loss_weights: LossWeightsPDEStatio
499
550
  derivative_keys: DerivativeKeysPDEStatio
500
- vmap_in_axes: tuple[int] = eqx.field(static=True)
501
551
 
502
552
  params: InitVar[Params[Array] | None]
503
553
 
@@ -505,7 +555,7 @@ class LossPDEStatio(
505
555
  self,
506
556
  *,
507
557
  u: AbstractPINN,
508
- dynamic_loss: PDEStatio | None,
558
+ dynamic_loss: tuple[PDEStatio, ...] | PDEStatio | None,
509
559
  loss_weights: LossWeightsPDEStatio | None = None,
510
560
  derivative_keys: DerivativeKeysPDEStatio | None = None,
511
561
  params: Params[Array] | None = None,
@@ -516,25 +566,26 @@ class LossPDEStatio(
516
566
  self.loss_weights = LossWeightsPDEStatio()
517
567
  else:
518
568
  self.loss_weights = loss_weights
519
- self.dynamic_loss = dynamic_loss
520
-
521
- super().__init__(
522
- **kwargs,
523
- )
524
569
 
525
570
  if derivative_keys is None:
526
571
  # be default we only take gradient wrt nn_params
527
572
  try:
528
- self.derivative_keys = DerivativeKeysPDEStatio(params=params)
573
+ derivative_keys = DerivativeKeysPDEStatio(params=params)
529
574
  except ValueError as exc:
530
575
  raise ValueError(
531
576
  "Problem at derivative_keys initialization "
532
577
  f"received {derivative_keys=} and {params=}"
533
578
  ) from exc
534
- else:
535
- self.derivative_keys = derivative_keys
536
579
 
537
- self.vmap_in_axes = (0,)
580
+ super().__init__(
581
+ derivative_keys=derivative_keys,
582
+ vmap_in_axes=(0,),
583
+ **kwargs,
584
+ )
585
+ if not isinstance(dynamic_loss, tuple):
586
+ self.dynamic_loss = (dynamic_loss,)
587
+ else:
588
+ self.dynamic_loss = dynamic_loss
538
589
 
539
590
  def _get_dynamic_loss_batch(
540
591
  self, batch: PDEStatioBatch
@@ -549,11 +600,12 @@ class LossPDEStatio(
549
600
  # we could have used typing.cast though
550
601
 
551
602
  def evaluate_by_terms(
552
- self, params: Params[Array], batch: PDEStatioBatch
553
- ) -> tuple[
554
- PDEStatioComponents[Float[Array, ""] | None],
555
- PDEStatioComponents[Float[Array, ""] | None],
556
- ]:
603
+ self,
604
+ opt_params: Params[Array],
605
+ batch: PDEStatioBatch,
606
+ *,
607
+ non_opt_params: Params[Array] | None = None,
608
+ ) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
557
609
  """
558
610
  Evaluate the loss function at a batch of points for given parameters.
559
611
 
@@ -561,15 +613,22 @@ class LossPDEStatio(
561
613
 
562
614
  Parameters
563
615
  ---------
564
- params
565
- Parameters at which the loss is evaluated
616
+ opt_params
617
+ Parameters, which are optimized, at which the loss is evaluated
566
618
  batch
567
619
  Composed of a batch of points in the
568
620
  domain, a batch of points in the domain
569
621
  border and an optional additional batch of parameters (eg. for
570
622
  metamodeling) and an optional additional batch of observed
571
623
  inputs/outputs/parameters
624
+ non_opt_params
625
+ Parameters, which are non optimized, at which the loss is evaluated
572
626
  """
627
+ if non_opt_params is not None:
628
+ params = eqx.combine(opt_params, non_opt_params)
629
+ else:
630
+ params = opt_params
631
+
573
632
  # Retrieve the optional eq_params_batch
574
633
  # and update eq_params with the latter
575
634
  # and update vmap_in_axes
@@ -594,29 +653,44 @@ class LossPDEStatio(
594
653
  )
595
654
 
596
655
  # get the unweighted mses for each loss term as well as the gradients
597
- all_funs: PDEStatioComponents[Callable[[Params[Array]], Array] | None] = (
656
+ all_funs: PDEStatioComponents[
657
+ tuple[Callable[[Params[Array]], Array], ...] | None
658
+ ] = PDEStatioComponents(
659
+ dyn_loss_fun,
660
+ norm_loss_fun,
661
+ boundary_loss_fun,
662
+ obs_loss_fun,
663
+ )
664
+ all_params: PDEStatioComponents[tuple[Params[Array], ...] | None] = (
598
665
  PDEStatioComponents(
599
- dyn_loss_fun, norm_loss_fun, boundary_loss_fun, obs_loss_fun
666
+ jax.tree.map(lambda l: params, dyn_loss_fun),
667
+ jax.tree.map(lambda l: params, norm_loss_fun),
668
+ jax.tree.map(lambda l: params, boundary_loss_fun),
669
+ params_obs,
600
670
  )
601
671
  )
602
- all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
603
- params, params, params, params_obs
604
- )
605
672
  mses_grads = jax.tree.map(
606
673
  self.get_gradients,
607
674
  all_funs,
608
675
  all_params,
609
676
  is_leaf=lambda x: x is None,
610
677
  )
678
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
679
+ # of dyn_loss and then stops (but also account it should not stop when
680
+ # the tuple of dyn_loss is of length 2)
611
681
  mses = jax.tree.map(
612
- lambda leaf: leaf[0], # type: ignore
682
+ lambda leaf: leaf[0],
613
683
  mses_grads,
614
- is_leaf=lambda x: isinstance(x, tuple),
684
+ is_leaf=lambda x: isinstance(x, tuple)
685
+ and len(x) == 2
686
+ and isinstance(x[1], Params),
615
687
  )
616
688
  grads = jax.tree.map(
617
- lambda leaf: leaf[1], # type: ignore
689
+ lambda leaf: leaf[1],
618
690
  mses_grads,
619
- is_leaf=lambda x: isinstance(x, tuple),
691
+ is_leaf=lambda x: isinstance(x, tuple)
692
+ and len(x) == 2
693
+ and isinstance(x[1], Params),
620
694
  )
621
695
 
622
696
  return mses, grads
@@ -648,7 +722,7 @@ class LossPDENonStatio(
648
722
  ----------
649
723
  u : AbstractPINN
650
724
  the PINN
651
- dynamic_loss : PDENonStatio
725
+ dynamic_loss : tuple[PDENonStatio, ...] | PDENonStatio | None
652
726
  the non stationary PDE dynamic part of the loss, basically the differential
653
727
  operator $\mathcal{N}[u](t, x)$. Should implement a method
654
728
  `dynamic_loss.evaluate(t, x, u, params)`.
@@ -725,14 +799,13 @@ class LossPDENonStatio(
725
799
  Alternatively, the user can pass a float or an integer.
726
800
  These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
727
801
  $q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
728
- obs_slice : slice, default=None
802
+ obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
729
803
  slice object specifying the begininning/ending of the PINN output
730
804
  that is observed (this is then useful for multidim PINN). Default is None.
731
-
732
805
  """
733
806
 
734
807
  u: AbstractPINN
735
- dynamic_loss: PDENonStatio | None
808
+ dynamic_loss: tuple[PDENonStatio | None, ...]
736
809
  loss_weights: LossWeightsPDENonStatio
737
810
  derivative_keys: DerivativeKeysPDENonStatio
738
811
  params: InitVar[Params[Array] | None]
@@ -740,7 +813,6 @@ class LossPDENonStatio(
740
813
  initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
741
814
  eqx.field(static=True)
742
815
  )
743
- vmap_in_axes: tuple[int] = eqx.field(static=True)
744
816
  max_norm_samples_omega: int = eqx.field(static=True)
745
817
  max_norm_time_slices: int = eqx.field(static=True)
746
818
 
@@ -750,7 +822,7 @@ class LossPDENonStatio(
750
822
  self,
751
823
  *,
752
824
  u: AbstractPINN,
753
- dynamic_loss: PDENonStatio | None,
825
+ dynamic_loss: tuple[PDENonStatio, ...] | PDENonStatio | None,
754
826
  loss_weights: LossWeightsPDENonStatio | None = None,
755
827
  derivative_keys: DerivativeKeysPDENonStatio | None = None,
756
828
  initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
@@ -766,25 +838,27 @@ class LossPDENonStatio(
766
838
  self.loss_weights = LossWeightsPDENonStatio()
767
839
  else:
768
840
  self.loss_weights = loss_weights
769
- self.dynamic_loss = dynamic_loss
770
-
771
- super().__init__(
772
- **kwargs,
773
- )
774
841
 
775
842
  if derivative_keys is None:
776
843
  # be default we only take gradient wrt nn_params
777
844
  try:
778
- self.derivative_keys = DerivativeKeysPDENonStatio(params=params)
845
+ derivative_keys = DerivativeKeysPDENonStatio(params=params)
779
846
  except ValueError as exc:
780
847
  raise ValueError(
781
848
  "Problem at derivative_keys initialization "
782
849
  f"received {derivative_keys=} and {params=}"
783
850
  ) from exc
784
- else:
785
- self.derivative_keys = derivative_keys
786
851
 
787
- self.vmap_in_axes = (0,) # for t_x
852
+ super().__init__(
853
+ derivative_keys=derivative_keys,
854
+ vmap_in_axes=(0,), # for t_x
855
+ **kwargs,
856
+ )
857
+
858
+ if not isinstance(dynamic_loss, tuple):
859
+ self.dynamic_loss = (dynamic_loss,)
860
+ else:
861
+ self.dynamic_loss = dynamic_loss
788
862
 
789
863
  if initial_condition_fun is None:
790
864
  warnings.warn(
@@ -820,7 +894,11 @@ class LossPDENonStatio(
820
894
  )
821
895
 
822
896
  def evaluate_by_terms(
823
- self, params: Params[Array], batch: PDENonStatioBatch
897
+ self,
898
+ opt_params: Params[Array],
899
+ batch: PDENonStatioBatch,
900
+ *,
901
+ non_opt_params: Params[Array] | None = None,
824
902
  ) -> tuple[
825
903
  PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
826
904
  ]:
@@ -831,15 +909,22 @@ class LossPDENonStatio(
831
909
 
832
910
  Parameters
833
911
  ---------
834
- params
835
- Parameters at which the loss is evaluated
912
+ opt_params
913
+ Parameters, which are optimized, at which the loss is evaluated
836
914
  batch
837
915
  Composed of a batch of points in the
838
916
  domain, a batch of points in the domain
839
917
  border and an optional additional batch of parameters (eg. for
840
918
  metamodeling) and an optional additional batch of observed
841
919
  inputs/outputs/parameters
920
+ non_opt_params
921
+ Parameters, which are non optimized, at which the loss is evaluated
842
922
  """
923
+ if non_opt_params is not None:
924
+ params = eqx.combine(opt_params, non_opt_params)
925
+ else:
926
+ params = opt_params
927
+
843
928
  omega_initial_batch = batch.initial_batch
844
929
  assert omega_initial_batch is not None
845
930
 
@@ -868,7 +953,13 @@ class LossPDENonStatio(
868
953
 
869
954
  # initial condition
870
955
  if self.initial_condition_fun is not None:
871
- mse_initial_condition_fun: Callable[[Params[Array]], Array] | None = (
956
+ # Note that since MR 92
957
+ # initial_condition_fun is formed as a tuple for
958
+ # consistency with dynamic and observation
959
+ # losses and more modularity for later
960
+ initial_condition_fun: (
961
+ tuple[Callable[[Params[Array]], Array], ...] | None
962
+ ) = (
872
963
  lambda p: initial_condition_apply(
873
964
  self.u,
874
965
  omega_initial_batch,
@@ -876,39 +967,52 @@ class LossPDENonStatio(
876
967
  (0,) + vmap_in_axes_params,
877
968
  self.initial_condition_fun, # type: ignore
878
969
  self.t0,
879
- )
970
+ ),
880
971
  )
881
972
  else:
882
- mse_initial_condition_fun = None
973
+ initial_condition_fun = None
883
974
 
884
975
  # get the unweighted mses for each loss term as well as the gradients
885
- all_funs: PDENonStatioComponents[Callable[[Params[Array]], Array] | None] = (
976
+ all_funs: PDENonStatioComponents[
977
+ tuple[Callable[[Params[Array]], Array], ...] | None
978
+ ] = PDENonStatioComponents(
979
+ dyn_loss_fun,
980
+ norm_loss_fun,
981
+ boundary_loss_fun,
982
+ obs_loss_fun,
983
+ initial_condition_fun,
984
+ )
985
+ all_params: PDENonStatioComponents[tuple[Params[Array], ...] | None] = (
886
986
  PDENonStatioComponents(
887
- dyn_loss_fun,
888
- norm_loss_fun,
889
- boundary_loss_fun,
890
- obs_loss_fun,
891
- mse_initial_condition_fun,
987
+ jax.tree.map(lambda l: params, dyn_loss_fun),
988
+ jax.tree.map(lambda l: params, norm_loss_fun),
989
+ jax.tree.map(lambda l: params, boundary_loss_fun),
990
+ params_obs,
991
+ jax.tree.map(lambda l: params, initial_condition_fun),
892
992
  )
893
993
  )
894
- all_params: PDENonStatioComponents[Params[Array] | None] = (
895
- PDENonStatioComponents(params, params, params, params_obs, params)
896
- )
897
994
  mses_grads = jax.tree.map(
898
995
  self.get_gradients,
899
996
  all_funs,
900
997
  all_params,
901
998
  is_leaf=lambda x: x is None,
902
999
  )
1000
+ # NOTE the is_leaf below is more complex since it must pass possible the tuple
1001
+ # of dyn_loss and then stops (but also account it should not stop when
1002
+ # the tuple of dyn_loss is of length 2)
903
1003
  mses = jax.tree.map(
904
- lambda leaf: leaf[0], # type: ignore
1004
+ lambda leaf: leaf[0],
905
1005
  mses_grads,
906
- is_leaf=lambda x: isinstance(x, tuple),
1006
+ is_leaf=lambda x: isinstance(x, tuple)
1007
+ and len(x) == 2
1008
+ and isinstance(x[1], Params),
907
1009
  )
908
1010
  grads = jax.tree.map(
909
- lambda leaf: leaf[1], # type: ignore
1011
+ lambda leaf: leaf[1],
910
1012
  mses_grads,
911
- is_leaf=lambda x: isinstance(x, tuple),
1013
+ is_leaf=lambda x: isinstance(x, tuple)
1014
+ and len(x) == 2
1015
+ and isinstance(x[1], Params),
912
1016
  )
913
1017
 
914
1018
  return mses, grads