CUQIpy 1.3.0__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/__init__.py +1 -1
  5. cuqi/distribution/_beta.py +1 -1
  6. cuqi/distribution/_cauchy.py +2 -2
  7. cuqi/distribution/_distribution.py +24 -15
  8. cuqi/distribution/_joint_distribution.py +97 -12
  9. cuqi/distribution/_posterior.py +9 -0
  10. cuqi/distribution/_truncated_normal.py +3 -3
  11. cuqi/distribution/_uniform.py +36 -2
  12. cuqi/experimental/__init__.py +1 -1
  13. cuqi/experimental/_recommender.py +216 -0
  14. cuqi/experimental/geometry/_productgeometry.py +3 -3
  15. cuqi/geometry/_geometry.py +12 -1
  16. cuqi/implicitprior/__init__.py +1 -1
  17. cuqi/implicitprior/_regularizedGaussian.py +40 -4
  18. cuqi/implicitprior/_restorator.py +35 -1
  19. cuqi/legacy/__init__.py +2 -0
  20. cuqi/legacy/sampler/__init__.py +11 -0
  21. cuqi/legacy/sampler/_conjugate.py +55 -0
  22. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  23. cuqi/legacy/sampler/_cwmh.py +196 -0
  24. cuqi/legacy/sampler/_gibbs.py +231 -0
  25. cuqi/legacy/sampler/_hmc.py +335 -0
  26. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  27. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  28. cuqi/legacy/sampler/_mh.py +190 -0
  29. cuqi/legacy/sampler/_pcn.py +244 -0
  30. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +134 -152
  31. cuqi/legacy/sampler/_sampler.py +182 -0
  32. cuqi/likelihood/_likelihood.py +1 -1
  33. cuqi/model/_model.py +1248 -357
  34. cuqi/pde/__init__.py +4 -0
  35. cuqi/pde/_observation_map.py +36 -0
  36. cuqi/pde/_pde.py +133 -32
  37. cuqi/problem/_problem.py +88 -82
  38. cuqi/sampler/__init__.py +120 -8
  39. cuqi/sampler/_conjugate.py +376 -35
  40. cuqi/sampler/_conjugate_approx.py +40 -16
  41. cuqi/sampler/_cwmh.py +132 -138
  42. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  43. cuqi/sampler/_gibbs.py +269 -130
  44. cuqi/sampler/_hmc.py +328 -201
  45. cuqi/sampler/_langevin_algorithm.py +282 -98
  46. cuqi/sampler/_laplace_approximation.py +87 -117
  47. cuqi/sampler/_mh.py +47 -157
  48. cuqi/sampler/_pcn.py +56 -211
  49. cuqi/sampler/_rto.py +206 -140
  50. cuqi/sampler/_sampler.py +540 -135
  51. cuqi/solver/_solver.py +6 -2
  52. cuqi/testproblem/_testproblem.py +2 -3
  53. cuqi/utilities/__init__.py +3 -1
  54. cuqi/utilities/_utilities.py +94 -12
  55. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +6 -4
  56. cuqipy-1.4.0.post0.dev61.dist-info/RECORD +102 -0
  57. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +1 -1
  58. CUQIpy-1.3.0.dist-info/RECORD +0 -100
  59. cuqi/experimental/mcmc/__init__.py +0 -123
  60. cuqi/experimental/mcmc/_conjugate.py +0 -345
  61. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  62. cuqi/experimental/mcmc/_cwmh.py +0 -193
  63. cuqi/experimental/mcmc/_gibbs.py +0 -318
  64. cuqi/experimental/mcmc/_hmc.py +0 -464
  65. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -392
  66. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  67. cuqi/experimental/mcmc/_mh.py +0 -80
  68. cuqi/experimental/mcmc/_pcn.py +0 -89
  69. cuqi/experimental/mcmc/_sampler.py +0 -566
  70. cuqi/experimental/mcmc/_utilities.py +0 -17
  71. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info/licenses}/LICENSE +0 -0
  72. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
cuqi/model/_model.py CHANGED
@@ -5,48 +5,55 @@ from scipy.sparse import hstack
5
5
  from scipy.linalg import solve
6
6
  from cuqi.samples import Samples
7
7
  from cuqi.array import CUQIarray
8
- from cuqi.geometry import Geometry, _DefaultGeometry1D, _DefaultGeometry2D, _get_identity_geometries
8
+ from cuqi.geometry import Geometry, _DefaultGeometry1D, _DefaultGeometry2D,\
9
+ _get_identity_geometries
9
10
  import cuqi
10
11
  import matplotlib.pyplot as plt
11
12
  from copy import copy
13
+ from functools import partial
14
+ from cuqi.utilities import force_ndarray
12
15
 
13
16
  class Model(object):
14
17
  """Generic model defined by a forward operator.
15
18
 
16
19
  Parameters
17
20
  -----------
18
- forward : 2D ndarray or callable function.
19
- Forward operator.
21
+ forward : callable function
22
+ Forward operator of the model. It takes one or more inputs and returns the model output.
20
23
 
21
- range_geometry : integer or cuqi.geometry.Geometry
22
- If integer is given, a cuqi.geometry._DefaultGeometry is created with dimension of the integer.
24
+ range_geometry : integer, a 1D or 2D tuple of integers, cuqi.geometry.Geometry
25
+ If integer or 1D tuple of integers is given, a cuqi.geometry._DefaultGeometry1D is created with dimension of the integer.
26
+ If 2D tuple of integers is given, a cuqi.geometry._DefaultGeometry2D is created with dimensions of the tuple.
27
+ If cuqi.geometry.Geometry object is given, it is used as the range geometry of the model.
23
28
 
24
- domain_geometry : integer or cuqi.geometry.Geometry
25
- If integer is given, a cuqi.geometry._DefaultGeometry is created with dimension of the integer.
29
+ domain_geometry : integer, a 1D or 2D tuple of integers, cuqi.geometry.Geometry or a tuple with items of any of the listed types
30
+ If integer or 1D tuple of integers is given, a cuqi.geometry._DefaultGeometry1D is created with dimension of the integer.
31
+ If 2D tuple of integers is given (and the forward model has one input only), a cuqi.geometry._DefaultGeometry2D is created with dimensions of the tuple.
32
+ If cuqi.geometry.Geometry is given, it is used as the domain geometry.
33
+ If tuple of the above types is given, a cuqi.geometry._ProductGeometry is created based on the tuple entries. This is used for models with multiple inputs where each entry in the tuple represents the geometry of each input.
26
34
 
27
- gradient : callable function, optional
28
- The direction-Jacobian product of the forward operator Jacobian with
29
- respect to the forward operator input, evaluated at a point (`wrt`).
30
- The signature of the gradient function should be (`direction`, `wrt`),
31
- where `direction` is the direction by which the Jacobian matrix is
32
- multiplied and `wrt` is the point at which the Jacobian is computed.
35
+ gradient : callable function, a tuple of callable functions or None, optional
36
+ The direction-Jacobian product of the forward model Jacobian with respect to the model input, evaluated at the model input. For example, if the forward model inputs are `x` and `y`, the gradient callable signature should be (`direction`, `x`, `y`), in that order, where `direction` is the direction by which the Jacobian matrix is multiplied and `x` and `y` are the parameters at which the Jacobian is computed.
33
37
 
34
- jacobian : callable function, optional
35
- The Jacobian of the forward operator with respect to the forward operator input,
36
- evaluated at a point (`wrt`). The signature of the Jacobian function should be (`wrt`).
37
- The Jacobian function should return a 2D ndarray of shape (range_dim, domain_dim).
38
- The Jacobian function is used to specify the gradient function by computing the vector-Jacobian
39
- product (VJP), here we refer to the vector in the VJP as the `direction` since it is the direction at
40
- which the gradient is computed.
41
- automatically and thus the gradient function should not be specified when the Jacobian
42
- function is specified.
38
+ If the gradient function is a single callable function, it returns a 1D ndarray if the model has only one input. If the model has multiple inputs, this gradient function should return a tuple of 1D ndarrays, each representing the gradient with respect to each input.
39
+
40
+ If the gradient function is a tuple of callable functions, each callable function should return a 1D ndarray representing the gradient with respect to each input. The order of the callable functions in the tuple should match the order of the model inputs.
41
+
42
+ jacobian : callable function, a tuple of callable functions or None, optional
43
+ The Jacobian of the forward model with respect to the forward model input, evaluated at the model input. For example, if the forward model inputs are `x` and `y`, the jacobian signature should be (`x`, `y`), in that order, where `x` and `y` are the parameters at which the Jacobian is computed.
44
+
45
+ If the Jacobian function is a single callable function, it should return a 2D ndarray of shape (range_dim, domain_dim) if the model has only one input. If the model has multiple inputs, this Jacobian function should return a tuple of 2D ndarrays, each representing the Jacobian with respect to each input.
46
+
47
+ If the Jacobian function is a tuple of callable functions, each callable function should return a 2D ndarray representing the Jacobian with respect to each input. The order of the callable functions in the tuple should match the order of the model inputs.
48
+
49
+ The Jacobian function is used to specify the gradient function by computing the vector-Jacobian product (VJP), here we refer to the vector in the VJP as the `direction` since it is the direction at which the gradient is computed. Either the gradient or the Jacobian can be specified, but not both.
43
50
 
44
51
 
45
52
  :ivar range_geometry: The geometry representing the range.
46
53
  :ivar domain_geometry: The geometry representing the domain.
47
54
 
48
- Example
49
- -------
55
+ Example 1
56
+ ----------
50
57
 
51
58
  Consider a forward model :math:`F: \mathbb{R}^2 \\rightarrow \mathbb{R}` defined by the following forward operator:
52
59
 
@@ -75,6 +82,9 @@ class Model(object):
75
82
 
76
83
  model = Model(forward, range_geometry=1, domain_geometry=2, jacobian=jacobian)
77
84
 
85
+ print(model(np.array([1, 1])))
86
+ print(model.gradient(np.array([1]), np.array([1, 1])))
87
+
78
88
  Alternatively, the gradient information in the forward model can be defined by direction-Jacobian product using the gradient keyword argument.
79
89
 
80
90
  This may be more efficient if forming the Jacobian matrix is expensive.
@@ -87,64 +97,162 @@ class Model(object):
87
97
  def forward(x):
88
98
  return 10*x[1] - 10*x[0]**3 + 5*x[0]**2 + 6*x[0]
89
99
 
90
- def gradient(direction, wrt):
91
- # Direction-Jacobian product direction@jacobian(wrt)
92
- return direction@np.array([[-30*wrt[0]**2 + 10*wrt[0] + 6, 10]])
100
+ def gradient(direction, x):
101
+ # Direction-Jacobian product direction@jacobian(x)
102
+ return direction@np.array([[-30*x[0]**2 + 10*x[0] + 6, 10]])
93
103
 
94
104
  model = Model(forward, range_geometry=1, domain_geometry=2, gradient=gradient)
95
105
 
106
+ print(model(np.array([1, 1])))
107
+ print(model.gradient(np.array([1]), np.array([1, 1])))
108
+
109
+ Example 2
110
+ ----------
111
+ Alternatively, the example above can be defined as a model with multiple inputs: :math:`x` and :math:`y`:
112
+
113
+ .. code-block:: python
114
+
115
+ import numpy as np
116
+ from cuqi.model import Model
117
+ from cuqi.geometry import Discrete
118
+
119
+ def forward(x, y):
120
+ return 10 * y - 10 * x**3 + 5 * x**2 + 6 * x
121
+
122
+ def jacobian(x, y):
123
+ return (np.array([[-30 * x**2 + 10 * x + 6]]), np.array([[10]]))
124
+
125
+ model = Model(
126
+ forward,
127
+ range_geometry=1,
128
+ domain_geometry=(Discrete(1), Discrete(1)),
129
+ jacobian=jacobian,
130
+ )
131
+
132
+ print(model(1, 1))
133
+ print(model.gradient(np.array([1]), 1, 1))
96
134
  """
135
+
136
+ _supports_partial_eval = True
137
+ """Flag indicating that partial evaluation of Model objects is supported, i.e., calling the model object with only some of the inputs specified returns a model that can be called with the remaining inputs."""
138
+
97
139
  def __init__(self, forward, range_geometry, domain_geometry, gradient=None, jacobian=None):
98
140
 
99
- #Check if input is callable
141
+ # Check if input is callable
100
142
  if callable(forward) is not True:
101
143
  raise TypeError("Forward needs to be callable function.")
102
-
144
+
145
+ # Store forward func
146
+ self._forward_func = forward
147
+ self._stored_non_default_args = None
148
+
149
+ # Store range_geometry
150
+ self.range_geometry = range_geometry
151
+
152
+ # Store domain_geometry
153
+ self.domain_geometry = domain_geometry
154
+
155
+ # Additional checks for the forward operator
156
+ self._check_domain_geometry_consistent_with_forward()
157
+
103
158
  # Check if only one of gradient and jacobian is given
104
159
  if (gradient is not None) and (jacobian is not None):
105
160
  raise TypeError("Only one of gradient and jacobian should be specified")
106
-
107
- #Check if input is callable
108
- if (gradient is not None) and (callable(gradient) is not True):
109
- raise TypeError("Gradient needs to be callable function.")
110
-
111
- if (jacobian is not None) and (callable(jacobian) is not True):
112
- raise TypeError("Jacobian needs to be callable function.")
113
-
114
- # Use jacobian function to specify gradient function (vector-Jacobian product)
161
+
162
+ # Check correct gradient form (check type, signature, etc.)
163
+ self._check_correct_gradient_jacobian_form(gradient, "gradient")
164
+
165
+ # Check correct jacobian form (check type, signature, etc.)
166
+ self._check_correct_gradient_jacobian_form(jacobian, "jacobian")
167
+
168
+ # If jacobian is provided, use it to specify gradient function
169
+ # (vector-Jacobian product)
115
170
  if jacobian is not None:
116
- gradient = lambda direction, wrt: direction@jacobian(wrt)
117
-
118
- #Store forward func
119
- self._forward_func = forward
171
+ gradient = self._use_jacobian_to_specify_gradient(jacobian)
172
+
120
173
  self._gradient_func = gradient
121
-
122
- #Store range_geometry
123
- if isinstance(range_geometry, tuple) and len(range_geometry) == 2:
124
- self.range_geometry = _DefaultGeometry2D(range_geometry)
125
- elif isinstance(range_geometry, int):
126
- self.range_geometry = _DefaultGeometry1D(grid=range_geometry)
127
- elif isinstance(range_geometry, Geometry):
128
- self.range_geometry = range_geometry
129
- elif range_geometry is None:
130
- raise AttributeError("The parameter 'range_geometry' is not specified by the user and it connot be inferred from the attribute 'forward'.")
131
- else:
132
- raise TypeError("The parameter 'range_geometry' should be of type 'int', 2 dimensional 'tuple' or 'cuqi.geometry.Geometry'.")
133
-
134
- #Store domain_geometry
135
- if isinstance(domain_geometry, tuple) and len(domain_geometry) == 2:
136
- self.domain_geometry = _DefaultGeometry2D(domain_geometry)
137
- elif isinstance(domain_geometry, int):
138
- self.domain_geometry = _DefaultGeometry1D(grid=domain_geometry)
139
- elif isinstance(domain_geometry, Geometry):
140
- self.domain_geometry = domain_geometry
141
- elif domain_geometry is None:
142
- raise AttributeError("The parameter 'domain_geometry' is not specified by the user and it connot be inferred from the attribute 'forward'.")
174
+
175
+ # Set gradient output stacked flag to False
176
+ self._gradient_output_stacked = False
177
+
178
+ @property
179
+ def _non_default_args(self):
180
+ if self._stored_non_default_args is None:
181
+ # Store non_default_args of the forward operator for faster caching
182
+ # when checking for those arguments.
183
+ self._stored_non_default_args =\
184
+ cuqi.utilities.get_non_default_args(self._forward_func)
185
+ return self._stored_non_default_args
186
+
187
+ @property
188
+ def number_of_inputs(self):
189
+ """ The number of inputs of the model. """
190
+ return len(self._non_default_args)
191
+
192
+ @property
193
+ def range_geometry(self):
194
+ """ The geometry representing the range of the model. """
195
+ return self._range_geometry
196
+
197
+ @range_geometry.setter
198
+ def range_geometry(self, value):
199
+ """ Update the range geometry of the model. """
200
+ if isinstance(value, Geometry):
201
+ self._range_geometry = value
202
+ elif isinstance(value, int):
203
+ self._range_geometry = self._create_default_geometry(value)
204
+ elif isinstance(value, tuple):
205
+ self._range_geometry = self._create_default_geometry(value)
206
+ elif value is None:
207
+ raise AttributeError(
208
+ "The parameter 'range_geometry' is not specified by the user and it cannot be inferred from the attribute 'forward'."
209
+ )
143
210
  else:
144
- raise TypeError("The parameter 'domain_geometry' should be of type 'int', 2 dimensional 'tuple' or 'cuqi.geometry.Geometry'.")
211
+ raise TypeError(
212
+ " The allowed types for 'range_geometry' are: 'cuqi.geometry.Geometry', int, 1D tuple of int, or 2D tuple of int."
213
+ )
145
214
 
146
- # Store non_default_args of the forward operator for faster caching when checking for those arguments.
147
- self._non_default_args = cuqi.utilities.get_non_default_args(self._forward_func)
215
+ @property
216
+ def domain_geometry(self):
217
+ """ The geometry representing the domain of the model. """
218
+ return self._domain_geometry
219
+
220
+ @domain_geometry.setter
221
+ def domain_geometry(self, value):
222
+ """ Update the domain geometry of the model. """
223
+
224
+ if isinstance(value, Geometry):
225
+ self._domain_geometry = value
226
+ elif isinstance(value, int):
227
+ self._domain_geometry = self._create_default_geometry(value)
228
+ elif isinstance(value, tuple) and self.number_of_inputs == 1:
229
+ self._domain_geometry = self._create_default_geometry(value)
230
+ elif isinstance(value, tuple) and self.number_of_inputs > 1:
231
+ geometries = [item if isinstance(item, Geometry) else self._create_default_geometry(item) for item in value]
232
+ self._domain_geometry = cuqi.experimental.geometry._ProductGeometry(*geometries)
233
+ elif value is None:
234
+ raise AttributeError(
235
+ "The parameter 'domain_geometry' is not specified by the user and it cannot be inferred from the attribute 'forward'."
236
+ )
237
+ else:
238
+ raise TypeError(
239
+ "For forward model with 1 input, the allowed types for 'domain_geometry' are: 'cuqi.geometry.Geometry', int, 1D tuple of int, or 2D tuple of int. For forward model with multiple inputs, the 'domain_geometry' should be a tuple with items of any of the above types."
240
+ )
241
+
242
+ def _create_default_geometry(self, value):
243
+ """Private function that creates default geometries for the model."""
244
+ if isinstance(value, tuple) and len(value) == 1:
245
+ value = value[0]
246
+ if isinstance(value, Geometry):
247
+ return value
248
+ if isinstance(value, int):
249
+ return _DefaultGeometry1D(grid=value)
250
+ elif isinstance(value, tuple) and len(value) == 2:
251
+ return _DefaultGeometry2D(im_shape=value)
252
+ else:
253
+ raise ValueError(
254
+ "Default geometry creation can be specified by an integer or a 2D tuple of integers."
255
+ )
148
256
 
149
257
  @property
150
258
  def domain_dim(self):
@@ -160,341 +268,977 @@ class Model(object):
160
268
  """
161
269
  return self.range_geometry.par_dim
162
270
 
163
- def _2fun(self, x, geometry, is_par):
164
- """ Converts `x` to function values (if needed) using the appropriate
165
- geometry. For example, `x` can be the model input which need to be
166
- converted to function value before being passed to
167
- :class:`~cuqi.model.Model` operators (e.g. _forward_func, _adjoint_func,
168
- _gradient_func).
271
+ def _check_domain_geometry_consistent_with_forward(self):
272
+ """Private function that checks if the domain geometry of the model is
273
+ consistent with the forward operator."""
274
+ if (
275
+ not isinstance(
276
+ self.domain_geometry, cuqi.experimental.geometry._ProductGeometry
277
+ )
278
+ and self.number_of_inputs > 1
279
+ ):
280
+ raise ValueError(
281
+ "The forward operator input is specified by more than one argument. This is only supported for domain geometry of type tuple with items of type: cuqi.geometry.Geometry object, int, or 2D tuple of int."
282
+ )
283
+
284
+ def _check_correct_gradient_jacobian_form(self, func, func_type):
285
+ """Private function that checks if the gradient/jacobian parameter is
286
+ in the correct form. That is, check if the gradient/jacobian has the
287
+ correct type, signature, etc."""
288
+
289
+ if func is None:
290
+ return
291
+
292
+ # gradient/jacobian should be callable (for single input and multiple input case)
293
+ # or a tuple of callables (for multiple inputs case)
294
+ if isinstance(func, tuple):
295
+ # tuple length should be same as the number of inputs
296
+ if len(func) != self.number_of_inputs:
297
+ raise ValueError(
298
+ f"The "
299
+ + func_type.lower()
300
+ + f" tuple length should be {self.number_of_inputs} for model with inputs {self._non_default_args}"
301
+ )
302
+ # tuple items should be callables or None
303
+ if not all([callable(func_i) or func_i is None for func_i in func]):
304
+ raise TypeError(
305
+ func_type.capitalize()
306
+ + " tuple should contain callable functions or None."
307
+ )
308
+
309
+ elif callable(func):
310
+ # temporarily convert gradient/jacobian to tuple for checking only
311
+ func = (func,)
312
+
313
+ else:
314
+ raise TypeError(
315
+ "Gradient needs to be callable function or tuple of callable functions."
316
+ )
317
+
318
+ expected_func_non_default_args = (
319
+ self._non_default_args
320
+ if not hasattr(self, "_original_non_default_args")
321
+ else self._original_non_default_args
322
+ )
323
+
324
+ if func_type.lower() == "gradient":
325
+ # prepend 'direction' to the expected gradient non default args
326
+ expected_func_non_default_args = [
327
+ "direction"
328
+ ] + expected_func_non_default_args
329
+
330
+ for func_i in func:
331
+ # make sure the signature of the gradient/jacobian function is correct
332
+ # that is, the same as the expected_func_non_default_args
333
+ if func_i is not None:
334
+ func_non_default_args = cuqi.utilities.get_non_default_args(func_i)
335
+
336
+ if list(func_non_default_args) != list(expected_func_non_default_args):
337
+ raise ValueError(
338
+ func_type.capitalize()
339
+ + f" function signature should be {expected_func_non_default_args}"
340
+ )
341
+
342
+ def _use_jacobian_to_specify_gradient(self, jacobian):
343
+ """Private function that uses the jacobian function to specify the
344
+ gradient function."""
345
+ # if jacobian is a single function and model has multiple inputs
346
+ if callable(jacobian) and self.number_of_inputs > 1:
347
+ gradient = self._create_gradient_lambda_function_from_jacobian_with_correct_signature(
348
+ jacobian, form='one_callable_multiple_inputs'
349
+ )
350
+ # Elif jacobian is a single function and model has only one input
351
+ elif callable(jacobian):
352
+ gradient = self._create_gradient_lambda_function_from_jacobian_with_correct_signature(
353
+ jacobian, form='one_callable_one_input'
354
+ )
355
+ # Else, jacobian is a tuple of jacobian functions
356
+ else:
357
+ gradient = []
358
+ for jac in jacobian:
359
+ if jac is not None:
360
+ gradient.append(
361
+ self._create_gradient_lambda_function_from_jacobian_with_correct_signature(
362
+ jac, form='tuple_of_callables'
363
+ )
364
+ )
365
+ else:
366
+ gradient.append(None)
367
+ return tuple(gradient) if isinstance(gradient, list) else gradient
368
+
369
+ def _create_gradient_lambda_function_from_jacobian_with_correct_signature(
370
+ self, jacobian, form
371
+ ):
372
+ """Private function that creates gradient lambda function from the
373
+ jacobian function, with the correct signature (based on the model
374
+ non_default_args).
375
+ """
376
+ # create the string representation of the lambda function
377
+ # for different forms of jacobian
378
+ if form=='one_callable_multiple_inputs':
379
+ grad_fun_str = (
380
+ "lambda direction, "
381
+ + ", ".join(self._non_default_args)
382
+ + ", jacobian: tuple([direction@jacobian("
383
+ + ", ".join(self._non_default_args)
384
+ + ")[i] for i in range("+str(self.number_of_inputs)+")])"
385
+ )
386
+ elif form=='tuple_of_callables' or form=='one_callable_one_input':
387
+ grad_fun_str = (
388
+ "lambda direction, "
389
+ + ", ".join(self._non_default_args)
390
+ + ", jacobian: direction@jacobian("
391
+ + ", ".join(self._non_default_args)
392
+ + ")"
393
+ )
394
+ else:
395
+ raise ValueError("form should be either 'one_callable' or 'tuple_of_callables'.")
396
+
397
+ # create the lambda function from the string
398
+ grad_func = eval(grad_fun_str)
399
+
400
+ # create partial function from the lambda function with jacobian as a
401
+ # fixed argument
402
+ grad_func = partial(grad_func, jacobian=jacobian)
403
+
404
+ return grad_func
405
+
406
+ def _2fun(self, geometry=None, is_par=True, **kwargs):
407
+ """ Converts `kwargs` to function values (if needed) using the geometry. For example, `kwargs` can be the model input which need to be converted to function value before being passed to :class:`~cuqi.model.Model` operators (e.g. _forward_func, _adjoint_func, _gradient_func).
169
408
 
170
409
  Parameters
171
410
  ----------
172
- x : ndarray or cuqi.array.CUQIarray
173
- The value to be converted.
174
-
175
411
  geometry : cuqi.geometry.Geometry
176
- The geometry representing `x`.
412
+ The geometry representing the values in `kwargs`.
177
413
 
178
- is_par : bool
179
- If True, `x` is assumed to be parameters.
180
- If False, `x` is assumed to be function values.
414
+ is_par : bool or a tuple of bools
415
+ If `is_par` is True, the values in `kwargs` are assumed to be parameters.
416
+ If `is_par` is False, the values in `kwargs` are assumed to be function values.
417
+ If `is_par` is a tuple of bools, the values in `kwargs` are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
418
+
419
+ **kwargs : keyword arguments to be converted to function values.
181
420
 
182
421
  Returns
183
422
  -------
184
- ndarray or cuqi.array.CUQIarray
185
- `x` represented as a function.
423
+ dict of the converted values
424
+ """
425
+ # Check kwargs and geometry are consistent and set up geometries list and
426
+ # is_par tuple
427
+ geometries, is_par = self._helper_pre_conversion_checks_and_processing(geometry, is_par, **kwargs)
428
+
429
+ # Convert to function values
430
+ for i, (k, v) in enumerate(kwargs.items()):
431
+ # Use CUQIarray funvals if geometry is consistent
432
+ if isinstance(v, CUQIarray) and v.geometry == geometries[i]:
433
+ kwargs[k] = v.funvals
434
+ # Else, if we still need to convert to function value (is_par[i] is True)
435
+ # we use the geometry par2fun method
436
+ elif is_par[i] and v is not None:
437
+ kwargs[k] = geometries[i].par2fun(v)
438
+ else:
439
+ # No need to convert
440
+ pass
441
+
442
+ return kwargs
443
+
444
+ def _helper_pre_conversion_checks_and_processing(self, geometry=None, is_par=True, **kwargs):
445
+ """ Helper function that checks if kwargs and geometry are consistent
446
+ and sets up geometries list and is_par tuple.
186
447
  """
187
- # Convert to function representation
188
- # if x is CUQIarray and geometry are consistent, we obtain funvals
189
- # directly
190
- if isinstance(x, CUQIarray) and x.geometry == geometry:
191
- x = x.funvals
192
- # Otherwise we use the geometry par2fun method
193
- elif is_par:
194
- x = geometry.par2fun(x)
195
-
196
- return x
197
-
198
- def _2par(self, val, geometry, to_CUQIarray=False, is_par=False):
199
- """ Converts val, normally output of :class:~`cuqi.model.Model`
200
- operators (e.g. _forward_func, _adjoint_func, _gradient_func), to
201
- parameters using the appropriate geometry.
448
+ # If len of kwargs is larger than 1, the geometry needs to be of type
449
+ # _ProductGeometry
450
+ if (
451
+ not isinstance(geometry, cuqi.experimental.geometry._ProductGeometry)
452
+ and len(kwargs) > 1
453
+ ):
454
+ raise ValueError(
455
+ "The input is specified by more than one argument. This is only "
456
+ + "supported for domain geometry of type "
457
+ + f"{cuqi.experimental.geometry._ProductGeometry.__name__}."
458
+ )
459
+
460
+ # If is_par is bool, make it a tuple of bools of the same length as
461
+ # kwargs
462
+ is_par = (is_par,) * len(kwargs) if isinstance(is_par, bool) else is_par
463
+
464
+ # Set up geometries list
465
+ geometries = (
466
+ geometry.geometries
467
+ if isinstance(geometry, cuqi.experimental.geometry._ProductGeometry)
468
+ else [geometry]
469
+ )
470
+
471
+ return geometries, is_par
472
+
473
+ def _2par(self, geometry=None, to_CUQIarray=False, is_par=False, **kwargs):
474
+ """ Converts `kwargs` to parameters using the geometry. For example, `kwargs` can be the output of :class:`~cuqi.model.Model` operators (e.g. _forward_func, _adjoint_func, _gradient_func) which need to be converted to parameters before being returned.
202
475
 
203
476
  Parameters
204
477
  ----------
205
- val : ndarray or cuqi.array.CUQIarray
206
- The value to be converted to parameters.
207
-
208
478
  geometry : cuqi.geometry.Geometry
209
- The geometry representing the argument `val`.
479
+ The geometry representing the values in `kwargs`.
210
480
 
211
- to_CUQIarray : bool
212
- If True, the returned value is wrapped as a cuqi.array.CUQIarray.
481
+ to_CUQIarray : bool or a tuple of bools
482
+ If `to_CUQIarray` is True, the values in `kwargs` will be wrapped in `CUQIarray`.
483
+ If `to_CUQIarray` is False, the values in `kwargs` will not be wrapped in `CUQIarray`.
484
+ If `to_CUQIarray` is a tuple of bools, the values in `kwargs` will be wrapped in `CUQIarray` or not based on the corresponding boolean value in the tuple.
213
485
 
214
- is_par : bool
215
- If True, `val` is assumed to be of parameter representation and
216
- hence no conversion to parameters is performed.
486
+ is_par : bool or a tuple of bools
487
+ If `is_par` is True, the values in `kwargs` are assumed to be parameters.
488
+ If `is_par` is False, the values in `kwargs` are assumed to be function values.
489
+ If `is_par` is a tuple of bools, the values in `kwargs` are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
217
490
 
218
491
  Returns
219
492
  -------
220
- ndarray or cuqi.array.CUQIarray
221
- The value `val` represented as parameters.
493
+ dict of the converted values
222
494
  """
495
+ # Check kwargs and geometry are consistent and set up geometries list and
496
+ # is_par tuple
497
+ geometries, is_par = self._helper_pre_conversion_checks_and_processing(geometry, is_par, **kwargs)
498
+
499
+ # if to_CUQIarray is bool, make it a tuple of bools of the same length
500
+ # as kwargs
501
+ to_CUQIarray = (to_CUQIarray,) * len(kwargs) if isinstance(to_CUQIarray, bool) else to_CUQIarray
502
+
223
503
  # Convert to parameters
224
- # if val is CUQIarray and geometry are consistent, we obtain parameters
225
- # directly
226
- if isinstance(val, CUQIarray) and val.geometry == geometry:
227
- val = val.parameters
228
- # Otherwise we use the geometry fun2par method
229
- elif not is_par:
230
- val = geometry.fun2par(val)
231
-
232
- # Wrap val in CUQIarray if requested
233
- if to_CUQIarray:
234
- val = CUQIarray(val, is_par=True, geometry=geometry)
235
-
236
- # Return val
237
- return val
238
-
504
+ for i , (k, v) in enumerate(kwargs.items()):
505
+ # Use CUQIarray parameters if geometry is consistent
506
+ if isinstance(v, CUQIarray) and v.geometry == geometries[i]:
507
+ v = v.parameters
508
+ # Else, if we still need to convert to parameter value (is_par[i] is False)
509
+ # we use the geometry fun2par method
510
+ elif not is_par[i] and v is not None:
511
+ v = geometries[i].fun2par(v)
512
+ else:
513
+ # No need to convert
514
+ pass
515
+
516
+ # Wrap the value v in CUQIarray if requested
517
+ if to_CUQIarray[i] and v is not None:
518
+ v = CUQIarray(v, is_par=True, geometry=geometries[i])
519
+
520
+ kwargs[k] = v
239
521
 
240
- def _apply_func(self, func, func_range_geometry, func_domain_geometry, x, is_par, **kwargs):
241
- """ Private function that applies the given function `func` to the input value `x`. It converts the input to function values (if needed) using the given `func_domain_geometry` and converts the output function values to parameters using the given `func_range_geometry`. It additionally handles the case of applying the function `func` to the cuqi.samples.Samples object.
522
+ return kwargs
242
523
 
243
- kwargs are keyword arguments passed to the functions `func`.
524
+ def _apply_func(self, func=None, fwd=True, is_par=True, **kwargs):
525
+ """ Private function that applies the given function `func` to the input `kwargs`. It converts the input to function values (if needed) and converts the output to parameter values. It additionally handles the case of applying the function `func` to cuqi.samples.Samples objects.
244
526
 
245
527
  Parameters
246
528
  ----------
247
529
  func: function handler
248
530
  The function to be applied.
249
531
 
250
- func_range_geometry : cuqi.geometry.Geometry
251
- The geometry representing the function `func` range.
252
-
253
- func_domain_geometry : cuqi.geometry.Geometry
254
- The geometry representing the function `func` domain.
255
-
256
- x : ndarray or cuqi.array.CUQIarray
257
- The input value to the operator.
532
+ fwd : bool
533
+ Flag indicating the direction of the operator to determine the range and domain geometries of the function.
534
+ If True the function is a forward operator.
535
+ If False the function is an adjoint operator.
258
536
 
259
- is_par : bool
260
- If True the input is assumed to be parameters.
261
- If False the input is assumed to be function values.
537
+ is_par : bool or list of bool
538
+ If True, the inputs in `kwargs` are assumed to be parameters.
539
+ If False, the input in `kwargs` are assumed to be function values.
540
+ If `is_par` is a list of bools, the inputs are assumed to be parameters or function values based on the corresponding boolean value in the list.
262
541
 
263
542
  Returns
264
543
  -------
265
- ndarray or cuqi.array.CUQIarray
266
- The output of the function `func` converted to parameters.
544
+ ndarray or cuqi.array.CUQIarray or cuqi.samples.Samples object
545
+ The output of the function.
267
546
  """
547
+ # Specify the range and domain geometries of the function
548
+ # If forward operator, range geometry is the model range geometry and
549
+ # domain geometry is the model domain geometry
550
+ if fwd:
551
+ func_range_geometry = self.range_geometry
552
+ func_domain_geometry = self.domain_geometry
553
+ # If adjoint operator, range geometry is the model domain geometry and
554
+ # domain geometry is the model range geometry
555
+ else:
556
+ func_range_geometry = self.domain_geometry
557
+ func_domain_geometry = self.range_geometry
558
+
268
559
  # If input x is Samples we apply func for each sample
269
560
  # TODO: Check if this can be done all-at-once for computational speed-up
270
- if isinstance(x,Samples):
271
- out = np.zeros((func_range_geometry.par_dim, x.Ns))
272
- # Recursively apply func to each sample
273
- for idx, item in enumerate(x):
274
- out[:,idx] = self._apply_func(func,
275
- func_range_geometry,
276
- func_domain_geometry,
277
- item, is_par=True,
278
- **kwargs)
279
- return Samples(out, geometry=func_range_geometry)
280
-
281
- # store if input x is CUQIarray
282
- is_CUQIarray = type(x) is CUQIarray
561
+ if any(isinstance(x, Samples) for x in kwargs.values()):
562
+ return self._handle_case_when_model_input_is_samples(func, fwd, **kwargs)
563
+
564
+ # store if any input x is CUQIarray
565
+ is_CUQIarray = any(isinstance(x, CUQIarray) for x in kwargs.values())
283
566
 
284
- x = self._2fun(x, func_domain_geometry, is_par=is_par)
285
- out = func(x, **kwargs)
567
+ # Convert input to function values
568
+ kwargs = self._2fun(geometry=func_domain_geometry, is_par=is_par, **kwargs)
286
569
 
287
- # Return output as parameters
288
- # (and wrapped in CUQIarray if input was CUQIarray)
289
- return self._2par(out, func_range_geometry,
290
- to_CUQIarray=is_CUQIarray)
570
+ # Apply the function
571
+ out = func(**kwargs)
291
572
 
292
- def _parse_args_add_to_kwargs(self, *args, **kwargs):
293
- """ Private function that parses the input arguments of the model and adds them as keyword arguments matching the non default arguments of the forward function. """
573
+ # Return output as parameters
574
+ # (wrapped in CUQIarray if any input was CUQIarray)
575
+ return self._2par(
576
+ geometry=func_range_geometry, to_CUQIarray=is_CUQIarray, **{"out": out}
577
+ )["out"]
294
578
 
579
+ def _handle_case_when_model_input_is_samples(self, func=None, fwd=True, **kwargs):
580
+ """Private function that calls apply_func for samples in the
581
+ Samples object(s).
582
+ """
583
+ # All kwargs should be Samples objects
584
+ if not all(isinstance(x, Samples) for x in kwargs.values()):
585
+ raise TypeError(
586
+ "If applying the function to Samples, all inputs should be Samples."
587
+ )
588
+
589
+ # All Samples objects should have the same number of samples
590
+ Ns = list(kwargs.values())[0].Ns
591
+ if not all(x.Ns == Ns for x in kwargs.values()):
592
+ raise ValueError(
593
+ "If applying the function to Samples, all inputs should have the same number of samples."
594
+ )
595
+
596
+ # Specify the range dimension of the function
597
+ range_dim = self.range_dim if fwd else self.domain_dim
598
+
599
+ # Create empty array to store the output
600
+ out = np.zeros((range_dim, Ns))
601
+
602
+ # Recursively apply func to each sample
603
+ for i in range(Ns):
604
+ kwargs_i = {
605
+ k: CUQIarray(v.samples[..., i], is_par=v.is_par, geometry=v.geometry)
606
+ for k, v in kwargs.items()
607
+ }
608
+ out[:, i] = self._apply_func(func=func, fwd=fwd, **kwargs_i)
609
+ # Specify the range geometries of the function
610
+ func_range_geometry = self.range_geometry if fwd else self.domain_geometry
611
+ return Samples(out, geometry=func_range_geometry)
612
+
613
+ def _parse_args_add_to_kwargs(
614
+ self, *args, is_par=True, non_default_args=None, map_name="model", **kwargs
615
+ ):
616
+ """ Private function that parses the input arguments and adds them as
617
+ keyword arguments matching (the order of) the non default arguments of
618
+ the forward function or other specified non_default_args list.
619
+ """
620
+ # If non_default_args is not specified, use the non_default_args of the
621
+ # model
622
+ if non_default_args is None:
623
+ non_default_args = self._non_default_args
624
+
625
+ # Either args or kwargs can be provided but not both
626
+ if len(args) > 0 and len(kwargs) > 0:
627
+ raise ValueError(
628
+ "The "
629
+ + map_name.lower()
630
+ + " input is specified both as positional and keyword arguments. This is not supported."
631
+ )
632
+
633
+ len_input = len(args) + len(kwargs)
634
+
635
+ # If partial evaluation, make sure input is not of type Samples
636
+ if len_input < len(non_default_args):
637
+ # If the argument is a Sample object, splitting or partial
638
+ # evaluation of the model is not supported
639
+ temp_args = args if len(args) > 0 else list(kwargs.values())
640
+ if any(isinstance(arg, Samples) for arg in temp_args):
641
+ raise ValueError(("When using Samples objects as input, the"
642
+ +" user should provide a Samples object for"
643
+ +f" each non_default_args {non_default_args}"
644
+ +" of the model. That is, partial evaluation"
645
+ +" or splitting is not supported for input"
646
+ +" of type Samples."))
647
+
648
+ # If args are given, add them to kwargs
295
649
  if len(args) > 0:
296
650
 
297
- if len(kwargs) > 0:
298
- raise ValueError("The model input is specified both as positional and keyword arguments. This is not supported.")
299
-
300
- if len(args) != len(self._non_default_args):
301
- raise ValueError("The number of positional arguments does not match the number of non-default arguments of the model.")
302
-
651
+ # Check if the input is for multiple input case and is stacked,
652
+ # then split it
653
+ if len(args) < len(non_default_args):
654
+ args = self._split_in_case_of_stacked_args(*args, is_par=is_par)
655
+
303
656
  # Add args to kwargs following the order of non_default_args
304
657
  for idx, arg in enumerate(args):
305
- kwargs[self._non_default_args[idx]] = arg
658
+ kwargs[non_default_args[idx]] = arg
659
+
660
+ # Check kwargs matches non_default_args
661
+ if not (set(list(kwargs.keys())) <= set(non_default_args)):
662
+ if map_name == "gradient":
663
+ error_msg = f"The gradient input is specified by a direction and keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the model {non_default_args}."
664
+ else:
665
+ error_msg = (
666
+ "The "
667
+ + map_name.lower()
668
+ + f" input is specified by keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the "
669
+ + map_name
670
+ + f" {non_default_args}."
671
+ )
672
+
673
+ raise ValueError(error_msg)
674
+
675
+ # Make sure order of kwargs is the same as non_default_args
676
+ kwargs = {k: kwargs[k] for k in non_default_args if k in kwargs}
306
677
 
307
678
  return kwargs
308
-
679
+
680
+ def _split_in_case_of_stacked_args(self, *args, is_par=True):
681
+ """Private function that checks if the input args is a stacked
682
+ CUQIarray or numpy array and splits it into multiple arguments based on
683
+ the domain geometry of the model. Otherwise, it returns the input args
684
+ unchanged."""
685
+
686
+ # Check conditions for splitting and split if all conditions are met
687
+ is_CUQIarray = isinstance(args[0], CUQIarray)
688
+ is_numpy_array = isinstance(args[0], np.ndarray)
689
+
690
+ if ((is_CUQIarray or is_numpy_array) and
691
+ is_par and
692
+ len(args) == 1 and
693
+ args[0].shape == (self.domain_dim,) and
694
+ isinstance(self.domain_geometry, cuqi.experimental.geometry._ProductGeometry)):
695
+ # Split the stacked input
696
+ split_args = np.split(args[0], self.domain_geometry.stacked_par_split_indices)
697
+ # Convert split args to CUQIarray if input is CUQIarray
698
+ if is_CUQIarray:
699
+ split_args = [
700
+ CUQIarray(arg, is_par=True, geometry=self.domain_geometry.geometries[i])
701
+ for i, arg in enumerate(split_args)
702
+ ]
703
+ return split_args
704
+
705
+ else:
706
+ return args
707
+
309
708
  def forward(self, *args, is_par=True, **kwargs):
310
709
  """ Forward function of the model.
311
-
312
- Forward converts the input to function values (if needed) using the domain geometry of the model.
313
- Forward converts the output function values to parameters using the range geometry of the model.
710
+
711
+ Forward converts the input to function values (if needed) using the domain geometry of the model. Then it applies the forward operator to the function values and converts the output to parameters using the range geometry of the model.
314
712
 
315
713
  Parameters
316
714
  ----------
317
- *args : ndarray or cuqi.array.CUQIarray
318
- The model input.
715
+ *args : ndarrays or cuqi.array.CUQIarray objects or cuqi.samples.Samples objects
716
+ Positional arguments for the forward operator. The forward operator input can be specified as either positional arguments or keyword arguments but not both.
319
717
 
320
- is_par : bool
321
- If True the input is assumed to be parameters.
322
- If False the input is assumed to be function values.
323
-
324
- **kwargs : keyword arguments for model input.
325
- Keywords must match the names of the non_default_args of the model.
718
+ If the input is specified as positional arguments, the order of the arguments should match the non_default_args of the model.
719
+
720
+ is_par : bool or a tuple of bools
721
+ If True, the inputs in `args` or `kwargs` are assumed to be parameters.
722
+ If False, the inputs in `args` or `kwargs` are assumed to be function values.
723
+ If `is_par` is a tuple of bools, the inputs are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
724
+
725
+ **kwargs : keyword arguments
726
+ keyword arguments for the forward operator. The forward operator input can be specified as either positional arguments or keyword arguments but not both.
727
+
728
+ If the input is specified as keyword arguments, the keys should match the non_default_args of the model.
326
729
 
327
730
  Returns
328
731
  -------
329
- ndarray or cuqi.array.CUQIarray
732
+ ndarray or cuqi.array.CUQIarray or cuqi.samples.Samples object
330
733
  The model output. Always returned as parameters.
331
734
  """
332
735
 
333
- kwargs = self._parse_args_add_to_kwargs(*args, **kwargs)
334
-
335
- # Check kwargs matches non_default_args
336
- if set(list(kwargs.keys())) != set(self._non_default_args):
337
- raise ValueError(f"The model input is specified by a keywords arguments {kwargs.keys()} that does not match the non_default_args of the model {self._non_default_args}.")
338
-
339
- # For now only support one input
340
- if len(kwargs) > 1:
341
- raise ValueError("The model input is specified by more than one argument. This is not supported.")
342
-
343
- # Get input matching the non_default_args
344
- x = kwargs[self._non_default_args[0]]
345
-
346
- # If input is a distribution, we simply change the parameter name of model to match the distribution name
347
- if isinstance(x, cuqi.distribution.Distribution):
348
- if x.dim != self.domain_dim:
349
- raise ValueError("Attempting to match parameter name of Model with given distribution, but distribution dimension does not match model domain dimension.")
350
- new_model = copy(self)
351
- new_model._non_default_args = [x.name] # Defaults to x if distribution had no name
352
- return new_model
736
+ # Add args to kwargs and ensure the order of the arguments matches the
737
+ # non_default_args of the forward function
738
+ kwargs = self._parse_args_add_to_kwargs(
739
+ *args, **kwargs, is_par=is_par, map_name="model"
740
+ )
741
+ # Extract args from kwargs
742
+ args = list(kwargs.values())
743
+
744
+ if len(kwargs) == 0:
745
+ return self
746
+
747
+ partial_arguments = len(kwargs) < len(self._non_default_args)
748
+
749
+ # If input is a distribution, we simply change the parameter name of
750
+ # model to match the distribution name
751
+ if all(isinstance(x, cuqi.distribution.Distribution)
752
+ for x in kwargs.values()):
753
+ if partial_arguments:
754
+ raise ValueError(
755
+ "Partial evaluation of the model is not supported for distributions."
756
+ )
757
+ return self._handle_case_when_model_input_is_distributions(kwargs)
353
758
 
354
759
  # If input is a random variable, we handle it separately
355
- if isinstance(x, cuqi.experimental.algebra.RandomVariable):
356
- return self._handle_random_variable(x)
357
-
760
+ elif all(isinstance(x, cuqi.experimental.algebra.RandomVariable)
761
+ for x in kwargs.values()):
762
+ if partial_arguments:
763
+ raise ValueError(
764
+ "Partial evaluation of the model is not supported for random variables."
765
+ )
766
+ return self._handle_case_when_model_input_is_random_variables(kwargs)
767
+
358
768
  # If input is a Node from internal abstract syntax tree, we let the Node handle the operation
359
769
  # We use NotImplemented to indicate that the operation is not supported from the Model class
360
770
  # in case of operations such as "@" that can be interpreted as both __matmul__ and __rmatmul__
361
771
  # the operation may be delegated to the Node class.
362
- if isinstance(x, cuqi.experimental.algebra.Node):
772
+ elif any(isinstance(args_i, cuqi.experimental.algebra.Node) for args_i in args):
363
773
  return NotImplemented
364
774
 
775
+ # if input is partial, we create a new model with the partial input
776
+ if partial_arguments:
777
+ # Create is_par_partial from the is_par to contain only the relevant parts
778
+ if isinstance(is_par, (list, tuple)):
779
+ is_par_partial = tuple(
780
+ is_par[i]
781
+ for i in range(self.number_of_inputs)
782
+ if self._non_default_args[i] in kwargs.keys()
783
+ )
784
+ else:
785
+ is_par_partial = is_par
786
+ # Build a partial model with the given kwargs
787
+ partial_model = self._build_partial_model(kwargs, is_par_partial)
788
+ return partial_model
789
+
365
790
  # Else we apply the forward operator
366
- return self._apply_func(self._forward_func,
367
- self.range_geometry,
368
- self.domain_geometry,
369
- x, is_par)
791
+ # if model has _original_non_default_args, we use it to replace the
792
+ # kwargs keys so that it matches self._forward_func signature
793
+ if hasattr(self, '_original_non_default_args'):
794
+ kwargs = {k:v for k,v in zip(self._original_non_default_args, args)}
795
+ return self._apply_func(func=self._forward_func,
796
+ fwd=True,
797
+ is_par=is_par,
798
+ **kwargs)
799
+
800
+ def _correct_distribution_dimension(self, distributions):
801
+ """Private function that checks if the dimension of the
802
+ distributions matches the domain dimension of the model."""
803
+ if len(distributions) == 1:
804
+ return list(distributions)[0].dim == self.domain_dim
805
+ elif len(distributions) > 1 and isinstance(
806
+ self.domain_geometry, cuqi.experimental.geometry._ProductGeometry
807
+ ):
808
+ return all(
809
+ d.dim == self.domain_geometry.par_dim_list[i]
810
+ for i, d in enumerate(distributions)
811
+ )
812
+ else:
813
+ return False
814
+
815
+ def _build_partial_model(self, kwargs, is_par):
816
+ """Private function that builds a partial model substituting the given
817
+ keyword arguments with their values. The created partial model will have
818
+ as inputs the non-default arguments that are not in the kwargs."""
819
+
820
+ # Extract args from kwargs
821
+ args = list(kwargs.values())
822
+
823
+ # Define original_non_default_args which represents the complete list of
824
+ # non-default arguments of the forward function.
825
+ original_non_default_args = (
826
+ self._original_non_default_args
827
+ if hasattr(self, "_original_non_default_args")
828
+ else self._non_default_args
829
+ )
830
+
831
+ if hasattr(self, "_original_non_default_args"):
832
+ # Split the _original_non_default_args into two lists:
833
+ # 1. reduced_original_non_default_args: the _original_non_default_args
834
+ # corresponding to the _non_default_args that are not in kwargs
835
+ # 2. substituted_non_default_args: the _original_non_default_args
836
+ # corresponding to the _non_default_args that are in kwargs
837
+ reduced_original_non_default_args = [
838
+ original_non_default_args[i]
839
+ for i in range(self.number_of_inputs)
840
+ if self._non_default_args[i] not in kwargs.keys()
841
+ ]
842
+ substituted_non_default_args = [
843
+ original_non_default_args[i]
844
+ for i in range(self.number_of_inputs)
845
+ if self._non_default_args[i] in kwargs.keys()
846
+ ]
847
+ # Replace the keys in kwargs with the substituted_non_default_args
848
+ # so that the kwargs match the signature of the _forward_func
849
+ kwargs = {k: v for k, v in zip(substituted_non_default_args, args)}
850
+
851
+ # Create a partial domain geometry with the geometries corresponding
852
+ # to the non-default arguments that are not in kwargs (remaining
853
+ # unspecified inputs)
854
+ partial_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
855
+ *[
856
+ self.domain_geometry.geometries[i]
857
+ for i in range(self.number_of_inputs)
858
+ if original_non_default_args[i] not in kwargs.keys()
859
+ ]
860
+ )
861
+
862
+ if len(partial_domain_geometry.geometries) == 1:
863
+ partial_domain_geometry = partial_domain_geometry.geometries[0]
864
+
865
+ # Create a domain geometry with the geometries corresponding to the
866
+ # non-default arguments that are specified
867
+ substituted_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
868
+ *[
869
+ self.domain_geometry.geometries[i]
870
+ for i in range(self.number_of_inputs)
871
+ if original_non_default_args[i] in kwargs.keys()
872
+ ]
873
+ )
874
+
875
+ if len(substituted_domain_geometry.geometries) == 1:
876
+ substituted_domain_geometry = substituted_domain_geometry.geometries[0]
877
+
878
+ # Create new model with partial input
879
+ # First, we convert the input to function values
880
+ kwargs = self._2fun(geometry=substituted_domain_geometry, is_par=is_par, **kwargs)
881
+
882
+ # Second, we create a partial function for the forward operator
883
+ partial_forward = partial(self._forward_func, **kwargs)
884
+
885
+ # Third, if applicable, we create a partial function for the gradient
886
+ if isinstance(self._gradient_func, tuple):
887
+ # If gradient is a tuple, we create a partial function for each
888
+ # gradient function in the tuple
889
+ partial_gradient = tuple(
890
+ (
891
+ partial(self._gradient_func[i], **kwargs)
892
+ if self._gradient_func[i] is not None
893
+ else None
894
+ )
895
+ for i in range(self.number_of_inputs)
896
+ if original_non_default_args[i] not in kwargs.keys()
897
+ )
898
+ if len(partial_gradient) == 1:
899
+ partial_gradient = partial_gradient[0]
900
+
901
+ elif callable(self._gradient_func):
902
+ raise NotImplementedError(
903
+ "Partial forward model is only supported for gradient/jacobian functions that are tuples of callable functions."
904
+ )
370
905
 
371
- def __call__(self, *args, **kwargs):
372
- return self.forward(*args, **kwargs)
906
+ else:
907
+ partial_gradient = None
908
+
909
+ # Lastly, we create the partial model with the partial forward
910
+ # operator (we set the gradient function later)
911
+ partial_model = Model(
912
+ forward=partial_forward,
913
+ range_geometry=self.range_geometry,
914
+ domain_geometry=partial_domain_geometry,
915
+ )
916
+
917
+ # Set the _original_non_default_args (if applicable) and
918
+ # _stored_non_default_args of the partial model
919
+ if hasattr(self, "_original_non_default_args"):
920
+ partial_model._original_non_default_args = reduced_original_non_default_args
921
+ partial_model._stored_non_default_args = [
922
+ self._non_default_args[i]
923
+ for i in range(self.number_of_inputs)
924
+ if original_non_default_args[i] not in kwargs.keys()
925
+ ]
926
+
927
+ # Set the gradient function of the partial model
928
+ partial_model._check_correct_gradient_jacobian_form(
929
+ partial_gradient, "gradient"
930
+ )
931
+ partial_model._gradient_func = partial_gradient
932
+
933
+ return partial_model
934
+
935
+ def _handle_case_when_model_input_is_distributions(self, kwargs):
936
+ """Private function that handles the case of the input being a
937
+ distribution or multiple distributions."""
938
+
939
+ if not self._correct_distribution_dimension(kwargs.values()):
940
+ raise ValueError(
941
+ "Attempting to match parameter name of Model with given distribution(s), but distribution(s) dimension(s) does not match model input dimension(s)."
942
+ )
943
+ new_model = copy(self)
944
+
945
+ # Store the original non_default_args of the model
946
+ new_model._original_non_default_args = (
947
+ self._original_non_default_args
948
+ if hasattr(self, "_original_non_default_args")
949
+ else self._non_default_args
950
+ )
951
+
952
+ # Update the non_default_args of the model to match the distribution
953
+ # names. Defaults to x in the case of only one distribution that has no
954
+ # name
955
+ new_model._stored_non_default_args = [x.name for x in kwargs.values()]
956
+
957
+ # If there is a repeated name, raise an error
958
+ if len(set(new_model._stored_non_default_args)) != len(
959
+ new_model._stored_non_default_args
960
+ ):
961
+ raise ValueError(
962
+ "Attempting to match parameter name of Model with given distributions, but distribution names are not unique. Please provide unique names for the distributions."
963
+ )
964
+
965
+ return new_model
966
+
967
+ def _handle_case_when_model_input_is_random_variables(self, kwargs):
968
+ """ Private function that handles the case of the input being a random variable. """
969
+ # If random variable is not a leaf-type node (e.g. internal node) we return NotImplemented
970
+ if any(not isinstance(x.tree, cuqi.experimental.algebra.VariableNode) for x in kwargs.values()):
971
+ return NotImplemented
972
+
973
+ # Extract the random variable distributions and check dimensions consistency with domain geometry
974
+ distributions = [value.distribution for value in kwargs.values()]
975
+ if not self._correct_distribution_dimension(distributions):
976
+ raise ValueError("Attempting to match parameter name of Model with given random variable(s), but random variable dimension(s) does not match model input dimension(s).")
977
+
978
+ new_model = copy(self)
373
979
 
374
- def gradient(self, direction, wrt, is_direction_par=True, is_wrt_par=True):
375
- """ Gradient of the forward operator (Direction-Jacobian product)
980
+ # Store the original non_default_args of the model
981
+ new_model._original_non_default_args = self._non_default_args
376
982
 
377
- For non-linear models the gradient is computed using the
378
- forward operator and the Jacobian of the forward operator.
983
+ # Update the non_default_args of the model to match the random variable
984
+ # names. Defaults to x in the case of only one random variable that has
985
+ # no name
986
+ new_model._stored_non_default_args = [x.name for x in distributions]
987
+
988
+ # If there is a repeated name, raise an error
989
+ if len(set(new_model._stored_non_default_args)) != len(
990
+ new_model._stored_non_default_args
991
+ ):
992
+ raise ValueError(
993
+ "Attempting to match parameter name of Model with given random variables, but random variables names are not unique. Please provide unique names for the random variables."
994
+ )
995
+
996
+ return new_model
997
+
998
+ def gradient(
999
+ self, direction, *args, is_direction_par=True, is_var_par=True, **kwargs
1000
+ ):
1001
+ """Gradient of the forward operator (Direction-Jacobian product)
1002
+
1003
+ The gradient computes the Vector-Jacobian product (VJP) of the forward operator evaluated at the given model input and the given vector (direction).
379
1004
 
380
1005
  Parameters
381
1006
  ----------
382
- direction : ndarray
383
- The direction to compute the gradient. The Jacobian is applied to this direction.
1007
+ direction : ndarray or cuqi.array.CUQIarray
1008
+ The direction at which to compute the gradient.
1009
+
1010
+ *args : ndarrays or cuqi.array.CUQIarray objects
1011
+ Positional arguments for the values at which to compute the gradient. The gradient operator input can be specified as either positional arguments or keyword arguments but not both.
384
1012
 
385
- wrt : ndarray
386
- The point to compute the Jacobian at. This is only used for non-linear models.
1013
+ If the input is specified as positional arguments, the order of the arguments should match the non_default_args of the model.
387
1014
 
388
1015
  is_direction_par : bool
389
1016
  If True, `direction` is assumed to be parameters.
390
1017
  If False, `direction` is assumed to be function values.
391
1018
 
392
- is_wrt_par : bool
393
- If True, `wrt` is assumed to be parameters.
394
- If False, `wrt` is assumed to be function values.
395
-
1019
+ is_var_par : bool or a tuple of bools
1020
+ If True, the inputs in `args` or `kwargs` are assumed to be parameters.
1021
+ If False, the inputs in `args` or `kwargs` are assumed to be function values.
1022
+ If `is_var_par` is a tuple of bools, the inputs in `args` or `kwargs` are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
396
1023
  """
397
- # Obtain the parameters representation of wrt and raise an error if it
398
- # cannot be obtained
399
- error_message = \
400
- "For the gradient to be computed, is_wrt_par needs " +\
401
- "to be True and wrt needs to be parameter value, not function " +\
402
- "value. Alternatively, the model domain_geometry: "+\
403
- f"{self.domain_geometry} " +\
404
- "should have an implementation of the method fun2par"
1024
+ # Add args to kwargs and ensure the order of the arguments matches the
1025
+ # non_default_args of the forward function
1026
+ kwargs = self._parse_args_add_to_kwargs(
1027
+ *args, **kwargs, is_par=is_var_par, map_name="gradient"
1028
+ )
1029
+
1030
+ # Obtain the parameters representation of the variables and raise an
1031
+ # error if it cannot be obtained
1032
+ error_message = (
1033
+ "For the gradient to be computed, is_var_par needs to be True and the variables in kwargs needs to be parameter value, not function value. Alternatively, the model domain_geometry:"
1034
+ + f" {self.domain_geometry} "
1035
+ + "should have an implementation of the method fun2par"
1036
+ )
405
1037
  try:
406
- wrt_par = self._2par(wrt,
407
- geometry=self.domain_geometry,
408
- is_par=is_wrt_par,
409
- to_CUQIarray=False,
410
- )
1038
+ kwargs_par = self._2par(
1039
+ geometry=self.domain_geometry,
1040
+ is_par=is_var_par,
1041
+ to_CUQIarray=False,
1042
+ **kwargs,
1043
+ )
411
1044
  # NotImplementedError will be raised if fun2par of the geometry is not
412
1045
  # implemented and ValueError will be raised when imap is not set in
413
1046
  # MappedGeometry
414
1047
  except ValueError as e:
415
- raise ValueError(error_message +
416
- " ,including an implementation of imap for " +
417
- "MappedGeometry")
1048
+ raise ValueError(
1049
+ error_message
1050
+ + " ,including an implementation of imap for MappedGeometry"
1051
+ )
418
1052
  except NotImplementedError as e:
419
1053
  raise NotImplementedError(error_message)
420
-
421
- # Check for other errors that may prevent computing the gradient
422
- self._check_gradient_can_be_computed(direction, wrt)
423
-
424
- wrt = self._2fun(wrt, self.domain_geometry, is_par=is_wrt_par)
425
-
426
- # Store if the input direction is CUQIarray
427
- is_direction_CUQIarray = type(direction) is CUQIarray
428
-
429
- direction = self._2fun(direction,
430
- self.range_geometry,
431
- is_par=is_direction_par)
432
1054
 
433
- grad = self._gradient_func(direction, wrt)
434
- grad_is_par = False # Assume gradient is function values
435
-
436
- # If domain_geometry has gradient attribute, we apply it to the gradient
437
- # The gradient returned by the domain_geometry.gradient is assumed to be
438
- # parameters
439
- if hasattr(self.domain_geometry, 'gradient'):
440
- grad = self.domain_geometry.gradient(grad, wrt_par)
441
- grad_is_par = True # Gradient is parameters
442
-
443
- # we convert the computed gradient to parameters
444
- grad = self._2par(grad,
445
- self.domain_geometry,
446
- to_CUQIarray=is_direction_CUQIarray,
447
- is_par=grad_is_par)
1055
+ # Check for other errors that may prevent computing the gradient
1056
+ self._check_gradient_can_be_computed(direction, kwargs)
1057
+
1058
+ # Also obtain the function values representation of the variables
1059
+ kwargs_fun = self._2fun(
1060
+ geometry=self.domain_geometry, is_par=is_var_par, **kwargs
1061
+ )
1062
+
1063
+ # Store if any of the inputs is a CUQIarray
1064
+ to_CUQIarray = isinstance(direction, CUQIarray) or any(
1065
+ isinstance(x, CUQIarray) for x in kwargs_fun.values()
1066
+ )
1067
+
1068
+ # Turn to_CUQIarray to a tuple of bools of the same length as kwargs_fun
1069
+ to_CUQIarray = tuple([to_CUQIarray] * len(kwargs_fun))
1070
+
1071
+ # Convert direction to function value
1072
+ direction_fun = self._2fun(
1073
+ direction=direction, geometry=self.range_geometry, is_par=is_direction_par
1074
+ )
1075
+
1076
+ # If model has _original_non_default_args, we use it to replace the
1077
+ # kwargs keys so that it matches self._gradient_func signature
1078
+ if hasattr(self, '_original_non_default_args'):
1079
+ args_fun = list(kwargs_fun.values())
1080
+ kwargs_fun = {
1081
+ k: v for k, v in zip(self._original_non_default_args, args_fun)
1082
+ }
1083
+ # Append the direction to the kwargs_fun as first input
1084
+ kwargs_fun_grad_input = {**direction_fun, **kwargs_fun}
1085
+
1086
+ # Form 1 of gradient (callable)
1087
+ if callable(self._gradient_func):
1088
+ grad = self._gradient_func(**kwargs_fun_grad_input)
1089
+ grad_is_par = False # Assume gradient is function value
1090
+
1091
+ # Form 2 of gradient (tuple of callables)
1092
+ elif isinstance(self._gradient_func, tuple):
1093
+ grad = []
1094
+ for i, grad_func in enumerate(self._gradient_func):
1095
+ if grad_func is not None:
1096
+ grad.append(grad_func(**kwargs_fun_grad_input))
1097
+ else:
1098
+ grad.append(None)
1099
+ # set the ith item of to_CUQIarray tuple to False
1100
+ # because the ith gradient is None
1101
+ to_CUQIarray = to_CUQIarray[:i] + (False,) + to_CUQIarray[i + 1 :]
1102
+ grad_is_par = False # Assume gradient is function value
1103
+
1104
+ grad = self._apply_chain_rule_to_account_for_domain_geometry_gradient(
1105
+ kwargs_par, grad, grad_is_par, to_CUQIarray
1106
+ )
1107
+
1108
+ if len(grad) == 1:
1109
+ return list(grad.values())[0]
1110
+ elif self._gradient_output_stacked:
1111
+ return np.hstack(
1112
+ [
1113
+ (
1114
+ v.to_numpy()
1115
+ if isinstance(v, CUQIarray)
1116
+ else force_ndarray(v, flatten=True)
1117
+ )
1118
+ for v in list(grad.values())
1119
+ ]
1120
+ )
448
1121
 
449
1122
  return grad
450
-
451
- def _check_gradient_can_be_computed(self, direction, wrt):
452
- """ Private function that checks if the gradient can be computed. By
1123
+
1124
+ def _check_gradient_can_be_computed(self, direction, kwargs_dict):
1125
+ """Private function that checks if the gradient can be computed. By
453
1126
  raising an error for the cases where the gradient cannot be computed."""
454
1127
 
455
1128
  # Raise an error if _gradient_func function is not set
456
1129
  if self._gradient_func is None:
457
1130
  raise NotImplementedError("Gradient is not implemented for this model.")
458
-
459
- # Raise error if either the direction or wrt are Samples object
460
- if isinstance(direction, Samples) or isinstance(wrt, Samples):
461
- raise ValueError("cuqi.samples.Samples input values for arguments `direction` and `wrt` are not supported")
462
-
1131
+
1132
+ # Raise an error if either the direction or kwargs are Samples objects
1133
+ if isinstance(direction, Samples) or any(
1134
+ isinstance(x, Samples) for x in kwargs_dict.values()
1135
+ ):
1136
+ raise NotImplementedError(
1137
+ "Gradient is not implemented for input of type Samples."
1138
+ )
1139
+
463
1140
  # Raise an error if range_geometry is not in the list returned by
464
- # `_get_identity_geometries()`. i.e. The Jacobian of its
465
- # par2fun map is not identity.
466
- #TODO: Add range geometry gradient to the chain rule
1141
+ # `_get_identity_geometries()`. i.e. The Jacobian of its
1142
+ # par2fun map is not identity.
1143
+ # TODO: Add range geometry gradient to the chain rule
467
1144
  if not type(self.range_geometry) in _get_identity_geometries():
468
- raise NotImplementedError("Gradient not implemented for model {} with range geometry {}".format(self,self.range_geometry))
469
-
470
- # Raise an error if domain_geometry does not have gradient attribute and
471
- # is not in the list returned by `_get_identity_geometries()`. i.e. the
472
- # Jacobian of its par2fun map is not identity.
473
- if not hasattr(self.domain_geometry, 'gradient') and \
474
- not type(self.domain_geometry) in _get_identity_geometries():
475
- raise NotImplementedError("Gradient not implemented for model {} with domain geometry {}".format(self,self.domain_geometry))
476
-
477
- def _handle_random_variable(self, x):
478
- """ Private function that handles the case of the input being a random variable. """
479
- # If random variable is not a leaf-type node (e.g. internal node) we return NotImplemented
480
- if not isinstance(x.tree, cuqi.experimental.algebra.VariableNode):
481
- return NotImplemented
482
-
483
- # In leaf-type node case we simply change the parameter name of model to match the random variable name
484
- dist = x.distribution
485
- if dist.dim != self.domain_dim:
486
- raise ValueError("Attempting to match parameter name of Model with given random variable, but random variable dimension does not match model domain dimension.")
487
-
488
- new_model = copy(self)
489
- new_model._non_default_args = [dist.name]
490
- return new_model
1145
+ raise NotImplementedError(
1146
+ "Gradient is not implemented for model {} with range geometry {}. You can use one of the geometries in the list {}.".format(
1147
+ self,
1148
+ self.range_geometry,
1149
+ [i_g.__name__ for i_g in _get_identity_geometries()],
1150
+ )
1151
+ )
1152
+
1153
+ # Raise an error if domain_geometry (or its components in case of
1154
+ # _ProductGeometry) does not have gradient attribute and is not in the
1155
+ # list returned by `_get_identity_geometries()`. i.e. The Jacobian of its
1156
+ # par2fun map is not identity.
1157
+ domain_geometries = (
1158
+ self.domain_geometry.geometries
1159
+ if isinstance(
1160
+ self.domain_geometry, cuqi.experimental.geometry._ProductGeometry
1161
+ )
1162
+ else [self.domain_geometry]
1163
+ )
1164
+ for domain_geometry in domain_geometries:
1165
+ if (
1166
+ not hasattr(domain_geometry, "gradient")
1167
+ and not type(domain_geometry) in _get_identity_geometries()
1168
+ ):
1169
+ raise NotImplementedError(
1170
+ "Gradient is not implemented for model \n{}\nwith domain geometry (or domain geometry component) {}. The domain geometries should have gradient method or be from the geometries in the list {}.".format(
1171
+ self,
1172
+ domain_geometry,
1173
+ [i_g.__name__ for i_g in _get_identity_geometries()],
1174
+ )
1175
+ )
1176
+
1177
+ def _apply_chain_rule_to_account_for_domain_geometry_gradient(self,
1178
+ kwargs_par,
1179
+ grad,
1180
+ grad_is_par,
1181
+ to_CUQIarray):
1182
+ """ Private function that applies the chain rule to account for the
1183
+ gradient of the domain geometry. That is, it computes the gradient of
1184
+ the function values with respect to the parameters values."""
1185
+ # Create list of domain geometries
1186
+ geometries = (
1187
+ self.domain_geometry.geometries
1188
+ if isinstance(self.domain_geometry, cuqi.experimental.geometry._ProductGeometry)
1189
+ else [self.domain_geometry]
1190
+ )
1191
+
1192
+ # turn grad_is_par to a tuple of bools if it is not already
1193
+ if isinstance(grad_is_par, bool):
1194
+ grad_is_par = tuple([grad_is_par]*self.number_of_inputs)
1195
+
1196
+ # If the domain geometry is a _ProductGeometry and the gradient is
1197
+ # stacked, split it
1198
+ if (
1199
+ isinstance(
1200
+ self.domain_geometry, cuqi.experimental.geometry._ProductGeometry
1201
+ )
1202
+ and not isinstance(grad, (list, tuple))
1203
+ and isinstance(grad, np.ndarray)
1204
+ ):
1205
+ grad = np.split(grad, self.domain_geometry.stacked_par_split_indices)
1206
+
1207
+ # If the domain geometry is not a _ProductGeometry, turn grad into a
1208
+ # list of length 1, so that we can iterate over it
1209
+ if not isinstance(self.domain_geometry, cuqi.experimental.geometry._ProductGeometry):
1210
+ grad = [grad]
1211
+
1212
+ # apply the gradient of each geometry component
1213
+ grad_kwargs = {}
1214
+ for i, (k, v_par) in enumerate(kwargs_par.items()):
1215
+ if hasattr(geometries[i], 'gradient') and grad[i] is not None:
1216
+ grad_kwargs[k] = geometries[i].gradient(grad[i], v_par)
1217
+ # update the ith component of grad_is_par to True
1218
+ grad_is_par = grad_is_par[:i] + (True,) + grad_is_par[i+1:]
1219
+ else:
1220
+ grad_kwargs[k] = grad[i]
1221
+
1222
+ # convert the computed gradient to parameters
1223
+ grad = self._2par(geometry=self.domain_geometry,
1224
+ to_CUQIarray=to_CUQIarray,
1225
+ is_par=grad_is_par,
1226
+ **grad_kwargs)
1227
+
1228
+ return grad
1229
+
1230
+ def __call__(self, *args, **kwargs):
1231
+ return self.forward(*args, **kwargs)
491
1232
 
492
1233
  def __len__(self):
493
1234
  return self.range_dim
494
1235
 
495
1236
  def __repr__(self) -> str:
496
- return "CUQI {}: {} -> {}.\n Forward parameters: {}.".format(self.__class__.__name__,self.domain_geometry,self.range_geometry,cuqi.utilities.get_non_default_args(self))
497
-
1237
+ kwargs = {}
1238
+ if self.number_of_inputs > 1:
1239
+ pad = " " * len("CUQI {}: ".format(self.__class__.__name__))
1240
+ kwargs["pad"]=pad
1241
+ return "CUQI {}: {} -> {}.\n Forward parameters: {}.".format(self.__class__.__name__,self.domain_geometry.__repr__(**kwargs),self.range_geometry,self._non_default_args)
498
1242
 
499
1243
  class AffineModel(Model):
500
1244
  """ Model class representing an affine model, i.e. a linear operator with a fixed shift. For linear models, represented by a linear operator only, see :class:`~cuqi.model.LinearModel`.
@@ -533,7 +1277,7 @@ class AffineModel(Model):
533
1277
  if hasattr(linear_operator, '__matmul__') and hasattr(linear_operator, 'T'):
534
1278
  if linear_operator_adjoint is not None:
535
1279
  raise ValueError("Adjoint of linear operator should not be provided when linear operator is a matrix. If you want to provide an adjoint, use a callable function for the linear operator.")
536
-
1280
+
537
1281
  matrix = linear_operator
538
1282
 
539
1283
  linear_operator = lambda x: matrix@x
@@ -559,11 +1303,50 @@ class AffineModel(Model):
559
1303
  if linear_operator_adjoint is not None and not callable(linear_operator_adjoint):
560
1304
  raise TypeError("Linear operator adjoint must be defined as a callable function of some kind")
561
1305
 
1306
+ # If linear operator is of type Model, it needs to be a LinearModel
1307
+ if isinstance(linear_operator, Model) and not isinstance(
1308
+ linear_operator, LinearModel
1309
+ ):
1310
+ raise TypeError(
1311
+ "The linear operator should be a LinearModel object, a callable function or a matrix."
1312
+ )
1313
+
1314
+ # If the adjoint operator is of type Model, it needs to be a LinearModel
1315
+ if isinstance(linear_operator_adjoint, Model) and not isinstance(
1316
+ linear_operator_adjoint, LinearModel
1317
+ ):
1318
+ raise TypeError(
1319
+ "The adjoint linear operator should be a LinearModel object, a callable function or a matrix."
1320
+ )
1321
+
1322
+ # Additional checks if the linear_operator is not a LinearModel:
1323
+ if not isinstance(linear_operator, LinearModel):
1324
+ # Ensure the linear operator has exactly one input argument
1325
+ if len(cuqi.utilities.get_non_default_args(linear_operator)) != 1:
1326
+ raise ValueError(
1327
+ "The linear operator should have exactly one input argument."
1328
+ )
1329
+ # Ensure the adjoint linear operator has exactly one input argument
1330
+ if (
1331
+ linear_operator_adjoint is not None
1332
+ and len(cuqi.utilities.get_non_default_args(linear_operator_adjoint))
1333
+ != 1
1334
+ ):
1335
+ raise ValueError(
1336
+ "The adjoint linear operator should have exactly one input argument."
1337
+ )
1338
+
562
1339
  # Check size of shift and match against range_geometry
563
1340
  if not np.isscalar(shift):
564
1341
  if len(shift) != range_geometry.par_dim:
565
1342
  raise ValueError("The shift should have the same dimension as the range geometry.")
566
1343
 
1344
+ # Store linear operator privately
1345
+ # Note: we need to set the _linear_operator before calling the
1346
+ # super().__init__() because it is needed when calling the property
1347
+ # _non_default_args within the super().__init__()
1348
+ self._linear_operator = linear_operator
1349
+
567
1350
  # Initialize Model class
568
1351
  super().__init__(linear_operator, range_geometry, domain_geometry)
569
1352
 
@@ -573,20 +1356,27 @@ class AffineModel(Model):
573
1356
  # Store shift as private attribute
574
1357
  self._shift = shift
575
1358
 
576
- # Store linear operator privately
577
- self._linear_operator = linear_operator
578
1359
 
579
1360
  # Store adjoint function
580
1361
  self._linear_operator_adjoint = linear_operator_adjoint
581
1362
 
582
1363
  # Define gradient
583
- self._gradient_func = lambda direction, wrt: linear_operator_adjoint(direction)
1364
+ self._gradient_func = lambda direction, *args, **kwargs: linear_operator_adjoint(direction)
584
1365
 
585
1366
  # Update forward function to include shift (overwriting the one from Model class)
586
1367
  self._forward_func = lambda *args, **kwargs: linear_operator(*args, **kwargs) + shift
587
1368
 
588
- # Use arguments from user's callable linear operator (overwriting those found by Model class)
589
- self._non_default_args = cuqi.utilities.get_non_default_args(linear_operator)
1369
+ # Set stored_non_default_args to None
1370
+ self._stored_non_default_args = None
1371
+
1372
+ @property
1373
+ def _non_default_args(self):
1374
+ if self._stored_non_default_args is None:
1375
+ # Use arguments from user's callable linear operator
1376
+ self._stored_non_default_args = cuqi.utilities.get_non_default_args(
1377
+ self._linear_operator
1378
+ )
1379
+ return self._stored_non_default_args
590
1380
 
591
1381
  @property
592
1382
  def shift(self):
@@ -599,19 +1389,35 @@ class AffineModel(Model):
599
1389
  self._shift = value
600
1390
  self._forward_func = lambda *args, **kwargs: self._linear_operator(*args, **kwargs) + value
601
1391
 
602
- def _forward_func_no_shift(self, x, is_par=True):
603
- """ Helper function for computing the forward operator without the shift. """
604
- return self._apply_func(self._linear_operator,
605
- self.range_geometry,
606
- self.domain_geometry,
607
- x, is_par)
1392
+ def _forward_func_no_shift(self, *args, is_par=True, **kwargs):
1393
+ """Helper function for computing the forward operator without the shift."""
1394
+ # convert args to kwargs
1395
+ kwargs = self._parse_args_add_to_kwargs(
1396
+ *args, **kwargs, map_name="model", is_par=is_par
1397
+ )
1398
+ args = list(kwargs.values())
1399
+ # if model has _original_non_default_args, we use it to replace the
1400
+ # kwargs keys so that it matches self._linear_operator signature
1401
+ if hasattr(self, '_original_non_default_args'):
1402
+ kwargs = {k:v for k,v in zip(self._original_non_default_args, args)}
1403
+ return self._apply_func(self._linear_operator, **kwargs, is_par=is_par)
1404
+
1405
+ def _adjoint_func_no_shift(self, *args, is_par=True, **kwargs):
1406
+ """Helper function for computing the adjoint operator without the shift."""
1407
+ # convert args to kwargs
1408
+ kwargs = self._parse_args_add_to_kwargs(
1409
+ *args,
1410
+ **kwargs,
1411
+ map_name='adjoint',
1412
+ is_par=is_par,
1413
+ non_default_args=cuqi.utilities.get_non_default_args(
1414
+ self._linear_operator_adjoint
1415
+ ),
1416
+ )
1417
+ return self._apply_func(
1418
+ self._linear_operator_adjoint, **kwargs, is_par=is_par, fwd=False
1419
+ )
608
1420
 
609
- def _adjoint_func_no_shift(self, y, is_par=True):
610
- """ Helper function for computing the adjoint operator without the shift. """
611
- return self._apply_func(self._linear_operator_adjoint,
612
- self.domain_geometry,
613
- self.range_geometry,
614
- y, is_par)
615
1421
 
616
1422
  class LinearModel(AffineModel):
617
1423
  """Model based on a Linear forward operator.
@@ -677,50 +1483,67 @@ class LinearModel(AffineModel):
677
1483
  Note that you would need to specify the range and domain geometries in this
678
1484
  case as they cannot be inferred from the forward and adjoint functions.
679
1485
  """
680
-
1486
+
681
1487
  def __init__(self, forward, adjoint=None, range_geometry=None, domain_geometry=None):
682
1488
 
683
- #Initialize as AffineModel with shift=0
1489
+ # Initialize as AffineModel with shift=0
684
1490
  super().__init__(forward, 0, adjoint, range_geometry, domain_geometry)
685
1491
 
686
- def adjoint(self, y, is_par=True):
1492
+ def adjoint(self, *args, is_par=True, **kwargs):
687
1493
  """ Adjoint of the model.
688
1494
 
689
- Adjoint converts the input to function values (if needed) using the range geometry of the model.
690
- Adjoint converts the output function values to parameters using the range geometry of the model.
1495
+ Adjoint converts the input to function values (if needed) using the range geometry of the model then applies the adjoint operator to the function values and converts the output function values to parameters using the domain geometry of the model.
691
1496
 
692
1497
  Parameters
693
1498
  ----------
694
- y : ndarray or cuqi.array.CUQIarray
695
- The adjoint model input.
1499
+ *args : ndarrays or cuqi.array.CUQIarray object
1500
+ Positional arguments for the adjoint operator ( maximum one argument). The adjoint operator input can be specified as either positional arguments or keyword arguments but not both.
1501
+
1502
+ **kwargs : keyword arguments
1503
+ keyword arguments for the adjoint operator (maximum one argument). The adjoint operator input can be specified as either positional arguments or keyword arguments but not both.
1504
+
1505
+ If the input is specified as keyword arguments, the keys should match the non_default_args of the model.
696
1506
 
697
1507
  Returns
698
1508
  -------
699
1509
  ndarray or cuqi.array.CUQIarray
700
1510
  The adjoint model output. Always returned as parameters.
701
1511
  """
1512
+ kwargs = self._parse_args_add_to_kwargs(
1513
+ *args,
1514
+ **kwargs,
1515
+ map_name='adjoint',
1516
+ is_par=is_par,
1517
+ non_default_args=cuqi.utilities.get_non_default_args(
1518
+ self._linear_operator_adjoint
1519
+ ),
1520
+ )
1521
+
1522
+ # length of kwargs should be 1
1523
+ if len(kwargs) > 1:
1524
+ raise ValueError(
1525
+ "The adjoint operator input is specified by more than one argument. This is not supported."
1526
+ )
702
1527
  if self._linear_operator_adjoint is None:
703
1528
  raise ValueError("No adjoint operator was provided for this model.")
704
- return self._apply_func(self._linear_operator_adjoint,
705
- self.domain_geometry,
706
- self.range_geometry,
707
- y, is_par)
1529
+ return self._apply_func(
1530
+ self._linear_operator_adjoint, **kwargs, is_par=is_par, fwd=False
1531
+ )
1532
+
1533
+ def __matmul__(self, *args, **kwargs):
1534
+ return self.forward(*args, **kwargs)
708
1535
 
709
- def __matmul__(self, x):
710
- return self.forward(x)
711
-
712
1536
  def get_matrix(self):
713
1537
  """
714
1538
  Returns an ndarray with the matrix representing the forward operator.
715
1539
  """
716
-
717
1540
  if self._matrix is not None: #Matrix exists so return it
718
1541
  return self._matrix
719
1542
  else:
720
- #TODO: Can we compute this faster while still in sparse format?
1543
+ # TODO: Can we compute this faster while still in sparse format?
721
1544
  mat = csc_matrix((self.range_dim,0)) #Sparse (m x 1 matrix)
722
1545
  e = np.zeros(self.domain_dim)
723
-
1546
+
724
1547
  # Stacks sparse matrices on csc matrix
725
1548
  for i in range(self.domain_dim):
726
1549
  e[i] = 1
@@ -728,7 +1551,7 @@ class LinearModel(AffineModel):
728
1551
  mat = hstack((mat,col_vec[:,None])) #mat[:,i] = self.forward(e)
729
1552
  e[i] = 0
730
1553
 
731
- #Store matrix for future use
1554
+ # Store matrix for future use
732
1555
  self._matrix = mat
733
1556
 
734
1557
  return self._matrix
@@ -736,61 +1559,129 @@ class LinearModel(AffineModel):
736
1559
  @property
737
1560
  def T(self):
738
1561
  """Transpose of linear model. Returns a new linear model acting as the transpose."""
739
- transpose = LinearModel(self.adjoint, self.forward, self.domain_geometry, self.range_geometry)
1562
+ transpose = LinearModel(
1563
+ self._linear_operator_adjoint,
1564
+ self._linear_operator,
1565
+ self.domain_geometry,
1566
+ self.range_geometry,
1567
+ )
740
1568
  if self._matrix is not None:
741
1569
  transpose._matrix = self._matrix.T
742
1570
  return transpose
743
-
1571
+
744
1572
 
745
1573
  class PDEModel(Model):
746
1574
  """
747
1575
  Model based on an underlying cuqi.pde.PDE.
748
- In the forward operation the PDE is assembled, solved and observed.
1576
+ In the forward method the PDE is assembled, solved and observed.
749
1577
 
750
1578
  Parameters
751
1579
  -----------
752
- forward : 2D ndarray or callable function.
753
- Forward operator assembling, solving and observing the pde.
1580
+ PDE : cuqi.pde.PDE
1581
+ The PDE that specifies the forward operator.
754
1582
 
755
- range_geometry : integer or cuqi.geometry.Geometry (optional)
1583
+ range_geometry : integer or cuqi.geometry.Geometry, optional
756
1584
  If integer is given, a cuqi.geometry._DefaultGeometry is created with dimension of the integer.
757
1585
 
758
- domain_geometry : integer or cuqi.geometry.Geometry (optional)
1586
+ domain_geometry : integer or cuqi.geometry.Geometry, optional
759
1587
  If integer is given, a cuqi.geometry._DefaultGeometry is created with dimension of the integer.
760
1588
 
761
1589
 
762
1590
  :ivar range_geometry: The geometry representing the range.
763
1591
  :ivar domain_geometry: The geometry representing the domain.
764
1592
  """
765
- def __init__(self, PDE: cuqi.pde.PDE, range_geometry, domain_geometry):
1593
+ def __init__(self, PDE: cuqi.pde.PDE, range_geometry, domain_geometry, **kwargs):
766
1594
 
767
1595
  if not isinstance(PDE, cuqi.pde.PDE):
768
1596
  raise ValueError("PDE needs to be a cuqi PDE.")
1597
+ # PDE needs to be set before calling super().__init__
1598
+ # for the property _non_default_args to work
1599
+ self.pde = PDE
1600
+ self._stored_non_default_args = None
1601
+
1602
+ # If gradient or jacobian is not provided, we create it from the PDE
1603
+ if not np.any([k in kwargs.keys() for k in ["gradient", "jacobian"]]):
1604
+ # Create gradient or jacobian function to pass to the Model based on
1605
+ # the PDE object. The dictionary derivative_kwarg contains the
1606
+ # created function along with the function type (either "gradient"
1607
+ # or "jacobian")
1608
+ derivative_kwarg = self._create_derivative_function()
1609
+ # append derivative_kwarg to kwargs
1610
+ kwargs.update(derivative_kwarg)
1611
+
1612
+ super().__init__(forward=self._forward_func_pde,
1613
+ range_geometry=range_geometry,
1614
+ domain_geometry=domain_geometry,
1615
+ **kwargs)
1616
+
1617
+ @property
1618
+ def _non_default_args(self):
1619
+ if self._stored_non_default_args is None:
1620
+ # extract the non-default arguments of the PDE
1621
+ self._stored_non_default_args = self.pde._non_default_args
769
1622
 
770
- super().__init__(self._forward_func, range_geometry, domain_geometry, gradient=self._gradient_func)
1623
+ return self._stored_non_default_args
771
1624
 
772
- self.pde = PDE
1625
+ def _forward_func_pde(self, **kwargs):
773
1626
 
774
- def _forward_func(self, x):
775
-
776
- self.pde.assemble(parameter=x)
1627
+ self.pde.assemble(**kwargs)
777
1628
 
778
1629
  sol, info = self.pde.solve()
779
1630
 
780
1631
  obs = self.pde.observe(sol)
781
1632
 
782
1633
  return obs
783
-
784
- def _gradient_func(self, direction, wrt):
785
- """ Compute direction-Jacobian product (gradient) of the model. """
1634
+
1635
+ def _create_derivative_function(self):
1636
+ """Private function that creates the derivative function (gradient or
1637
+ jacobian) based on the PDE object. The derivative function is created as
1638
+ a lambda function that takes the direction and the parameters as input
1639
+ and returns the gradient or jacobian of the PDE. This private function
1640
+ returns a dictionary with the created function and the function type
1641
+ (either "gradient" or "jacobian")."""
1642
+
786
1643
  if hasattr(self.pde, "gradient_wrt_parameter"):
787
- return self.pde.gradient_wrt_parameter(direction, wrt)
1644
+ # Build the string that will be used to create the lambda function
1645
+ function_str = (
1646
+ "lambda direction, "
1647
+ + ", ".join(self._non_default_args)
1648
+ + ", pde_func: pde_func(direction, "
1649
+ + ", ".join(self._non_default_args)
1650
+ + ")"
1651
+ )
1652
+
1653
+ # create the lambda function from the string
1654
+ function = eval(function_str)
1655
+
1656
+ # create partial function from the lambda function with gradient_wrt_parameter
1657
+ # as the first argument
1658
+ grad_func = partial(function, pde_func=self.pde.gradient_wrt_parameter)
1659
+
1660
+ # Return the gradient function
1661
+ return {"gradient": grad_func}
1662
+
788
1663
  elif hasattr(self.pde, "jacobian_wrt_parameter"):
789
- return direction@self.pde.jacobian_wrt_parameter(wrt)
1664
+ # Build the string that will be used to create the lambda function
1665
+ function_str = (
1666
+ "lambda "
1667
+ + ", ".join(self._non_default_args)
1668
+ + ", pde_func: pde_func( "
1669
+ + ", ".join(self._non_default_args)
1670
+ + ")"
1671
+ )
1672
+
1673
+ # create the lambda function from the string
1674
+ function = eval(function_str)
1675
+
1676
+ # create partial function from the lambda function with jacobian_wrt_parameter
1677
+ # as the first argument
1678
+ jacobian_func = partial(function, pde_func=self.pde.jacobian_wrt_parameter)
1679
+
1680
+ # Return the jacobian function
1681
+ return {"jacobian": jacobian_func}
790
1682
  else:
791
- raise NotImplementedError("Gradient is not implemented for this model.")
1683
+ return {} # empty dictionary if no gradient or jacobian is found
792
1684
 
793
1685
  # Add the underlying PDE class name to the repr.
794
1686
  def __repr__(self) -> str:
795
1687
  return super().__repr__()+"\n PDE: {}.".format(self.pde.__class__.__name__)
796
-