pygeoinf 1.3.6__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pygeoinf/plot.py CHANGED
@@ -4,18 +4,21 @@ import numpy as np
4
4
  import scipy.stats as stats
5
5
  from typing import Union, List, Optional
6
6
 
7
+
7
8
  def plot_1d_distributions(
8
9
  posterior_measures: Union[object, List[object]],
10
+ /,
11
+ *,
9
12
  prior_measures: Optional[Union[object, List[object]]] = None,
10
13
  true_value: Optional[float] = None,
11
14
  xlabel: str = "Property Value",
12
15
  title: str = "Prior and Posterior Probability Distributions",
13
16
  figsize: tuple = (12, 7),
14
- show_plot: bool = True
17
+ show_plot: bool = True,
15
18
  ):
16
19
  """
17
20
  Plot 1D probability distributions for prior and posterior measures using dual y-axes.
18
-
21
+
19
22
  Args:
20
23
  posterior_measures: Single measure or list of measures for posterior distributions
21
24
  prior_measures: Single measure or list of measures for prior distributions (optional)
@@ -24,26 +27,44 @@ def plot_1d_distributions(
24
27
  title: Title for the plot
25
28
  figsize: Figure size tuple
26
29
  show_plot: Whether to display the plot
27
-
30
+
28
31
  Returns:
29
32
  fig, (ax1, ax2): Figure and axes objects
30
33
  """
31
-
34
+
32
35
  # Convert single measures to lists for uniform handling
33
36
  if not isinstance(posterior_measures, list):
34
37
  posterior_measures = [posterior_measures]
35
-
38
+
36
39
  if prior_measures is not None and not isinstance(prior_measures, list):
37
40
  prior_measures = [prior_measures]
38
-
41
+
39
42
  # Define color sequences
40
- prior_colors = ['green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']
41
- posterior_colors = ['blue', 'red', 'darkgreen', 'orange', 'purple', 'brown', 'pink', 'gray']
42
-
43
+ prior_colors = [
44
+ "green",
45
+ "orange",
46
+ "purple",
47
+ "brown",
48
+ "pink",
49
+ "gray",
50
+ "olive",
51
+ "cyan",
52
+ ]
53
+ posterior_colors = [
54
+ "blue",
55
+ "red",
56
+ "darkgreen",
57
+ "orange",
58
+ "purple",
59
+ "brown",
60
+ "pink",
61
+ "gray",
62
+ ]
63
+
43
64
  # Calculate statistics for all distributions
44
65
  posterior_stats = []
45
66
  for measure in posterior_measures:
46
- if hasattr(measure, 'expectation') and hasattr(measure, 'covariance'):
67
+ if hasattr(measure, "expectation") and hasattr(measure, "covariance"):
47
68
  # For pygeoinf measures
48
69
  mean = measure.expectation[0]
49
70
  var = measure.covariance.matrix(dense=True)[0, 0]
@@ -53,11 +74,11 @@ def plot_1d_distributions(
53
74
  mean = measure.mean[0]
54
75
  std = np.sqrt(measure.cov[0, 0])
55
76
  posterior_stats.append((mean, std))
56
-
77
+
57
78
  prior_stats = []
58
79
  if prior_measures is not None:
59
80
  for measure in prior_measures:
60
- if hasattr(measure, 'expectation') and hasattr(measure, 'covariance'):
81
+ if hasattr(measure, "expectation") and hasattr(measure, "covariance"):
61
82
  # For pygeoinf measures
62
83
  mean = measure.expectation[0]
63
84
  var = measure.covariance.matrix(dense=True)[0, 0]
@@ -67,99 +88,106 @@ def plot_1d_distributions(
67
88
  mean = measure.mean[0]
68
89
  std = np.sqrt(measure.cov[0, 0])
69
90
  prior_stats.append((mean, std))
70
-
91
+
71
92
  # Determine plot range to include all distributions
72
93
  all_means = [stat[0] for stat in posterior_stats]
73
94
  all_stds = [stat[1] for stat in posterior_stats]
74
-
95
+
75
96
  if prior_measures is not None:
76
97
  all_means.extend([stat[0] for stat in prior_stats])
77
98
  all_stds.extend([stat[1] for stat in prior_stats])
78
-
99
+
79
100
  if true_value is not None:
80
101
  all_means.append(true_value)
81
102
  all_stds.append(0) # No std for true value
82
-
103
+
83
104
  # Calculate x-axis range (6 sigma coverage)
84
105
  x_min = min([mean - 6 * std for mean, std in zip(all_means, all_stds) if std > 0])
85
106
  x_max = max([mean + 6 * std for mean, std in zip(all_means, all_stds) if std > 0])
86
-
107
+
87
108
  # Add some padding around true value if needed
88
109
  if true_value is not None:
89
110
  range_size = x_max - x_min
90
111
  x_min = min(x_min, true_value - 0.1 * range_size)
91
112
  x_max = max(x_max, true_value + 0.1 * range_size)
92
-
113
+
93
114
  x_axis = np.linspace(x_min, x_max, 1000)
94
-
115
+
95
116
  # Create the plot with two y-axes
96
117
  fig, ax1 = plt.subplots(figsize=figsize)
97
-
118
+
98
119
  # Plot priors on the first axis (left y-axis) if provided
99
120
  if prior_measures is not None:
100
- color1 = prior_colors[0] if len(prior_measures) > 0 else 'green'
121
+ color1 = prior_colors[0] if len(prior_measures) > 0 else "green"
101
122
  ax1.set_xlabel(xlabel)
102
- ax1.set_ylabel('Prior Probability Density', color=color1)
103
-
123
+ ax1.set_ylabel("Prior Probability Density", color=color1)
124
+
104
125
  for i, (measure, (mean, std)) in enumerate(zip(prior_measures, prior_stats)):
105
126
  color = prior_colors[i % len(prior_colors)]
106
-
127
+
107
128
  # Calculate PDF values using scipy.stats
108
129
  pdf_values = stats.norm.pdf(x_axis, loc=mean, scale=std)
109
-
130
+
110
131
  # Determine label
111
132
  if len(prior_measures) == 1:
112
- label = f'Prior PDF (Mean: {mean:.5f})'
133
+ label = f"Prior PDF (Mean: {mean:.5f})"
113
134
  else:
114
- label = f'Prior {i+1} (Mean: {mean:.5f})'
115
-
116
- ax1.plot(x_axis, pdf_values, color=color, lw=2, linestyle=':', label=label)
135
+ label = f"Prior {i+1} (Mean: {mean:.5f})"
136
+
137
+ ax1.plot(x_axis, pdf_values, color=color, lw=2, linestyle=":", label=label)
117
138
  ax1.fill_between(x_axis, pdf_values, color=color, alpha=0.15)
118
-
119
- ax1.tick_params(axis='y', labelcolor=color1)
120
- ax1.grid(True, linestyle='--')
139
+
140
+ ax1.tick_params(axis="y", labelcolor=color1)
141
+ ax1.grid(True, linestyle="--")
121
142
  else:
122
143
  # If no priors, use the left axis for posteriors
123
144
  ax1.set_xlabel(xlabel)
124
- ax1.set_ylabel('Probability Density')
125
- ax1.grid(True, linestyle='--')
126
-
145
+ ax1.set_ylabel("Probability Density")
146
+ ax1.grid(True, linestyle="--")
147
+
127
148
  # Create second y-axis for posteriors (or use first if no priors)
128
149
  if prior_measures is not None:
129
150
  ax2 = ax1.twinx()
130
- color2 = posterior_colors[0] if len(posterior_measures) > 0 else 'blue'
131
- ax2.set_ylabel('Posterior Probability Density', color=color2)
132
- ax2.tick_params(axis='y', labelcolor=color2)
151
+ color2 = posterior_colors[0] if len(posterior_measures) > 0 else "blue"
152
+ ax2.set_ylabel("Posterior Probability Density", color=color2)
153
+ ax2.tick_params(axis="y", labelcolor=color2)
133
154
  ax2.grid(False)
134
155
  plot_ax = ax2
135
156
  else:
136
157
  plot_ax = ax1
137
- color2 = posterior_colors[0] if len(posterior_measures) > 0 else 'blue'
138
-
158
+ color2 = posterior_colors[0] if len(posterior_measures) > 0 else "blue"
159
+
139
160
  # Plot posteriors
140
- for i, (measure, (mean, std)) in enumerate(zip(posterior_measures, posterior_stats)):
161
+ for i, (measure, (mean, std)) in enumerate(
162
+ zip(posterior_measures, posterior_stats)
163
+ ):
141
164
  color = posterior_colors[i % len(posterior_colors)]
142
-
165
+
143
166
  # Calculate PDF values using scipy.stats
144
167
  pdf_values = stats.norm.pdf(x_axis, loc=mean, scale=std)
145
-
168
+
146
169
  # Determine label
147
170
  if len(posterior_measures) == 1:
148
- label = f'Posterior PDF (Mean: {mean:.5f})'
171
+ label = f"Posterior PDF (Mean: {mean:.5f})"
149
172
  else:
150
- label = f'Posterior {i+1} (Mean: {mean:.5f})'
151
-
173
+ label = f"Posterior {i+1} (Mean: {mean:.5f})"
174
+
152
175
  plot_ax.plot(x_axis, pdf_values, color=color, lw=2, label=label)
153
176
  plot_ax.fill_between(x_axis, pdf_values, color=color, alpha=0.2)
154
-
177
+
155
178
  # Plot true value if provided
156
179
  if true_value is not None:
157
- ax1.axvline(true_value, color='black', linestyle='-', lw=2,
158
- label=f'True Value: {true_value:.5f}')
159
-
180
+ ax1.axvline(
181
+ true_value,
182
+ color="black",
183
+ linestyle="-",
184
+ lw=2,
185
+ label=f"True Value: {true_value:.5f}",
186
+ )
187
+
160
188
  # Create combined legend
161
189
  handles1, labels1 = ax1.get_legend_handles_labels()
162
-
190
+
163
191
  if prior_measures is not None:
164
192
  handles2, labels2 = ax2.get_legend_handles_labels()
165
193
  all_handles = handles1 + handles2
@@ -167,14 +195,14 @@ def plot_1d_distributions(
167
195
  else:
168
196
  all_handles = handles1
169
197
  all_labels = [h.get_label() for h in all_handles]
170
-
171
- fig.legend(all_handles, all_labels, loc='upper right', bbox_to_anchor=(0.9, 0.9))
198
+
199
+ fig.legend(all_handles, all_labels, loc="upper right", bbox_to_anchor=(0.9, 0.9))
172
200
  fig.suptitle(title, fontsize=16)
173
201
  fig.tight_layout(rect=[0, 0, 1, 0.96])
174
-
202
+
175
203
  if show_plot:
176
204
  plt.show()
177
-
205
+
178
206
  if prior_measures is not None:
179
207
  return fig, (ax1, ax2)
180
208
  else:
@@ -183,17 +211,19 @@ def plot_1d_distributions(
183
211
 
184
212
  def plot_corner_distributions(
185
213
  posterior_measure: object,
214
+ /,
215
+ *,
186
216
  true_values: Optional[Union[List[float], np.ndarray]] = None,
187
217
  labels: Optional[List[str]] = None,
188
218
  title: str = "Joint Posterior Distribution",
189
219
  figsize: Optional[tuple] = None,
190
220
  show_plot: bool = True,
191
221
  include_sigma_contours: bool = True,
192
- colormap: str = "Blues"
222
+ colormap: str = "Blues",
193
223
  ):
194
224
  """
195
225
  Create a corner plot for multi-dimensional posterior distributions.
196
-
226
+
197
227
  Args:
198
228
  posterior_measure: Multi-dimensional posterior measure (pygeoinf object)
199
229
  true_values: True values for each dimension (optional)
@@ -203,148 +233,180 @@ def plot_corner_distributions(
203
233
  show_plot: Whether to display the plot
204
234
  include_sigma_contours: Whether to include 1-sigma contour lines
205
235
  colormap: Colormap for 2D plots
206
-
236
+
207
237
  Returns:
208
238
  fig, axes: Figure and axes array
209
239
  """
210
-
240
+
211
241
  # Extract statistics from the measure
212
- if hasattr(posterior_measure, 'expectation') and hasattr(posterior_measure, 'covariance'):
242
+ if hasattr(posterior_measure, "expectation") and hasattr(
243
+ posterior_measure, "covariance"
244
+ ):
213
245
  mean_posterior = posterior_measure.expectation
214
246
  cov_posterior = posterior_measure.covariance.matrix(dense=True, parallel=True)
215
247
  else:
216
- raise ValueError("posterior_measure must have 'expectation' and 'covariance' attributes")
217
-
248
+ raise ValueError(
249
+ "posterior_measure must have 'expectation' and 'covariance' attributes"
250
+ )
251
+
218
252
  n_dims = len(mean_posterior)
219
-
253
+
220
254
  # Set default labels if not provided
221
255
  if labels is None:
222
256
  labels = [f"Dimension {i+1}" for i in range(n_dims)]
223
-
257
+
224
258
  # Set figure size based on dimensions if not provided
225
259
  if figsize is None:
226
260
  figsize = (3 * n_dims, 3 * n_dims)
227
-
261
+
228
262
  # Create subplots
229
263
  fig, axes = plt.subplots(n_dims, n_dims, figsize=figsize)
230
264
  fig.suptitle(title, fontsize=16)
231
-
265
+
232
266
  # Ensure axes is always 2D array
233
267
  if n_dims == 1:
234
268
  axes = np.array([[axes]])
235
269
  elif n_dims == 2:
236
270
  axes = axes.reshape(2, 2)
237
-
271
+
238
272
  # Initialize pcm variable for colorbar
239
273
  pcm = None
240
-
274
+
241
275
  for i in range(n_dims):
242
276
  for j in range(n_dims):
243
277
  ax = axes[i, j]
244
-
278
+
245
279
  if i == j: # Diagonal plots (1D marginal distributions)
246
280
  mu = mean_posterior[i]
247
281
  sigma = np.sqrt(cov_posterior[i, i])
248
-
282
+
249
283
  # Create x-axis range
250
284
  x = np.linspace(mu - 3.75 * sigma, mu + 3.75 * sigma, 200)
251
285
  pdf = stats.norm.pdf(x, mu, sigma)
252
-
286
+
253
287
  # Plot the PDF
254
288
  ax.plot(x, pdf, "darkblue", label="Posterior PDF")
255
289
  ax.fill_between(x, pdf, color="lightblue", alpha=0.6)
256
-
290
+
257
291
  # Add true value if provided
258
292
  if true_values is not None:
259
293
  true_val = true_values[i]
260
- ax.axvline(true_val, color="black", linestyle="-",
261
- label=f"True: {true_val:.2f}")
262
-
294
+ ax.axvline(
295
+ true_val,
296
+ color="black",
297
+ linestyle="-",
298
+ label=f"True: {true_val:.2f}",
299
+ )
300
+
263
301
  ax.set_xlabel(labels[i])
264
302
  ax.set_ylabel("Density" if i == 0 else "")
265
303
  ax.set_yticklabels([])
266
-
304
+
267
305
  elif i > j: # Lower triangle: 2D joint distributions
268
306
  # Extract 2D mean and covariance
269
307
  mean_2d = np.array([mean_posterior[j], mean_posterior[i]])
270
- cov_2d = np.array([
271
- [cov_posterior[j, j], cov_posterior[j, i]],
272
- [cov_posterior[i, j], cov_posterior[i, i]]
273
- ])
274
-
308
+ cov_2d = np.array(
309
+ [
310
+ [cov_posterior[j, j], cov_posterior[j, i]],
311
+ [cov_posterior[i, j], cov_posterior[i, i]],
312
+ ]
313
+ )
314
+
275
315
  # Create 2D grid
276
316
  sigma_j = np.sqrt(cov_posterior[j, j])
277
317
  sigma_i = np.sqrt(cov_posterior[i, i])
278
-
279
- x_range = np.linspace(mean_2d[0] - 3.75 * sigma_j,
280
- mean_2d[0] + 3.75 * sigma_j, 100)
281
- y_range = np.linspace(mean_2d[1] - 3.75 * sigma_i,
282
- mean_2d[1] + 3.75 * sigma_i, 100)
283
-
318
+
319
+ x_range = np.linspace(
320
+ mean_2d[0] - 3.75 * sigma_j, mean_2d[0] + 3.75 * sigma_j, 100
321
+ )
322
+ y_range = np.linspace(
323
+ mean_2d[1] - 3.75 * sigma_i, mean_2d[1] + 3.75 * sigma_i, 100
324
+ )
325
+
284
326
  X, Y = np.meshgrid(x_range, y_range)
285
327
  pos = np.dstack((X, Y))
286
-
328
+
287
329
  # Calculate PDF values
288
330
  rv = stats.multivariate_normal(mean_2d, cov_2d)
289
331
  Z = rv.pdf(pos)
290
-
332
+
291
333
  # Create filled contour plot using pcolormesh like the original
292
334
  pcm = ax.pcolormesh(
293
- X, Y, Z, shading="auto", cmap=colormap,
294
- norm=colors.LogNorm(vmin=Z.min(), vmax=Z.max())
335
+ X,
336
+ Y,
337
+ Z,
338
+ shading="auto",
339
+ cmap=colormap,
340
+ norm=colors.LogNorm(vmin=Z.min(), vmax=Z.max()),
295
341
  )
296
-
342
+
297
343
  # Add contour lines
298
344
  ax.contour(X, Y, Z, colors="black", linewidths=0.5, alpha=0.6)
299
-
345
+
300
346
  # Add 1-sigma contour if requested
301
347
  if include_sigma_contours:
302
348
  # Calculate 1-sigma level (approximately 39% of peak for 2D Gaussian)
303
349
  sigma_level = rv.pdf(mean_2d) * np.exp(-0.5)
304
- ax.contour(X, Y, Z, levels=[sigma_level], colors="red",
305
- linewidths=1, linestyles="--", alpha=0.8)
306
-
350
+ ax.contour(
351
+ X,
352
+ Y,
353
+ Z,
354
+ levels=[sigma_level],
355
+ colors="red",
356
+ linewidths=1,
357
+ linestyles="--",
358
+ alpha=0.8,
359
+ )
360
+
307
361
  # Plot mean point
308
- ax.plot(mean_posterior[j], mean_posterior[i], "r+",
309
- markersize=10, mew=2, label="Posterior Mean")
310
-
362
+ ax.plot(
363
+ mean_posterior[j],
364
+ mean_posterior[i],
365
+ "r+",
366
+ markersize=10,
367
+ mew=2,
368
+ label="Posterior Mean",
369
+ )
370
+
311
371
  # Plot true value if provided
312
372
  if true_values is not None:
313
- ax.plot(true_values[j], true_values[i], "kx",
314
- markersize=10, mew=2, label="True Value")
315
-
373
+ ax.plot(
374
+ true_values[j],
375
+ true_values[i],
376
+ "kx",
377
+ markersize=10,
378
+ mew=2,
379
+ label="True Value",
380
+ )
381
+
316
382
  ax.set_xlabel(labels[j])
317
383
  ax.set_ylabel(labels[i])
318
-
384
+
319
385
  else: # Upper triangle: hide these plots
320
386
  ax.axis("off")
321
-
387
+
322
388
  # Create legend similar to the original
323
389
  handles, labels_leg = axes[0, 0].get_legend_handles_labels()
324
390
  if n_dims > 1:
325
391
  handles2, labels2 = axes[1, 0].get_legend_handles_labels()
326
392
  handles.extend(handles2)
327
393
  labels_leg.extend(labels2)
328
-
394
+
329
395
  # Clean up labels by removing values after colons
330
396
  cleaned_labels = [label.split(":")[0] for label in labels_leg]
331
-
332
- fig.legend(
333
- handles, cleaned_labels,
334
- loc="upper right",
335
- bbox_to_anchor=(0.9, 0.95)
336
- )
337
-
397
+
398
+ fig.legend(handles, cleaned_labels, loc="upper right", bbox_to_anchor=(0.9, 0.95))
399
+
338
400
  # Adjust main plot layout to make room on the right for the colorbar
339
401
  plt.tight_layout(rect=[0, 0, 0.88, 0.96])
340
-
402
+
341
403
  # Add a colorbar if we have 2D plots
342
404
  if n_dims > 1 and pcm is not None:
343
405
  cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.7])
344
406
  cbar = fig.colorbar(pcm, cax=cbar_ax)
345
407
  cbar.set_label("Probability Density", size=12)
346
-
408
+
347
409
  if show_plot:
348
410
  plt.show()
349
-
350
- return fig, axes
411
+
412
+ return fig, axes
@@ -0,0 +1,140 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Optional
3
+ import numpy as np
4
+
5
+ from .linear_operators import LinearOperator, DiagonalSparseMatrixLinearOperator
6
+ from .linear_solvers import LinearSolver, IterativeLinearSolver
7
+ from .random_matrix import random_diagonal
8
+
9
+ if TYPE_CHECKING:
10
+ from .hilbert_space import Vector
11
+
12
+
13
+ class IdentityPreconditioningMethod(LinearSolver):
14
+ """
15
+ A trivial preconditioning method that returns the Identity operator.
16
+
17
+ This acts as a "no-op" placeholder in the preconditioning framework,
18
+ useful for benchmarking or default configurations.
19
+ """
20
+
21
+ def __call__(self, operator: LinearOperator) -> LinearOperator:
22
+ """
23
+ Returns the identity operator for the domain of the input operator.
24
+ """
25
+ return operator.domain.identity_operator()
26
+
27
+
28
+ class JacobiPreconditioningMethod(LinearSolver):
29
+ """
30
+ A LinearSolver wrapper that generates a Jacobi preconditioner.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ num_samples: Optional[int] = 20,
36
+ method: str = "variable",
37
+ rtol: float = 1e-2,
38
+ block_size: int = 10,
39
+ parallel: bool = True,
40
+ n_jobs: int = -1,
41
+ ) -> None:
42
+ # Damping is removed: the operator passed to __call__ is already damped
43
+ self._num_samples = num_samples
44
+ self._method = method
45
+ self._rtol = rtol
46
+ self._block_size = block_size
47
+ self._parallel = parallel
48
+ self._n_jobs = n_jobs
49
+
50
+ def __call__(self, operator: LinearOperator) -> LinearOperator:
51
+ # Hutchinson's method or exact extraction on the damped normal operator
52
+ if self._num_samples is not None:
53
+ diag_values = random_diagonal(
54
+ operator.matrix(galerkin=True),
55
+ self._num_samples,
56
+ method=self._method,
57
+ rtol=self._rtol,
58
+ block_size=self._block_size,
59
+ parallel=self._parallel,
60
+ n_jobs=self._n_jobs,
61
+ )
62
+ else:
63
+ diag_values = operator.extract_diagonal(
64
+ galerkin=True, parallel=self._parallel, n_jobs=self._n_jobs
65
+ )
66
+
67
+ inv_diag = np.where(np.abs(diag_values) > 1e-14, 1.0 / diag_values, 1.0)
68
+
69
+ return DiagonalSparseMatrixLinearOperator.from_diagonal_values(
70
+ operator.domain, operator.domain, inv_diag, galerkin=True
71
+ )
72
+
73
+
74
+ class SpectralPreconditioningMethod(LinearSolver):
75
+ """
76
+ A LinearSolver wrapper that generates a spectral (low-rank) preconditioner.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ damping: float,
82
+ rank: int = 20,
83
+ power: int = 2,
84
+ ) -> None:
85
+ self._damping = damping
86
+ self._rank = rank
87
+ self._power = power
88
+
89
+ def __call__(self, operator: LinearOperator) -> LinearOperator:
90
+ """
91
+ Generates a spectral preconditioner.
92
+ Note: This assumes the operator provided is the data-misfit operator A*WA.
93
+ """
94
+ space = operator.domain
95
+
96
+ # Use randomized eigendecomposition to find dominant modes
97
+ U, S = operator.random_eig(self._rank, power=self._power)
98
+
99
+ s_vals = S.extract_diagonal()
100
+ d_vals = s_vals / (s_vals + self._damping**2)
101
+
102
+ def mapping(r: Vector) -> Vector:
103
+ ut_r = U.adjoint(r)
104
+ d_ut_r = d_vals * ut_r
105
+ correction = U(d_ut_r)
106
+
107
+ diff = space.subtract(r, correction)
108
+ return space.multiply(1.0 / self._damping**2, diff)
109
+
110
+ return LinearOperator(space, space, mapping, adjoint_mapping=mapping)
111
+
112
+
113
+ class IterativePreconditioningMethod(LinearSolver):
114
+ """
115
+ Wraps an iterative solver to act as a preconditioner.
116
+
117
+ This is best used with FCGSolver to handle the potential
118
+ variability of the inner iterations.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ inner_solver: IterativeLinearSolver,
124
+ max_inner_iter: int = 5,
125
+ rtol: float = 1e-1,
126
+ ) -> None:
127
+ self._inner_solver = inner_solver
128
+ self._max_iter = max_inner_iter
129
+ self._rtol = rtol
130
+
131
+ def __call__(self, operator: LinearOperator) -> LinearOperator:
132
+ """
133
+ Returns a LinearOperator whose action is 'solve the system'.
134
+ """
135
+ # We override the inner solver parameters for efficiency
136
+ self._inner_solver._maxiter = self._max_iter
137
+ self._inner_solver._rtol = self._rtol
138
+
139
+ # The solver's __call__ returns the InverseLinearOperator
140
+ return self._inner_solver(operator)
pygeoinf/random_matrix.py CHANGED
@@ -182,11 +182,14 @@ def variable_rank_random_range(
182
182
  basis_vectors = np.hstack([basis_vectors, new_basis[:, :cols_to_add]])
183
183
 
184
184
  if not converged and basis_vectors.shape[1] >= max_rank:
185
- warnings.warn(
186
- f"Tolerance {rtol} not met before reaching max_rank={max_rank}. "
187
- "Result may be inaccurate. Consider increasing `max_rank` or `power`.",
188
- UserWarning,
189
- )
185
+ # If we reached the full dimension of the matrix,
186
+ # the result is exact, so no warning is needed.
187
+ if max_rank < min(m, n):
188
+ warnings.warn(
189
+ f"Tolerance {rtol} not met before reaching max_rank={max_rank}. "
190
+ "Result may be inaccurate. Consider increasing `max_rank` or `power`.",
191
+ UserWarning,
192
+ )
190
193
 
191
194
  return basis_vectors
192
195