CUQIpy 1.0.0.post0.dev127__tar.gz → 1.0.0.post0.dev145__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (108) hide show
  1. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/PKG-INFO +1 -1
  3. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/_version.py +3 -3
  4. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_cwmh.py +2 -22
  5. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_langevin_algorithm.py +12 -41
  6. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_mh.py +4 -12
  7. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_pcn.py +5 -15
  8. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_rto.py +2 -7
  9. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/_sampler.py +109 -24
  10. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/CUQIpy.egg-info/SOURCES.txt +0 -0
  11. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/CUQIpy.egg-info/dependency_links.txt +0 -0
  12. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/CUQIpy.egg-info/requires.txt +0 -0
  13. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/CUQIpy.egg-info/top_level.txt +0 -0
  14. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/LICENSE +0 -0
  15. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/README.md +0 -0
  16. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/__init__.py +0 -0
  17. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/_messages.py +0 -0
  18. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/array/__init__.py +0 -0
  19. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/array/_array.py +0 -0
  20. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/config.py +0 -0
  21. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/__init__.py +0 -0
  22. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/_data.py +0 -0
  23. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/astronaut.npz +0 -0
  24. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/camera.npz +0 -0
  25. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/cat.npz +0 -0
  26. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/cookie.png +0 -0
  27. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/data/satellite.mat +0 -0
  28. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/density/__init__.py +0 -0
  29. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/density/_density.py +0 -0
  30. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/diagnostics.py +0 -0
  31. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/__init__.py +0 -0
  32. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_beta.py +0 -0
  33. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_cauchy.py +0 -0
  34. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_cmrf.py +0 -0
  35. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_custom.py +0 -0
  36. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_distribution.py +0 -0
  37. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_gamma.py +0 -0
  38. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_gaussian.py +0 -0
  39. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_gmrf.py +0 -0
  40. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_inverse_gamma.py +0 -0
  41. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_joint_distribution.py +0 -0
  42. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_laplace.py +0 -0
  43. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_lmrf.py +0 -0
  44. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_lognormal.py +0 -0
  45. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_normal.py +0 -0
  46. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_posterior.py +0 -0
  47. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/distribution/_uniform.py +0 -0
  48. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/__init__.py +0 -0
  49. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/experimental/mcmc/__init__.py +0 -0
  50. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/geometry/__init__.py +0 -0
  51. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/geometry/_geometry.py +0 -0
  52. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/implicitprior/__init__.py +0 -0
  53. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  54. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  55. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/likelihood/__init__.py +0 -0
  56. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/likelihood/_likelihood.py +0 -0
  57. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/model/__init__.py +0 -0
  58. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/model/_model.py +0 -0
  59. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/operator/__init__.py +0 -0
  60. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/operator/_operator.py +0 -0
  61. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/pde/__init__.py +0 -0
  62. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/pde/_pde.py +0 -0
  63. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/problem/__init__.py +0 -0
  64. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/problem/_problem.py +0 -0
  65. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/__init__.py +0 -0
  66. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_conjugate.py +0 -0
  67. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_conjugate_approx.py +0 -0
  68. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_cwmh.py +0 -0
  69. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_gibbs.py +0 -0
  70. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_hmc.py +0 -0
  71. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_langevin_algorithm.py +0 -0
  72. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_laplace_approximation.py +0 -0
  73. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_mh.py +0 -0
  74. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_pcn.py +0 -0
  75. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_rto.py +0 -0
  76. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/sampler/_sampler.py +0 -0
  77. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/samples/__init__.py +0 -0
  78. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/samples/_samples.py +0 -0
  79. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/solver/__init__.py +0 -0
  80. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/solver/_solver.py +0 -0
  81. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/testproblem/__init__.py +0 -0
  82. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/testproblem/_testproblem.py +0 -0
  83. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/utilities/__init__.py +0 -0
  84. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/utilities/_get_python_variable_name.py +0 -0
  85. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/cuqi/utilities/_utilities.py +0 -0
  86. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/pyproject.toml +0 -0
  87. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/requirements.txt +0 -0
  88. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/setup.cfg +0 -0
  89. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/setup.py +0 -0
  90. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_MRFs.py +0 -0
  91. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_abstract_distribution_density.py +0 -0
  92. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_bayesian_inversion.py +0 -0
  93. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_density.py +0 -0
  94. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_distribution.py +0 -0
  95. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_distributions_shape.py +0 -0
  96. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_geometry.py +0 -0
  97. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_implicit_priors.py +0 -0
  98. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_joint_distribution.py +0 -0
  99. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_likelihood.py +0 -0
  100. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_model.py +0 -0
  101. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_pde.py +0 -0
  102. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_posterior.py +0 -0
  103. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_problem.py +0 -0
  104. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_sampler.py +0 -0
  105. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_samples.py +0 -0
  106. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_solver.py +0 -0
  107. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_testproblem.py +0 -0
  108. {CUQIpy-1.0.0.post0.dev127 → cuqipy-1.0.0.post0.dev145}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev127
3
+ Version: 1.0.0.post0.dev145
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev127
3
+ Version: 1.0.0.post0.dev145
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-04-05T11:26:24+0200",
11
+ "date": "2024-04-15T12:29:37+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "476cc08200034336b46144fac4c819f8298fa587",
15
- "version": "1.0.0.post0.dev127"
14
+ "full-revisionid": "744dccfaf90c7e50a11b7c5a4a59d5438c37525d",
15
+ "version": "1.0.0.post0.dev145"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -144,7 +144,7 @@ class CWMHNew(ProposalBasedSamplerNew):
144
144
 
145
145
  # Propose a sample x_all_components from the proposal distribution
146
146
  # for all the components
147
- target_eval_t = self.current_target
147
+ target_eval_t = self.current_target_logd
148
148
  if isinstance(self.proposal,cuqi.distribution.Distribution):
149
149
  x_all_components = self.proposal(
150
150
  location= self.current_point, scale=self.scale).sample()
@@ -175,7 +175,7 @@ class CWMHNew(ProposalBasedSamplerNew):
175
175
 
176
176
  x_star = x_t.copy()
177
177
 
178
- self.current_target = target_eval_t
178
+ self.current_target_logd = target_eval_t
179
179
  self.current_point = x_t
180
180
 
181
181
  return acc
@@ -199,23 +199,3 @@ class CWMHNew(ProposalBasedSamplerNew):
199
199
  # Update the scale parameter
200
200
  self.scale = np.minimum(scale_temp, np.ones(self.dim))
201
201
  self._scale_temp = scale_temp
202
-
203
- def get_state(self):
204
- current_point = self.current_point
205
- if isinstance(current_point, CUQIarray):
206
- current_point = current_point.to_numpy()
207
-
208
- return {'sampler_type': 'CWMH',
209
- 'current_point': current_point,
210
- 'current_target': self.current_target,
211
- 'scale': self.scale}
212
-
213
- def set_state(self, state):
214
- current_point = state['current_point']
215
- if not isinstance(current_point, CUQIarray):
216
- current_point = CUQIarray(current_point,
217
- geometry=self.target.geometry)
218
-
219
- self.current_point = current_point
220
- self.current_target = state['current_target']
221
- self.scale = state['scale']
@@ -60,14 +60,17 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
60
60
  A Deblur example can be found in demos/demo27_ULA.py
61
61
  # TODO: update demo once sampler merged
62
62
  """
63
+
64
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale', 'current_target_grad'})
65
+
63
66
  def __init__(self, target, scale=1.0, **kwargs):
64
67
 
65
68
  super().__init__(target, **kwargs)
66
69
 
67
70
  self.scale = scale
68
71
  self.current_point = self.initial_point
69
- self.current_target_eval = self.target.logd(self.current_point)
70
- self.current_target_grad_eval = self.target.gradient(self.current_point)
72
+ self.current_target_logd = self.target.logd(self.current_point)
73
+ self.current_target_grad = self.target.gradient(self.current_point)
71
74
  self._acc = [1] # TODO. Check if we need this
72
75
 
73
76
  def validate_target(self):
@@ -99,15 +102,15 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
99
102
  1 (accepted)
100
103
  """
101
104
  self.current_point = x_star
102
- self.current_target_eval = target_eval_star
103
- self.current_target_grad_eval = target_grad_star
105
+ self.current_target_logd = target_eval_star
106
+ self.current_target_grad = target_grad_star
104
107
  acc = 1
105
108
  return acc
106
109
 
107
110
  def step(self):
108
111
  # propose state
109
112
  xi = cuqi.distribution.Normal(mean=np.zeros(self.dim), std=np.sqrt(self.scale)).sample()
110
- x_star = self.current_point + 0.5*self.scale*self.current_target_grad_eval + xi
113
+ x_star = self.current_point + 0.5*self.scale*self.current_target_grad + xi
111
114
 
112
115
  # evaluate target
113
116
  target_eval_star, target_grad_star = self.target.logd(x_star), self.target.gradient(x_star)
@@ -120,26 +123,6 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
120
123
  def tune(self, skip_len, update_count):
121
124
  pass
122
125
 
123
- def get_state(self):
124
- if isinstance(self.current_point, CUQIarray):
125
- self.current_point = self.current_point.to_numpy()
126
- if isinstance(self.current_target_eval, CUQIarray):
127
- self.current_target_eval = self.current_target_eval.to_numpy()
128
- if isinstance(self.current_target_grad_eval, CUQIarray):
129
- self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
130
- return {'sampler_type': 'ULA', 'current_point': self.current_point, \
131
- 'current_target_eval': self.current_target_eval, \
132
- 'current_target_grad_eval': self.current_target_grad_eval, \
133
- 'scale': self.scale}
134
-
135
- def set_state(self, state):
136
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
137
- self.current_point = temp
138
- temp = CUQIarray(state['current_target_eval'] , geometry=self.target.geometry)
139
- self.current_target_eval = temp
140
- temp = CUQIarray(state['current_target_grad_eval'] , geometry=self.target.geometry)
141
- self.current_target_grad_eval = temp
142
- self.scale = state['scale']
143
126
 
144
127
  class MALANew(ULANew): # Refactor to Proposal-based sampler?
145
128
  """ Metropolis-adjusted Langevin algorithm (MALA) (Roberts and Tweedie, 1996)
@@ -219,9 +202,9 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
219
202
  scaler
220
203
  1 if accepted, 0 otherwise
221
204
  """
222
- log_target_ratio = target_eval_star - self.current_target_eval
205
+ log_target_ratio = target_eval_star - self.current_target_logd
223
206
  log_prop_ratio = self._log_proposal(self.current_point, x_star, target_grad_star) \
224
- - self._log_proposal(x_star, self.current_point, self.current_target_grad_eval)
207
+ - self._log_proposal(x_star, self.current_point, self.current_target_grad)
225
208
  log_alpha = min(0, log_target_ratio + log_prop_ratio)
226
209
 
227
210
  # accept/reject with Metropolis
@@ -229,8 +212,8 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
229
212
  log_u = np.log(np.random.rand())
230
213
  if (log_u <= log_alpha) and (np.isnan(target_eval_star) == False):
231
214
  self.current_point = x_star
232
- self.current_target_eval = target_eval_star
233
- self.current_target_grad_eval = target_grad_star
215
+ self.current_target_logd = target_eval_star
216
+ self.current_target_grad = target_grad_star
234
217
  acc = 1
235
218
  return acc
236
219
 
@@ -241,15 +224,3 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
241
224
  mu = theta_k + ((self.scale)/2)*g_logpi_k
242
225
  misfit = theta_star - mu
243
226
  return -0.5*((1/(self.scale))*(misfit.T @ misfit))
244
-
245
- def get_state(self):
246
- if isinstance(self.current_point, CUQIarray):
247
- self.current_point = self.current_point.to_numpy()
248
- if isinstance(self.current_target_eval, CUQIarray):
249
- self.current_target_eval = self.current_target_eval.to_numpy()
250
- if isinstance(self.current_target_grad_eval, CUQIarray):
251
- self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
252
- return {'sampler_type': 'MALA', 'current_point': self.current_point, \
253
- 'current_target_eval': self.current_target_eval, \
254
- 'current_target_grad_eval': self.current_target_grad_eval, \
255
- 'scale': self.scale}
@@ -23,6 +23,8 @@ class MHNew(ProposalBasedSamplerNew):
23
23
 
24
24
  """
25
25
 
26
+ _STATE_KEYS = ProposalBasedSamplerNew._STATE_KEYS.union({'scale', '_scale_temp'})
27
+
26
28
  def __init__(self, target, proposal=None, scale=1, **kwargs):
27
29
  super().__init__(target, proposal=proposal, scale=scale, **kwargs)
28
30
  # Due to a bug? in old MH, we must keep track of this extra variable to match behavior.
@@ -54,7 +56,7 @@ class MHNew(ProposalBasedSamplerNew):
54
56
  target_eval_star = self.target.logd(x_star)
55
57
 
56
58
  # ratio and acceptance probability
57
- ratio = target_eval_star - self.current_target # proposal is symmetric
59
+ ratio = target_eval_star - self.current_target_logd # proposal is symmetric
58
60
  alpha = min(0, ratio)
59
61
 
60
62
  # accept/reject
@@ -62,7 +64,7 @@ class MHNew(ProposalBasedSamplerNew):
62
64
  acc = 0
63
65
  if (u_theta <= alpha):
64
66
  self.current_point = x_star
65
- self.current_target = target_eval_star
67
+ self.current_target_logd = target_eval_star
66
68
  acc = 1
67
69
 
68
70
  return acc
@@ -79,13 +81,3 @@ class MHNew(ProposalBasedSamplerNew):
79
81
 
80
82
  # update parameters
81
83
  self.scale = min(self._scale_temp, 1)
82
-
83
- def get_state(self):
84
- return {'sampler_type': 'MH', 'current_point': self.current_point.to_numpy(), 'current_target': self.current_target.to_numpy(), 'scale': self.scale}
85
-
86
- def set_state(self, state):
87
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
88
- self.current_point = temp
89
- temp = CUQIarray(state['current_target'] , geometry=self.target.geometry)
90
- self.current_target = temp
91
- self.scale = state['scale']
@@ -5,13 +5,15 @@ from cuqi.array import CUQIarray
5
5
 
6
6
  class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
7
7
 
8
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'scale', 'current_likelihood_logd'})
9
+
8
10
  def __init__(self, target, scale=1.0, **kwargs):
9
11
 
10
12
  super().__init__(target, **kwargs)
11
13
 
12
14
  self.scale = scale
13
15
  self.current_point = self.initial_point
14
- self.current_loglike_eval = self._loglikelihood(self.current_point)
16
+ self.current_likelihood_logd = self._loglikelihood(self.current_point)
15
17
 
16
18
  self._acc = [1] # TODO. Check if we need this
17
19
 
@@ -33,7 +35,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
33
35
  loglike_eval_star = self._loglikelihood(x_star)
34
36
 
35
37
  # ratio and acceptance probability
36
- ratio = loglike_eval_star - self.current_loglike_eval # proposal is symmetric
38
+ ratio = loglike_eval_star - self.current_likelihood_logd # proposal is symmetric
37
39
  alpha = min(0, ratio)
38
40
 
39
41
  # accept/reject
@@ -41,7 +43,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
41
43
  u_theta = np.log(np.random.rand())
42
44
  if (u_theta <= alpha):
43
45
  self.current_point = x_star
44
- self.current_loglike_eval = loglike_eval_star
46
+ self.current_likelihood_logd = loglike_eval_star
45
47
  acc = 1
46
48
 
47
49
  return acc
@@ -87,15 +89,3 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
87
89
 
88
90
  def tune(self, skip_len, update_count):
89
91
  pass
90
-
91
- def get_state(self):
92
- return {'sampler_type': 'PCN', 'current_point': self.current_point.to_numpy(), \
93
- 'current_loglike_eval': self.current_loglike_eval.to_numpy(), \
94
- 'scale': self.scale}
95
-
96
- def set_state(self, state):
97
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
98
- self.current_point = temp
99
- temp = CUQIarray(state['current_loglike_eval'] , geometry=self.target.geometry)
100
- self.current_loglike_eval = temp
101
- self.scale = state['scale']
@@ -50,6 +50,7 @@ class LinearRTONew(SamplerNew):
50
50
 
51
51
  if initial_point is None: #TODO: Replace later with a getter
52
52
  self.initial_point = np.zeros(self.dim)
53
+ self._samples = [self.initial_point]
53
54
 
54
55
  self.current_point = self.initial_point
55
56
  self._acc = [1] # TODO. Check if we need this
@@ -188,12 +189,6 @@ class LinearRTONew(SamplerNew):
188
189
  if not hasattr(self.prior, "sqrtprecTimesMean"):
189
190
  raise TypeError("Prior must contain a sqrtprecTimesMean attribute")
190
191
 
191
- def get_state(self): #TODO: LinearRTO only need initial_point for reproducibility?
192
- return {'sampler_type': 'LinearRTO'}
193
-
194
- def set_state(self, state): #TODO: LinearRTO only need initial_point for reproducibility?
195
- pass
196
-
197
192
  class RegularizedLinearRTONew(LinearRTONew):
198
193
  """
199
194
  Regularized Linear RTO (Randomize-Then-Optimize) sampler.
@@ -272,4 +267,4 @@ class RegularizedLinearRTONew(LinearRTONew):
272
267
  maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
273
268
  self.current_point, _ = sim.solve()
274
269
  acc = 1
275
- return acc
270
+ return acc
@@ -21,6 +21,11 @@ class SamplerNew(ABC):
21
21
  Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
22
22
 
23
23
  """
24
+ _STATE_KEYS = {'current_point'}
25
+ """ Set of keys for the state dictionary. """
26
+
27
+ _HISTORY_KEYS = {'_samples', '_acc'}
28
+ """ Set of keys for the history dictionary. """
24
29
 
25
30
  def __init__(self, target: cuqi.density.Density, initial_point=None, callback=None):
26
31
  """ Initializer for abstract base class for all samplers.
@@ -45,8 +50,10 @@ class SamplerNew(ABC):
45
50
  # Choose initial point if not given
46
51
  if initial_point is None:
47
52
  initial_point = np.ones(self.dim)
53
+
54
+ self.initial_point = initial_point
48
55
 
49
- self._samples = [initial_point]
56
+ self._samples = [initial_point] # Remove. See #324.
50
57
 
51
58
  # ------------ Abstract methods to be implemented by subclasses ------------
52
59
 
@@ -65,29 +72,7 @@ class SamplerNew(ABC):
65
72
  """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
66
73
  pass
67
74
 
68
- @abstractmethod
69
- def get_state(self):
70
- """ Return the state of the sampler. """
71
- pass
72
-
73
- @abstractmethod
74
- def set_state(self, state):
75
- """ Set the state of the sampler. """
76
- pass
77
-
78
-
79
75
  # ------------ Public attributes ------------
80
-
81
- @property
82
- def initial_point(self):
83
- """ Return the initial point of the sampler. This is always the first sample. """
84
- return self._samples[0]
85
-
86
- @initial_point.setter
87
- def initial_point(self, value):
88
- """ Set the initial point of the sampler. """
89
- self._samples[0] = value
90
-
91
76
  @property
92
77
  def dim(self):
93
78
  """ Dimension of the target density. """
@@ -109,6 +94,15 @@ class SamplerNew(ABC):
109
94
  self._target = value
110
95
  self.validate_target()
111
96
 
97
+ @property
98
+ def current_point(self):
99
+ """ The current point of the sampler. """
100
+ return self._current_point
101
+
102
+ @current_point.setter
103
+ def current_point(self, value):
104
+ """ Set the current point of the sampler. """
105
+ self._current_point = value
112
106
 
113
107
  # ------------ Public methods ------------
114
108
 
@@ -209,6 +203,94 @@ class SamplerNew(ABC):
209
203
  self._call_callback(self.current_point, len(self._samples)-1)
210
204
 
211
205
  return self
206
+
207
+ def get_state(self) -> dict:
208
+ """ Return the state of the sampler.
209
+
210
+ The state is used when checkpointing the sampler.
211
+
212
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
213
+ The 'metadata' key contains information about the sampler type.
214
+ The 'state' key contains the state of the sampler.
215
+
216
+ For example, the state of a "MH" sampler could be:
217
+
218
+ state = {
219
+ 'metadata': {
220
+ 'sampler_type': 'MH'
221
+ },
222
+ 'state': {
223
+ 'current_point': np.array([...]),
224
+ 'current_target_logd': -123.45,
225
+ 'scale': 1.0,
226
+ ...
227
+ }
228
+ }
229
+ """
230
+ state = {
231
+ 'metadata': {
232
+ 'sampler_type': self.__class__.__name__
233
+ },
234
+ 'state': {
235
+ key: getattr(self, key) for key in self._STATE_KEYS
236
+ }
237
+ }
238
+ return state
239
+
240
+ def set_state(self, state: dict):
241
+ """ Set the state of the sampler.
242
+
243
+ The state is used when loading the sampler from a checkpoint.
244
+
245
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
246
+
247
+ For example, the state of a "MH" sampler could be:
248
+
249
+ state = {
250
+ 'metadata': {
251
+ 'sampler_type': 'MH'
252
+ },
253
+ 'state': {
254
+ 'current_point': np.array([...]),
255
+ 'current_target_logd': -123.45,
256
+ 'scale': 1.0,
257
+ ...
258
+ }
259
+ }
260
+ """
261
+ if state['metadata']['sampler_type'] != self.__class__.__name__:
262
+ raise ValueError(f"Sampler type in state dictionary ({state['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
263
+
264
+ for key, value in state['state'].items():
265
+ if key in self._STATE_KEYS:
266
+ setattr(self, key, value)
267
+ else:
268
+ raise ValueError(f"Key {key} not recognized in state dictionary of sampler {self.__class__.__name__}.")
269
+
270
+ def get_history(self) -> dict:
271
+ """ Return the history of the sampler. """
272
+ history = {
273
+ 'metadata': {
274
+ 'sampler_type': self.__class__.__name__
275
+ },
276
+ 'history': {
277
+ key: getattr(self, key) for key in self._HISTORY_KEYS
278
+ }
279
+ }
280
+ return history
281
+
282
+ def set_history(self, history: dict):
283
+ """ Set the history of the sampler. """
284
+ if history['metadata']['sampler_type'] != self.__class__.__name__:
285
+ raise ValueError(f"Sampler type in history dictionary ({history['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
286
+
287
+ for key, value in history['history'].items():
288
+ if key in self._HISTORY_KEYS:
289
+ setattr(self, key, value)
290
+ else:
291
+ raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
292
+
293
+ # ------------ Private methods ------------
212
294
 
213
295
  def _call_callback(self, sample, sample_index):
214
296
  """ Calls the callback function. Assumes input is sample and sample index"""
@@ -218,6 +300,9 @@ class SamplerNew(ABC):
218
300
 
219
301
  class ProposalBasedSamplerNew(SamplerNew, ABC):
220
302
  """ Abstract base class for samplers that use a proposal distribution. """
303
+
304
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale'})
305
+
221
306
  def __init__(self, target, proposal=None, scale=1, **kwargs):
222
307
  """ Initializer for proposal based samplers.
223
308
 
@@ -240,7 +325,7 @@ class ProposalBasedSamplerNew(SamplerNew, ABC):
240
325
  super().__init__(target, **kwargs)
241
326
 
242
327
  self.current_point = self.initial_point
243
- self.current_target = self.target.logd(self.current_point)
328
+ self.current_target_logd = self.target.logd(self.current_point)
244
329
  self.proposal = proposal
245
330
  self.scale = scale
246
331