CUQIpy 1.3.0.post0.dev298__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/_distribution.py +24 -15
  5. cuqi/distribution/_joint_distribution.py +96 -11
  6. cuqi/distribution/_posterior.py +9 -0
  7. cuqi/experimental/__init__.py +1 -2
  8. cuqi/experimental/_recommender.py +4 -4
  9. cuqi/implicitprior/__init__.py +1 -1
  10. cuqi/implicitprior/_restorator.py +35 -1
  11. cuqi/legacy/__init__.py +2 -0
  12. cuqi/legacy/sampler/__init__.py +11 -0
  13. cuqi/legacy/sampler/_conjugate.py +55 -0
  14. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  15. cuqi/legacy/sampler/_cwmh.py +196 -0
  16. cuqi/legacy/sampler/_gibbs.py +231 -0
  17. cuqi/legacy/sampler/_hmc.py +335 -0
  18. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  19. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  20. cuqi/legacy/sampler/_mh.py +190 -0
  21. cuqi/legacy/sampler/_pcn.py +244 -0
  22. cuqi/legacy/sampler/_rto.py +284 -0
  23. cuqi/legacy/sampler/_sampler.py +182 -0
  24. cuqi/likelihood/_likelihood.py +1 -1
  25. cuqi/model/_model.py +212 -77
  26. cuqi/pde/__init__.py +4 -0
  27. cuqi/pde/_observation_map.py +36 -0
  28. cuqi/pde/_pde.py +52 -21
  29. cuqi/problem/_problem.py +87 -80
  30. cuqi/sampler/__init__.py +120 -8
  31. cuqi/sampler/_conjugate.py +376 -35
  32. cuqi/sampler/_conjugate_approx.py +40 -16
  33. cuqi/sampler/_cwmh.py +132 -138
  34. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  35. cuqi/sampler/_gibbs.py +269 -130
  36. cuqi/sampler/_hmc.py +328 -201
  37. cuqi/sampler/_langevin_algorithm.py +282 -98
  38. cuqi/sampler/_laplace_approximation.py +87 -117
  39. cuqi/sampler/_mh.py +47 -157
  40. cuqi/sampler/_pcn.py +56 -211
  41. cuqi/sampler/_rto.py +206 -140
  42. cuqi/sampler/_sampler.py +540 -135
  43. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +1 -1
  44. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/RECORD +47 -45
  45. cuqi/experimental/mcmc/__init__.py +0 -122
  46. cuqi/experimental/mcmc/_conjugate.py +0 -396
  47. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  48. cuqi/experimental/mcmc/_cwmh.py +0 -190
  49. cuqi/experimental/mcmc/_gibbs.py +0 -374
  50. cuqi/experimental/mcmc/_hmc.py +0 -460
  51. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
  52. cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
  53. cuqi/experimental/mcmc/_mh.py +0 -80
  54. cuqi/experimental/mcmc/_pcn.py +0 -89
  55. cuqi/experimental/mcmc/_rto.py +0 -306
  56. cuqi/experimental/mcmc/_sampler.py +0 -564
  57. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +0 -0
  58. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/licenses/LICENSE +0 -0
  59. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
cuqi/model/_model.py CHANGED
@@ -132,6 +132,10 @@ class Model(object):
132
132
  print(model(1, 1))
133
133
  print(model.gradient(np.array([1]), 1, 1))
134
134
  """
135
+
136
+ _supports_partial_eval = True
137
+ """Flag indicating that partial evaluation of Model objects is supported, i.e., calling the model object with only some of the inputs specified returns a model that can be called with the remaining inputs."""
138
+
135
139
  def __init__(self, forward, range_geometry, domain_geometry, gradient=None, jacobian=None):
136
140
 
137
141
  # Check if input is callable
@@ -311,7 +315,12 @@ class Model(object):
311
315
  "Gradient needs to be callable function or tuple of callable functions."
312
316
  )
313
317
 
314
- expected_func_non_default_args = self._non_default_args
318
+ expected_func_non_default_args = (
319
+ self._non_default_args
320
+ if not hasattr(self, "_original_non_default_args")
321
+ else self._original_non_default_args
322
+ )
323
+
315
324
  if func_type.lower() == "gradient":
316
325
  # prepend 'direction' to the expected gradient non default args
317
326
  expected_func_non_default_args = [
@@ -613,52 +622,43 @@ class Model(object):
613
622
  if non_default_args is None:
614
623
  non_default_args = self._non_default_args
615
624
 
616
- # If any args are given, add them to kwargs
625
+ # Either args or kwargs can be provided but not both
626
+ if len(args) > 0 and len(kwargs) > 0:
627
+ raise ValueError(
628
+ "The "
629
+ + map_name.lower()
630
+ + " input is specified both as positional and keyword arguments. This is not supported."
631
+ )
632
+
633
+ len_input = len(args) + len(kwargs)
634
+
635
+ # If partial evaluation, make sure input is not of type Samples
636
+ if len_input < len(non_default_args):
637
+ # If the argument is a Sample object, splitting or partial
638
+ # evaluation of the model is not supported
639
+ temp_args = args if len(args) > 0 else list(kwargs.values())
640
+ if any(isinstance(arg, Samples) for arg in temp_args):
641
+ raise ValueError(("When using Samples objects as input, the"
642
+ +" user should provide a Samples object for"
643
+ +f" each non_default_args {non_default_args}"
644
+ +" of the model. That is, partial evaluation"
645
+ +" or splitting is not supported for input"
646
+ +" of type Samples."))
647
+
648
+ # If args are given, add them to kwargs
617
649
  if len(args) > 0:
618
- if len(kwargs) > 0:
619
- raise ValueError(
620
- "The "
621
- + map_name.lower()
622
- + " input is specified both as positional and keyword arguments. This is not supported."
623
- )
624
650
 
625
- appending_error_message = ""
626
651
  # Check if the input is for multiple input case and is stacked,
627
652
  # then split it
628
- if len(args)==1 and len(non_default_args)>1:
629
- # If the argument is a Sample object, splitting is not supported
630
- if isinstance(args[0], Samples):
631
- raise ValueError(
632
- "The "
633
- + map_name.lower()
634
- + f" input is specified by a Samples object that cannot be split into multiple arguments corresponding to the non_default_args {non_default_args}."
635
- )
636
- split_succeeded, split_args = self._is_stacked_args(*args, is_par=is_par)
637
- if split_succeeded:
638
- args = split_args
639
- else:
640
- appending_error_message = (
641
- " Additionally, the "
642
- + map_name.lower()
643
- + f" input is specified by a single argument that cannot be split into multiple arguments matching the expected non_default_args {non_default_args}."
644
- )
645
-
646
- # Check if the number of args does not match the number of
647
- # non_default_args of the model
648
- if len(args) != len(non_default_args):
649
- raise ValueError(
650
- "The number of positional arguments does not match the number of non-default arguments of the "
651
- + map_name.lower()
652
- + "."
653
- + appending_error_message
654
- )
653
+ if len(args) < len(non_default_args):
654
+ args = self._split_in_case_of_stacked_args(*args, is_par=is_par)
655
655
 
656
656
  # Add args to kwargs following the order of non_default_args
657
657
  for idx, arg in enumerate(args):
658
658
  kwargs[non_default_args[idx]] = arg
659
-
659
+
660
660
  # Check kwargs matches non_default_args
661
- if set(list(kwargs.keys())) != set(non_default_args):
661
+ if not (set(list(kwargs.keys())) <= set(non_default_args)):
662
662
  if map_name == "gradient":
663
663
  error_msg = f"The gradient input is specified by a direction and keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the model {non_default_args}."
664
664
  else:
@@ -673,53 +673,41 @@ class Model(object):
673
673
  raise ValueError(error_msg)
674
674
 
675
675
  # Make sure order of kwargs is the same as non_default_args
676
- kwargs = {k: kwargs[k] for k in non_default_args}
676
+ kwargs = {k: kwargs[k] for k in non_default_args if k in kwargs}
677
677
 
678
678
  return kwargs
679
679
 
680
- def _is_stacked_args(self, *args, is_par=True):
681
- """Private function that checks if the input arguments are stacked
682
- and splits them if they are."""
683
- # Length of args should be 1 if the input is stacked (no partial
684
- # stacking is supported)
685
- if len(args) > 1:
686
- return False, args
687
-
688
- # Type of args should be parameter
689
- if not is_par:
690
- return False, args
680
+ def _split_in_case_of_stacked_args(self, *args, is_par=True):
681
+ """Private function that checks if the input args is a stacked
682
+ CUQIarray or numpy array and splits it into multiple arguments based on
683
+ the domain geometry of the model. Otherwise, it returns the input args
684
+ unchanged."""
691
685
 
692
- # args[0] should be numpy array or CUQIarray
686
+ # Check conditions for splitting and split if all conditions are met
693
687
  is_CUQIarray = isinstance(args[0], CUQIarray)
694
688
  is_numpy_array = isinstance(args[0], np.ndarray)
695
- if not is_CUQIarray and not is_numpy_array:
696
- return False, args
697
689
 
698
- # Shape of args[0] should be (domain_dim,)
699
- if not args[0].shape == (self.domain_dim,):
700
- return False, args
701
-
702
- # Ensure domain geometry is _ProductGeometry
703
- if not isinstance(
704
- self.domain_geometry, cuqi.experimental.geometry._ProductGeometry
705
- ):
706
- return False, args
707
-
708
- # Split the stacked input
709
- split_args = np.split(args[0], self.domain_geometry.stacked_par_split_indices)
710
-
711
- # Covert split args to CUQIarray if input is CUQIarray
712
- if is_CUQIarray:
713
- split_args = [
714
- CUQIarray(arg, is_par=True, geometry=self.domain_geometry.geometries[i])
715
- for i, arg in enumerate(split_args)
716
- ]
690
+ if ((is_CUQIarray or is_numpy_array) and
691
+ is_par and
692
+ len(args) == 1 and
693
+ args[0].shape == (self.domain_dim,) and
694
+ isinstance(self.domain_geometry, cuqi.experimental.geometry._ProductGeometry)):
695
+ # Split the stacked input
696
+ split_args = np.split(args[0], self.domain_geometry.stacked_par_split_indices)
697
+ # Convert split args to CUQIarray if input is CUQIarray
698
+ if is_CUQIarray:
699
+ split_args = [
700
+ CUQIarray(arg, is_par=True, geometry=self.domain_geometry.geometries[i])
701
+ for i, arg in enumerate(split_args)
702
+ ]
703
+ return split_args
717
704
 
718
- return True, split_args
705
+ else:
706
+ return args
719
707
 
720
708
  def forward(self, *args, is_par=True, **kwargs):
721
709
  """ Forward function of the model.
722
-
710
+
723
711
  Forward converts the input to function values (if needed) using the domain geometry of the model. Then it applies the forward operator to the function values and converts the output to parameters using the range geometry of the model.
724
712
 
725
713
  Parameters
@@ -733,7 +721,7 @@ class Model(object):
733
721
  If True, the inputs in `args` or `kwargs` are assumed to be parameters.
734
722
  If False, the inputs in `args` or `kwargs` are assumed to be function values.
735
723
  If `is_par` is a tuple of bools, the inputs are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
736
-
724
+
737
725
  **kwargs : keyword arguments
738
726
  keyword arguments for the forward operator. The forward operator input can be specified as either positional arguments or keyword arguments but not both.
739
727
 
@@ -750,19 +738,31 @@ class Model(object):
750
738
  kwargs = self._parse_args_add_to_kwargs(
751
739
  *args, **kwargs, is_par=is_par, map_name="model"
752
740
  )
753
-
754
- # extract args from kwargs
741
+ # Extract args from kwargs
755
742
  args = list(kwargs.values())
756
743
 
744
+ if len(kwargs) == 0:
745
+ return self
746
+
747
+ partial_arguments = len(kwargs) < len(self._non_default_args)
748
+
757
749
  # If input is a distribution, we simply change the parameter name of
758
750
  # model to match the distribution name
759
751
  if all(isinstance(x, cuqi.distribution.Distribution)
760
752
  for x in kwargs.values()):
753
+ if partial_arguments:
754
+ raise ValueError(
755
+ "Partial evaluation of the model is not supported for distributions."
756
+ )
761
757
  return self._handle_case_when_model_input_is_distributions(kwargs)
762
758
 
763
759
  # If input is a random variable, we handle it separately
764
760
  elif all(isinstance(x, cuqi.experimental.algebra.RandomVariable)
765
761
  for x in kwargs.values()):
762
+ if partial_arguments:
763
+ raise ValueError(
764
+ "Partial evaluation of the model is not supported for random variables."
765
+ )
766
766
  return self._handle_case_when_model_input_is_random_variables(kwargs)
767
767
 
768
768
  # If input is a Node from internal abstract syntax tree, we let the Node handle the operation
@@ -772,6 +772,21 @@ class Model(object):
772
772
  elif any(isinstance(args_i, cuqi.experimental.algebra.Node) for args_i in args):
773
773
  return NotImplemented
774
774
 
775
+ # if input is partial, we create a new model with the partial input
776
+ if partial_arguments:
777
+ # Create is_par_partial from the is_par to contain only the relevant parts
778
+ if isinstance(is_par, (list, tuple)):
779
+ is_par_partial = tuple(
780
+ is_par[i]
781
+ for i in range(self.number_of_inputs)
782
+ if self._non_default_args[i] in kwargs.keys()
783
+ )
784
+ else:
785
+ is_par_partial = is_par
786
+ # Build a partial model with the given kwargs
787
+ partial_model = self._build_partial_model(kwargs, is_par_partial)
788
+ return partial_model
789
+
775
790
  # Else we apply the forward operator
776
791
  # if model has _original_non_default_args, we use it to replace the
777
792
  # kwargs keys so that it matches self._forward_func signature
@@ -797,6 +812,126 @@ class Model(object):
797
812
  else:
798
813
  return False
799
814
 
815
+ def _build_partial_model(self, kwargs, is_par):
816
+ """Private function that builds a partial model substituting the given
817
+ keyword arguments with their values. The created partial model will have
818
+ as inputs the non-default arguments that are not in the kwargs."""
819
+
820
+ # Extract args from kwargs
821
+ args = list(kwargs.values())
822
+
823
+ # Define original_non_default_args which represents the complete list of
824
+ # non-default arguments of the forward function.
825
+ original_non_default_args = (
826
+ self._original_non_default_args
827
+ if hasattr(self, "_original_non_default_args")
828
+ else self._non_default_args
829
+ )
830
+
831
+ if hasattr(self, "_original_non_default_args"):
832
+ # Split the _original_non_default_args into two lists:
833
+ # 1. reduced_original_non_default_args: the _original_non_default_args
834
+ # corresponding to the _non_default_args that are not in kwargs
835
+ # 2. substituted_non_default_args: the _original_non_default_args
836
+ # corresponding to the _non_default_args that are in kwargs
837
+ reduced_original_non_default_args = [
838
+ original_non_default_args[i]
839
+ for i in range(self.number_of_inputs)
840
+ if self._non_default_args[i] not in kwargs.keys()
841
+ ]
842
+ substituted_non_default_args = [
843
+ original_non_default_args[i]
844
+ for i in range(self.number_of_inputs)
845
+ if self._non_default_args[i] in kwargs.keys()
846
+ ]
847
+ # Replace the keys in kwargs with the substituted_non_default_args
848
+ # so that the kwargs match the signature of the _forward_func
849
+ kwargs = {k: v for k, v in zip(substituted_non_default_args, args)}
850
+
851
+ # Create a partial domain geometry with the geometries corresponding
852
+ # to the non-default arguments that are not in kwargs (remaining
853
+ # unspecified inputs)
854
+ partial_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
855
+ *[
856
+ self.domain_geometry.geometries[i]
857
+ for i in range(self.number_of_inputs)
858
+ if original_non_default_args[i] not in kwargs.keys()
859
+ ]
860
+ )
861
+
862
+ if len(partial_domain_geometry.geometries) == 1:
863
+ partial_domain_geometry = partial_domain_geometry.geometries[0]
864
+
865
+ # Create a domain geometry with the geometries corresponding to the
866
+ # non-default arguments that are specified
867
+ substituted_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
868
+ *[
869
+ self.domain_geometry.geometries[i]
870
+ for i in range(self.number_of_inputs)
871
+ if original_non_default_args[i] in kwargs.keys()
872
+ ]
873
+ )
874
+
875
+ if len(substituted_domain_geometry.geometries) == 1:
876
+ substituted_domain_geometry = substituted_domain_geometry.geometries[0]
877
+
878
+ # Create new model with partial input
879
+ # First, we convert the input to function values
880
+ kwargs = self._2fun(geometry=substituted_domain_geometry, is_par=is_par, **kwargs)
881
+
882
+ # Second, we create a partial function for the forward operator
883
+ partial_forward = partial(self._forward_func, **kwargs)
884
+
885
+ # Third, if applicable, we create a partial function for the gradient
886
+ if isinstance(self._gradient_func, tuple):
887
+ # If gradient is a tuple, we create a partial function for each
888
+ # gradient function in the tuple
889
+ partial_gradient = tuple(
890
+ (
891
+ partial(self._gradient_func[i], **kwargs)
892
+ if self._gradient_func[i] is not None
893
+ else None
894
+ )
895
+ for i in range(self.number_of_inputs)
896
+ if original_non_default_args[i] not in kwargs.keys()
897
+ )
898
+ if len(partial_gradient) == 1:
899
+ partial_gradient = partial_gradient[0]
900
+
901
+ elif callable(self._gradient_func):
902
+ raise NotImplementedError(
903
+ "Partial forward model is only supported for gradient/jacobian functions that are tuples of callable functions."
904
+ )
905
+
906
+ else:
907
+ partial_gradient = None
908
+
909
+ # Lastly, we create the partial model with the partial forward
910
+ # operator (we set the gradient function later)
911
+ partial_model = Model(
912
+ forward=partial_forward,
913
+ range_geometry=self.range_geometry,
914
+ domain_geometry=partial_domain_geometry,
915
+ )
916
+
917
+ # Set the _original_non_default_args (if applicable) and
918
+ # _stored_non_default_args of the partial model
919
+ if hasattr(self, "_original_non_default_args"):
920
+ partial_model._original_non_default_args = reduced_original_non_default_args
921
+ partial_model._stored_non_default_args = [
922
+ self._non_default_args[i]
923
+ for i in range(self.number_of_inputs)
924
+ if original_non_default_args[i] not in kwargs.keys()
925
+ ]
926
+
927
+ # Set the gradient function of the partial model
928
+ partial_model._check_correct_gradient_jacobian_form(
929
+ partial_gradient, "gradient"
930
+ )
931
+ partial_model._gradient_func = partial_gradient
932
+
933
+ return partial_model
934
+
800
935
  def _handle_case_when_model_input_is_distributions(self, kwargs):
801
936
  """Private function that handles the case of the input being a
802
937
  distribution or multiple distributions."""
cuqi/pde/__init__.py CHANGED
@@ -4,3 +4,7 @@ from ._pde import (
4
4
  SteadyStateLinearPDE,
5
5
  TimeDependentLinearPDE
6
6
  )
7
+
8
+ from ._observation_map import (
9
+ FD_spatial_gradient
10
+ )
@@ -0,0 +1,36 @@
1
+ import scipy
2
+ import numpy as np
3
+ """
4
+ This module contains observation map examples for PDE problems. The map can
5
+ be passed to the `PDE` object initializer via the `observation_map` argument.
6
+
7
+ For example on how to use set observation maps in time dependent PDEs, see
8
+ `demos/howtos/TimeDependentLinearPDE.py`.
9
+ """
10
+
11
+ # 1. Steady State Observation Maps
12
+ # --------------------------------
13
+
14
+ # 2. Time-Dependent Observation Maps
15
+ # -----------------------------------
16
+ def FD_spatial_gradient(sol, grid, times):
17
+ """Time dependent observation map that computes the finite difference (FD) spatial gradient of a solution given at grid points (grid) and times (times). This map is supported for 1D spatial domains only.
18
+
19
+ Parameters
20
+ ----------
21
+ sol : np.ndarray
22
+ The solution array of shape (number of grid points, number of time steps).
23
+
24
+ grid : np.ndarray
25
+ The spatial grid points of shape (number of grid points,).
26
+
27
+ times : np.ndarray
28
+ The discretized time steps of shape (number of time steps,)."""
29
+
30
+ if len(grid.shape) != 1:
31
+ raise ValueError("FD_spatial_gradient only supports 1D spatial domains.")
32
+ observed_quantity = np.zeros((len(grid)-1, len(times)))
33
+ for i in range(observed_quantity.shape[0]):
34
+ observed_quantity[i, :] = ((sol[i, :] - sol[i+1, :])/
35
+ (grid[i] - grid[i+1]))
36
+ return observed_quantity
cuqi/pde/_pde.py CHANGED
@@ -15,14 +15,15 @@ class PDE(ABC):
15
15
  PDE_form : callable function
16
16
  Callable function which returns a tuple of the needed PDE components (expected components are explained in the subclasses)
17
17
 
18
- observation_map: a function handle
19
- A function that takes the PDE solution as input and the returns the observed solution. e.g. `observation_map=lambda u: u**2` or `observation_map=lambda u: u[0]`
20
-
21
18
  grid_sol: np.ndarray
22
19
  The grid on which solution is defined
23
20
 
24
21
  grid_obs: np.ndarray
25
- The grid on which the observed solution should be interpolated (currently only supported for 1D problems).
22
+ The grid on which the observed solution should be interpolated (currently only supported for 1D problems).
23
+
24
+ observation_map: a function handle
25
+ A function that takes the PDE solution, interpolated on `grid_obs`, as input and returns the observed solution. e.g., `observation_map=lambda u, grid_obs: u**2`.
26
+
26
27
  """
27
28
 
28
29
  def __init__(self, PDE_form, grid_sol=None, grid_obs=None, observation_map=None):
@@ -187,6 +188,10 @@ class LinearPDE(PDE):
187
188
  info = None
188
189
 
189
190
  return solution, info
191
+
192
+ def interpolate_on_observed_domain(self, solution):
193
+ """Interpolate solution on observed space domain."""
194
+ raise NotImplementedError("interpolate_on_observed_domain method is not implemented for LinearPDE base class.")
190
195
 
191
196
  class SteadyStateLinearPDE(LinearPDE):
192
197
  """Linear steady state PDE.
@@ -194,7 +199,10 @@ class SteadyStateLinearPDE(LinearPDE):
194
199
  Parameters
195
200
  -----------
196
201
  PDE_form : callable function
197
- Callable function with signature `PDE_form(parameter1, parameter2, ...)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.). The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
202
+ Callable function with signature `PDE_form(parameter1, parameter2, ...)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.). The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
203
+
204
+ observation_map: a function handle
205
+ A function that takes the PDE solution, interpolated on `grid_obs`, as input and returns the observed solution. e.g. `observation_map=lambda u, grid_obs: u**2`.
198
206
 
199
207
  kwargs:
200
208
  See :class:`~cuqi.pde.LinearPDE` for the remaining keyword arguments.
@@ -204,8 +212,8 @@ class SteadyStateLinearPDE(LinearPDE):
204
212
  See demo demos/demo24_fwd_poisson.py for an illustration on how to use SteadyStateLinearPDE with varying solver choices. And demos demos/demo25_fwd_poisson_2D.py and demos/demo26_fwd_poisson_mixedBC.py for examples with mixed (Dirichlet and Neumann) boundary conditions problems. demos/demo25_fwd_poisson_2D.py also illustrates how to observe on a specific boundary, for example.
205
213
  """
206
214
 
207
- def __init__(self, PDE_form, **kwargs):
208
- super().__init__(PDE_form, **kwargs)
215
+ def __init__(self, PDE_form, observation_map=None, **kwargs):
216
+ super().__init__(PDE_form, observation_map=observation_map, **kwargs)
209
217
 
210
218
  def assemble(self, *args, **kwargs):
211
219
  """Assembles differential operator and rhs according to PDE_form"""
@@ -221,17 +229,25 @@ class SteadyStateLinearPDE(LinearPDE):
221
229
 
222
230
  return self._solve_linear_system(self.diff_op, self.rhs, self._linalg_solve, self._linalg_solve_kwargs)
223
231
 
224
-
225
- def observe(self, solution):
226
-
232
+ def interpolate_on_observed_domain(self, solution):
233
+ """Interpolate solution on observed space grid."""
227
234
  if self.grids_equal:
228
235
  solution_obs = solution
229
236
  else:
230
237
  solution_obs = interp1d(self.grid_sol, solution, kind='quadratic')(self.grid_obs)
238
+ return solution_obs
239
+
240
+ def observe(self, solution):
241
+ """Apply observation operator to the solution. This includes
242
+ interpolation to observation points (if different from the
243
+ solution grid) then applying the observation map (if provided)."""
244
+
245
+ # Interpolate solution on observed domain
246
+ solution_obs = self.interpolate_on_observed_domain(solution)
231
247
 
232
248
  if self.observation_map is not None:
233
- solution_obs = self.observation_map(solution_obs)
234
-
249
+ solution_obs = self.observation_map(solution_obs, self.grid_obs)
250
+
235
251
  return solution_obs
236
252
 
237
253
  class TimeDependentLinearPDE(LinearPDE):
@@ -251,16 +267,20 @@ class TimeDependentLinearPDE(LinearPDE):
251
267
  method: str
252
268
  Time stepping method. Currently two options are available `forward_euler` and `backward_euler`.
253
269
 
270
+ observation_map: a function handle
271
+ A function that takes the PDE solution, interpolated on `grid_obs` and `time_obs`, as input and returns the observed solution. e.g. `observation_map=lambda u, grid_obs, time_obs: u**2`.
272
+
254
273
  kwargs:
255
274
  See :class:`~cuqi.pde.LinearPDE` for the remaining keyword arguments
256
275
 
257
276
  Example
258
277
  -----------
259
- See demos/demo34_TimeDependentLinearPDE.py for 1D heat and 1D wave equations.
278
+ See demos/howtos/TimeDependentLinearPDE.py for 1D heat and 1D wave equations examples. It demonstrates setting up `TimeDependentLinearPDE` objects, including the choice of time stepping methods, observation domain, and observation map.
260
279
  """
261
280
 
262
- def __init__(self, PDE_form, time_steps, time_obs='final', method='forward_euler', **kwargs):
263
- super().__init__(PDE_form, **kwargs)
281
+ def __init__(self, PDE_form, time_steps, time_obs='final',
282
+ method='forward_euler', observation_map=None, **kwargs):
283
+ super().__init__(PDE_form, observation_map=observation_map, **kwargs)
264
284
 
265
285
  self.time_steps = time_steps
266
286
  self.method = method
@@ -339,8 +359,8 @@ class TimeDependentLinearPDE(LinearPDE):
339
359
 
340
360
  return u, info
341
361
 
342
- def observe(self, solution):
343
-
362
+ def interpolate_on_observed_domain(self, solution):
363
+ """Interpolate solution on observed time and space points."""
344
364
  # If observation grid is the same as solution grid and observation time
345
365
  # is the final time step then no need to interpolate
346
366
  if self.grids_equal and np.all(self.time_steps[-1:] == self._time_obs):
@@ -361,15 +381,26 @@ class TimeDependentLinearPDE(LinearPDE):
361
381
  # Interpolate solution in space and time to the observation
362
382
  # time and space
363
383
  solution_obs = scipy.interpolate.RectBivariateSpline(
364
- self.grid_sol, self.time_steps, solution)(self.grid_obs,
365
- self._time_obs)
384
+ self.grid_sol, self.time_steps, solution
385
+ )(self.grid_obs, self._time_obs)
366
386
 
387
+ return solution_obs
388
+
389
+ def observe(self, solution):
390
+ """Apply observation operator to the solution. This includes
391
+ interpolation to observation points (if different from the
392
+ solution grid) then applying the observation map (if provided)."""
393
+
394
+ # Interpolate solution on observed domain
395
+ solution_obs = self.interpolate_on_observed_domain(solution)
396
+
367
397
  # Apply observation map
368
398
  if self.observation_map is not None:
369
- solution_obs = self.observation_map(solution_obs)
399
+ solution_obs = self.observation_map(solution_obs, self.grid_obs,
400
+ self._time_obs)
370
401
 
371
402
  # squeeze if only one time observation
372
403
  if len(self._time_obs) == 1:
373
404
  solution_obs = solution_obs.squeeze()
374
405
 
375
- return solution_obs
406
+ return solution_obs