gpclarity 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpclarity/__init__.py +190 -0
- gpclarity/_version.py +3 -0
- gpclarity/data_influence.py +501 -0
- gpclarity/exceptions.py +46 -0
- gpclarity/hyperparam_tracker.py +718 -0
- gpclarity/kernel_summary.py +285 -0
- gpclarity/model_complexity.py +619 -0
- gpclarity/plotting.py +337 -0
- gpclarity/uncertainty_analysis.py +647 -0
- gpclarity/utils.py +411 -0
- gpclarity-0.0.2.dist-info/METADATA +248 -0
- gpclarity-0.0.2.dist-info/RECORD +14 -0
- gpclarity-0.0.2.dist-info/WHEEL +4 -0
- gpclarity-0.0.2.dist-info/licenses/LICENSE +37 -0
gpclarity/__init__.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
gpclarity: Interpretability Toolkit for Gaussian Processes
|
|
3
|
+
==============================================================
|
|
4
|
+
|
|
5
|
+
Extends GPy and emukit with human-readable diagnostics, uncertainty
|
|
6
|
+
analysis, and model interpretability tools. Designed for researchers
|
|
7
|
+
and practitioners who need trustworthy GP models.
|
|
8
|
+
|
|
9
|
+
Public API Philosophy
|
|
10
|
+
---------------------
|
|
11
|
+
This module exposes only high-level interpretability interfaces.
|
|
12
|
+
Internal utilities and helpers are intentionally not exposed to maintain
|
|
13
|
+
a clean, stable API surface for users.
|
|
14
|
+
|
|
15
|
+
Import Safety
|
|
16
|
+
-------------
|
|
17
|
+
Heavy scientific computing dependencies (GPy, emukit) are wrapped in
|
|
18
|
+
lazy loading mechanisms to enable fast imports and lightweight installs
|
|
19
|
+
for documentation or analysis-only workflows.
|
|
20
|
+
|
|
21
|
+
Quick Start
|
|
22
|
+
-----------
|
|
23
|
+
>>> from gpclarity import UncertaintyProfiler, summarize_kernel
|
|
24
|
+
>>> # Features unavailable without dependencies raise informative errors
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import warnings
|
|
28
|
+
from importlib.util import find_spec
|
|
29
|
+
|
|
30
|
+
# Version source of truth - keep this lightweight (no heavy imports)
|
|
31
|
+
from ._version import __version__
|
|
32
|
+
|
|
33
|
+
# Dependency availability flags (checked without importing)
|
|
34
|
+
_GPY_AVAILABLE = find_spec("GPy") is not None
|
|
35
|
+
_EMUKIT_AVAILABLE = find_spec("emukit") is not None
|
|
36
|
+
|
|
37
|
+
# Build __all__ dynamically based on availability
|
|
38
|
+
__all__ = ["__version__", "AVAILABLE"]
|
|
39
|
+
|
|
40
|
+
# Public availability flag for programmatic checks
|
|
41
|
+
AVAILABLE = {
|
|
42
|
+
"gpy": _GPY_AVAILABLE,
|
|
43
|
+
"emukit": _EMUKIT_AVAILABLE,
|
|
44
|
+
"full": _GPY_AVAILABLE and _EMUKIT_AVAILABLE,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _UnavailableFeature:
|
|
49
|
+
"""
|
|
50
|
+
Stub that raises informative ImportError when accessed.
|
|
51
|
+
Allows `from gpclarity import X` to succeed, but fails gracefully on usage.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, name: str, dependency: str, install_extra: str):
|
|
55
|
+
self.name = name
|
|
56
|
+
self.dependency = dependency
|
|
57
|
+
self.install_extra = install_extra
|
|
58
|
+
|
|
59
|
+
def __call__(self, *args, **kwargs):
|
|
60
|
+
raise ImportError(
|
|
61
|
+
f"{self.name} requires {self.dependency} which is not installed. "
|
|
62
|
+
f"Install with: pip install gpclarity[{self.install_extra}]"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def __getattr__(self, attr):
|
|
66
|
+
return self.__call__
|
|
67
|
+
|
|
68
|
+
def __repr__(self):
|
|
69
|
+
return f"<Unavailable: {self.name} (requires {self.dependency})>"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def __getattr__(name: str):
|
|
73
|
+
"""
|
|
74
|
+
Lazy loader for heavy modules.
|
|
75
|
+
Only imports GPy/emukit-dependent code when actually accessed.
|
|
76
|
+
"""
|
|
77
|
+
# Kernel Summary
|
|
78
|
+
if name in ("summarize_kernel", "format_kernel_tree",
|
|
79
|
+
"interpret_lengthscale", "interpret_variance"):
|
|
80
|
+
if not _GPY_AVAILABLE:
|
|
81
|
+
return _UnavailableFeature(name, "GPy", "full")
|
|
82
|
+
from .kernel_summary import (
|
|
83
|
+
summarize_kernel,
|
|
84
|
+
format_kernel_tree,
|
|
85
|
+
interpret_lengthscale,
|
|
86
|
+
interpret_variance,
|
|
87
|
+
)
|
|
88
|
+
return locals()[name]
|
|
89
|
+
|
|
90
|
+
# Uncertainty Analysis
|
|
91
|
+
if name == "UncertaintyProfiler":
|
|
92
|
+
if not _EMUKIT_AVAILABLE:
|
|
93
|
+
return _UnavailableFeature(name, "emukit", "full")
|
|
94
|
+
from .uncertainty_analysis import UncertaintyProfiler
|
|
95
|
+
return UncertaintyProfiler
|
|
96
|
+
|
|
97
|
+
# Hyperparameter Tracking
|
|
98
|
+
if name == "HyperparameterTracker":
|
|
99
|
+
if not _GPY_AVAILABLE:
|
|
100
|
+
return _UnavailableFeature(name, "GPy", "full")
|
|
101
|
+
from .hyperparam_tracker import HyperparameterTracker
|
|
102
|
+
return HyperparameterTracker
|
|
103
|
+
|
|
104
|
+
# Model Complexity
|
|
105
|
+
if name in ("compute_complexity_score", "count_kernel_components",
|
|
106
|
+
"compute_roughness_score", "compute_noise_ratio"):
|
|
107
|
+
if not _GPY_AVAILABLE:
|
|
108
|
+
return _UnavailableFeature(name, "GPy", "full")
|
|
109
|
+
from .model_complexity import (
|
|
110
|
+
compute_complexity_score,
|
|
111
|
+
count_kernel_components,
|
|
112
|
+
compute_roughness_score,
|
|
113
|
+
compute_noise_ratio,
|
|
114
|
+
)
|
|
115
|
+
return locals()[name]
|
|
116
|
+
|
|
117
|
+
# Data Influence
|
|
118
|
+
if name == "DataInfluenceMap":
|
|
119
|
+
if not _GPY_AVAILABLE:
|
|
120
|
+
return _UnavailableFeature(name, "GPy", "full")
|
|
121
|
+
from .data_influence import DataInfluenceMap
|
|
122
|
+
return DataInfluenceMap
|
|
123
|
+
|
|
124
|
+
# Utilities
|
|
125
|
+
if name in ("check_model_health", "extract_kernel_params_flat",
|
|
126
|
+
"get_lengthscale", "get_noise_variance"):
|
|
127
|
+
if not _GPY_AVAILABLE:
|
|
128
|
+
return _UnavailableFeature(name, "GPy", "full")
|
|
129
|
+
from .utils import (
|
|
130
|
+
check_model_health,
|
|
131
|
+
extract_kernel_params_flat,
|
|
132
|
+
get_lengthscale,
|
|
133
|
+
get_noise_variance,
|
|
134
|
+
)
|
|
135
|
+
return locals()[name]
|
|
136
|
+
|
|
137
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def __dir__():
|
|
141
|
+
"""Enable tab completion for available features."""
|
|
142
|
+
return sorted(__all__)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# Populate __all__ with potentially available features
|
|
146
|
+
__all__.extend([
|
|
147
|
+
# Kernel Summary
|
|
148
|
+
"summarize_kernel",
|
|
149
|
+
"format_kernel_tree",
|
|
150
|
+
"interpret_lengthscale",
|
|
151
|
+
"interpret_variance",
|
|
152
|
+
# Uncertainty Analysis
|
|
153
|
+
"UncertaintyProfiler",
|
|
154
|
+
# Hyperparameter Tracking
|
|
155
|
+
"HyperparameterTracker",
|
|
156
|
+
# Model Complexity
|
|
157
|
+
"compute_complexity_score",
|
|
158
|
+
"count_kernel_components",
|
|
159
|
+
"compute_roughness_score",
|
|
160
|
+
"compute_noise_ratio",
|
|
161
|
+
# Data Influence
|
|
162
|
+
"DataInfluenceMap",
|
|
163
|
+
# Utilities
|
|
164
|
+
"get_lengthscale",
|
|
165
|
+
"get_noise_variance",
|
|
166
|
+
"extract_kernel_params_flat",
|
|
167
|
+
"check_model_health",
|
|
168
|
+
])
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# Warn once if running in limited mode
|
|
172
|
+
if not AVAILABLE["full"]:
|
|
173
|
+
missing = []
|
|
174
|
+
if not _GPY_AVAILABLE:
|
|
175
|
+
missing.append("GPy")
|
|
176
|
+
if not _EMUKIT_AVAILABLE:
|
|
177
|
+
missing.append("emukit")
|
|
178
|
+
|
|
179
|
+
warnings.warn(
|
|
180
|
+
f"gpclarity running in limited mode. Missing: {', '.join(missing)}. "
|
|
181
|
+
f"Install with 'pip install gpclarity[full]' for complete functionality.",
|
|
182
|
+
ImportWarning,
|
|
183
|
+
stacklevel=2,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# Package metadata
|
|
188
|
+
__author__ = "Angad Kumar"
|
|
189
|
+
__email__ = "angadkumar16ak@gmail.com"
|
|
190
|
+
__description__ = "Interpretability and Diagnostics for Gaussian Processes"
|
gpclarity/_version.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data influence analysis: quantify how training points affect model uncertainty.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
import warnings
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
|
12
|
+
|
|
13
|
+
import GPy
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from gpclarity.exceptions import InfluenceError
|
|
17
|
+
from gpclarity.utils import _cholesky_with_jitter, _validate_kernel_matrix
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class InfluenceResult:
|
|
27
|
+
"""Container for influence computation results."""
|
|
28
|
+
scores: np.ndarray
|
|
29
|
+
method: str
|
|
30
|
+
computation_time: float
|
|
31
|
+
n_points: int
|
|
32
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
33
|
+
|
|
34
|
+
def __array__(self):
|
|
35
|
+
"""Allow numpy operations on result directly."""
|
|
36
|
+
return self.scores
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _validate_train_data(func: Callable) -> Callable:
|
|
40
|
+
"""Decorator to standardize training data validation."""
|
|
41
|
+
def wrapper(
|
|
42
|
+
self,
|
|
43
|
+
X_train: np.ndarray,
|
|
44
|
+
y_train: Optional[np.ndarray] = None,
|
|
45
|
+
*args,
|
|
46
|
+
**kwargs
|
|
47
|
+
):
|
|
48
|
+
if X_train is None:
|
|
49
|
+
raise ValueError("X_train cannot be None")
|
|
50
|
+
|
|
51
|
+
if not hasattr(X_train, "shape"):
|
|
52
|
+
raise ValueError("X_train must be array-like with shape attribute")
|
|
53
|
+
|
|
54
|
+
if X_train.ndim != 2:
|
|
55
|
+
raise ValueError(f"X_train must be 2D, got shape {X_train.shape}")
|
|
56
|
+
|
|
57
|
+
if X_train.shape[0] == 0:
|
|
58
|
+
raise ValueError("X_train cannot be empty")
|
|
59
|
+
|
|
60
|
+
if y_train is not None:
|
|
61
|
+
if y_train.shape[0] != X_train.shape[0]:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Shape mismatch: X_train has {X_train.shape[0]} samples, "
|
|
64
|
+
f"y_train has {y_train.shape[0]}"
|
|
65
|
+
)
|
|
66
|
+
if y_train.ndim != 1 and (y_train.ndim != 2 or y_train.shape[1] != 1):
|
|
67
|
+
warnings.warn(
|
|
68
|
+
f"y_train should be 1D, got shape {y_train.shape}. "
|
|
69
|
+
"Flattening automatically.",
|
|
70
|
+
UserWarning
|
|
71
|
+
)
|
|
72
|
+
y_train = y_train.ravel()
|
|
73
|
+
|
|
74
|
+
return func(self, X_train, y_train, *args, **kwargs)
|
|
75
|
+
|
|
76
|
+
return wrapper
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DataInfluenceMap:
|
|
80
|
+
"""
|
|
81
|
+
Compute influence scores for training data points.
|
|
82
|
+
|
|
83
|
+
Identifies which observations most reduce model uncertainty and
|
|
84
|
+
which have high leverage on predictions.
|
|
85
|
+
|
|
86
|
+
Attributes:
|
|
87
|
+
model: Trained GP model
|
|
88
|
+
_cache: Internal cache for expensive computations
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, model: GPy.models.GPRegression):
|
|
92
|
+
"""
|
|
93
|
+
Initialize with GP model.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
model: Trained GP model with 'predict' and 'kern' attributes
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If model lacks required attributes
|
|
100
|
+
"""
|
|
101
|
+
if not hasattr(model, "predict"):
|
|
102
|
+
raise ValueError("Model must have predict() method")
|
|
103
|
+
if not hasattr(model, "kern"):
|
|
104
|
+
raise ValueError("Model must have 'kern' attribute")
|
|
105
|
+
|
|
106
|
+
self.model = model
|
|
107
|
+
self._cache: Dict[str, Any] = {}
|
|
108
|
+
self._cache_key: Optional[str] = None
|
|
109
|
+
|
|
110
|
+
def _get_cache_key(self, X: np.ndarray) -> str:
|
|
111
|
+
"""Generate unique cache key for input data."""
|
|
112
|
+
return f"{X.shape}_{hash(X.tobytes()) % (2**32)}"
|
|
113
|
+
|
|
114
|
+
def _get_cached_kernel(
|
|
115
|
+
self,
|
|
116
|
+
X: np.ndarray,
|
|
117
|
+
noise_var: Optional[float] = None
|
|
118
|
+
) -> Tuple[np.ndarray, np.ndarray, str]:
|
|
119
|
+
"""
|
|
120
|
+
Retrieve or compute kernel matrix with Cholesky decomposition.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Tuple of (K, L, cache_key)
|
|
124
|
+
"""
|
|
125
|
+
cache_key = self._get_cache_key(X)
|
|
126
|
+
|
|
127
|
+
if cache_key == self._cache_key and "K" in self._cache:
|
|
128
|
+
logger.debug("Using cached kernel matrix")
|
|
129
|
+
K = self._cache["K"]
|
|
130
|
+
L = self._cache["L"]
|
|
131
|
+
else:
|
|
132
|
+
logger.debug("Computing new kernel matrix")
|
|
133
|
+
K = self.model.kern.K(X, X)
|
|
134
|
+
_validate_kernel_matrix(K)
|
|
135
|
+
|
|
136
|
+
if noise_var is None:
|
|
137
|
+
noise_var = float(self.model.Gaussian_noise.variance)
|
|
138
|
+
|
|
139
|
+
if not np.isfinite(noise_var) or noise_var < 0:
|
|
140
|
+
logger.warning(f"Invalid noise variance: {noise_var}, using fallback")
|
|
141
|
+
noise_var = 1e-6
|
|
142
|
+
|
|
143
|
+
K_stable = K + np.eye(K.shape[0]) * noise_var
|
|
144
|
+
L = _cholesky_with_jitter(K_stable)
|
|
145
|
+
|
|
146
|
+
# Update cache
|
|
147
|
+
self._cache = {"K": K, "L": L, "noise_var": noise_var}
|
|
148
|
+
self._cache_key = cache_key
|
|
149
|
+
|
|
150
|
+
return K, L, cache_key
|
|
151
|
+
|
|
152
|
+
def clear_cache(self) -> None:
|
|
153
|
+
"""Clear internal computation cache to free memory."""
|
|
154
|
+
self._cache.clear()
|
|
155
|
+
self._cache_key = None
|
|
156
|
+
logger.debug("Cache cleared")
|
|
157
|
+
|
|
158
|
+
@_validate_train_data
|
|
159
|
+
def compute_influence_scores(
|
|
160
|
+
self,
|
|
161
|
+
X_train: np.ndarray,
|
|
162
|
+
*,
|
|
163
|
+
use_cache: bool = True
|
|
164
|
+
) -> InfluenceResult:
|
|
165
|
+
"""
|
|
166
|
+
Compute influence scores using leverage scores (optimized O(n³)).
|
|
167
|
+
|
|
168
|
+
Leverage scores computed via diagonal of hat matrix using cached
|
|
169
|
+
Cholesky decomposition.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
X_train: Training input locations (shape: [n_train, n_dims])
|
|
173
|
+
use_cache: Whether to use internal cache for kernel matrix
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
InfluenceResult with scores and metadata
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
InfluenceError: If computation fails
|
|
180
|
+
"""
|
|
181
|
+
start_time = time.perf_counter()
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
# Get or compute kernel and Cholesky factor
|
|
185
|
+
if use_cache:
|
|
186
|
+
K, L, _ = self._get_cached_kernel(X_train)
|
|
187
|
+
else:
|
|
188
|
+
K = self.model.kern.K(X_train, X_train)
|
|
189
|
+
_validate_kernel_matrix(K)
|
|
190
|
+
noise_var = float(self.model.Gaussian_noise.variance)
|
|
191
|
+
K_stable = K + np.eye(K.shape[0]) * max(noise_var, 1e-6)
|
|
192
|
+
L = _cholesky_with_jitter(K_stable)
|
|
193
|
+
|
|
194
|
+
n = X_train.shape[0]
|
|
195
|
+
|
|
196
|
+
# Optimized leverage score computation: O(n³) instead of O(n⁴)
|
|
197
|
+
# L @ L.T = K, solve for L_inv then K_inv = L_inv.T @ L_inv
|
|
198
|
+
L_inv = np.linalg.solve_triangular(L, np.eye(n), lower=True)
|
|
199
|
+
K_inv = L_inv.T @ L_inv
|
|
200
|
+
scores = 1.0 / np.diag(K_inv)
|
|
201
|
+
|
|
202
|
+
# Handle numerical edge cases
|
|
203
|
+
if not np.all(np.isfinite(scores)):
|
|
204
|
+
n_invalid = np.sum(~np.isfinite(scores))
|
|
205
|
+
logger.warning(f"{n_invalid} influence scores are non-finite, clipping to 0")
|
|
206
|
+
scores = np.where(np.isfinite(scores), scores, 0.0)
|
|
207
|
+
|
|
208
|
+
if np.any(scores < 0):
|
|
209
|
+
n_neg = np.sum(scores < 0)
|
|
210
|
+
logger.warning(f"{n_neg} negative influence scores found, setting to 0")
|
|
211
|
+
scores = np.maximum(scores, 0)
|
|
212
|
+
|
|
213
|
+
computation_time = time.perf_counter() - start_time
|
|
214
|
+
|
|
215
|
+
return InfluenceResult(
|
|
216
|
+
scores=scores,
|
|
217
|
+
method="leverage",
|
|
218
|
+
computation_time=computation_time,
|
|
219
|
+
n_points=n,
|
|
220
|
+
metadata={
|
|
221
|
+
"kernel_type": type(self.model.kern).__name__,
|
|
222
|
+
"noise_variance": float(self.model.Gaussian_noise.variance),
|
|
223
|
+
"cache_used": use_cache,
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
logger.error(f"Failed to compute influence scores: {e}")
|
|
229
|
+
raise InfluenceError(f"Influence computation failed: {e}") from e
|
|
230
|
+
|
|
231
|
+
@_validate_train_data
|
|
232
|
+
def compute_loo_variance_increase(
|
|
233
|
+
self,
|
|
234
|
+
X_train: np.ndarray,
|
|
235
|
+
y_train: np.ndarray,
|
|
236
|
+
*,
|
|
237
|
+
n_jobs: int = 1,
|
|
238
|
+
verbose: bool = False,
|
|
239
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
240
|
+
"""
|
|
241
|
+
Exact Leave-One-Out variance increase with optional parallelization.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
X_train: Training inputs (shape: [n_train, n_dims])
|
|
245
|
+
y_train: Training outputs (shape: [n_train,] or [n_train, 1])
|
|
246
|
+
n_jobs: Number of parallel jobs (-1 for all cores, 1 for sequential)
|
|
247
|
+
verbose: Whether to display progress bar
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Tuple of (variance_increase, prediction_errors)
|
|
251
|
+
|
|
252
|
+
Raises:
|
|
253
|
+
InfluenceError: If computation fails
|
|
254
|
+
"""
|
|
255
|
+
n = X_train.shape[0]
|
|
256
|
+
variance_increase = np.full(n, np.nan)
|
|
257
|
+
prediction_errors = np.full(n, np.nan)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
# Pre-compute full covariance once
|
|
261
|
+
K_full = self.model.kern.K(X_train, X_train)
|
|
262
|
+
_validate_kernel_matrix(K_full)
|
|
263
|
+
|
|
264
|
+
noise_var = float(self.model.Gaussian_noise.variance)
|
|
265
|
+
if not np.isfinite(noise_var):
|
|
266
|
+
raise InfluenceError(f"Invalid noise variance: {noise_var}")
|
|
267
|
+
|
|
268
|
+
K_full = K_full.copy()
|
|
269
|
+
np.fill_diagonal(K_full, np.diag(K_full) + noise_var)
|
|
270
|
+
|
|
271
|
+
# Setup iterator with optional progress bar
|
|
272
|
+
iterator = range(n)
|
|
273
|
+
if verbose:
|
|
274
|
+
try:
|
|
275
|
+
from tqdm import tqdm
|
|
276
|
+
iterator = tqdm(iterator, desc="Computing LOO")
|
|
277
|
+
except ImportError:
|
|
278
|
+
logger.warning("tqdm not installed, progress bar disabled")
|
|
279
|
+
verbose = False
|
|
280
|
+
|
|
281
|
+
# Parallel or sequential execution
|
|
282
|
+
if n_jobs != 1:
|
|
283
|
+
result = self._compute_loo_parallel(
|
|
284
|
+
X_train, y_train, K_full, n_jobs, iterator
|
|
285
|
+
)
|
|
286
|
+
variance_increase, prediction_errors = result
|
|
287
|
+
else:
|
|
288
|
+
for i in iterator:
|
|
289
|
+
result = self._compute_loo_point(
|
|
290
|
+
i, X_train, y_train, K_full
|
|
291
|
+
)
|
|
292
|
+
variance_increase[i], prediction_errors[i] = result
|
|
293
|
+
|
|
294
|
+
return variance_increase, prediction_errors
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.error(f"Failed to compute LOO variance: {e}")
|
|
298
|
+
raise InfluenceError(f"LOO computation failed: {e}") from e
|
|
299
|
+
|
|
300
|
+
def _compute_loo_point(
|
|
301
|
+
self,
|
|
302
|
+
i: int,
|
|
303
|
+
X_train: np.ndarray,
|
|
304
|
+
y_train: np.ndarray,
|
|
305
|
+
K_full: np.ndarray,
|
|
306
|
+
) -> Tuple[float, float]:
|
|
307
|
+
"""Compute LOO metrics for single point."""
|
|
308
|
+
n = X_train.shape[0]
|
|
309
|
+
idx = np.arange(n) != i
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
# Extract sub-matrices
|
|
313
|
+
K_loo = K_full[np.ix_(idx, idx)]
|
|
314
|
+
k_star = K_full[np.ix_([i], idx)][0]
|
|
315
|
+
|
|
316
|
+
# Quick validation
|
|
317
|
+
if np.any(~np.isfinite(K_loo)) or np.any(~np.isfinite(k_star)):
|
|
318
|
+
logger.warning(f"Non-finite values in LOO matrices for point {i}")
|
|
319
|
+
return np.nan, np.nan
|
|
320
|
+
|
|
321
|
+
# Cholesky with jitter if needed
|
|
322
|
+
L = _cholesky_with_jitter(K_loo)
|
|
323
|
+
|
|
324
|
+
# Solve for variance without point i
|
|
325
|
+
v = np.linalg.solve_triangular(L, k_star, lower=True)
|
|
326
|
+
k_inv_k = np.dot(v, v)
|
|
327
|
+
var_without_i = K_full[i, i] - k_inv_k
|
|
328
|
+
|
|
329
|
+
if var_without_i < 0:
|
|
330
|
+
logger.debug(f"Negative variance at point {i}: {var_without_i}, clamping to 0")
|
|
331
|
+
var_without_i = 0
|
|
332
|
+
|
|
333
|
+
# Get model prediction with point i
|
|
334
|
+
mean_with_i, var_with_i = self.model.predict(X_train[i:i+1])
|
|
335
|
+
|
|
336
|
+
var_increase = max(0, var_without_i - float(var_with_i[0, 0]))
|
|
337
|
+
|
|
338
|
+
# Prediction error
|
|
339
|
+
pred_error = abs(float(mean_with_i[0, 0]) - float(y_train[i]))
|
|
340
|
+
|
|
341
|
+
return var_increase, pred_error
|
|
342
|
+
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.debug(f"LOO failed for point {i}: {e}")
|
|
345
|
+
return np.nan, np.nan
|
|
346
|
+
|
|
347
|
+
def _compute_loo_parallel(
|
|
348
|
+
self,
|
|
349
|
+
X_train: np.ndarray,
|
|
350
|
+
y_train: np.ndarray,
|
|
351
|
+
K_full: np.ndarray,
|
|
352
|
+
n_jobs: int,
|
|
353
|
+
iterator,
|
|
354
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
355
|
+
"""Parallel LOO computation using joblib."""
|
|
356
|
+
try:
|
|
357
|
+
from joblib import Parallel, delayed
|
|
358
|
+
except ImportError:
|
|
359
|
+
warnings.warn("joblib not installed, falling back to sequential")
|
|
360
|
+
return self._compute_loo_variance_increase(
|
|
361
|
+
X_train, y_train, n_jobs=1, verbose=False
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
n = X_train.shape[0]
|
|
365
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
366
|
+
delayed(self._compute_loo_point)(i, X_train, y_train, K_full)
|
|
367
|
+
for i in iterator
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
variance_increase = np.array([r[0] for r in results])
|
|
371
|
+
prediction_errors = np.array([r[1] for r in results])
|
|
372
|
+
return variance_increase, prediction_errors
|
|
373
|
+
|
|
374
|
+
def plot_influence(
|
|
375
|
+
self,
|
|
376
|
+
X_train: np.ndarray,
|
|
377
|
+
influence_scores: Union[np.ndarray, InfluenceResult],
|
|
378
|
+
ax: Optional["plt.Axes"] = None,
|
|
379
|
+
**scatter_kwargs,
|
|
380
|
+
) -> "plt.Axes":
|
|
381
|
+
"""
|
|
382
|
+
Visualize data point influence (delegated to plotting module).
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
X_train: Training input locations
|
|
386
|
+
influence_scores: Computed scores or InfluenceResult
|
|
387
|
+
ax: Matplotlib axes (created if None)
|
|
388
|
+
**scatter_kwargs: Additional scatter arguments
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Matplotlib axes object
|
|
392
|
+
|
|
393
|
+
Raises:
|
|
394
|
+
ImportError: If matplotlib not installed
|
|
395
|
+
ValueError: If input dimensions > 2
|
|
396
|
+
"""
|
|
397
|
+
# Deferred import to keep computation module lightweight
|
|
398
|
+
from gpclarity.plotting import plot_influence_map
|
|
399
|
+
|
|
400
|
+
if isinstance(influence_scores, InfluenceResult):
|
|
401
|
+
influence_scores = influence_scores.scores
|
|
402
|
+
|
|
403
|
+
return plot_influence_map(
|
|
404
|
+
X_train, influence_scores, ax=ax, **scatter_kwargs
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
@_validate_train_data
|
|
408
|
+
def get_influence_report(
|
|
409
|
+
self,
|
|
410
|
+
X_train: np.ndarray,
|
|
411
|
+
y_train: np.ndarray,
|
|
412
|
+
*,
|
|
413
|
+
compute_loo: bool = True,
|
|
414
|
+
n_jobs: int = 1,
|
|
415
|
+
) -> Dict[str, Any]:
|
|
416
|
+
"""
|
|
417
|
+
Comprehensive influence analysis report.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
X_train: Training inputs
|
|
421
|
+
y_train: Training outputs
|
|
422
|
+
compute_loo: Whether to include LOO analysis (slow for large n)
|
|
423
|
+
n_jobs: Parallel jobs for LOO computation
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
Dictionary with influence statistics and diagnostics
|
|
427
|
+
"""
|
|
428
|
+
start_time = time.perf_counter()
|
|
429
|
+
|
|
430
|
+
# Leverage scores (fast)
|
|
431
|
+
leverage_result = self.compute_influence_scores(X_train)
|
|
432
|
+
scores = leverage_result.scores
|
|
433
|
+
|
|
434
|
+
# LOO analysis (optional, slower)
|
|
435
|
+
loo_var, loo_err = None, None
|
|
436
|
+
if compute_loo:
|
|
437
|
+
try:
|
|
438
|
+
loo_var, loo_err = self.compute_loo_variance_increase(
|
|
439
|
+
X_train, y_train, n_jobs=n_jobs
|
|
440
|
+
)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
logger.warning(f"LOO computation skipped: {e}")
|
|
443
|
+
|
|
444
|
+
# Compute statistics on finite values
|
|
445
|
+
finite_mask = np.isfinite(scores)
|
|
446
|
+
finite_scores = scores[finite_mask]
|
|
447
|
+
|
|
448
|
+
if len(finite_scores) == 0:
|
|
449
|
+
raise InfluenceError("No finite influence scores available")
|
|
450
|
+
|
|
451
|
+
# Percentile-based diagnostics
|
|
452
|
+
p95 = np.percentile(finite_scores, 95)
|
|
453
|
+
p5 = np.percentile(finite_scores, 5)
|
|
454
|
+
|
|
455
|
+
most_inf_idx = int(np.nanargmax(scores))
|
|
456
|
+
least_inf_idx = int(np.nanargmin(scores))
|
|
457
|
+
|
|
458
|
+
report = {
|
|
459
|
+
"computation_summary": {
|
|
460
|
+
"total_time": time.perf_counter() - start_time,
|
|
461
|
+
"leverage_time": leverage_result.computation_time,
|
|
462
|
+
"n_points": leverage_result.n_points,
|
|
463
|
+
"method": leverage_result.method,
|
|
464
|
+
},
|
|
465
|
+
"influence_scores": {
|
|
466
|
+
"mean": float(np.mean(finite_scores)),
|
|
467
|
+
"std": float(np.std(finite_scores)),
|
|
468
|
+
"median": float(np.median(finite_scores)),
|
|
469
|
+
"max": float(np.max(finite_scores)),
|
|
470
|
+
"min": float(np.min(finite_scores)),
|
|
471
|
+
"p95": float(p95),
|
|
472
|
+
"p5": float(p5),
|
|
473
|
+
},
|
|
474
|
+
"most_influential": {
|
|
475
|
+
"index": most_inf_idx,
|
|
476
|
+
"location": X_train[most_inf_idx].tolist(),
|
|
477
|
+
"score": float(scores[most_inf_idx]),
|
|
478
|
+
},
|
|
479
|
+
"least_influential": {
|
|
480
|
+
"index": least_inf_idx,
|
|
481
|
+
"location": X_train[least_inf_idx].tolist(),
|
|
482
|
+
"score": float(scores[least_inf_idx]),
|
|
483
|
+
},
|
|
484
|
+
"diagnostics": {
|
|
485
|
+
"high_leverage_count": int(np.sum(scores > p95)),
|
|
486
|
+
"low_influence_count": int(np.sum(scores < p5)),
|
|
487
|
+
"non_finite_scores": int(np.sum(~finite_mask)),
|
|
488
|
+
},
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
# Add LOO results if available
|
|
492
|
+
if loo_var is not None and loo_err is not None:
|
|
493
|
+
finite_loo = loo_err[np.isfinite(loo_err)]
|
|
494
|
+
report["loo_analysis"] = {
|
|
495
|
+
"variance_increase": loo_var.tolist(),
|
|
496
|
+
"prediction_errors": loo_err.tolist(),
|
|
497
|
+
"mean_error": float(np.mean(finite_loo)) if len(finite_loo) > 0 else None,
|
|
498
|
+
"max_error": float(np.max(finite_loo)) if len(finite_loo) > 0 else None,
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
return report
|