pygeoinf 1.3.1__tar.gz → 1.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/PKG-INFO +3 -1
  2. pygeoinf-1.3.3/pygeoinf/auxiliary.py +29 -0
  3. pygeoinf-1.3.3/pygeoinf/checks/linear_operators.py +197 -0
  4. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/checks/nonlinear_operators.py +63 -19
  5. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pyproject.toml +5 -1
  6. pygeoinf-1.3.1/pygeoinf/checks/linear_operators.py +0 -124
  7. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/LICENSE +0 -0
  8. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/README.md +0 -0
  9. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/__init__.py +0 -0
  10. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/backus_gilbert.py +0 -0
  11. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/checks/hilbert_space.py +0 -0
  12. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/direct_sum.py +0 -0
  13. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/forward_problem.py +0 -0
  14. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/gaussian_measure.py +0 -0
  15. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/hilbert_space.py +0 -0
  16. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/inversion.py +0 -0
  17. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/linear_bayesian.py +0 -0
  18. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/linear_forms.py +0 -0
  19. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/linear_operators.py +0 -0
  20. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/linear_optimisation.py +0 -0
  21. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/linear_solvers.py +0 -0
  22. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/nonlinear_forms.py +0 -0
  23. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/nonlinear_operators.py +0 -0
  24. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/nonlinear_optimisation.py +0 -0
  25. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/parallel.py +0 -0
  26. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/plot.py +0 -0
  27. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/random_matrix.py +0 -0
  28. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/symmetric_space/__init__.py +0 -0
  29. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/symmetric_space/circle.py +0 -0
  30. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/symmetric_space/sphere.py +0 -0
  31. {pygeoinf-1.3.1 → pygeoinf-1.3.3}/pygeoinf/symmetric_space/symmetric_space.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pygeoinf
3
- Version: 1.3.1
3
+ Version: 1.3.3
4
4
  Summary: A package for solving geophysical inference and inverse problems
5
5
  License: BSD-3-Clause
6
6
  License-File: LICENSE
@@ -13,10 +13,12 @@ Classifier: Programming Language :: Python :: 3.12
13
13
  Classifier: Programming Language :: Python :: 3.13
14
14
  Classifier: Programming Language :: Python :: 3.14
15
15
  Provides-Extra: sphere
16
+ Requires-Dist: Cartopy (>=0.23.0,<0.24.0) ; extra == "sphere"
16
17
  Requires-Dist: joblib (>=1.5.2,<2.0.0)
17
18
  Requires-Dist: matplotlib (>=3.0.0)
18
19
  Requires-Dist: numpy (>=1.26.0)
19
20
  Requires-Dist: pyqt6 (>=6.0.0)
21
+ Requires-Dist: pyshtools (>=4.0.0) ; extra == "sphere"
20
22
  Requires-Dist: scipy (>=1.16.1)
21
23
  Description-Content-Type: text/markdown
22
24
 
@@ -0,0 +1,29 @@
1
+ from .gaussian_measure import GaussianMeasure
2
+
3
+ def empirical_data_error_measure(model_measure, forward_operator, n_samples=10, scale_factor=1.0):
4
+ """
5
+ Generate an empirical data error measure based on samples from a measure on the model space. Useful for when you need
6
+ to define a reasonable data error measure for synthetic testing, and need the covariance matrix to be easily accessible.
7
+
8
+ Args:
9
+ model_measure: The measure on the model space used as a basis for the error measure (e.g., the model prior measure)
10
+ forward_operator: Linear operator mapping from model space to data space (e.g., operator B)
11
+ n_samples: Number of samples to generate for computing statistics (default: 10)
12
+ scale_factor: Scaling factor for the standard deviations (default: 1.0)
13
+
14
+ Returns:
15
+ inf.GaussianMeasure: Data error measure with empirically determined covariance
16
+ """
17
+ # Generate samples in data space by pushing forward model samples
18
+ data_samples = model_measure.affine_mapping(operator=forward_operator).samples(n_samples)
19
+ data_space = forward_operator.codomain
20
+
21
+ # Remove the mean from each sample
22
+ total = data_space.zero
23
+ for sample in data_samples:
24
+ total = data_space.add(total, sample)
25
+ mean = data_space.multiply(1.0 / n_samples, total)
26
+ zeroed_samples = [data_space.multiply(scale_factor, data_space.subtract(data_sample, mean)) for data_sample in data_samples]
27
+
28
+ # Create and return the Gaussian measure from the zeroed samples
29
+ return GaussianMeasure.from_samples(forward_operator.codomain, zeroed_samples)
@@ -0,0 +1,197 @@
1
+ """
2
+ Provides a self-checking mechanism for LinearOperator implementations.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import TYPE_CHECKING
7
+ import numpy as np
8
+
9
+ # Import the base checks from the sibling module
10
+ from .nonlinear_operators import NonLinearOperatorAxiomChecks
11
+
12
+
13
+ if TYPE_CHECKING:
14
+ from ..hilbert_space import Vector
15
+ from ..linear_forms import LinearForm
16
+
17
+
18
+ class LinearOperatorAxiomChecks(NonLinearOperatorAxiomChecks):
19
+ """
20
+ A mixin for checking the properties of a LinearOperator.
21
+
22
+ Inherits the derivative check from NonLinearOperatorAxiomChecks and adds
23
+ checks for linearity and the adjoint identity.
24
+ """
25
+
26
+ def _check_linearity(
27
+ self,
28
+ x: Vector,
29
+ y: Vector,
30
+ a: float,
31
+ b: float,
32
+ check_rtol: float = 1e-5,
33
+ check_atol: float = 1e-8,
34
+ ):
35
+ """Verifies the linearity property: L(ax + by) = a*L(x) + b*L(y)"""
36
+ ax_plus_by = self.domain.add(
37
+ self.domain.multiply(a, x), self.domain.multiply(b, y)
38
+ )
39
+ lhs = self(ax_plus_by)
40
+
41
+ aLx = self.codomain.multiply(a, self(x))
42
+ bLy = self.codomain.multiply(b, self(y))
43
+ rhs = self.codomain.add(aLx, bLy)
44
+
45
+ # Compare the results in the codomain
46
+ diff_norm = self.codomain.norm(self.codomain.subtract(lhs, rhs))
47
+ rhs_norm = self.codomain.norm(rhs)
48
+ relative_error = diff_norm / (rhs_norm + 1e-12)
49
+
50
+ if relative_error > check_rtol and diff_norm > check_atol:
51
+ raise AssertionError(
52
+ f"Linearity check failed: L(ax+by) != aL(x)+bL(y). "
53
+ f"Relative error: {relative_error:.2e} (Tol: {check_rtol:.2e}), "
54
+ f"Absolute error: {diff_norm:.2e} (Tol: {check_atol:.2e})"
55
+ )
56
+
57
+ def _check_adjoint_definition(
58
+ self,
59
+ x: Vector,
60
+ y: Vector,
61
+ check_rtol: float = 1e-5,
62
+ check_atol: float = 1e-8,
63
+ ):
64
+ """Verifies the adjoint identity: <L(x), y> = <x, L*(y)>"""
65
+ lhs = self.codomain.inner_product(self(x), y)
66
+ rhs = self.domain.inner_product(x, self.adjoint(y))
67
+
68
+ if not np.isclose(lhs, rhs, rtol=check_rtol, atol=check_atol):
69
+ raise AssertionError(
70
+ f"Adjoint definition failed: <L(x),y> = {lhs:.4e}, "
71
+ f"but <x,L*(y)> = {rhs:.4e} (RelTol: {check_rtol:.2e}, AbsTol: {check_atol:.2e})"
72
+ )
73
+
74
+ def _check_algebraic_identities(
75
+ self,
76
+ op1,
77
+ op2,
78
+ x,
79
+ y,
80
+ a,
81
+ check_rtol: float = 1e-5,
82
+ check_atol: float = 1e-8,
83
+ ):
84
+ """
85
+ Verifies the algebraic properties of the adjoint and dual operators.
86
+ Requires a second compatible operator (op2).
87
+ """
88
+
89
+ def _check_norm_based(res1, res2, space, axiom_name):
90
+ """Helper to perform norm-based comparison."""
91
+ diff_norm = space.norm(space.subtract(res1, res2))
92
+ norm_res2 = space.norm(res2)
93
+ if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
94
+ raise AssertionError(
95
+ f"Axiom failed: {axiom_name}. "
96
+ f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
97
+ )
98
+
99
+ # --- Adjoint Identities ---
100
+ # (A+B)* = A* + B*
101
+ res1 = (op1 + op2).adjoint(y)
102
+ res2 = (op1.adjoint + op2.adjoint)(y)
103
+ _check_norm_based(res1, res2, op1.domain, "(A+B)* != A* + B*")
104
+
105
+ # (a*A)* = a*A*
106
+ res1 = (a * op1).adjoint(y)
107
+ res2 = (a * op1.adjoint)(y)
108
+ _check_norm_based(res1, res2, op1.domain, "(a*A)* != a*A*")
109
+
110
+ # (A*)* = A
111
+ res1 = op1.adjoint.adjoint(x)
112
+ res2 = op1(x)
113
+ _check_norm_based(res1, res2, op1.codomain, "(A*)* != A")
114
+
115
+ # (A@B)* = B*@A*
116
+ if op1.domain == op2.codomain:
117
+ res1 = (op1 @ op2).adjoint(y)
118
+ res2 = (op2.adjoint @ op1.adjoint)(y)
119
+ _check_norm_based(res1, res2, op2.domain, "(A@B)* != B*@A*")
120
+
121
+ # --- Dual Identities ---
122
+ # (A+B)' = A' + B'
123
+ op_sum_dual = (op1 + op2).dual
124
+ dual_sum = op1.dual + op2.dual
125
+ y_dual = op1.codomain.to_dual(y)
126
+
127
+ # The result of applying a dual operator is a LinearForm
128
+ res1_form: LinearForm = op_sum_dual(y_dual)
129
+ res2_form: LinearForm = dual_sum(y_dual)
130
+
131
+ # CORRECTED: Use LinearForm subtraction and dual space norm
132
+ # (This assumes LinearForm overloads __sub__)
133
+ try:
134
+ diff_form = res1_form - res2_form
135
+ diff_norm = op1.domain.dual.norm(diff_form)
136
+ norm_res2 = op1.domain.dual.norm(res2_form)
137
+
138
+ if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
139
+ raise AssertionError(
140
+ f"Axiom failed: (A+B)' != A' + B'. "
141
+ f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
142
+ )
143
+ except (AttributeError, TypeError):
144
+ # Fallback if LinearForm doesn't support subtraction or norm
145
+ if not np.allclose(
146
+ res1_form.components,
147
+ res2_form.components,
148
+ rtol=check_rtol,
149
+ atol=check_atol,
150
+ ):
151
+ raise AssertionError(
152
+ "Axiom failed: (A+B)' != A' + B' (component check)."
153
+ )
154
+
155
+ def check(
156
+ self,
157
+ n_checks: int = 5,
158
+ op2=None,
159
+ check_rtol: float = 1e-5,
160
+ check_atol: float = 1e-8,
161
+ ) -> None:
162
+ """
163
+ Runs all checks for the LinearOperator, including non-linear checks
164
+ and algebraic identities.
165
+
166
+ Args:
167
+ n_checks: The number of randomized trials to perform.
168
+ op2: An optional second operator for testing algebraic rules.
169
+ check_rtol: The relative tolerance for numerical checks.
170
+ check_atol: The absolute tolerance for numerical checks.
171
+ """
172
+ # First, run the parent (non-linear) checks from the base class
173
+ super().check(n_checks, op2=op2, check_rtol=check_rtol, check_atol=check_atol)
174
+
175
+ # Now, run the linear-specific checks
176
+ print(
177
+ f"Running {n_checks} additional randomized checks for linearity and adjoints..."
178
+ )
179
+ for _ in range(n_checks):
180
+ x1 = self.domain.random()
181
+ x2 = self.domain.random()
182
+ y = self.codomain.random()
183
+ a, b = np.random.randn(), np.random.randn()
184
+
185
+ self._check_linearity(
186
+ x1, x2, a, b, check_rtol=check_rtol, check_atol=check_atol
187
+ )
188
+ self._check_adjoint_definition(
189
+ x1, y, check_rtol=check_rtol, check_atol=check_atol
190
+ )
191
+
192
+ if op2:
193
+ self._check_algebraic_identities(
194
+ self, op2, x1, y, a, check_rtol=check_rtol, check_atol=check_atol
195
+ )
196
+
197
+ print(f"✅ All {n_checks} linear operator checks passed successfully.")
@@ -13,7 +13,9 @@ if TYPE_CHECKING:
13
13
  class NonLinearOperatorAxiomChecks:
14
14
  """A mixin for checking the properties of a NonLinearOperator."""
15
15
 
16
- def _check_derivative_finite_difference(self, x, v, h=1e-7):
16
+ def _check_derivative_finite_difference(
17
+ self, x, v, h=1e-7, check_rtol: float = 1e-5, check_atol: float = 1e-8
18
+ ):
17
19
  """
18
20
  Verifies the derivative using the finite difference formula:
19
21
  D[F](x) @ v ≈ (F(x + h*v) - F(x)) / h
@@ -49,12 +51,20 @@ class NonLinearOperatorAxiomChecks:
49
51
  analytic_norm = self.codomain.norm(analytic_result)
50
52
  relative_error = diff_norm / (analytic_norm + 1e-12)
51
53
 
52
- if relative_error > 1e-4:
54
+ # The finite difference method itself has an error, so we use
55
+ # the max of the requested rtol and a default 1e-4.
56
+ effective_rtol = max(check_rtol, 1e-4)
57
+
58
+ if relative_error > effective_rtol and diff_norm > check_atol:
53
59
  raise AssertionError(
54
- f"Finite difference check failed. Relative error: {relative_error:.2e}"
60
+ f"Finite difference check failed. Relative error: {relative_error:.2e} "
61
+ f"(Tolerance: {effective_rtol:.2e}), "
62
+ f"Absolute error: {diff_norm:.2e} (Tol: {check_atol:.2e})"
55
63
  )
56
64
 
57
- def _check_add_derivative(self, op1, op2, x, v):
65
+ def _check_add_derivative(
66
+ self, op1, op2, x, v, check_rtol: float = 1e-5, check_atol: float = 1e-8
67
+ ):
58
68
  """Verifies the sum rule for derivatives: (F+G)' = F' + G'"""
59
69
  if not (op1.has_derivative and op2.has_derivative):
60
70
  return # Skip if derivatives aren't defined
@@ -70,11 +80,19 @@ class NonLinearOperatorAxiomChecks:
70
80
  res1 = derivative_of_sum(v)
71
81
  res2 = sum_of_derivatives(v)
72
82
 
73
- diff_norm = self.codomain.norm(self.codomain.subtract(res1, res2))
74
- if diff_norm > 1e-9:
75
- raise AssertionError("Axiom failed: Derivative of sum is incorrect.")
83
+ # CORRECTED: Use norm-based comparison, not np.allclose
84
+ diff_norm = op1.codomain.norm(op1.codomain.subtract(res1, res2))
85
+ norm_res2 = op1.codomain.norm(res2)
86
+
87
+ if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
88
+ raise AssertionError(
89
+ f"Axiom failed: Derivative of sum is incorrect. "
90
+ f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
91
+ )
76
92
 
77
- def _check_scalar_mul_derivative(self, op, x, v, a):
93
+ def _check_scalar_mul_derivative(
94
+ self, op, x, v, a, check_rtol: float = 1e-5, check_atol: float = 1e-8
95
+ ):
78
96
  """Verifies the scalar multiple rule: (a*F)' = a*F'"""
79
97
  if not op.has_derivative:
80
98
  return
@@ -90,13 +108,19 @@ class NonLinearOperatorAxiomChecks:
90
108
  res1 = derivative_of_scaled(v)
91
109
  res2 = scaled_derivative(v)
92
110
 
93
- diff_norm = self.codomain.norm(self.codomain.subtract(res1, res2))
94
- if diff_norm > 1e-9:
111
+ # CORRECTED: Use norm-based comparison
112
+ diff_norm = op.codomain.norm(op.codomain.subtract(res1, res2))
113
+ norm_res2 = op.codomain.norm(res2)
114
+
115
+ if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
95
116
  raise AssertionError(
96
- "Axiom failed: Derivative of scalar multiple is incorrect."
117
+ f"Axiom failed: Derivative of scalar multiple is incorrect. "
118
+ f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
97
119
  )
98
120
 
99
- def _check_matmul_derivative(self, op1, op2, x, v):
121
+ def _check_matmul_derivative(
122
+ self, op1, op2, x, v, check_rtol: float = 1e-5, check_atol: float = 1e-8
123
+ ):
100
124
  """Verifies the chain rule for derivatives: (F o G)'(x) = F'(G(x)) @ G'(x)"""
101
125
  if not (op1.has_derivative and op2.has_derivative):
102
126
  return
@@ -115,13 +139,23 @@ class NonLinearOperatorAxiomChecks:
115
139
  res1 = derivative_of_composed(v)
116
140
  res2 = chain_rule_derivative(v)
117
141
 
142
+ # CORRECTED: Use norm-based comparison
118
143
  diff_norm = op1.codomain.norm(op1.codomain.subtract(res1, res2))
119
- if diff_norm > 1e-9:
144
+ norm_res2 = op1.codomain.norm(res2)
145
+
146
+ if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
120
147
  raise AssertionError(
121
- "Axiom failed: Chain rule for derivatives is incorrect."
148
+ f"Axiom failed: Chain rule for derivatives is incorrect. "
149
+ f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
122
150
  )
123
151
 
124
- def check(self, n_checks: int = 5, op2=None) -> None:
152
+ def check(
153
+ self,
154
+ n_checks: int = 5,
155
+ op2=None,
156
+ check_rtol: float = 1e-5,
157
+ check_atol: float = 1e-8,
158
+ ) -> None:
125
159
  """
126
160
  Runs randomized checks to validate the operator's derivative and
127
161
  its algebraic properties.
@@ -129,6 +163,8 @@ class NonLinearOperatorAxiomChecks:
129
163
  Args:
130
164
  n_checks: The number of randomized trials to perform.
131
165
  op2: An optional second operator for testing algebraic rules.
166
+ check_rtol: The relative tolerance for numerical checks.
167
+ check_atol: The absolute tolerance for numerical checks.
132
168
  """
133
169
  print(
134
170
  f"\nRunning {n_checks} randomized checks for {self.__class__.__name__}..."
@@ -143,12 +179,20 @@ class NonLinearOperatorAxiomChecks:
143
179
  v = self.domain.random()
144
180
 
145
181
  # Original check
146
- self._check_derivative_finite_difference(x, v)
182
+ self._check_derivative_finite_difference(
183
+ x, v, check_rtol=check_rtol, check_atol=check_atol
184
+ )
147
185
 
148
186
  # New algebraic checks
149
- self._check_scalar_mul_derivative(self, x, v, a)
187
+ self._check_scalar_mul_derivative(
188
+ self, x, v, a, check_rtol=check_rtol, check_atol=check_atol
189
+ )
150
190
  if op2:
151
- self._check_add_derivative(self, op2, x, v)
152
- self._check_matmul_derivative(self, op2, x, v)
191
+ self._check_add_derivative(
192
+ self, op2, x, v, check_rtol=check_rtol, check_atol=check_atol
193
+ )
194
+ self._check_matmul_derivative(
195
+ self, op2, x, v, check_rtol=check_rtol, check_atol=check_atol
196
+ )
153
197
 
154
198
  print(f"✅ All {n_checks} non-linear operator checks passed successfully.")
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "pygeoinf"
3
- version = "1.3.1"
3
+ version = "1.3.3"
4
4
  description = "A package for solving geophysical inference and inverse problems"
5
5
  authors = ["David Al-Attar and Dan Heathcote"]
6
6
  readme = "README.md"
@@ -14,6 +14,10 @@ matplotlib = ">=3.0.0"
14
14
  pyqt6 = ">=6.0.0"
15
15
  joblib = "^1.5.2"
16
16
 
17
+ # Define optional dependencies WITH versions
18
+ pyshtools = { version = ">=4.0.0", optional = true }
19
+ Cartopy = { version = "^0.23.0", optional = true }
20
+
17
21
  [tool.poetry.extras]
18
22
  sphere = ["pyshtools", "Cartopy"]
19
23
 
@@ -1,124 +0,0 @@
1
- """
2
- Provides a self-checking mechanism for LinearOperator implementations.
3
- """
4
-
5
- from __future__ import annotations
6
- from typing import TYPE_CHECKING
7
- import numpy as np
8
-
9
- # Import the base checks from the sibling module
10
- from .nonlinear_operators import NonLinearOperatorAxiomChecks
11
-
12
-
13
- if TYPE_CHECKING:
14
- from ..hilbert_space import Vector
15
-
16
-
17
- class LinearOperatorAxiomChecks(NonLinearOperatorAxiomChecks):
18
- """
19
- A mixin for checking the properties of a LinearOperator.
20
-
21
- Inherits the derivative check from NonLinearOperatorAxiomChecks and adds
22
- checks for linearity and the adjoint identity.
23
- """
24
-
25
- def _check_linearity(self, x: Vector, y: Vector, a: float, b: float):
26
- """Verifies the linearity property: L(ax + by) = a*L(x) + b*L(y)"""
27
- ax_plus_by = self.domain.add(
28
- self.domain.multiply(a, x), self.domain.multiply(b, y)
29
- )
30
- lhs = self(ax_plus_by)
31
-
32
- aLx = self.codomain.multiply(a, self(x))
33
- bLy = self.codomain.multiply(b, self(y))
34
- rhs = self.codomain.add(aLx, bLy)
35
-
36
- # Compare the results in the codomain
37
- diff_norm = self.codomain.norm(self.codomain.subtract(lhs, rhs))
38
- rhs_norm = self.codomain.norm(rhs)
39
- relative_error = diff_norm / (rhs_norm + 1e-12)
40
-
41
- if relative_error > 1e-9:
42
- raise AssertionError(
43
- f"Linearity check failed: L(ax+by) != aL(x)+bL(y). Relative error: {relative_error:.2e}"
44
- )
45
-
46
- def _check_adjoint_definition(self, x: Vector, y: Vector):
47
- """Verifies the adjoint identity: <L(x), y> = <x, L*(y)>"""
48
- lhs = self.codomain.inner_product(self(x), y)
49
- rhs = self.domain.inner_product(x, self.adjoint(y))
50
-
51
- if not np.isclose(lhs, rhs):
52
- raise AssertionError(
53
- f"Adjoint definition failed: <L(x),y> = {lhs:.4e}, but <x,L*(y)> = {rhs:.4e}"
54
- )
55
-
56
- def _check_algebraic_identities(self, op1, op2, x, y, a):
57
- """
58
- Verifies the algebraic properties of the adjoint and dual operators.
59
- Requires a second compatible operator (op2).
60
- """
61
- # --- Adjoint Identities ---
62
- # (A+B)* = A* + B*
63
- op_sum_adj = (op1 + op2).adjoint
64
- adj_sum = op1.adjoint + op2.adjoint
65
- diff = op1.domain.subtract(op_sum_adj(y), adj_sum(y))
66
- if op1.domain.norm(diff) > 1e-9:
67
- raise AssertionError("Axiom failed: (A+B)* != A* + B*")
68
-
69
- # (a*A)* = a*A*
70
- op_scaled_adj = (a * op1).adjoint
71
- adj_scaled = a * op1.adjoint
72
- diff = op1.domain.subtract(op_scaled_adj(y), adj_scaled(y))
73
- if op1.domain.norm(diff) > 1e-9:
74
- raise AssertionError("Axiom failed: (a*A)* != a*A*")
75
-
76
- # (A*)* = A
77
- op_adj_adj = op1.adjoint.adjoint
78
- diff = op1.codomain.subtract(op_adj_adj(x), op1(x))
79
- if op1.codomain.norm(diff) > 1e-9:
80
- raise AssertionError("Axiom failed: (A*)* != A")
81
-
82
- # (A@B)* = B*@A*
83
- if op1.domain == op2.codomain:
84
- op_comp_adj = (op1 @ op2).adjoint
85
- adj_comp = op2.adjoint @ op1.adjoint
86
- diff = op2.domain.subtract(op_comp_adj(y), adj_comp(y))
87
- if op2.domain.norm(diff) > 1e-9:
88
- raise AssertionError("Axiom failed: (A@B)* != B*@A*")
89
-
90
- # --- Dual Identities ---
91
- # (A+B)' = A' + B'
92
- op_sum_dual = (op1 + op2).dual
93
- dual_sum = op1.dual + op2.dual
94
- y_dual = op1.codomain.to_dual(y)
95
- # The result of applying a dual operator is a LinearForm, which supports subtraction
96
- diff_dual = op_sum_dual(y_dual) - dual_sum(y_dual)
97
- if op1.domain.dual.norm(diff_dual) > 1e-9:
98
- raise AssertionError("Axiom failed: (A+B)' != A' + B'")
99
-
100
- def check(self, n_checks: int = 5, op2=None) -> None:
101
- """
102
- Runs all checks for the LinearOperator, including non-linear checks
103
- and algebraic identities.
104
- """
105
- # First, run the parent (non-linear) checks from the base class
106
- super().check(n_checks, op2=op2)
107
-
108
- # Now, run the linear-specific checks
109
- print(
110
- f"Running {n_checks} additional randomized checks for linearity and adjoints..."
111
- )
112
- for _ in range(n_checks):
113
- x1 = self.domain.random()
114
- x2 = self.domain.random()
115
- y = self.codomain.random()
116
- a, b = np.random.randn(), np.random.randn()
117
-
118
- self._check_linearity(x1, x2, a, b)
119
- self._check_adjoint_definition(x1, y)
120
-
121
- if op2:
122
- self._check_algebraic_identities(self, op2, x1, y, a)
123
-
124
- print(f"✅ All {n_checks} linear operator checks passed successfully.")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes