pygeoinf 1.3.7__tar.gz → 1.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/PKG-INFO +2 -1
  2. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/__init__.py +41 -0
  3. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/gaussian_measure.py +42 -12
  4. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/plot.py +185 -117
  5. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/preconditioners.py +1 -1
  6. pygeoinf-1.3.9/pygeoinf/subsets.py +845 -0
  7. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/subspaces.py +173 -23
  8. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/symmetric_space/sphere.py +1 -1
  9. pygeoinf-1.3.9/pygeoinf/utils.py +15 -0
  10. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pyproject.toml +2 -1
  11. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/LICENSE +0 -0
  12. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/README.md +0 -0
  13. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/auxiliary.py +0 -0
  14. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/backus_gilbert.py +0 -0
  15. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/checks/__init__.py +0 -0
  16. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/checks/hilbert_space.py +0 -0
  17. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/checks/linear_operators.py +0 -0
  18. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/checks/nonlinear_operators.py +0 -0
  19. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/direct_sum.py +0 -0
  20. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/forward_problem.py +0 -0
  21. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/hilbert_space.py +0 -0
  22. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/inversion.py +0 -0
  23. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/linear_bayesian.py +0 -0
  24. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/linear_forms.py +0 -0
  25. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/linear_operators.py +0 -0
  26. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/linear_optimisation.py +0 -0
  27. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/linear_solvers.py +0 -0
  28. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/nonlinear_forms.py +0 -0
  29. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/nonlinear_operators.py +0 -0
  30. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/nonlinear_optimisation.py +0 -0
  31. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/parallel.py +0 -0
  32. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/random_matrix.py +0 -0
  33. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/symmetric_space/__init__.py +0 -0
  34. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/symmetric_space/circle.py +0 -0
  35. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/symmetric_space/sh_tools.py +0 -0
  36. {pygeoinf-1.3.7 → pygeoinf-1.3.9}/pygeoinf/symmetric_space/symmetric_space.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pygeoinf
3
- Version: 1.3.7
3
+ Version: 1.3.9
4
4
  Summary: A package for solving geophysical inference and inverse problems
5
5
  License: BSD-3-Clause
6
6
  License-File: LICENSE
@@ -20,6 +20,7 @@ Requires-Dist: numpy (>=1.26.0)
20
20
  Requires-Dist: pyqt6 (>=6.0.0)
21
21
  Requires-Dist: pyshtools (>=4.0.0) ; extra == "sphere"
22
22
  Requires-Dist: scipy (>=1.16.1)
23
+ Requires-Dist: threadpoolctl (>=3.6.0,<4.0.0)
23
24
  Description-Content-Type: text/markdown
24
25
 
25
26
  # pygeoinf: A Python Library for Geophysical Inference
@@ -104,6 +104,27 @@ from .nonlinear_optimisation import (
104
104
 
105
105
  from .subspaces import OrthogonalProjector, AffineSubspace, LinearSubspace
106
106
 
107
+ from .subsets import (
108
+ Subset,
109
+ EmptySet,
110
+ UniversalSet,
111
+ Complement,
112
+ Intersection,
113
+ Union,
114
+ SublevelSet,
115
+ LevelSet,
116
+ ConvexSubset,
117
+ Ellipsoid,
118
+ NormalisedEllipsoid,
119
+ EllipsoidSurface,
120
+ Ball,
121
+ Sphere,
122
+ )
123
+
124
+ from .plot import plot_1d_distributions, plot_corner_distributions
125
+
126
+ from .utils import configure_threading
127
+
107
128
  __all__ = [
108
129
  # random_matrix
109
130
  "fixed_rank_random_range",
@@ -182,4 +203,24 @@ __all__ = [
182
203
  "OrthogonalProjector",
183
204
  "AffineSubspace",
184
205
  "LinearSubspace",
206
+ # Subsets
207
+ "Subset",
208
+ "EmptySet",
209
+ "UniversalSet",
210
+ "Complement",
211
+ "Intersection",
212
+ "Union",
213
+ "SublevelSet",
214
+ "LevelSet",
215
+ "ConvexSubset",
216
+ "Ellipsoid",
217
+ "NormalisedEllipsoid",
218
+ "EllipsoidSurface",
219
+ "Ball",
220
+ "Sphere",
221
+ # plot
222
+ "plot_1d_distributions",
223
+ "plot_corner_distributions",
224
+ # utils
225
+ "configure_threading",
185
226
  ]
@@ -27,7 +27,7 @@ import numpy as np
27
27
  from scipy.linalg import eigh
28
28
  from scipy.sparse import diags
29
29
  from scipy.stats import multivariate_normal
30
-
30
+ from joblib import Parallel, delayed
31
31
 
32
32
  from .hilbert_space import EuclideanSpace, HilbertModule, Vector
33
33
 
@@ -44,7 +44,6 @@ from .direct_sum import (
44
44
  # This block is only processed by type checkers, not at runtime.
45
45
  if TYPE_CHECKING:
46
46
  from .hilbert_space import HilbertSpace
47
- from .typing import Vector
48
47
 
49
48
 
50
49
  class GaussianMeasure:
@@ -402,24 +401,52 @@ class GaussianMeasure:
402
401
  raise NotImplementedError("A sample method is not set for this measure.")
403
402
  return self._sample()
404
403
 
405
- def samples(self, n: int) -> List[Vector]:
406
- """Returns a list of n random samples from the measure."""
404
+ def samples(
405
+ self, n: int, /, *, parallel: bool = False, n_jobs: int = -1
406
+ ) -> List[Vector]:
407
+ """
408
+ Returns a list of n random samples from the measure.
409
+
410
+ Args:
411
+ n: Number of samples to draw.
412
+ parallel: If True, draws samples in parallel.
413
+ n_jobs: Number of CPU cores to use. -1 means all available.
414
+ """
407
415
  if n < 1:
408
416
  raise ValueError("Number of samples must be a positive integer.")
409
- return [self.sample() for _ in range(n)]
410
417
 
411
- def sample_expectation(self, n: int) -> Vector:
412
- """Estimates the expectation by drawing n samples."""
418
+ if not parallel:
419
+ return [self.sample() for _ in range(n)]
420
+
421
+ return Parallel(n_jobs=n_jobs)(delayed(self.sample)() for _ in range(n))
422
+
423
+ def sample_expectation(
424
+ self, n: int, /, *, parallel: bool = False, n_jobs: int = -1
425
+ ) -> Vector:
426
+ """
427
+ Estimates the expectation by drawing n samples.
428
+
429
+ Args:
430
+ n: Number of samples to draw.
431
+ parallel: If True, draws samples in parallel.
432
+ n_jobs: Number of CPU cores to use. -1 means all available.
433
+ """
413
434
  if n < 1:
414
435
  raise ValueError("Number of samples must be a positive integer.")
415
- return self.domain.sample_expectation(self.samples(n))
436
+ return self.domain.sample_expectation(
437
+ self.samples(n, parallel=parallel, n_jobs=n_jobs)
438
+ )
416
439
 
417
- def sample_pointwise_variance(self, n: int) -> Vector:
440
+ def sample_pointwise_variance(
441
+ self, n: int, /, *, parallel: bool = False, n_jobs: int = -1
442
+ ) -> Vector:
418
443
  """
419
444
  Estimates the pointwise variance by drawing n samples.
420
445
 
421
- This method is only available if the domain supports vector
422
- multiplication.
446
+ Args:
447
+ n: Number of samples to draw.
448
+ parallel: If True, draws samples in parallel.
449
+ n_jobs: Number of CPU cores to use. -1 means all available.
423
450
  """
424
451
  if not isinstance(self.domain, HilbertModule):
425
452
  raise NotImplementedError(
@@ -428,7 +455,10 @@ class GaussianMeasure:
428
455
  if n < 1:
429
456
  raise ValueError("Number of samples must be a positive integer.")
430
457
 
431
- samples = self.samples(n)
458
+ # Step 1: Draw samples (Parallelized)
459
+ samples = self.samples(n, parallel=parallel, n_jobs=n_jobs)
460
+
461
+ # Step 2: Compute variance using vector arithmetic
432
462
  expectation = self.expectation
433
463
  variance = self.domain.zero
434
464
 
@@ -4,18 +4,21 @@ import numpy as np
4
4
  import scipy.stats as stats
5
5
  from typing import Union, List, Optional
6
6
 
7
+
7
8
  def plot_1d_distributions(
8
9
  posterior_measures: Union[object, List[object]],
10
+ /,
11
+ *,
9
12
  prior_measures: Optional[Union[object, List[object]]] = None,
10
13
  true_value: Optional[float] = None,
11
14
  xlabel: str = "Property Value",
12
15
  title: str = "Prior and Posterior Probability Distributions",
13
16
  figsize: tuple = (12, 7),
14
- show_plot: bool = True
17
+ show_plot: bool = True,
15
18
  ):
16
19
  """
17
20
  Plot 1D probability distributions for prior and posterior measures using dual y-axes.
18
-
21
+
19
22
  Args:
20
23
  posterior_measures: Single measure or list of measures for posterior distributions
21
24
  prior_measures: Single measure or list of measures for prior distributions (optional)
@@ -24,26 +27,44 @@ def plot_1d_distributions(
24
27
  title: Title for the plot
25
28
  figsize: Figure size tuple
26
29
  show_plot: Whether to display the plot
27
-
30
+
28
31
  Returns:
29
32
  fig, (ax1, ax2): Figure and axes objects
30
33
  """
31
-
34
+
32
35
  # Convert single measures to lists for uniform handling
33
36
  if not isinstance(posterior_measures, list):
34
37
  posterior_measures = [posterior_measures]
35
-
38
+
36
39
  if prior_measures is not None and not isinstance(prior_measures, list):
37
40
  prior_measures = [prior_measures]
38
-
41
+
39
42
  # Define color sequences
40
- prior_colors = ['green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
41
- posterior_colors = ['blue', 'red', 'darkgreen', 'orange', 'purple', 'brown', 'pink', 'gray']
42
-
43
+ prior_colors = [
44
+ "green",
45
+ "orange",
46
+ "purple",
47
+ "brown",
48
+ "pink",
49
+ "gray",
50
+ "olive",
51
+ "cyan",
52
+ ]
53
+ posterior_colors = [
54
+ "blue",
55
+ "red",
56
+ "darkgreen",
57
+ "orange",
58
+ "purple",
59
+ "brown",
60
+ "pink",
61
+ "gray",
62
+ ]
63
+
43
64
  # Calculate statistics for all distributions
44
65
  posterior_stats = []
45
66
  for measure in posterior_measures:
46
- if hasattr(measure, 'expectation') and hasattr(measure, 'covariance'):
67
+ if hasattr(measure, "expectation") and hasattr(measure, "covariance"):
47
68
  # For pygeoinf measures
48
69
  mean = measure.expectation[0]
49
70
  var = measure.covariance.matrix(dense=True)[0, 0]
@@ -53,11 +74,11 @@ def plot_1d_distributions(
53
74
  mean = measure.mean[0]
54
75
  std = np.sqrt(measure.cov[0, 0])
55
76
  posterior_stats.append((mean, std))
56
-
77
+
57
78
  prior_stats = []
58
79
  if prior_measures is not None:
59
80
  for measure in prior_measures:
60
- if hasattr(measure, 'expectation') and hasattr(measure, 'covariance'):
81
+ if hasattr(measure, "expectation") and hasattr(measure, "covariance"):
61
82
  # For pygeoinf measures
62
83
  mean = measure.expectation[0]
63
84
  var = measure.covariance.matrix(dense=True)[0, 0]
@@ -67,99 +88,106 @@ def plot_1d_distributions(
67
88
  mean = measure.mean[0]
68
89
  std = np.sqrt(measure.cov[0, 0])
69
90
  prior_stats.append((mean, std))
70
-
91
+
71
92
  # Determine plot range to include all distributions
72
93
  all_means = [stat[0] for stat in posterior_stats]
73
94
  all_stds = [stat[1] for stat in posterior_stats]
74
-
95
+
75
96
  if prior_measures is not None:
76
97
  all_means.extend([stat[0] for stat in prior_stats])
77
98
  all_stds.extend([stat[1] for stat in prior_stats])
78
-
99
+
79
100
  if true_value is not None:
80
101
  all_means.append(true_value)
81
102
  all_stds.append(0) # No std for true value
82
-
103
+
83
104
  # Calculate x-axis range (6 sigma coverage)
84
105
  x_min = min([mean - 6 * std for mean, std in zip(all_means, all_stds) if std > 0])
85
106
  x_max = max([mean + 6 * std for mean, std in zip(all_means, all_stds) if std > 0])
86
-
107
+
87
108
  # Add some padding around true value if needed
88
109
  if true_value is not None:
89
110
  range_size = x_max - x_min
90
111
  x_min = min(x_min, true_value - 0.1 * range_size)
91
112
  x_max = max(x_max, true_value + 0.1 * range_size)
92
-
113
+
93
114
  x_axis = np.linspace(x_min, x_max, 1000)
94
-
115
+
95
116
  # Create the plot with two y-axes
96
117
  fig, ax1 = plt.subplots(figsize=figsize)
97
-
118
+
98
119
  # Plot priors on the first axis (left y-axis) if provided
99
120
  if prior_measures is not None:
100
- color1 = prior_colors[0] if len(prior_measures) > 0 else 'green'
121
+ color1 = prior_colors[0] if len(prior_measures) > 0 else "green"
101
122
  ax1.set_xlabel(xlabel)
102
- ax1.set_ylabel('Prior Probability Density', color=color1)
103
-
123
+ ax1.set_ylabel("Prior Probability Density", color=color1)
124
+
104
125
  for i, (measure, (mean, std)) in enumerate(zip(prior_measures, prior_stats)):
105
126
  color = prior_colors[i % len(prior_colors)]
106
-
127
+
107
128
  # Calculate PDF values using scipy.stats
108
129
  pdf_values = stats.norm.pdf(x_axis, loc=mean, scale=std)
109
-
130
+
110
131
  # Determine label
111
132
  if len(prior_measures) == 1:
112
- label = f'Prior PDF (Mean: {mean:.5f})'
133
+ label = f"Prior PDF (Mean: {mean:.5f})"
113
134
  else:
114
- label = f'Prior {i+1} (Mean: {mean:.5f})'
115
-
116
- ax1.plot(x_axis, pdf_values, color=color, lw=2, linestyle=':', label=label)
135
+ label = f"Prior {i+1} (Mean: {mean:.5f})"
136
+
137
+ ax1.plot(x_axis, pdf_values, color=color, lw=2, linestyle=":", label=label)
117
138
  ax1.fill_between(x_axis, pdf_values, color=color, alpha=0.15)
118
-
119
- ax1.tick_params(axis='y', labelcolor=color1)
120
- ax1.grid(True, linestyle='--')
139
+
140
+ ax1.tick_params(axis="y", labelcolor=color1)
141
+ ax1.grid(True, linestyle="--")
121
142
  else:
122
143
  # If no priors, use the left axis for posteriors
123
144
  ax1.set_xlabel(xlabel)
124
- ax1.set_ylabel('Probability Density')
125
- ax1.grid(True, linestyle='--')
126
-
145
+ ax1.set_ylabel("Probability Density")
146
+ ax1.grid(True, linestyle="--")
147
+
127
148
  # Create second y-axis for posteriors (or use first if no priors)
128
149
  if prior_measures is not None:
129
150
  ax2 = ax1.twinx()
130
- color2 = posterior_colors[0] if len(posterior_measures) > 0 else 'blue'
131
- ax2.set_ylabel('Posterior Probability Density', color=color2)
132
- ax2.tick_params(axis='y', labelcolor=color2)
151
+ color2 = posterior_colors[0] if len(posterior_measures) > 0 else "blue"
152
+ ax2.set_ylabel("Posterior Probability Density", color=color2)
153
+ ax2.tick_params(axis="y", labelcolor=color2)
133
154
  ax2.grid(False)
134
155
  plot_ax = ax2
135
156
  else:
136
157
  plot_ax = ax1
137
- color2 = posterior_colors[0] if len(posterior_measures) > 0 else 'blue'
138
-
158
+ color2 = posterior_colors[0] if len(posterior_measures) > 0 else "blue"
159
+
139
160
  # Plot posteriors
140
- for i, (measure, (mean, std)) in enumerate(zip(posterior_measures, posterior_stats)):
161
+ for i, (measure, (mean, std)) in enumerate(
162
+ zip(posterior_measures, posterior_stats)
163
+ ):
141
164
  color = posterior_colors[i % len(posterior_colors)]
142
-
165
+
143
166
  # Calculate PDF values using scipy.stats
144
167
  pdf_values = stats.norm.pdf(x_axis, loc=mean, scale=std)
145
-
168
+
146
169
  # Determine label
147
170
  if len(posterior_measures) == 1:
148
- label = f'Posterior PDF (Mean: {mean:.5f})'
171
+ label = f"Posterior PDF (Mean: {mean:.5f})"
149
172
  else:
150
- label = f'Posterior {i+1} (Mean: {mean:.5f})'
151
-
173
+ label = f"Posterior {i+1} (Mean: {mean:.5f})"
174
+
152
175
  plot_ax.plot(x_axis, pdf_values, color=color, lw=2, label=label)
153
176
  plot_ax.fill_between(x_axis, pdf_values, color=color, alpha=0.2)
154
-
177
+
155
178
  # Plot true value if provided
156
179
  if true_value is not None:
157
- ax1.axvline(true_value, color='black', linestyle='-', lw=2,
158
- label=f'True Value: {true_value:.5f}')
159
-
180
+ ax1.axvline(
181
+ true_value,
182
+ color="black",
183
+ linestyle="-",
184
+ lw=2,
185
+ label=f"True Value: {true_value:.5f}",
186
+ )
187
+
160
188
  # Create combined legend
161
189
  handles1, labels1 = ax1.get_legend_handles_labels()
162
-
190
+
163
191
  if prior_measures is not None:
164
192
  handles2, labels2 = ax2.get_legend_handles_labels()
165
193
  all_handles = handles1 + handles2
@@ -167,14 +195,14 @@ def plot_1d_distributions(
167
195
  else:
168
196
  all_handles = handles1
169
197
  all_labels = [h.get_label() for h in all_handles]
170
-
171
- fig.legend(all_handles, all_labels, loc='upper right', bbox_to_anchor=(0.9, 0.9))
198
+
199
+ fig.legend(all_handles, all_labels, loc="upper right", bbox_to_anchor=(0.9, 0.9))
172
200
  fig.suptitle(title, fontsize=16)
173
201
  fig.tight_layout(rect=[0, 0, 1, 0.96])
174
-
202
+
175
203
  if show_plot:
176
204
  plt.show()
177
-
205
+
178
206
  if prior_measures is not None:
179
207
  return fig, (ax1, ax2)
180
208
  else:
@@ -183,17 +211,21 @@ def plot_1d_distributions(
183
211
 
184
212
  def plot_corner_distributions(
185
213
  posterior_measure: object,
214
+ /,
215
+ *,
186
216
  true_values: Optional[Union[List[float], np.ndarray]] = None,
187
217
  labels: Optional[List[str]] = None,
188
218
  title: str = "Joint Posterior Distribution",
189
219
  figsize: Optional[tuple] = None,
190
220
  show_plot: bool = True,
191
221
  include_sigma_contours: bool = True,
192
- colormap: str = "Blues"
222
+ colormap: str = "Blues",
223
+ parallel: bool = False,
224
+ n_jobs: int = -1,
193
225
  ):
194
226
  """
195
227
  Create a corner plot for multi-dimensional posterior distributions.
196
-
228
+
197
229
  Args:
198
230
  posterior_measure: Multi-dimensional posterior measure (pygeoinf object)
199
231
  true_values: True values for each dimension (optional)
@@ -203,148 +235,184 @@ def plot_corner_distributions(
203
235
  show_plot: Whether to display the plot
204
236
  include_sigma_contours: Whether to include 1-sigma contour lines
205
237
  colormap: Colormap for 2D plots
206
-
238
+ parallel: Compute dense covariance matrix in parallel, default False.
239
+ n_jobs: Number of cores to use in parallel calculations, default -1.
240
+
207
241
  Returns:
208
242
  fig, axes: Figure and axes array
209
243
  """
210
-
244
+
211
245
  # Extract statistics from the measure
212
- if hasattr(posterior_measure, 'expectation') and hasattr(posterior_measure, 'covariance'):
246
+ if hasattr(posterior_measure, "expectation") and hasattr(
247
+ posterior_measure, "covariance"
248
+ ):
213
249
  mean_posterior = posterior_measure.expectation
214
- cov_posterior = posterior_measure.covariance.matrix(dense=True, parallel=True)
250
+ cov_posterior = posterior_measure.covariance.matrix(
251
+ dense=True, parallel=parallel, n_jobs=n_jobs
252
+ )
215
253
  else:
216
- raise ValueError("posterior_measure must have 'expectation' and 'covariance' attributes")
217
-
254
+ raise ValueError(
255
+ "posterior_measure must have 'expectation' and 'covariance' attributes"
256
+ )
257
+
218
258
  n_dims = len(mean_posterior)
219
-
259
+
220
260
  # Set default labels if not provided
221
261
  if labels is None:
222
262
  labels = [f"Dimension {i+1}" for i in range(n_dims)]
223
-
263
+
224
264
  # Set figure size based on dimensions if not provided
225
265
  if figsize is None:
226
266
  figsize = (3 * n_dims, 3 * n_dims)
227
-
267
+
228
268
  # Create subplots
229
269
  fig, axes = plt.subplots(n_dims, n_dims, figsize=figsize)
230
270
  fig.suptitle(title, fontsize=16)
231
-
271
+
232
272
  # Ensure axes is always 2D array
233
273
  if n_dims == 1:
234
274
  axes = np.array([[axes]])
235
275
  elif n_dims == 2:
236
276
  axes = axes.reshape(2, 2)
237
-
277
+
238
278
  # Initialize pcm variable for colorbar
239
279
  pcm = None
240
-
280
+
241
281
  for i in range(n_dims):
242
282
  for j in range(n_dims):
243
283
  ax = axes[i, j]
244
-
284
+
245
285
  if i == j: # Diagonal plots (1D marginal distributions)
246
286
  mu = mean_posterior[i]
247
287
  sigma = np.sqrt(cov_posterior[i, i])
248
-
288
+
249
289
  # Create x-axis range
250
290
  x = np.linspace(mu - 3.75 * sigma, mu + 3.75 * sigma, 200)
251
291
  pdf = stats.norm.pdf(x, mu, sigma)
252
-
292
+
253
293
  # Plot the PDF
254
294
  ax.plot(x, pdf, "darkblue", label="Posterior PDF")
255
295
  ax.fill_between(x, pdf, color="lightblue", alpha=0.6)
256
-
296
+
257
297
  # Add true value if provided
258
298
  if true_values is not None:
259
299
  true_val = true_values[i]
260
- ax.axvline(true_val, color="black", linestyle="-",
261
- label=f"True: {true_val:.2f}")
262
-
300
+ ax.axvline(
301
+ true_val,
302
+ color="black",
303
+ linestyle="-",
304
+ label=f"True: {true_val:.2f}",
305
+ )
306
+
263
307
  ax.set_xlabel(labels[i])
264
308
  ax.set_ylabel("Density" if i == 0 else "")
265
309
  ax.set_yticklabels([])
266
-
310
+
267
311
  elif i > j: # Lower triangle: 2D joint distributions
268
312
  # Extract 2D mean and covariance
269
313
  mean_2d = np.array([mean_posterior[j], mean_posterior[i]])
270
- cov_2d = np.array([
271
- [cov_posterior[j, j], cov_posterior[j, i]],
272
- [cov_posterior[i, j], cov_posterior[i, i]]
273
- ])
274
-
314
+ cov_2d = np.array(
315
+ [
316
+ [cov_posterior[j, j], cov_posterior[j, i]],
317
+ [cov_posterior[i, j], cov_posterior[i, i]],
318
+ ]
319
+ )
320
+
275
321
  # Create 2D grid
276
322
  sigma_j = np.sqrt(cov_posterior[j, j])
277
323
  sigma_i = np.sqrt(cov_posterior[i, i])
278
-
279
- x_range = np.linspace(mean_2d[0] - 3.75 * sigma_j,
280
- mean_2d[0] + 3.75 * sigma_j, 100)
281
- y_range = np.linspace(mean_2d[1] - 3.75 * sigma_i,
282
- mean_2d[1] + 3.75 * sigma_i, 100)
283
-
324
+
325
+ x_range = np.linspace(
326
+ mean_2d[0] - 3.75 * sigma_j, mean_2d[0] + 3.75 * sigma_j, 100
327
+ )
328
+ y_range = np.linspace(
329
+ mean_2d[1] - 3.75 * sigma_i, mean_2d[1] + 3.75 * sigma_i, 100
330
+ )
331
+
284
332
  X, Y = np.meshgrid(x_range, y_range)
285
333
  pos = np.dstack((X, Y))
286
-
334
+
287
335
  # Calculate PDF values
288
336
  rv = stats.multivariate_normal(mean_2d, cov_2d)
289
337
  Z = rv.pdf(pos)
290
-
338
+
291
339
  # Create filled contour plot using pcolormesh like the original
292
340
  pcm = ax.pcolormesh(
293
- X, Y, Z, shading="auto", cmap=colormap,
294
- norm=colors.LogNorm(vmin=Z.min(), vmax=Z.max())
341
+ X,
342
+ Y,
343
+ Z,
344
+ shading="auto",
345
+ cmap=colormap,
346
+ norm=colors.LogNorm(vmin=Z.min(), vmax=Z.max()),
295
347
  )
296
-
348
+
297
349
  # Add contour lines
298
350
  ax.contour(X, Y, Z, colors="black", linewidths=0.5, alpha=0.6)
299
-
351
+
300
352
  # Add 1-sigma contour if requested
301
353
  if include_sigma_contours:
302
354
  # Calculate 1-sigma level (approximately 39% of peak for 2D Gaussian)
303
355
  sigma_level = rv.pdf(mean_2d) * np.exp(-0.5)
304
- ax.contour(X, Y, Z, levels=[sigma_level], colors="red",
305
- linewidths=1, linestyles="--", alpha=0.8)
306
-
356
+ ax.contour(
357
+ X,
358
+ Y,
359
+ Z,
360
+ levels=[sigma_level],
361
+ colors="red",
362
+ linewidths=1,
363
+ linestyles="--",
364
+ alpha=0.8,
365
+ )
366
+
307
367
  # Plot mean point
308
- ax.plot(mean_posterior[j], mean_posterior[i], "r+",
309
- markersize=10, mew=2, label="Posterior Mean")
310
-
368
+ ax.plot(
369
+ mean_posterior[j],
370
+ mean_posterior[i],
371
+ "r+",
372
+ markersize=10,
373
+ mew=2,
374
+ label="Posterior Mean",
375
+ )
376
+
311
377
  # Plot true value if provided
312
378
  if true_values is not None:
313
- ax.plot(true_values[j], true_values[i], "kx",
314
- markersize=10, mew=2, label="True Value")
315
-
379
+ ax.plot(
380
+ true_values[j],
381
+ true_values[i],
382
+ "kx",
383
+ markersize=10,
384
+ mew=2,
385
+ label="True Value",
386
+ )
387
+
316
388
  ax.set_xlabel(labels[j])
317
389
  ax.set_ylabel(labels[i])
318
-
390
+
319
391
  else: # Upper triangle: hide these plots
320
392
  ax.axis("off")
321
-
393
+
322
394
  # Create legend similar to the original
323
395
  handles, labels_leg = axes[0, 0].get_legend_handles_labels()
324
396
  if n_dims > 1:
325
397
  handles2, labels2 = axes[1, 0].get_legend_handles_labels()
326
398
  handles.extend(handles2)
327
399
  labels_leg.extend(labels2)
328
-
400
+
329
401
  # Clean up labels by removing values after colons
330
402
  cleaned_labels = [label.split(":")[0] for label in labels_leg]
331
-
332
- fig.legend(
333
- handles, cleaned_labels,
334
- loc="upper right",
335
- bbox_to_anchor=(0.9, 0.95)
336
- )
337
-
403
+
404
+ fig.legend(handles, cleaned_labels, loc="upper right", bbox_to_anchor=(0.9, 0.95))
405
+
338
406
  # Adjust main plot layout to make room on the right for the colorbar
339
407
  plt.tight_layout(rect=[0, 0, 0.88, 0.96])
340
-
408
+
341
409
  # Add a colorbar if we have 2D plots
342
410
  if n_dims > 1 and pcm is not None:
343
411
  cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.7])
344
412
  cbar = fig.colorbar(pcm, cax=cbar_ax)
345
413
  cbar.set_label("Probability Density", size=12)
346
-
414
+
347
415
  if show_plot:
348
416
  plt.show()
349
-
350
- return fig, axes
417
+
418
+ return fig, axes