CUQIpy 1.2.0.post0.dev501__py3-none-any.whl → 1.2.0.post0.dev522__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev501
3
+ Version: 1.2.0.post0.dev522
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=-wStiLQQiCB-PTnm1EIuY2Bqw9TKX4zIsPlS3OkQHg8,510
3
+ cuqi/_version.py,sha256=JIWqq8WjGrOEv1FekTSRr2dHyBuR8Q4UxWr6ffpxqww,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -42,7 +42,7 @@ cuqi/experimental/algebra/_randomvariable.py,sha256=isbFtIWsWXF-yF5Vb56nLy4MCkQM
42
42
  cuqi/experimental/geometry/__init__.py,sha256=kgoKegfz3Jhr7fpORB_l55z9zLZRtloTLyXFDh1oF2o,47
43
43
  cuqi/experimental/geometry/_productgeometry.py,sha256=G-hIYnfLiRS5IWD2EPXORNBKNP2zSaCCHAeBlDC_R3I,7177
44
44
  cuqi/experimental/mcmc/__init__.py,sha256=zSqLZmxOqQ-F94C9-gPv7g89TX1XxlrlNm071Eb167I,4487
45
- cuqi/experimental/mcmc/_conjugate.py,sha256=MT_On2figPyYPVwrL19ocRFzfcJIvG2SDr3yRSmCSno,18971
45
+ cuqi/experimental/mcmc/_conjugate.py,sha256=yyN5ZKcz8xuWDa7wGJLAcx8xN_XpE_Hvg0DzcJTG9-g,19488
46
46
  cuqi/experimental/mcmc/_conjugate_approx.py,sha256=jmxe2FEbO9fwpc8opyjJ2px0oed3dGyj0qDwyHo4aOk,3545
47
47
  cuqi/experimental/mcmc/_cwmh.py,sha256=50v3uZaWhlVnfrEB5-lB_7pn8QoUVBe-xWxKGKbmNHg,7234
48
48
  cuqi/experimental/mcmc/_direct.py,sha256=9pQS_2Qk2-ybt6m8WTfPoKetcxQ00WaTRN85-Z0FrBY,777
@@ -52,15 +52,15 @@ cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=NIoCLKL5x89Bxm-JLDLR_NTunRE
52
52
  cuqi/experimental/mcmc/_laplace_approximation.py,sha256=XcGIa2wl9nCSTtAFurejYYOKkDVAJ22q75xQKsyu2nI,5803
53
53
  cuqi/experimental/mcmc/_mh.py,sha256=MXo0ahXP4KGFkaY4HtvcBE-TMQzsMlTmLKzSvpz7drU,2941
54
54
  cuqi/experimental/mcmc/_pcn.py,sha256=wqJBZLuRFSwxihaI53tumAg6AWVuceLMOmXssTetd1A,3374
55
- cuqi/experimental/mcmc/_rto.py,sha256=pFFzKaEDT2yLp-kstGs0tPdoq20gh3axj3XAszj9xFQ,13905
55
+ cuqi/experimental/mcmc/_rto.py,sha256=pnhgBR63A_OboHto-I9_o-GsY5yZaH3wbU8S6eXLXX0,13931
56
56
  cuqi/experimental/mcmc/_sampler.py,sha256=BZHnpB6s-YSddd46wQSds0vNF61RA58Nc9ZU05WngdU,20184
57
57
  cuqi/experimental/mcmc/_utilities.py,sha256=kUzHbhIS3HYZRbneNBK41IogUYX5dS_bJxqEGm7TQBI,525
58
58
  cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
59
59
  cuqi/geometry/_geometry.py,sha256=tsWMca6E-KEXwr_LhjwP7Lsdi5TWCyu0T956Cj5LEXQ,47091
60
60
  cuqi/implicitprior/__init__.py,sha256=6z3lvw-tWDyjZSpB3pYzvijSMK9Zlf1IYqOVTtMD2h4,309
61
- cuqi/implicitprior/_regularizedGMRF.py,sha256=Ck1JGo8jTb-Z8zDQl4-shMEB2_T0Az9fBpbxDjwatRU,6308
62
- cuqi/implicitprior/_regularizedGaussian.py,sha256=QlaloekbKojhdXVdmSEFwq__T15XoKBt3uL75mdi0KU,14935
63
- cuqi/implicitprior/_regularizedUnboundedUniform.py,sha256=Ez7TuyR3Y9Km4qeqGnUJl5tQ8-G3assAQm_id4yBNlI,3491
61
+ cuqi/implicitprior/_regularizedGMRF.py,sha256=BUeT4rwJzary9K56fkxCNGCeKZd-2VSgOT8XNHxFPRE,6345
62
+ cuqi/implicitprior/_regularizedGaussian.py,sha256=whTitoB5ZvWg8jm7ipugR_1ouK1M2EGrfwAnr46xDnE,19395
63
+ cuqi/implicitprior/_regularizedUnboundedUniform.py,sha256=uHGYYnTjVxdPbY-5JwocFOH0sHRfGrrLiHWahzH9R8A,3533
64
64
  cuqi/implicitprior/_restorator.py,sha256=Z350XUJEt7N59Qw-SIUaBljQNDJk4Zb0i_KRFrt2DCg,10087
65
65
  cuqi/likelihood/__init__.py,sha256=QXif382iwZ5bT3ZUqmMs_n70JVbbjxbqMrlQYbMn4Zo,1776
66
66
  cuqi/likelihood/_likelihood.py,sha256=PuW8ufRefLt6w40JQWqNnEh3YCLxu4pz0h0PcpT8inc,7075
@@ -73,7 +73,7 @@ cuqi/pde/_pde.py,sha256=WRkOYyIdT_T3aZepRh0aS9C5nBbUZUcHaA80iSRvgoo,12572
73
73
  cuqi/problem/__init__.py,sha256=JxJty4JqHTOqSG6NeTGiXRQ7OLxiRK9jvVq3lXLeIRw,38
74
74
  cuqi/problem/_problem.py,sha256=31ByO279-6hM8PhWjwD5k7i9aBAkk9S1tcgMzxv1PiQ,38256
75
75
  cuqi/sampler/__init__.py,sha256=D-dYa0gFgIwQukP8_VKhPGmlGKXbvVo7YqaET4SdAeQ,382
76
- cuqi/sampler/_conjugate.py,sha256=ztmUR3V3qZk9zelKx48ULnmMs_zKTDUfohc256VOIe8,2753
76
+ cuqi/sampler/_conjugate.py,sha256=x5OsFk1zDm2tvoFsSxbCKwjSqBHUGbcUvcTwDOvL-tw,2841
77
77
  cuqi/sampler/_conjugate_approx.py,sha256=xX-X71EgxGnZooOY6CIBhuJTs3dhcKfoLnoFxX3CO2g,1938
78
78
  cuqi/sampler/_cwmh.py,sha256=VlAVT1SXQU0yD5ZeR-_ckWvX-ifJrMweFFdFbxdfB_k,7775
79
79
  cuqi/sampler/_gibbs.py,sha256=N7qcePwMkRtxINN5JF0FaMIdDCXZGqsfKjfha_KHCck,8627
@@ -93,8 +93,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
93
93
  cuqi/utilities/__init__.py,sha256=RB84VstmFcZgPOz58LKSzOvCVebbeKDcKl9MGk-EwoA,515
94
94
  cuqi/utilities/_get_python_variable_name.py,sha256=wxpCaj9f3ZtBNqlGmmuGiITgBaTsY-r94lUIlK6UAU4,2043
95
95
  cuqi/utilities/_utilities.py,sha256=gc9YAj7wFKzyZTE1H5iI_1Tt4AtjT1g5l1-zxBdH-Co,15281
96
- CUQIpy-1.2.0.post0.dev501.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
97
- CUQIpy-1.2.0.post0.dev501.dist-info/METADATA,sha256=qAa9TTLdoTmo7L8fOcRrE_9sj6Sfk2z2DpfCY8eitD0,18529
98
- CUQIpy-1.2.0.post0.dev501.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
99
- CUQIpy-1.2.0.post0.dev501.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
100
- CUQIpy-1.2.0.post0.dev501.dist-info/RECORD,,
96
+ CUQIpy-1.2.0.post0.dev522.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
97
+ CUQIpy-1.2.0.post0.dev522.dist-info/METADATA,sha256=LpjrZejnpUX-w1ewT2GAXeRihKYXiE4WYBNp3OGq5eU,18529
98
+ CUQIpy-1.2.0.post0.dev522.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
99
+ CUQIpy-1.2.0.post0.dev522.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
100
+ CUQIpy-1.2.0.post0.dev522.dist-info/RECORD,,
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-01-30T19:07:50+0100",
11
+ "date": "2025-01-31T10:17:23+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "ead079a625cad175f6a0159d46c90d7cdf80bc0f",
15
- "version": "1.2.0.post0.dev501"
14
+ "full-revisionid": "0303b8f2322f41aeb6304e705b7bbcfde5e2de23",
15
+ "version": "1.2.0.post0.dev522"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -147,7 +147,7 @@ class _RegularizedGaussianGammaPair(_ConjugatePair):
147
147
  if self.target.prior.dim != 1:
148
148
  raise ValueError("RegularizedGaussian-Gamma conjugacy only works with univariate ModifiedHalfNormal prior")
149
149
 
150
- if self.target.likelihood.distribution.preset not in ["nonnegativity"]:
150
+ if self.target.likelihood.distribution.preset["constraint"] not in ["nonnegativity"]:
151
151
  raise ValueError("RegularizedGaussian-Gamma conjugacy only works with implicit regularized Gaussian likelihood with nonnegativity constraints")
152
152
 
153
153
  key_value_pairs = _get_conjugate_parameter(self.target)
@@ -183,7 +183,7 @@ class _RegularizedUnboundedUniformGammaPair(_ConjugatePair):
183
183
  if self.target.prior.dim != 1:
184
184
  raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy only works with univariate Gamma prior")
185
185
 
186
- if self.target.likelihood.distribution.preset not in ["l1", "tv"]:
186
+ if self.target.likelihood.distribution.preset["regularization"] not in ["l1", "tv"]:
187
187
  raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization")
188
188
 
189
189
  key_value_pairs = _get_conjugate_parameter(self.target)
@@ -203,12 +203,7 @@ class _RegularizedUnboundedUniformGammaPair(_ConjugatePair):
203
203
 
204
204
  # Compute likelihood quantities
205
205
  x = self.target.likelihood.data
206
- if self.target.likelihood.distribution.preset == "l1":
207
- m = count_nonzero(x)
208
- elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, Continuous1D):
209
- m = count_constant_components_1D(x)
210
- elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
211
- m = count_constant_components_2D(self.target.likelihood.distribution.geometry.par2fun(x))
206
+ m = _compute_sparsity_level(self.target)
212
207
 
213
208
  reg_op = self.target.likelihood.distribution._regularization_oper
214
209
  reg_strength = self.target.likelihood.distribution(np.array([1])).strength
@@ -224,7 +219,7 @@ class _RegularizedGaussianModifiedHalfNormalPair(_ConjugatePair):
224
219
  if self.target.prior.dim != 1:
225
220
  raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with univariate ModifiedHalfNormal prior")
226
221
 
227
- if self.target.likelihood.distribution.preset not in ["l1", "tv"]:
222
+ if self.target.likelihood.distribution.preset["regularization"] not in ["l1", "tv"]:
228
223
  raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with implicit regularized Gaussian likelihood with l1 or tv regularization")
229
224
 
230
225
  key_value_pairs = _get_conjugate_parameter(self.target)
@@ -254,13 +249,8 @@ class _RegularizedGaussianModifiedHalfNormalPair(_ConjugatePair):
254
249
  x = self.target.likelihood.data
255
250
  mu = self.target.likelihood.distribution.mean
256
251
  L = self.target.likelihood.distribution(np.array([1])).sqrtprec
257
-
258
- if self.target.likelihood.distribution.preset == "l1":
259
- m = count_nonzero(x)
260
- elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, Continuous1D):
261
- m = count_constant_components_1D(x)
262
- elif self.target.likelihood.distribution.preset == "tv" and isinstance(self.target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
263
- m = count_constant_components_2D(self.target.likelihood.distribution.geometry.par2fun(x))
252
+
253
+ m = _compute_sparsity_level(self.target)
264
254
 
265
255
  reg_op = self.target.likelihood.distribution._regularization_oper
266
256
  reg_strength = self.target.likelihood.distribution(np.array([1])).strength
@@ -275,6 +265,26 @@ class _RegularizedGaussianModifiedHalfNormalPair(_ConjugatePair):
275
265
  return ModifiedHalfNormal(conj_alpha, conj_beta, conj_gamma)
276
266
 
277
267
 
268
+ def _compute_sparsity_level(target):
269
+ """Computes the sparsity level in accordance with Section 4 from [2],"""
270
+ x = target.likelihood.data
271
+ if target.likelihood.distribution.preset["constraint"] == "nonnegativity":
272
+ if target.likelihood.distribution.preset["regularization"] == "l1":
273
+ m = count_nonzero(x)
274
+ elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
275
+ m = count_constant_components_1D(x, lower = 0.0)
276
+ elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
277
+ m = count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x), lower = 0.0)
278
+ else: # No constraints, only regularization
279
+ if target.likelihood.distribution.preset["regularization"] == "l1":
280
+ m = count_nonzero(x)
281
+ elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
282
+ m = count_constant_components_1D(x)
283
+ elif target.likelihood.distribution.preset["regularization"] == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
284
+ m = count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x))
285
+ return m
286
+
287
+
278
288
  def _get_conjugate_parameter(target):
279
289
  """Extract the conjugate parameter name (e.g. d), and returns the mutable variable that is defined by the conjugate parameter, e.g. cov and its value e.g. lambda d:1/d"""
280
290
  par_name = target.prior.name
@@ -239,7 +239,7 @@ class RegularizedLinearRTO(LinearRTO):
239
239
  @solver.setter
240
240
  def solver(self, value):
241
241
  if value == "ScipyLinearLSQ":
242
- if (self.target.prior._preset == "nonnegativity" or self.target.prior._preset == "box"):
242
+ if (self.target.prior.preset["constraint"] == "nonnegativity" or self.target.prior.preset["constraint"] == "box"):
243
243
  self._solver = value
244
244
  else:
245
245
  raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")
@@ -68,6 +68,7 @@ class RegularizedGMRF(RegularizedGaussian):
68
68
  # Init from abstract distribution class
69
69
  super(Distribution, self).__init__(**kwargs)
70
70
 
71
+ self._force_list = False
71
72
  self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, args)
72
73
 
73
74
 
@@ -2,9 +2,10 @@ from cuqi.utilities import get_non_default_args
2
2
  from cuqi.distribution import Distribution, Gaussian
3
3
  from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
4
4
  from cuqi.geometry import Continuous1D, Continuous2D, Image2D
5
- from cuqi.operator import FirstOrderFiniteDifference
5
+ from cuqi.operator import FirstOrderFiniteDifference, Operator
6
6
 
7
7
  import numpy as np
8
+ import scipy.sparse as sparse
8
9
  from copy import copy
9
10
 
10
11
 
@@ -48,6 +49,8 @@ class RegularizedGaussian(Distribution):
48
49
  min_z 0.5||x-z||_2^2+scale*g(x).
49
50
  If list of tuples (callable proximal operator of f_i, linear operator L_i):
50
51
  Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
52
+ Each linear operator needs to have the '__matmul__', 'T' and 'shape' attributes;
53
+ this includes numpy.ndarray, scipy.sparse.sparray, scipy.sparse.linalg.LinearOperator and cuqi.operator.Operator.
51
54
  The corresponding regularization takes the form
52
55
  sum_i f_i(L_i x),
53
56
  where the sum ranges from 1 to an arbitrary n.
@@ -88,59 +91,137 @@ class RegularizedGaussian(Distribution):
88
91
  # Init from abstract distribution class
89
92
  super().__init__(**kwargs)
90
93
 
94
+ self._force_list = False
91
95
  self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, optional_regularization_parameters)
92
96
 
93
97
  def _parse_regularization_input_arguments(self, proximal, projector, constraint, regularization, optional_regularization_parameters):
94
98
  """ Parse regularization input arguments with guarding statements and store internal states """
95
99
 
96
- # Check that only one of proximal, projector, constraint or regularization is provided
97
- if (proximal is not None) + (projector is not None) + (constraint is not None) + (regularization is not None) != 1:
98
- raise ValueError("Precisely one of proximal, projector, constraint or regularization needs to be provided.")
100
+ # Guards checking whether the regularization inputs are valid
101
+ if (proximal is not None) + (projector is not None) + max((constraint is not None), (regularization is not None)) == 0:
102
+ raise ValueError("At least some constraint or regularization has to be specified for RegularizedGaussian")
103
+
104
+ if (proximal is not None) + (projector is not None) == 2:
105
+ raise ValueError("Only one of proximal or projector can be used.")
106
+
107
+ if (proximal is not None) + (projector is not None) + max((constraint is not None), (regularization is not None)) > 1:
108
+ raise ValueError("User-defined proximals and projectors cannot be combined with pre-defined constraints and regularization.")
109
+
110
+ # Branch between user-defined and preset
111
+ if (proximal is not None) + (projector is not None) >= 1:
112
+ self._parse_user_specified_input(proximal, projector)
113
+ else:
114
+ # Set constraint and regularization presets for use with Gibbs
115
+ self._preset = {"constraint": None,
116
+ "regularization": None}
99
117
 
118
+ self._parse_preset_constraint_input(constraint, optional_regularization_parameters)
119
+ self._parse_preset_regularization_input(regularization, optional_regularization_parameters)
120
+
121
+ # Merge
122
+ self._merge_predefined_option()
123
+
124
+ def _parse_user_specified_input(self, proximal, projector):
125
+ # Guard for checking partial validy of proximals or projectors
126
+ if proximal is not None:
127
+ if callable(proximal):
128
+ if len(get_non_default_args(proximal)) != 2:
129
+ raise ValueError("Proximal should take 2 arguments.")
130
+ elif isinstance(proximal, list):
131
+ for val in proximal:
132
+ if len(val) != 2:
133
+ raise ValueError("Each value in the proximal list needs to consistent of two elements: a proximal operator and a linear operator.")
134
+ if callable(val[0]):
135
+ if len(get_non_default_args(val[0])) != 2:
136
+ raise ValueError("Proximal should take 2 arguments.")
137
+ else:
138
+ raise ValueError("Proximal operators need to be callable.")
139
+ if not (hasattr(val[1], '__matmul__') and hasattr(val[1], 'T') and hasattr(val[1], 'shape')):
140
+ raise ValueError("Linear operator not supported, must have '__matmul__', 'T' and 'shape' attributes.")
141
+ else:
142
+ raise ValueError("Proximal needs to be callable or a list. See documentation.")
143
+
100
144
  if projector is not None:
101
- if not callable(projector):
102
- raise ValueError("Projector needs to be callable.")
103
- if len(get_non_default_args(projector)) != 1:
104
- raise ValueError("Projector should take 1 argument.")
145
+ if callable(projector):
146
+ if len(get_non_default_args(projector)) != 1:
147
+ raise ValueError("Projector should take 1 argument.")
148
+ else:
149
+ raise ValueError("Projector needs to be callable")
105
150
 
106
- # Preset information, for use in Gibbs
107
- self._preset = None
108
-
151
+ # Set user-defined proximals or projectors
109
152
  if proximal is not None:
110
- # No need to generate the proximal and associated information
111
- self.proximal = proximal
112
- elif projector is not None:
153
+ self._preset = None
154
+ self._proximal = proximal
155
+ return
156
+
157
+ if projector is not None:
158
+ self._preset = None
113
159
  self._proximal = lambda z, gamma: projector(z)
114
- elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
115
- self._proximal = lambda z, gamma: ProjectNonnegative(z)
116
- self._preset = "nonnegativity"
117
- self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
118
- elif (isinstance(constraint, str) and constraint.lower() == "box"):
119
- self._box_lower = optional_regularization_parameters["lower_bound"]
120
- self._box_upper = optional_regularization_parameters["upper_bound"]
121
- self._box_bounds = (np.ones(self.dim)*self._box_lower, np.ones(self.dim)*self._box_upper)
122
- self._proximal = lambda z, _: ProjectBox(z, self._box_lower, self._box_upper)
123
- self._preset = "box" # Not supported in Gibbs
124
- elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
125
- self._strength = optional_regularization_parameters["strength"]
126
- self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
127
- self._preset = "l1"
128
- elif (isinstance(regularization, str) and regularization.lower() in ["tv"]):
160
+ return
161
+
162
+ def _parse_preset_constraint_input(self, constraint, optional_regularization_parameters):
163
+ # Create data for constraints
164
+ self._constraint_prox = None
165
+ self._constraint_oper = None
166
+ if constraint is not None:
167
+ if not isinstance(constraint, str):
168
+ raise ValueError("Constraint needs to be specified as a string.")
169
+
170
+ c_lower = constraint.lower()
171
+ if c_lower == "nonnegativity":
172
+ self._constraint_prox = lambda z, gamma: ProjectNonnegative(z)
173
+ self._box_bounds = (np.ones(self.dim)*0, np.ones(self.dim)*np.inf)
174
+ self._preset["constraint"] = "nonnegativity"
175
+ elif c_lower == "box":
176
+ _box_lower = optional_regularization_parameters["lower_bound"]
177
+ _box_upper = optional_regularization_parameters["upper_bound"]
178
+ self._proximal = lambda z, _: ProjectBox(z, _box_lower, _box_upper)
179
+ self._box_bounds = (np.ones(self.dim)*_box_lower, np.ones(self.dim)*_box_upper)
180
+ self._preset["constraint"] = "box"
181
+ else:
182
+ raise ValueError("Constraint not supported.")
183
+
184
+ def _parse_preset_regularization_input(self, regularization, optional_regularization_parameters):
185
+ # Create data for regularization
186
+ self._regularization_prox = None
187
+ self._regularization_oper = None
188
+ if regularization is not None:
189
+ if not isinstance(regularization, str):
190
+ raise ValueError("Regularization needs to be specified as a string.")
191
+
129
192
  self._strength = optional_regularization_parameters["strength"]
130
- if isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
131
- self._transformation = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
193
+ r_lower = regularization.lower()
194
+ if r_lower == "l1":
195
+ self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
196
+ self._preset["regularization"] = "l1"
197
+ elif r_lower == "tv":
198
+ # Store the transformation to reuse when modifying the strength
199
+ if not isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
200
+ raise ValueError("Geometry not supported for total variation")
201
+ self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
202
+ self._regularization_oper = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
203
+ self._preset["regularization"] = "tv"
132
204
  else:
133
- raise ValueError("Geometry not supported for total variation")
134
-
135
- self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
136
- self._regularization_oper = self._transformation
137
-
138
- self._proximal = [(self._regularization_prox, self._regularization_oper)]
139
- self._preset = "tv"
140
- else:
141
- raise ValueError("Regularization not supported")
142
-
205
+ raise ValueError("Regularization not supported.")
143
206
 
207
+ def _merge_predefined_option(self):
208
+ # Check whether it is a single proximal and hence FISTA could be used in RegularizedLinearRTO
209
+ if ((not self._force_list) and
210
+ ((self._constraint_prox is not None) + (self._regularization_prox is not None) == 1) and
211
+ ((self._constraint_oper is not None) + (self._regularization_oper is not None) == 0)):
212
+ if self._constraint_prox is not None:
213
+ self._proximal = self._constraint_prox
214
+ else:
215
+ self._proximal = self._regularization_prox
216
+ return
217
+
218
+ # Merge regularization choices in list for use in ADMM by RegularizedLinearRTO
219
+ self._proximal = []
220
+ if self._constraint_prox is not None:
221
+ self._proximal += [(self._constraint_prox, self._constraint_oper if self._constraint_oper is not None else sparse.eye(self.geometry.par_dim))]
222
+ if self._regularization_prox is not None:
223
+ self._proximal += [(self._regularization_prox, self._regularization_oper if self._regularization_oper is not None else sparse.eye(self.geometry.par_dim))]
224
+
144
225
  @property
145
226
  def transformation(self):
146
227
  return self._transformation
@@ -151,15 +232,15 @@ class RegularizedGaussian(Distribution):
151
232
 
152
233
  @strength.setter
153
234
  def strength(self, value):
154
- if self._preset not in self.regularization_options():
235
+ if self._preset is None or self._preset["regularization"] is None:
155
236
  raise TypeError("Strength is only used when the regularization is set to l1 or TV.")
156
237
 
157
238
  self._strength = value
158
- if self._preset == "tv":
239
+ if self._preset["regularization"] in ["l1", "tv"]:
159
240
  self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
160
- self._proximal = [(self._regularization_prox, self._regularization_oper)]
161
- elif self._preset == "l1":
162
- self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
241
+
242
+ # Create new list of proximals based on updated regularization
243
+ self._merge_predefined_option()
163
244
 
164
245
  # This is a getter only attribute for the underlying Gaussian
165
246
  # It also ensures that the name of the underlying Gaussian
@@ -266,7 +347,7 @@ class RegularizedGaussian(Distribution):
266
347
 
267
348
  def get_mutable_variables(self):
268
349
  mutable_vars = self.gaussian.get_mutable_variables().copy()
269
- if self.preset in self.regularization_options():
350
+ if self.preset is not None and self.preset['regularization'] in ["l1", "tv"]:
270
351
  mutable_vars += ["strength"]
271
352
  return mutable_vars
272
353
 
@@ -63,4 +63,5 @@ class RegularizedUnboundedUniform(RegularizedGaussian):
63
63
  # Init from abstract distribution class
64
64
  super(Distribution, self).__init__(**kwargs)
65
65
 
66
+ self._force_list = False
66
67
  self._parse_regularization_input_arguments(proximal, projector, constraint, regularization, args)
@@ -28,7 +28,7 @@ class Conjugate: # TODO: Subclass from Sampler once updated
28
28
  if not target.prior.dim == 1:
29
29
  raise ValueError("Conjugate sampler only works with univariate Gamma prior")
30
30
 
31
- if isinstance(target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and target.likelihood.distribution.preset not in ["nonnegativity"]:
31
+ if isinstance(target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and (target.likelihood.distribution.preset["constraint"] not in ["nonnegativity"] or target.likelihood.distribution.preset["regularization"] is not None) :
32
32
  raise ValueError("Conjugate sampler only works implicit regularized Gaussian likelihood with nonnegativity constraints")
33
33
 
34
34
  self.target = target