CUQIpy 1.2.0.post0.dev400__py3-none-any.whl → 1.2.0.post0.dev444__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev400
3
+ Version: 1.2.0.post0.dev444
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=073mgTeoid6lu7OZMsn2gA0lMQNXn1tiqWjv2Vv7SXs,510
3
+ cuqi/_version.py,sha256=Fc1E3Z65WGvlyMnL6AoW6ky07TJ_RDfj_2x-NUvGRqE,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -36,9 +36,9 @@ cuqi/distribution/_truncated_normal.py,sha256=sZkLYgnkGOyS_3ZxY7iw6L62t-Jh6shzsw
36
36
  cuqi/distribution/_uniform.py,sha256=KA8yQ6ZS3nQGS4PYJ4hpDg6Eq8EQKQvPsIpYfR8fj2w,1967
37
37
  cuqi/experimental/__init__.py,sha256=6PFlmAkWuxWhzVrZz2g10tBBDuH5542G02nIRQfQCNg,128
38
38
  cuqi/experimental/algebra/__init__.py,sha256=btRAWG58ZfdtK0afXKOg60AX7d76KMBjlZa4AWBCCgU,81
39
- cuqi/experimental/algebra/_ast.py,sha256=iJ_umDzTct4O9tZM-ep2NNkrdxR4_PTIEOrxZiwvlc0,9257
40
- cuqi/experimental/algebra/_orderedset.py,sha256=8SxktP1333ByTldtqzU9xLQ5SAFU0V9B-i6U1prVBYk,2019
41
- cuqi/experimental/algebra/_randomvariable.py,sha256=lwOTy9KApqXwJ57VBYUBMrCwpbA7cR91OlT0ZUtiIaE,15841
39
+ cuqi/experimental/algebra/_ast.py,sha256=PdPz19cJMjvnMx4KEzhn4gvxIZX_UViE33Mbttj_5Xw,9873
40
+ cuqi/experimental/algebra/_orderedset.py,sha256=fKysh4pmI4xF7Y5Z6O86ABzg20o4uBs-v8jmLBMrdpo,2849
41
+ cuqi/experimental/algebra/_randomvariable.py,sha256=1VwJjsF5PPmkchGa7mbNCcAgnt19olkvMeHCRAvEVtk,18911
42
42
  cuqi/experimental/geometry/__init__.py,sha256=kgoKegfz3Jhr7fpORB_l55z9zLZRtloTLyXFDh1oF2o,47
43
43
  cuqi/experimental/geometry/_productgeometry.py,sha256=G-hIYnfLiRS5IWD2EPXORNBKNP2zSaCCHAeBlDC_R3I,7177
44
44
  cuqi/experimental/mcmc/__init__.py,sha256=zSqLZmxOqQ-F94C9-gPv7g89TX1XxlrlNm071Eb167I,4487
@@ -52,18 +52,18 @@ cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=NIoCLKL5x89Bxm-JLDLR_NTunRE
52
52
  cuqi/experimental/mcmc/_laplace_approximation.py,sha256=XcGIa2wl9nCSTtAFurejYYOKkDVAJ22q75xQKsyu2nI,5803
53
53
  cuqi/experimental/mcmc/_mh.py,sha256=MXo0ahXP4KGFkaY4HtvcBE-TMQzsMlTmLKzSvpz7drU,2941
54
54
  cuqi/experimental/mcmc/_pcn.py,sha256=wqJBZLuRFSwxihaI53tumAg6AWVuceLMOmXssTetd1A,3374
55
- cuqi/experimental/mcmc/_rto.py,sha256=lzfeUuV8jUiWG-80KQ4if6toVcX7bMv-a0chBZq0vZ4,12021
55
+ cuqi/experimental/mcmc/_rto.py,sha256=pFFzKaEDT2yLp-kstGs0tPdoq20gh3axj3XAszj9xFQ,13905
56
56
  cuqi/experimental/mcmc/_sampler.py,sha256=BZHnpB6s-YSddd46wQSds0vNF61RA58Nc9ZU05WngdU,20184
57
57
  cuqi/experimental/mcmc/_utilities.py,sha256=kUzHbhIS3HYZRbneNBK41IogUYX5dS_bJxqEGm7TQBI,525
58
58
  cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
59
59
  cuqi/geometry/_geometry.py,sha256=5ZNrw6LivxEEw0vrk1eCxKIw-8mkAh7930voRVywDbY,47089
60
60
  cuqi/implicitprior/__init__.py,sha256=6z3lvw-tWDyjZSpB3pYzvijSMK9Zlf1IYqOVTtMD2h4,309
61
61
  cuqi/implicitprior/_regularizedGMRF.py,sha256=rr3R2C1aheuu_KD35MureZKfOwY8O1pkVDHvuaFnFFU,6300
62
- cuqi/implicitprior/_regularizedGaussian.py,sha256=mzaAHq0yz73FZo-OB2iqFMd2i2NNzVv4mjd9-ger8a0,15435
62
+ cuqi/implicitprior/_regularizedGaussian.py,sha256=btpjKUG1byLSu7S3J8N1MZZBuskCEIdmatMASwQEHtE,15656
63
63
  cuqi/implicitprior/_regularizedUnboundedUniform.py,sha256=H2fTOSqYTlDiLxQ7Ya6wnpCUIkpO4qKrkTOsOPnBBeU,3483
64
64
  cuqi/implicitprior/_restorator.py,sha256=Z350XUJEt7N59Qw-SIUaBljQNDJk4Zb0i_KRFrt2DCg,10087
65
65
  cuqi/likelihood/__init__.py,sha256=QXif382iwZ5bT3ZUqmMs_n70JVbbjxbqMrlQYbMn4Zo,1776
66
- cuqi/likelihood/_likelihood.py,sha256=z3AXAbIrv_DjOYh4jy3iDHemuIFUUJu6wdvJ5e2dgW0,6913
66
+ cuqi/likelihood/_likelihood.py,sha256=PuW8ufRefLt6w40JQWqNnEh3YCLxu4pz0h0PcpT8inc,7075
67
67
  cuqi/model/__init__.py,sha256=jgY2-jyxEMC79vkyH9BpfowW7_DbMRjqedOtO5fykXQ,62
68
68
  cuqi/model/_model.py,sha256=LqeMwOSb1oIGpT7g1cmItP_2Q4dmgg8eNPNo0joPUyg,32905
69
69
  cuqi/operator/__init__.py,sha256=0pc9p-KPyl7KtPV0noB0ddI0CP2iYEHw5rbw49D8Njk,136
@@ -86,15 +86,15 @@ cuqi/sampler/_rto.py,sha256=KIs0cDEoYK5I35RwO9fr5eKWeINLsmTLSVBnLdZmzzM,11921
86
86
  cuqi/sampler/_sampler.py,sha256=TkZ_WAS-5Q43oICa-Elc2gftsRTBd7PEDUMDZ9tTGmU,5712
87
87
  cuqi/samples/__init__.py,sha256=vCs6lVk-pi8RBqa6cIN5wyn6u-K9oEf1Na4k1ZMrYv8,44
88
88
  cuqi/samples/_samples.py,sha256=hUc8OnCF9CTCuDTrGHwwzv3wp8mG_6vsJAFvuQ-x0uA,35832
89
- cuqi/solver/__init__.py,sha256=KuNlGxjPphG9tV-46YrmbcSQNhi0HMyhDd_v6V5sRaQ,209
90
- cuqi/solver/_solver.py,sha256=2mil7Gq7InHlQaOfnuTRDB5nw6dqJk9gR8DLFxpw_6g,29481
89
+ cuqi/solver/__init__.py,sha256=KYgAi_8VoAwljTB3S2I87YnJkRtedskLee7hQp_-zp8,220
90
+ cuqi/solver/_solver.py,sha256=hGOFEJ74s0qHvXwt8h0JBdnqE5LGa_yZzNB9cd8zPAs,30661
91
91
  cuqi/testproblem/__init__.py,sha256=DWTOcyuNHMbhEuuWlY5CkYkNDSAqhvsKmJXBLivyblU,202
92
92
  cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7sS0Q,52540
93
93
  cuqi/utilities/__init__.py,sha256=H7xpJe2UinjZftKvE2JuXtTi4DqtkR6uIezStAXwfGg,428
94
94
  cuqi/utilities/_get_python_variable_name.py,sha256=wxpCaj9f3ZtBNqlGmmuGiITgBaTsY-r94lUIlK6UAU4,2043
95
95
  cuqi/utilities/_utilities.py,sha256=Jc4knn80vLoA7kgw9FzXwKVFGaNBOXiA9kgvltZU3Ao,11777
96
- CUQIpy-1.2.0.post0.dev400.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
97
- CUQIpy-1.2.0.post0.dev400.dist-info/METADATA,sha256=Tof8IyOdtibJyUT46nATVE0MdmsP-YDEAKP6qWVAc10,18529
98
- CUQIpy-1.2.0.post0.dev400.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
99
- CUQIpy-1.2.0.post0.dev400.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
100
- CUQIpy-1.2.0.post0.dev400.dist-info/RECORD,,
96
+ CUQIpy-1.2.0.post0.dev444.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
97
+ CUQIpy-1.2.0.post0.dev444.dist-info/METADATA,sha256=F7jDyAUca9QDe7vKeYEH8B8IdBXrdI8QWIhnGhWtqP0,18529
98
+ CUQIpy-1.2.0.post0.dev444.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
99
+ CUQIpy-1.2.0.post0.dev444.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
100
+ CUQIpy-1.2.0.post0.dev444.dist-info/RECORD,,
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-01-10T07:03:38+0300",
11
+ "date": "2025-01-17T09:21:17+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "4ef35e31b8a7b894a879a6e8c22487ab96989b95",
15
- "version": "1.2.0.post0.dev400"
14
+ "full-revisionid": "0b7a72688059f892d521f7e332065b2ccdc456bd",
15
+ "version": "1.2.0.post0.dev444"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -51,6 +51,11 @@ class Node(ABC):
51
51
  """Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
52
52
  pass
53
53
 
54
+ @abstractmethod
55
+ def condition(self, **kwargs):
56
+ """ Conditions the tree by replacing any VariableNode with a ValueNode if the variable is in the kwargs dictionary. """
57
+ pass
58
+
54
59
  @abstractmethod
55
60
  def __repr__(self):
56
61
  """String representation of the node. Used for printing the AST."""
@@ -129,6 +134,9 @@ class UnaryNode(Node, ABC):
129
134
  def __init__(self, child: Node):
130
135
  self.child = child
131
136
 
137
+ def condition(self, **kwargs):
138
+ return self.__class__(self.child.condition(**kwargs))
139
+
132
140
 
133
141
  class BinaryNode(Node, ABC):
134
142
  """Base class for all binary nodes in the abstract syntax tree.
@@ -155,6 +163,9 @@ class BinaryNode(Node, ABC):
155
163
  self.left = left
156
164
  self.right = right
157
165
 
166
+ def condition(self, **kwargs):
167
+ return self.__class__(self.left.condition(**kwargs), self.right.condition(**kwargs))
168
+
158
169
  def __repr__(self):
159
170
  return f"{self.left} {self.op_symbol} {self.right}"
160
171
 
@@ -205,6 +216,11 @@ class VariableNode(Node):
205
216
  )
206
217
  return kwargs[self.name]
207
218
 
219
+ def condition(self, **kwargs):
220
+ if self.name in kwargs:
221
+ return ValueNode(kwargs[self.name])
222
+ return self
223
+
208
224
  def __repr__(self):
209
225
  return self.name
210
226
 
@@ -226,6 +242,9 @@ class ValueNode(Node):
226
242
  """Returns the value of the node."""
227
243
  return self.value
228
244
 
245
+ def condition(self, **kwargs):
246
+ return self
247
+
229
248
  def __repr__(self):
230
249
  return str(self.value)
231
250
 
@@ -19,6 +19,13 @@ class _OrderedSet:
19
19
  """
20
20
  self.dict[item] = None
21
21
 
22
+ def remove(self, item):
23
+ """Remove an item from the set.
24
+
25
+ If the item is not in the set, it raises a KeyError.
26
+ """
27
+ del self.dict[item]
28
+
22
29
  def __contains__(self, item):
23
30
  """Check if an item is in the set.
24
31
 
@@ -47,6 +54,18 @@ class _OrderedSet:
47
54
  for item in other:
48
55
  self.add(item)
49
56
 
57
+ def replace(self, old_item, new_item):
58
+ """Replace old_item with new_item at the same position, preserving order."""
59
+ if old_item not in self.dict:
60
+ raise KeyError(f"{old_item} not in set")
61
+
62
+ items = list(self.dict.keys()) # Preserve order
63
+ index = items.index(old_item) # Find position
64
+ items[index] = new_item # Replace at the same position
65
+
66
+ # Reconstruct the ordered set with the new item in place
67
+ self.dict = dict.fromkeys(items)
68
+
50
69
  def __or__(self, other):
51
70
  """Return a new set that is the union of this set and another set.
52
71
 
@@ -57,3 +76,7 @@ class _OrderedSet:
57
76
  new_set = _OrderedSet(self.dict.keys())
58
77
  new_set.extend(other)
59
78
  return new_set
79
+
80
+ def __repr__(self):
81
+ """Return a string representation of the set."""
82
+ return "_OrderedSet({})".format(list(self.dict.keys()))
@@ -5,8 +5,7 @@ from ._orderedset import _OrderedSet
5
5
  import operator
6
6
  import cuqi
7
7
  from cuqi.distribution import Distribution
8
- from copy import copy
9
-
8
+ from copy import copy, deepcopy
10
9
 
11
10
  class RandomVariable:
12
11
  """ Random variable defined by a distribution with the option to apply algebraic operations on it.
@@ -210,7 +209,7 @@ class RandomVariable:
210
209
  def parameter_names(self) -> str:
211
210
  """ Name of the parameter that the random variable can be evaluated at. """
212
211
  self._inject_name_into_distribution()
213
- return [distribution.name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions
212
+ return [distribution._name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions
214
213
 
215
214
  @property
216
215
  def dim(self):
@@ -239,7 +238,57 @@ class RandomVariable:
239
238
  def is_transformed(self):
240
239
  """ Returns True if the random variable is transformed. """
241
240
  return not isinstance(self.tree, VariableNode)
242
-
241
+
242
+ @property
243
+ def is_cond(self):
244
+ """ Returns True if the random variable is a conditional random variable. """
245
+ return any(dist.is_cond for dist in self.distributions)
246
+
247
+ def condition(self, *args, **kwargs):
248
+ """Condition the random variable on a given value. Only one of either positional or keyword arguments can be passed.
249
+
250
+ Parameters
251
+ ----------
252
+ *args : Any
253
+ Positional arguments to condition the random variable on. The order of the arguments must match the order of the parameter names.
254
+
255
+ **kwargs : Any
256
+ Keyword arguments to condition the random variable on. The keys must match the parameter names.
257
+
258
+ """
259
+
260
+ # Before conditioning, capture repr to ensure all variable names are injected
261
+ self.__repr__()
262
+
263
+ if args and kwargs:
264
+ raise ValueError("Cannot pass both positional and keyword arguments to RandomVariable")
265
+
266
+ if args:
267
+ kwargs = self._parse_args_add_to_kwargs(args, kwargs)
268
+
269
+ # Create a deep copy of the random variable to ensure the original tree is not modified
270
+ new_variable = self._make_copy(deep=True)
271
+
272
+ for kwargs_name in list(kwargs.keys()):
273
+ value = kwargs.pop(kwargs_name)
274
+
275
+ # Condition the tree turning the variable into a constant
276
+ if kwargs_name in self.parameter_names:
277
+ new_variable._tree = new_variable.tree.condition(**{kwargs_name: value})
278
+
279
+ # Condition the random variable on both the distribution parameter name and distribution conditioning variables
280
+ for dist in self.distributions:
281
+ if kwargs_name == dist.name:
282
+ new_variable._remove_distribution(dist.name)
283
+ elif kwargs_name in dist.get_conditioning_variables():
284
+ new_variable._replace_distribution(dist.name, dist(**{kwargs_name: value}))
285
+
286
+ # Check if any kwargs are left unprocessed
287
+ if kwargs:
288
+ raise ValueError(f"Conditioning variables {list(kwargs.keys())} not found in the random variable {self}")
289
+
290
+ return new_variable
291
+
243
292
  @property
244
293
  def _non_default_args(self) -> List[str]:
245
294
  """List of non-default arguments to distribution. This is used to return the correct
@@ -247,13 +296,31 @@ class RandomVariable:
247
296
  """
248
297
  return self.parameter_names
249
298
 
299
+ def _replace_distribution(self, name, new_distribution):
300
+ """ Replace distribution with a given name with a new distribution in the same position of the ordered set. """
301
+ for dist in self.distributions:
302
+ if dist._name == name:
303
+ self._distributions.replace(dist, new_distribution)
304
+ break
305
+
306
+ def _remove_distribution(self, name):
307
+ """ Remove distribution with a given name from the set of distributions. """
308
+ for dist in self.distributions:
309
+ if dist._name == name:
310
+ self._distributions.remove(dist)
311
+ break
312
+
250
313
  def _inject_name_into_distribution(self, name=None):
251
314
  if len(self._distributions) == 1:
252
315
  dist = next(iter(self._distributions))
316
+
317
+ if dist._is_copy:
318
+ dist = dist._original_density
319
+
253
320
  if dist._name is None:
254
321
  if name is None:
255
322
  name = self.name
256
- dist._name = name
323
+ dist.name = name # Inject using setter
257
324
 
258
325
  def _parse_args_add_to_kwargs(self, args, kwargs) -> dict:
259
326
  """ Parse args and add to kwargs if any. Arguments follow self.parameter_names order. """
@@ -293,8 +360,12 @@ class RandomVariable:
293
360
  """ Returns True if this is a copy of another random variable, e.g. by conditioning. """
294
361
  return hasattr(self, '_original_variable') and self._original_variable is not None
295
362
 
296
- def _make_copy(self):
297
- """ Returns a shallow copy of the density keeping a pointer to the original. """
363
+ def _make_copy(self, deep=False) -> 'RandomVariable':
364
+ """ Returns a copy of the density keeping a pointer to the original. """
365
+ if deep:
366
+ new_variable = deepcopy(self)
367
+ new_variable._original_variable = self
368
+ return new_variable
298
369
  new_variable = copy(self)
299
370
  new_variable._distributions = copy(self.distributions)
300
371
  new_variable._tree = copy(self._tree)
@@ -3,7 +3,7 @@ from scipy.linalg.interpolative import estimate_spectral_norm
3
3
  from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
4
4
  import numpy as np
5
5
  import cuqi
6
- from cuqi.solver import CGLS, FISTA, ADMM
6
+ from cuqi.solver import CGLS, FISTA, ADMM, ScipyLinearLSQ
7
7
  from cuqi.experimental.mcmc import Sampler
8
8
 
9
9
 
@@ -168,6 +168,7 @@ class RegularizedLinearRTO(LinearRTO):
168
168
  Used when prior.proximal is callable.
169
169
  ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
170
170
  Used when prior.proximal is a list of penalty terms.
171
+ ScipyLinearLSQ: Wrapper for Scipy's lsq_linear for the Trust Region Reflective algorithm. Optionally used when the constraint is either "nonnegativity" or "box".
171
172
 
172
173
  Parameters
173
174
  ------------
@@ -178,7 +179,7 @@ class RegularizedLinearRTO(LinearRTO):
178
179
  Initial point for the sampler. *Optional*.
179
180
 
180
181
  maxit : int
181
- Maximum number of iterations of the inner FISTA/ADMM solver. *Optional*.
182
+ Maximum number of iterations of the FISTA/ADMM/ScipyLinearLSQ solver. *Optional*.
182
183
 
183
184
  inner_max_it : int
184
185
  Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
@@ -188,14 +189,20 @@ class RegularizedLinearRTO(LinearRTO):
188
189
  If stepsize is a float, then this stepsize is used.
189
190
 
190
191
  penalty_parameter : int
191
- Penalty parameter of the inner ADMM solver. *Optional*.
192
+ Penalty parameter of the ADMM solver. *Optional*.
192
193
  See [2] or `cuqi.solver.ADMM`
193
194
 
194
195
  abstol : float
195
- Absolute tolerance of the inner FISTA solver. *Optional*.
196
+ Absolute tolerance of the FISTA/ScipyLinearLSQ solver. *Optional*.
197
+
198
+ inner_abstol : float
199
+ Tolerance parameter for ScipyLinearLSQ's inner solve of the unbounded least-squares problem. *Optional*.
196
200
 
197
201
  adaptive : bool
198
- If True, FISTA is used as inner solver, otherwise ISTA is used. *Optional*.
202
+ If True, FISTA is used as solver, otherwise ISTA is used. *Optional*.
203
+
204
+ solver : string
205
+ If set to "ScipyLinearLSQ", solver is set to cuqi.solver.ScipyLinearLSQ, otherwise FISTA/ISTA or ADMM is used. Note "ScipyLinearLSQ" can only be used with `RegularizedGaussian` of `box` or `nonnegativity` constraint. *Optional*.
199
206
 
200
207
  callback : callable, *Optional*
201
208
  If set this function will be called after every sample.
@@ -204,23 +211,41 @@ class RegularizedLinearRTO(LinearRTO):
204
211
  An example is shown in demos/demo31_callback.py.
205
212
 
206
213
  """
207
- def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, **kwargs):
214
+ def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):
208
215
 
209
216
  super().__init__(target=target, initial_point=initial_point, **kwargs)
210
217
 
211
218
  # Other parameters
212
219
  self.stepsize = stepsize
213
- self.abstol = abstol
220
+ self.abstol = abstol
221
+ self.inner_abstol = inner_abstol
214
222
  self.adaptive = adaptive
215
223
  self.maxit = maxit
216
224
  self.inner_max_it = inner_max_it
217
225
  self.penalty_parameter = penalty_parameter
226
+ self.solver = solver
218
227
 
219
228
  def _initialize(self):
220
229
  super()._initialize()
221
- if self._inner_solver == "FISTA":
230
+ if self.solver is None:
231
+ self.solver = "FISTA" if callable(self.proximal) else "ADMM"
232
+ if self.solver == "FISTA":
222
233
  self._stepsize = self._choose_stepsize()
223
234
 
235
+ @property
236
+ def solver(self):
237
+ return self._solver
238
+
239
+ @solver.setter
240
+ def solver(self, value):
241
+ if value == "ScipyLinearLSQ":
242
+ if (self.target.prior._preset == "nonnegativity" or self.target.prior._preset == "box"):
243
+ self._solver = value
244
+ else:
245
+ raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")
246
+ else:
247
+ self._solver = value
248
+
224
249
  @property
225
250
  def proximal(self):
226
251
  return self.target.prior.proximal
@@ -229,7 +254,6 @@ class RegularizedLinearRTO(LinearRTO):
229
254
  super().validate_target()
230
255
  if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
231
256
  raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
232
- self._inner_solver = "FISTA" if callable(self.proximal) else "ADMM"
233
257
 
234
258
  def _choose_stepsize(self):
235
259
  if isinstance(self.stepsize, str):
@@ -254,15 +278,25 @@ class RegularizedLinearRTO(LinearRTO):
254
278
  def step(self):
255
279
  y = self.b_tild + np.random.randn(len(self.b_tild))
256
280
 
257
- if self._inner_solver == "FISTA":
281
+ if self.solver == "FISTA":
258
282
  sim = FISTA(self.M, y, self.proximal,
259
283
  self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
260
- elif self._inner_solver == "ADMM":
284
+ elif self.solver == "ADMM":
261
285
  sim = ADMM(self.M, y, self.proximal,
262
- self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
286
+ self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
287
+ elif self.solver == "ScipyLinearLSQ":
288
+ A_op = sp.sparse.linalg.LinearOperator((sum([llh.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
289
+ matvec=lambda x: self.M(x, 1),
290
+ rmatvec=lambda x: self.M(x, 2)
291
+ )
292
+ sim = ScipyLinearLSQ(A_op, y, self.target.prior._box_bounds,
293
+ max_iter = self.maxit,
294
+ lsmr_maxiter = self.inner_max_it,
295
+ tol = self.abstol,
296
+ lsmr_tol = self.inner_abstol)
263
297
  else:
264
298
  raise ValueError("Choice of solver not supported.")
265
299
 
266
300
  self.current_point, _ = sim.solve()
267
301
  acc = 1
268
- return acc
302
+ return acc
@@ -113,10 +113,12 @@ class RegularizedGaussian(Distribution):
113
113
  elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
114
114
  self._proximal = lambda z, gamma: ProjectNonnegative(z)
115
115
  self._preset = "nonnegativity"
116
+ self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
116
117
  elif (isinstance(constraint, str) and constraint.lower() == "box"):
117
- lower = optional_regularization_parameters["lower_bound"]
118
- upper = optional_regularization_parameters["upper_bound"]
119
- self._proximal = lambda z, _: ProjectBox(z, lower, upper)
118
+ self._box_lower = optional_regularization_parameters["lower_bound"]
119
+ self._box_upper = optional_regularization_parameters["upper_bound"]
120
+ self._box_bounds = (np.ones(self.dim)*self._box_lower, np.ones(self.dim)*self._box_upper)
121
+ self._proximal = lambda z, _: ProjectBox(z, self._box_lower, self._box_upper)
120
122
  self._preset = "box" # Not supported in Gibbs
121
123
  elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
122
124
  self._strength = optional_regularization_parameters["strength"]
@@ -43,6 +43,14 @@ class Likelihood(Density):
43
43
  def name(self, value):
44
44
  self.distribution.name = value
45
45
 
46
+ @property
47
+ def _name(self):
48
+ return self.distribution._name
49
+
50
+ @_name.setter
51
+ def _name(self, value):
52
+ self.distribution._name = value
53
+
46
54
  @property
47
55
  def FD_enabled(self):
48
56
  """ Return FD_enabled of the likelihood from the underlying distribution """
cuqi/solver/__init__.py CHANGED
@@ -2,7 +2,8 @@ from ._solver import (
2
2
  ScipyLBFGSB,
3
3
  ScipyMinimizer,
4
4
  ScipyMaximizer,
5
- ScipyLeastSquares,
5
+ ScipyLSQ,
6
+ ScipyLinearLSQ,
6
7
  CGLS,
7
8
  LM,
8
9
  PDHG,
cuqi/solver/_solver.py CHANGED
@@ -164,7 +164,7 @@ class ScipyMaximizer(ScipyMinimizer):
164
164
 
165
165
 
166
166
 
167
- class ScipyLeastSquares(object):
167
+ class ScipyLSQ(object):
168
168
  """Wrapper for :meth:`scipy.optimize.least_squares`.
169
169
 
170
170
  Solve nonlinear least-squares problems with bounds:
@@ -227,6 +227,44 @@ class ScipyLeastSquares(object):
227
227
  sol = solution['x']
228
228
  return sol, info
229
229
 
230
+ class ScipyLinearLSQ(object):
231
+ """Wrapper for :meth:`scipy.optimize.lsq_linear`.
232
+
233
+ Solve linear least-squares problems with bounds:
234
+
235
+ .. math::
236
+
237
+ \min \|A x - b\|_2^2
238
+
239
+ subject to :math:`lb <= x <= ub`.
240
+
241
+ Parameters
242
+ ----------
243
+ A : ndarray, LinearOperator
244
+ Design matrix (system matrix).
245
+ b : ndarray
246
+ The right-hand side of the linear system.
247
+ bounds : 2-tuple of array_like or scipy.optimize Bounds
248
+ Bounds for variables.
249
+ kwargs : Other keyword arguments passed to Scipy's `lsq_linear`. See documentation of `scipy.optimize.lsq_linear` for details.
250
+ """
251
+ def __init__(self, A, b, bounds=(-np.inf, np.inf), **kwargs):
252
+ self.A = A
253
+ self.b = b
254
+ self.bounds = bounds
255
+ self.kwargs = kwargs
256
+
257
+ def solve(self):
258
+ """Runs optimization algorithm and returns solution and optimization information.
259
+
260
+ Returns
261
+ ----------
262
+ solution : Tuple
263
+ Solution found (array_like) and optimization information (dictionary).
264
+ """
265
+ res = opt.lsq_linear(self.A, self.b, bounds=self.bounds, **self.kwargs)
266
+ x = res.pop('x')
267
+ return x, res
230
268
 
231
269
 
232
270
  class CGLS(object):