gpclarity 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpclarity/__init__.py +190 -0
- gpclarity/_version.py +3 -0
- gpclarity/data_influence.py +501 -0
- gpclarity/exceptions.py +46 -0
- gpclarity/hyperparam_tracker.py +718 -0
- gpclarity/kernel_summary.py +285 -0
- gpclarity/model_complexity.py +619 -0
- gpclarity/plotting.py +337 -0
- gpclarity/uncertainty_analysis.py +647 -0
- gpclarity/utils.py +411 -0
- gpclarity-0.0.2.dist-info/METADATA +248 -0
- gpclarity-0.0.2.dist-info/RECORD +14 -0
- gpclarity-0.0.2.dist-info/WHEEL +4 -0
- gpclarity-0.0.2.dist-info/licenses/LICENSE +37 -0
|
@@ -0,0 +1,619 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model complexity quantification for Gaussian Processes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum, auto
|
|
10
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from gpclarity.exceptions import ComplexityError, KernelError
|
|
15
|
+
from gpclarity.kernel_summary import count_kernel_components, extract_kernel_params_flat
|
|
16
|
+
from gpclarity.utils import _cholesky_with_jitter, _validate_kernel_matrix
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ComplexityCategory(Enum):
|
|
22
|
+
"""Categorization of model complexity levels."""
|
|
23
|
+
TOO_SIMPLE = auto()
|
|
24
|
+
SIMPLE = auto()
|
|
25
|
+
MODERATE = auto()
|
|
26
|
+
COMPLEX = auto()
|
|
27
|
+
TOO_COMPLEX = auto()
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def description(self) -> str:
|
|
31
|
+
descriptions = {
|
|
32
|
+
ComplexityCategory.TOO_SIMPLE: "Overly simplistic (high underfitting risk)",
|
|
33
|
+
ComplexityCategory.SIMPLE: "Simple model (possible underfitting)",
|
|
34
|
+
ComplexityCategory.MODERATE: "Well-balanced complexity",
|
|
35
|
+
ComplexityCategory.COMPLEX: "Complex model (monitor for overfitting)",
|
|
36
|
+
ComplexityCategory.TOO_COMPLEX: "Overly complex (high overfitting risk)",
|
|
37
|
+
}
|
|
38
|
+
return descriptions[self]
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def risk_level(self) -> str:
|
|
42
|
+
levels = {
|
|
43
|
+
ComplexityCategory.TOO_SIMPLE: "HIGH",
|
|
44
|
+
ComplexityCategory.SIMPLE: "MEDIUM",
|
|
45
|
+
ComplexityCategory.MODERATE: "LOW",
|
|
46
|
+
ComplexityCategory.COMPLEX: "MEDIUM",
|
|
47
|
+
ComplexityCategory.TOO_COMPLEX: "HIGH",
|
|
48
|
+
}
|
|
49
|
+
return levels[self]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class ComplexityThresholds:
|
|
54
|
+
"""
|
|
55
|
+
Data-adaptive thresholds for complexity interpretation.
|
|
56
|
+
|
|
57
|
+
Thresholds are in log10(complexity_score) space.
|
|
58
|
+
"""
|
|
59
|
+
too_simple: float = -0.5
|
|
60
|
+
simple: float = 0.5
|
|
61
|
+
complex: float = 1.5
|
|
62
|
+
too_complex: float = 2.5
|
|
63
|
+
high_noise_ratio: float = 0.1 # signal/noise < this is noisy
|
|
64
|
+
low_signal_ratio: float = 10.0 # signal/noise > this is dominated by signal
|
|
65
|
+
jitter: float = 1e-10
|
|
66
|
+
|
|
67
|
+
def __post_init__(self):
|
|
68
|
+
# Validate ordering
|
|
69
|
+
thresholds = [self.too_simple, self.simple, self.complex, self.too_complex]
|
|
70
|
+
if not all(t < u for t, u in zip(thresholds, thresholds[1:])):
|
|
71
|
+
raise ValueError("Thresholds must be strictly increasing")
|
|
72
|
+
if self.jitter <= 0:
|
|
73
|
+
raise ValueError("jitter must be positive")
|
|
74
|
+
|
|
75
|
+
def categorize(self, log_score: float) -> ComplexityCategory:
|
|
76
|
+
"""Categorize complexity based on log score."""
|
|
77
|
+
if log_score < self.too_simple:
|
|
78
|
+
return ComplexityCategory.TOO_SIMPLE
|
|
79
|
+
elif log_score < self.simple:
|
|
80
|
+
return ComplexityCategory.SIMPLE
|
|
81
|
+
elif log_score < self.complex:
|
|
82
|
+
return ComplexityCategory.MODERATE
|
|
83
|
+
elif log_score < self.too_complex:
|
|
84
|
+
return ComplexityCategory.COMPLEX
|
|
85
|
+
return ComplexityCategory.TOO_COMPLEX
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class ComplexityMetrics:
|
|
90
|
+
"""Detailed complexity metrics for a GP model."""
|
|
91
|
+
total_score: float
|
|
92
|
+
log_score: float
|
|
93
|
+
category: ComplexityCategory
|
|
94
|
+
n_parameters: int
|
|
95
|
+
n_kernel_components: int
|
|
96
|
+
roughness_score: float
|
|
97
|
+
signal_noise_ratio: float
|
|
98
|
+
effective_degrees_of_freedom: float
|
|
99
|
+
capacity_ratio: float # DOF / n_samples
|
|
100
|
+
geometric_complexity: float # Alternative metric based on eigenvalues
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def is_well_specified(self) -> bool:
|
|
104
|
+
"""Check if model complexity is appropriate."""
|
|
105
|
+
return self.category == ComplexityCategory.MODERATE
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def risk_factors(self) -> List[str]:
|
|
109
|
+
"""Identify specific risk factors."""
|
|
110
|
+
risks = []
|
|
111
|
+
if self.category in (ComplexityCategory.TOO_SIMPLE, ComplexityCategory.SIMPLE):
|
|
112
|
+
risks.append("Underfitting risk: model may be too restrictive")
|
|
113
|
+
if self.category in (ComplexityCategory.COMPLEX, ComplexityCategory.TOO_COMPLEX):
|
|
114
|
+
risks.append("Overfitting risk: model may be too flexible")
|
|
115
|
+
if self.signal_noise_ratio < 0.1:
|
|
116
|
+
risks.append("High noise: predictions may be unreliable")
|
|
117
|
+
if self.capacity_ratio > 0.8:
|
|
118
|
+
risks.append("High capacity: model can memorize training data")
|
|
119
|
+
return risks
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ComplexityScorer(Protocol):
|
|
123
|
+
"""Protocol for pluggable complexity scoring strategies."""
|
|
124
|
+
def __call__(self, model: Any, X: np.ndarray, **kwargs) -> float:
|
|
125
|
+
...
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class ComplexityAnalyzer:
|
|
130
|
+
"""
|
|
131
|
+
Configurable complexity analysis with multiple scoring strategies.
|
|
132
|
+
|
|
133
|
+
This class provides the core analysis logic, separated from
|
|
134
|
+
the high-level API functions.
|
|
135
|
+
"""
|
|
136
|
+
thresholds: ComplexityThresholds = field(default_factory=ComplexityThresholds)
|
|
137
|
+
scoring_strategy: str = "default"
|
|
138
|
+
|
|
139
|
+
# Registry of scoring strategies
|
|
140
|
+
_strategies: Dict[str, ComplexityScorer] = field(default_factory=dict, repr=False)
|
|
141
|
+
|
|
142
|
+
def __post_init__(self):
|
|
143
|
+
if not self._strategies:
|
|
144
|
+
self._register_default_strategies()
|
|
145
|
+
|
|
146
|
+
def _register_default_strategies(self):
|
|
147
|
+
"""Register built-in scoring strategies."""
|
|
148
|
+
self._strategies["default"] = self._default_score
|
|
149
|
+
self._strategies["geometric"] = self._geometric_score
|
|
150
|
+
self._strategies["bayesian"] = self._bayesian_score
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _default_score(model: Any, X: np.ndarray, **kwargs) -> float:
|
|
154
|
+
"""
|
|
155
|
+
Default complexity score based on components, roughness, and SNR.
|
|
156
|
+
|
|
157
|
+
Score = (n_components * roughness) / (SNR * capacity_ratio)
|
|
158
|
+
"""
|
|
159
|
+
n_comp = kwargs.get('n_components', 1)
|
|
160
|
+
roughness = kwargs.get('roughness', 1.0)
|
|
161
|
+
snr = kwargs.get('snr', 1.0)
|
|
162
|
+
capacity = kwargs.get('capacity_ratio', 0.5)
|
|
163
|
+
jitter = kwargs.get('jitter', 1e-10)
|
|
164
|
+
|
|
165
|
+
return (n_comp * roughness) / (snr * capacity + jitter)
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _geometric_score(model: Any, X: np.ndarray, **kwargs) -> float:
|
|
169
|
+
"""
|
|
170
|
+
Geometric complexity based on eigenvalue spectrum of kernel matrix.
|
|
171
|
+
|
|
172
|
+
Uses effective rank as complexity measure.
|
|
173
|
+
"""
|
|
174
|
+
try:
|
|
175
|
+
K = model.kern.K(X, X)
|
|
176
|
+
eigenvals = np.linalg.eigvalsh(K)
|
|
177
|
+
eigenvals = np.maximum(eigenvals, 0) # Numerical safety
|
|
178
|
+
|
|
179
|
+
# Effective rank (participation ratio)
|
|
180
|
+
if np.sum(eigenvals) > 0:
|
|
181
|
+
effective_rank = (np.sum(eigenvals) ** 2) / (np.sum(eigenvals ** 2) + 1e-10)
|
|
182
|
+
return float(effective_rank / X.shape[0])
|
|
183
|
+
return 1.0
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.warning(f"Geometric score failed: {e}")
|
|
186
|
+
return ComplexityAnalyzer._default_score(model, X, **kwargs)
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def _bayesian_score(model: Any, X: np.ndarray, **kwargs) -> float:
|
|
190
|
+
"""
|
|
191
|
+
Bayesian model complexity using log marginal likelihood curvature.
|
|
192
|
+
|
|
193
|
+
Approximates complexity as trace of Fisher information.
|
|
194
|
+
"""
|
|
195
|
+
try:
|
|
196
|
+
# Approximate using gradient of log-likelihood
|
|
197
|
+
if hasattr(model, 'log_likelihood') and hasattr(model, 'gradient'):
|
|
198
|
+
ll = model.log_likelihood()
|
|
199
|
+
grad = model.gradient
|
|
200
|
+
if grad is not None and len(grad) > 0:
|
|
201
|
+
# Complexity ~ ||gradient|| / |LL| (steep LL = complex)
|
|
202
|
+
complexity = np.linalg.norm(grad) / (abs(ll) + 1.0)
|
|
203
|
+
return float(complexity)
|
|
204
|
+
return ComplexityAnalyzer._default_score(model, X, **kwargs)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.warning(f"Bayesian score failed: {e}")
|
|
207
|
+
return ComplexityAnalyzer._default_score(model, X, **kwargs)
|
|
208
|
+
|
|
209
|
+
def analyze(self, model: Any, X: np.ndarray) -> ComplexityMetrics:
|
|
210
|
+
"""
|
|
211
|
+
Perform comprehensive complexity analysis.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
model: GP model with kern and likelihood attributes
|
|
215
|
+
X: Training data (n_samples, n_features)
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
ComplexityMetrics with detailed diagnostics
|
|
219
|
+
"""
|
|
220
|
+
if X is None or not hasattr(X, 'shape'):
|
|
221
|
+
raise ComplexityError("X must be a valid array")
|
|
222
|
+
if X.shape[0] == 0:
|
|
223
|
+
raise ComplexityError("X cannot be empty")
|
|
224
|
+
|
|
225
|
+
n_samples = X.shape[0]
|
|
226
|
+
|
|
227
|
+
# Collect component metrics
|
|
228
|
+
n_components = self._count_components(model)
|
|
229
|
+
n_params = self._count_parameters(model)
|
|
230
|
+
roughness = self._compute_roughness(model)
|
|
231
|
+
snr = self._compute_snr(model)
|
|
232
|
+
effective_dof = self._compute_effective_dof(model, X)
|
|
233
|
+
capacity_ratio = effective_dof / n_samples
|
|
234
|
+
|
|
235
|
+
# Compute geometric complexity
|
|
236
|
+
geom_complexity = self._strategies["geometric"](
|
|
237
|
+
model, X, n_components=n_components
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Get scoring function
|
|
241
|
+
score_fn = self._strategies.get(
|
|
242
|
+
self.scoring_strategy,
|
|
243
|
+
self._default_score
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Compute composite score
|
|
247
|
+
score = score_fn(
|
|
248
|
+
model, X,
|
|
249
|
+
n_components=n_components,
|
|
250
|
+
roughness=roughness,
|
|
251
|
+
snr=snr,
|
|
252
|
+
capacity_ratio=capacity_ratio,
|
|
253
|
+
jitter=self.thresholds.jitter,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Categorize
|
|
257
|
+
log_score = np.log10(max(score, self.thresholds.jitter))
|
|
258
|
+
category = self.thresholds.categorize(log_score)
|
|
259
|
+
|
|
260
|
+
return ComplexityMetrics(
|
|
261
|
+
total_score=float(score),
|
|
262
|
+
log_score=float(log_score),
|
|
263
|
+
category=category,
|
|
264
|
+
n_parameters=n_params,
|
|
265
|
+
n_kernel_components=n_components,
|
|
266
|
+
roughness_score=float(roughness),
|
|
267
|
+
signal_noise_ratio=float(snr),
|
|
268
|
+
effective_degrees_of_freedom=float(effective_dof),
|
|
269
|
+
capacity_ratio=float(capacity_ratio),
|
|
270
|
+
geometric_complexity=float(geom_complexity),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def _count_components(self, model: Any) -> int:
|
|
274
|
+
"""Safely count kernel components."""
|
|
275
|
+
try:
|
|
276
|
+
return count_kernel_components(model.kern)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.warning(f"Component counting failed: {e}")
|
|
279
|
+
return 1
|
|
280
|
+
|
|
281
|
+
def _count_parameters(self, model: Any) -> int:
|
|
282
|
+
"""Count total trainable parameters."""
|
|
283
|
+
try:
|
|
284
|
+
params = extract_kernel_params_flat(model)
|
|
285
|
+
return len(params)
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.warning(f"Parameter counting failed: {e}")
|
|
288
|
+
return 0
|
|
289
|
+
|
|
290
|
+
def _compute_roughness(self, model: Any) -> float:
|
|
291
|
+
"""Compute function roughness score."""
|
|
292
|
+
try:
|
|
293
|
+
return compute_roughness_score(model.kern)
|
|
294
|
+
except Exception as e:
|
|
295
|
+
logger.warning(f"Roughness computation failed: {e}")
|
|
296
|
+
return 1.0
|
|
297
|
+
|
|
298
|
+
def _compute_snr(self, model: Any) -> float:
|
|
299
|
+
"""Compute signal-to-noise ratio."""
|
|
300
|
+
try:
|
|
301
|
+
return compute_noise_ratio(model)
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"SNR computation failed: {e}")
|
|
304
|
+
return 1.0
|
|
305
|
+
|
|
306
|
+
def _compute_effective_dof(self, model: Any, X: np.ndarray) -> float:
|
|
307
|
+
"""
|
|
308
|
+
Compute effective degrees of freedom using trace of hat matrix.
|
|
309
|
+
|
|
310
|
+
For GP regression: DOF = trace(K @ (K + sigma^2 I)^{-1})
|
|
311
|
+
"""
|
|
312
|
+
try:
|
|
313
|
+
K = model.kern.K(X, X)
|
|
314
|
+
_validate_kernel_matrix(K)
|
|
315
|
+
|
|
316
|
+
noise_var = 1.0
|
|
317
|
+
if hasattr(model, 'Gaussian_noise') and hasattr(model.Gaussian_noise, 'variance'):
|
|
318
|
+
noise_var = float(model.Gaussian_noise.variance)
|
|
319
|
+
if not np.isfinite(noise_var) or noise_var < 0:
|
|
320
|
+
noise_var = 1.0
|
|
321
|
+
|
|
322
|
+
# Stable computation via eigendecomposition
|
|
323
|
+
eigenvals = np.linalg.eigvalsh(K)
|
|
324
|
+
eigenvals = np.maximum(eigenvals, 0)
|
|
325
|
+
|
|
326
|
+
# DOF = sum(eigenvals / (eigenvals + noise_var))
|
|
327
|
+
dof = np.sum(eigenvals / (eigenvals + noise_var + self.thresholds.jitter))
|
|
328
|
+
|
|
329
|
+
return float(np.clip(dof, 0, X.shape[0]))
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.debug(f"Effective DOF computation failed: {e}")
|
|
333
|
+
# Fallback: use parameter count as proxy
|
|
334
|
+
return float(self._count_parameters(model))
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# High-level API functions
|
|
338
|
+
def compute_complexity_score(
|
|
339
|
+
model: Any,
|
|
340
|
+
X: np.ndarray,
|
|
341
|
+
*,
|
|
342
|
+
strategy: str = "default",
|
|
343
|
+
thresholds: Optional[ComplexityThresholds] = None,
|
|
344
|
+
return_diagnostics: bool = False,
|
|
345
|
+
) -> Union[Dict[str, Any], ComplexityMetrics]:
|
|
346
|
+
"""
|
|
347
|
+
Comprehensive model complexity quantification.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
model: Trained GP model with kern and likelihood
|
|
351
|
+
X: Training data (n_samples, n_features)
|
|
352
|
+
strategy: Scoring strategy ('default', 'geometric', 'bayesian')
|
|
353
|
+
thresholds: Custom thresholds (uses defaults if None)
|
|
354
|
+
return_diagnostics: If True, return full ComplexityMetrics object
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Dictionary summary or ComplexityMetrics object
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
ComplexityError: If analysis fails
|
|
361
|
+
"""
|
|
362
|
+
try:
|
|
363
|
+
analyzer = ComplexityAnalyzer(
|
|
364
|
+
thresholds=thresholds or ComplexityThresholds(),
|
|
365
|
+
scoring_strategy=strategy,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
metrics = analyzer.analyze(model, X)
|
|
369
|
+
|
|
370
|
+
if return_diagnostics:
|
|
371
|
+
return metrics
|
|
372
|
+
|
|
373
|
+
# Build summary dictionary
|
|
374
|
+
result = {
|
|
375
|
+
"score": metrics.total_score,
|
|
376
|
+
"log_score": metrics.log_score,
|
|
377
|
+
"category": metrics.category.name,
|
|
378
|
+
"interpretation": metrics.category.description,
|
|
379
|
+
"risk_level": metrics.category.risk_level,
|
|
380
|
+
"risk_factors": metrics.risk_factors,
|
|
381
|
+
"metrics": {
|
|
382
|
+
"n_parameters": metrics.n_parameters,
|
|
383
|
+
"n_kernel_components": metrics.n_kernel_components,
|
|
384
|
+
"roughness_score": metrics.roughness_score,
|
|
385
|
+
"signal_noise_ratio": metrics.signal_noise_ratio,
|
|
386
|
+
"effective_dof": metrics.effective_degrees_of_freedom,
|
|
387
|
+
"capacity_ratio": metrics.capacity_ratio,
|
|
388
|
+
},
|
|
389
|
+
"recommendations": _generate_recommendations(metrics),
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
return result
|
|
393
|
+
|
|
394
|
+
except ComplexityError:
|
|
395
|
+
raise
|
|
396
|
+
except Exception as e:
|
|
397
|
+
raise ComplexityError(f"Complexity analysis failed: {e}") from e
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def compute_roughness_score(kern: Any) -> float:
|
|
401
|
+
"""
|
|
402
|
+
Compute function roughness as inverse of characteristic lengthscale.
|
|
403
|
+
|
|
404
|
+
Higher roughness = more wiggly function = higher complexity.
|
|
405
|
+
"""
|
|
406
|
+
roughness_values = []
|
|
407
|
+
|
|
408
|
+
def traverse(kernel: Any, path: str = "") -> None:
|
|
409
|
+
"""Recursively collect lengthscale-based roughness."""
|
|
410
|
+
kernel_name = getattr(kernel, 'name', 'unknown')
|
|
411
|
+
current_path = f"{path}.{kernel_name}" if path else kernel_name
|
|
412
|
+
|
|
413
|
+
# Check for lengthscale attribute
|
|
414
|
+
if hasattr(kernel, 'lengthscale'):
|
|
415
|
+
try:
|
|
416
|
+
ls = kernel.lengthscale
|
|
417
|
+
if hasattr(ls, 'values'):
|
|
418
|
+
ls_val = ls.values
|
|
419
|
+
elif hasattr(ls, 'param_array'):
|
|
420
|
+
ls_val = ls.param_array
|
|
421
|
+
else:
|
|
422
|
+
ls_val = ls
|
|
423
|
+
|
|
424
|
+
arr = np.atleast_1d(ls_val)
|
|
425
|
+
|
|
426
|
+
# Use harmonic mean for ARD (penalizes small lengthscales more)
|
|
427
|
+
if len(arr) > 1:
|
|
428
|
+
# Filter out invalid values
|
|
429
|
+
valid = arr[np.isfinite(arr) & (arr > 0)]
|
|
430
|
+
if len(valid) > 0:
|
|
431
|
+
hmean = len(valid) / np.sum(1.0 / valid)
|
|
432
|
+
roughness_values.append(1.0 / hmean)
|
|
433
|
+
else:
|
|
434
|
+
ls_float = float(arr[0])
|
|
435
|
+
if np.isfinite(ls_float) and ls_float > 0:
|
|
436
|
+
roughness_values.append(1.0 / ls_float)
|
|
437
|
+
|
|
438
|
+
except Exception as e:
|
|
439
|
+
logger.debug(f"Could not extract lengthscale from {current_path}: {e}")
|
|
440
|
+
|
|
441
|
+
# Recurse into composite kernels
|
|
442
|
+
if hasattr(kernel, 'parts') and kernel.parts:
|
|
443
|
+
for i, part in enumerate(kernel.parts):
|
|
444
|
+
traverse(part, f"{current_path}[{i}]")
|
|
445
|
+
|
|
446
|
+
try:
|
|
447
|
+
traverse(kern)
|
|
448
|
+
except RecursionError:
|
|
449
|
+
raise ComplexityError("Kernel structure too deep (possible circular reference)")
|
|
450
|
+
except Exception as e:
|
|
451
|
+
raise ComplexityError(f"Roughness computation failed: {e}") from e
|
|
452
|
+
|
|
453
|
+
if not roughness_values:
|
|
454
|
+
logger.debug("No lengthscales found, returning unit roughness")
|
|
455
|
+
return 1.0
|
|
456
|
+
|
|
457
|
+
# Return geometric mean of roughness values
|
|
458
|
+
log_roughness = np.mean(np.log(roughness_values))
|
|
459
|
+
return float(np.exp(log_roughness))
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def compute_noise_ratio(model: Any) -> float:
|
|
463
|
+
"""
|
|
464
|
+
Compute signal-to-noise ratio (variance_signal / variance_noise).
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
SNR value (>1 means signal dominates, <1 means noise dominates)
|
|
468
|
+
"""
|
|
469
|
+
try:
|
|
470
|
+
# Extract signal variance from kernel
|
|
471
|
+
signal_var = _extract_signal_variance(model)
|
|
472
|
+
|
|
473
|
+
# Extract noise variance from likelihood
|
|
474
|
+
noise_var = _extract_noise_variance(model)
|
|
475
|
+
|
|
476
|
+
if not np.isfinite(signal_var) or not np.isfinite(noise_var):
|
|
477
|
+
logger.warning("Non-finite variance values detected")
|
|
478
|
+
return 1.0
|
|
479
|
+
|
|
480
|
+
if noise_var <= 0:
|
|
481
|
+
logger.warning("Zero or negative noise variance")
|
|
482
|
+
return 10.0 # Assume high SNR if no noise
|
|
483
|
+
|
|
484
|
+
snr = signal_var / noise_var
|
|
485
|
+
|
|
486
|
+
# Sanity bounds
|
|
487
|
+
return float(np.clip(snr, 1e-6, 1e6))
|
|
488
|
+
|
|
489
|
+
except Exception as e:
|
|
490
|
+
logger.debug(f"SNR computation failed: {e}")
|
|
491
|
+
return 1.0
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _extract_signal_variance(model: Any) -> float:
|
|
495
|
+
"""Extract signal variance from model kernel."""
|
|
496
|
+
if not hasattr(model, 'kern'):
|
|
497
|
+
raise ComplexityError("Model has no kernel")
|
|
498
|
+
|
|
499
|
+
kern = model.kern
|
|
500
|
+
|
|
501
|
+
# Try to get variance from kernel
|
|
502
|
+
if hasattr(kern, 'variance'):
|
|
503
|
+
try:
|
|
504
|
+
return float(kern.variance)
|
|
505
|
+
except (TypeError, ValueError):
|
|
506
|
+
pass
|
|
507
|
+
|
|
508
|
+
# For composite kernels, sum variances
|
|
509
|
+
if hasattr(kern, 'parts') and kern.parts:
|
|
510
|
+
total_var = 0.0
|
|
511
|
+
for part in kern.parts:
|
|
512
|
+
if hasattr(part, 'variance'):
|
|
513
|
+
try:
|
|
514
|
+
total_var += float(part.variance)
|
|
515
|
+
except (TypeError, ValueError):
|
|
516
|
+
pass
|
|
517
|
+
if total_var > 0:
|
|
518
|
+
return total_var
|
|
519
|
+
|
|
520
|
+
# Fallback: estimate from kernel diagonal
|
|
521
|
+
try:
|
|
522
|
+
# Sample variance from kernel matrix diagonal
|
|
523
|
+
x_dummy = np.zeros((10, 1)) # Dummy input
|
|
524
|
+
K = kern.K(x_dummy, x_dummy)
|
|
525
|
+
return float(np.mean(np.diag(K)))
|
|
526
|
+
except Exception:
|
|
527
|
+
pass
|
|
528
|
+
|
|
529
|
+
logger.debug("Could not extract signal variance, using default")
|
|
530
|
+
return 1.0
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def _extract_noise_variance(model: Any) -> float:
|
|
534
|
+
"""Extract noise variance from model likelihood."""
|
|
535
|
+
# Try Gaussian_noise attribute (GPy style)
|
|
536
|
+
if hasattr(model, 'Gaussian_noise'):
|
|
537
|
+
if hasattr(model.Gaussian_noise, 'variance'):
|
|
538
|
+
try:
|
|
539
|
+
return float(model.Gaussian_noise.variance)
|
|
540
|
+
except (TypeError, ValueError):
|
|
541
|
+
pass
|
|
542
|
+
|
|
543
|
+
# Try likelihood attribute (general)
|
|
544
|
+
if hasattr(model, 'likelihood'):
|
|
545
|
+
lik = model.likelihood
|
|
546
|
+
if hasattr(lik, 'variance'):
|
|
547
|
+
try:
|
|
548
|
+
return float(lik.variance)
|
|
549
|
+
except (TypeError, ValueError):
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
logger.debug("Could not extract noise variance, using default")
|
|
553
|
+
return 0.1
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def _generate_recommendations(metrics: ComplexityMetrics) -> List[str]:
|
|
557
|
+
"""Generate actionable recommendations based on metrics."""
|
|
558
|
+
recs = []
|
|
559
|
+
|
|
560
|
+
if metrics.category == ComplexityCategory.TOO_SIMPLE:
|
|
561
|
+
recs.extend([
|
|
562
|
+
"Add more kernel components (e.g., RBF + Linear)",
|
|
563
|
+
"Increase kernel flexibility (reduce lengthscale)",
|
|
564
|
+
"Check if model captures all data trends",
|
|
565
|
+
])
|
|
566
|
+
elif metrics.category == ComplexityCategory.SIMPLE:
|
|
567
|
+
recs.extend([
|
|
568
|
+
"Consider more expressive kernel structure",
|
|
569
|
+
"Verify that lengthscales are appropriate for data",
|
|
570
|
+
])
|
|
571
|
+
elif metrics.category == ComplexityCategory.COMPLEX:
|
|
572
|
+
recs.extend([
|
|
573
|
+
"Monitor validation performance for overfitting",
|
|
574
|
+
"Consider kernel simplification or regularization",
|
|
575
|
+
"Collect more training data if possible",
|
|
576
|
+
])
|
|
577
|
+
elif metrics.category == ComplexityCategory.TOO_COMPLEX:
|
|
578
|
+
recs.extend([
|
|
579
|
+
"Simplify kernel structure (remove components)",
|
|
580
|
+
"Add strong priors on hyperparameters",
|
|
581
|
+
"Increase noise variance to regularize",
|
|
582
|
+
"Use sparse approximation methods",
|
|
583
|
+
])
|
|
584
|
+
|
|
585
|
+
# SNR-specific recommendations
|
|
586
|
+
if metrics.signal_noise_ratio < 0.1:
|
|
587
|
+
recs.append("High noise level: consider denoising preprocessing")
|
|
588
|
+
elif metrics.signal_noise_ratio > 100:
|
|
589
|
+
recs.append("Very clean signal: can use simpler model")
|
|
590
|
+
|
|
591
|
+
# Capacity recommendations
|
|
592
|
+
if metrics.capacity_ratio > 0.9:
|
|
593
|
+
recs.append("Model has capacity to interpolate: risk of overfitting")
|
|
594
|
+
|
|
595
|
+
return recs
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
# Backwards compatibility wrappers
|
|
599
|
+
def check_variance_reasonable(variance: float, max_val: float = 1e6, min_val: float = 0.0) -> bool:
|
|
600
|
+
"""Check if variance is within reasonable bounds."""
|
|
601
|
+
return min_val < variance < max_val
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
# Convenience function for quick assessment
|
|
605
|
+
def quick_complexity_check(model: Any, X: np.ndarray) -> str:
|
|
606
|
+
"""
|
|
607
|
+
One-line complexity assessment.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
Human-readable complexity assessment string
|
|
611
|
+
"""
|
|
612
|
+
try:
|
|
613
|
+
result = compute_complexity_score(model, X)
|
|
614
|
+
cat = result.get('category', 'UNKNOWN')
|
|
615
|
+
interp = result.get('interpretation', '')
|
|
616
|
+
score = result.get('log_score', 0)
|
|
617
|
+
return f"{cat}: {interp} (log-score={score:.2f})"
|
|
618
|
+
except Exception as e:
|
|
619
|
+
return f"Could not assess complexity: {e}"
|