gpclarity 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,619 @@
1
+ """
2
+ Model complexity quantification for Gaussian Processes.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from dataclasses import dataclass, field
9
+ from enum import Enum, auto
10
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from gpclarity.exceptions import ComplexityError, KernelError
15
+ from gpclarity.kernel_summary import count_kernel_components, extract_kernel_params_flat
16
+ from gpclarity.utils import _cholesky_with_jitter, _validate_kernel_matrix
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ComplexityCategory(Enum):
22
+ """Categorization of model complexity levels."""
23
+ TOO_SIMPLE = auto()
24
+ SIMPLE = auto()
25
+ MODERATE = auto()
26
+ COMPLEX = auto()
27
+ TOO_COMPLEX = auto()
28
+
29
+ @property
30
+ def description(self) -> str:
31
+ descriptions = {
32
+ ComplexityCategory.TOO_SIMPLE: "Overly simplistic (high underfitting risk)",
33
+ ComplexityCategory.SIMPLE: "Simple model (possible underfitting)",
34
+ ComplexityCategory.MODERATE: "Well-balanced complexity",
35
+ ComplexityCategory.COMPLEX: "Complex model (monitor for overfitting)",
36
+ ComplexityCategory.TOO_COMPLEX: "Overly complex (high overfitting risk)",
37
+ }
38
+ return descriptions[self]
39
+
40
+ @property
41
+ def risk_level(self) -> str:
42
+ levels = {
43
+ ComplexityCategory.TOO_SIMPLE: "HIGH",
44
+ ComplexityCategory.SIMPLE: "MEDIUM",
45
+ ComplexityCategory.MODERATE: "LOW",
46
+ ComplexityCategory.COMPLEX: "MEDIUM",
47
+ ComplexityCategory.TOO_COMPLEX: "HIGH",
48
+ }
49
+ return levels[self]
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class ComplexityThresholds:
54
+ """
55
+ Data-adaptive thresholds for complexity interpretation.
56
+
57
+ Thresholds are in log10(complexity_score) space.
58
+ """
59
+ too_simple: float = -0.5
60
+ simple: float = 0.5
61
+ complex: float = 1.5
62
+ too_complex: float = 2.5
63
+ high_noise_ratio: float = 0.1 # signal/noise < this is noisy
64
+ low_signal_ratio: float = 10.0 # signal/noise > this is dominated by signal
65
+ jitter: float = 1e-10
66
+
67
+ def __post_init__(self):
68
+ # Validate ordering
69
+ thresholds = [self.too_simple, self.simple, self.complex, self.too_complex]
70
+ if not all(t < u for t, u in zip(thresholds, thresholds[1:])):
71
+ raise ValueError("Thresholds must be strictly increasing")
72
+ if self.jitter <= 0:
73
+ raise ValueError("jitter must be positive")
74
+
75
+ def categorize(self, log_score: float) -> ComplexityCategory:
76
+ """Categorize complexity based on log score."""
77
+ if log_score < self.too_simple:
78
+ return ComplexityCategory.TOO_SIMPLE
79
+ elif log_score < self.simple:
80
+ return ComplexityCategory.SIMPLE
81
+ elif log_score < self.complex:
82
+ return ComplexityCategory.MODERATE
83
+ elif log_score < self.too_complex:
84
+ return ComplexityCategory.COMPLEX
85
+ return ComplexityCategory.TOO_COMPLEX
86
+
87
+
88
+ @dataclass
89
+ class ComplexityMetrics:
90
+ """Detailed complexity metrics for a GP model."""
91
+ total_score: float
92
+ log_score: float
93
+ category: ComplexityCategory
94
+ n_parameters: int
95
+ n_kernel_components: int
96
+ roughness_score: float
97
+ signal_noise_ratio: float
98
+ effective_degrees_of_freedom: float
99
+ capacity_ratio: float # DOF / n_samples
100
+ geometric_complexity: float # Alternative metric based on eigenvalues
101
+
102
+ @property
103
+ def is_well_specified(self) -> bool:
104
+ """Check if model complexity is appropriate."""
105
+ return self.category == ComplexityCategory.MODERATE
106
+
107
+ @property
108
+ def risk_factors(self) -> List[str]:
109
+ """Identify specific risk factors."""
110
+ risks = []
111
+ if self.category in (ComplexityCategory.TOO_SIMPLE, ComplexityCategory.SIMPLE):
112
+ risks.append("Underfitting risk: model may be too restrictive")
113
+ if self.category in (ComplexityCategory.COMPLEX, ComplexityCategory.TOO_COMPLEX):
114
+ risks.append("Overfitting risk: model may be too flexible")
115
+ if self.signal_noise_ratio < 0.1:
116
+ risks.append("High noise: predictions may be unreliable")
117
+ if self.capacity_ratio > 0.8:
118
+ risks.append("High capacity: model can memorize training data")
119
+ return risks
120
+
121
+
122
+ class ComplexityScorer(Protocol):
123
+ """Protocol for pluggable complexity scoring strategies."""
124
+ def __call__(self, model: Any, X: np.ndarray, **kwargs) -> float:
125
+ ...
126
+
127
+
128
+ @dataclass
129
+ class ComplexityAnalyzer:
130
+ """
131
+ Configurable complexity analysis with multiple scoring strategies.
132
+
133
+ This class provides the core analysis logic, separated from
134
+ the high-level API functions.
135
+ """
136
+ thresholds: ComplexityThresholds = field(default_factory=ComplexityThresholds)
137
+ scoring_strategy: str = "default"
138
+
139
+ # Registry of scoring strategies
140
+ _strategies: Dict[str, ComplexityScorer] = field(default_factory=dict, repr=False)
141
+
142
+ def __post_init__(self):
143
+ if not self._strategies:
144
+ self._register_default_strategies()
145
+
146
+ def _register_default_strategies(self):
147
+ """Register built-in scoring strategies."""
148
+ self._strategies["default"] = self._default_score
149
+ self._strategies["geometric"] = self._geometric_score
150
+ self._strategies["bayesian"] = self._bayesian_score
151
+
152
+ @staticmethod
153
+ def _default_score(model: Any, X: np.ndarray, **kwargs) -> float:
154
+ """
155
+ Default complexity score based on components, roughness, and SNR.
156
+
157
+ Score = (n_components * roughness) / (SNR * capacity_ratio)
158
+ """
159
+ n_comp = kwargs.get('n_components', 1)
160
+ roughness = kwargs.get('roughness', 1.0)
161
+ snr = kwargs.get('snr', 1.0)
162
+ capacity = kwargs.get('capacity_ratio', 0.5)
163
+ jitter = kwargs.get('jitter', 1e-10)
164
+
165
+ return (n_comp * roughness) / (snr * capacity + jitter)
166
+
167
+ @staticmethod
168
+ def _geometric_score(model: Any, X: np.ndarray, **kwargs) -> float:
169
+ """
170
+ Geometric complexity based on eigenvalue spectrum of kernel matrix.
171
+
172
+ Uses effective rank as complexity measure.
173
+ """
174
+ try:
175
+ K = model.kern.K(X, X)
176
+ eigenvals = np.linalg.eigvalsh(K)
177
+ eigenvals = np.maximum(eigenvals, 0) # Numerical safety
178
+
179
+ # Effective rank (participation ratio)
180
+ if np.sum(eigenvals) > 0:
181
+ effective_rank = (np.sum(eigenvals) ** 2) / (np.sum(eigenvals ** 2) + 1e-10)
182
+ return float(effective_rank / X.shape[0])
183
+ return 1.0
184
+ except Exception as e:
185
+ logger.warning(f"Geometric score failed: {e}")
186
+ return ComplexityAnalyzer._default_score(model, X, **kwargs)
187
+
188
+ @staticmethod
189
+ def _bayesian_score(model: Any, X: np.ndarray, **kwargs) -> float:
190
+ """
191
+ Bayesian model complexity using log marginal likelihood curvature.
192
+
193
+ Approximates complexity as trace of Fisher information.
194
+ """
195
+ try:
196
+ # Approximate using gradient of log-likelihood
197
+ if hasattr(model, 'log_likelihood') and hasattr(model, 'gradient'):
198
+ ll = model.log_likelihood()
199
+ grad = model.gradient
200
+ if grad is not None and len(grad) > 0:
201
+ # Complexity ~ ||gradient|| / |LL| (steep LL = complex)
202
+ complexity = np.linalg.norm(grad) / (abs(ll) + 1.0)
203
+ return float(complexity)
204
+ return ComplexityAnalyzer._default_score(model, X, **kwargs)
205
+ except Exception as e:
206
+ logger.warning(f"Bayesian score failed: {e}")
207
+ return ComplexityAnalyzer._default_score(model, X, **kwargs)
208
+
209
+ def analyze(self, model: Any, X: np.ndarray) -> ComplexityMetrics:
210
+ """
211
+ Perform comprehensive complexity analysis.
212
+
213
+ Args:
214
+ model: GP model with kern and likelihood attributes
215
+ X: Training data (n_samples, n_features)
216
+
217
+ Returns:
218
+ ComplexityMetrics with detailed diagnostics
219
+ """
220
+ if X is None or not hasattr(X, 'shape'):
221
+ raise ComplexityError("X must be a valid array")
222
+ if X.shape[0] == 0:
223
+ raise ComplexityError("X cannot be empty")
224
+
225
+ n_samples = X.shape[0]
226
+
227
+ # Collect component metrics
228
+ n_components = self._count_components(model)
229
+ n_params = self._count_parameters(model)
230
+ roughness = self._compute_roughness(model)
231
+ snr = self._compute_snr(model)
232
+ effective_dof = self._compute_effective_dof(model, X)
233
+ capacity_ratio = effective_dof / n_samples
234
+
235
+ # Compute geometric complexity
236
+ geom_complexity = self._strategies["geometric"](
237
+ model, X, n_components=n_components
238
+ )
239
+
240
+ # Get scoring function
241
+ score_fn = self._strategies.get(
242
+ self.scoring_strategy,
243
+ self._default_score
244
+ )
245
+
246
+ # Compute composite score
247
+ score = score_fn(
248
+ model, X,
249
+ n_components=n_components,
250
+ roughness=roughness,
251
+ snr=snr,
252
+ capacity_ratio=capacity_ratio,
253
+ jitter=self.thresholds.jitter,
254
+ )
255
+
256
+ # Categorize
257
+ log_score = np.log10(max(score, self.thresholds.jitter))
258
+ category = self.thresholds.categorize(log_score)
259
+
260
+ return ComplexityMetrics(
261
+ total_score=float(score),
262
+ log_score=float(log_score),
263
+ category=category,
264
+ n_parameters=n_params,
265
+ n_kernel_components=n_components,
266
+ roughness_score=float(roughness),
267
+ signal_noise_ratio=float(snr),
268
+ effective_degrees_of_freedom=float(effective_dof),
269
+ capacity_ratio=float(capacity_ratio),
270
+ geometric_complexity=float(geom_complexity),
271
+ )
272
+
273
+ def _count_components(self, model: Any) -> int:
274
+ """Safely count kernel components."""
275
+ try:
276
+ return count_kernel_components(model.kern)
277
+ except Exception as e:
278
+ logger.warning(f"Component counting failed: {e}")
279
+ return 1
280
+
281
+ def _count_parameters(self, model: Any) -> int:
282
+ """Count total trainable parameters."""
283
+ try:
284
+ params = extract_kernel_params_flat(model)
285
+ return len(params)
286
+ except Exception as e:
287
+ logger.warning(f"Parameter counting failed: {e}")
288
+ return 0
289
+
290
+ def _compute_roughness(self, model: Any) -> float:
291
+ """Compute function roughness score."""
292
+ try:
293
+ return compute_roughness_score(model.kern)
294
+ except Exception as e:
295
+ logger.warning(f"Roughness computation failed: {e}")
296
+ return 1.0
297
+
298
+ def _compute_snr(self, model: Any) -> float:
299
+ """Compute signal-to-noise ratio."""
300
+ try:
301
+ return compute_noise_ratio(model)
302
+ except Exception as e:
303
+ logger.warning(f"SNR computation failed: {e}")
304
+ return 1.0
305
+
306
+ def _compute_effective_dof(self, model: Any, X: np.ndarray) -> float:
307
+ """
308
+ Compute effective degrees of freedom using trace of hat matrix.
309
+
310
+ For GP regression: DOF = trace(K @ (K + sigma^2 I)^{-1})
311
+ """
312
+ try:
313
+ K = model.kern.K(X, X)
314
+ _validate_kernel_matrix(K)
315
+
316
+ noise_var = 1.0
317
+ if hasattr(model, 'Gaussian_noise') and hasattr(model.Gaussian_noise, 'variance'):
318
+ noise_var = float(model.Gaussian_noise.variance)
319
+ if not np.isfinite(noise_var) or noise_var < 0:
320
+ noise_var = 1.0
321
+
322
+ # Stable computation via eigendecomposition
323
+ eigenvals = np.linalg.eigvalsh(K)
324
+ eigenvals = np.maximum(eigenvals, 0)
325
+
326
+ # DOF = sum(eigenvals / (eigenvals + noise_var))
327
+ dof = np.sum(eigenvals / (eigenvals + noise_var + self.thresholds.jitter))
328
+
329
+ return float(np.clip(dof, 0, X.shape[0]))
330
+
331
+ except Exception as e:
332
+ logger.debug(f"Effective DOF computation failed: {e}")
333
+ # Fallback: use parameter count as proxy
334
+ return float(self._count_parameters(model))
335
+
336
+
337
+ # High-level API functions
338
+ def compute_complexity_score(
339
+ model: Any,
340
+ X: np.ndarray,
341
+ *,
342
+ strategy: str = "default",
343
+ thresholds: Optional[ComplexityThresholds] = None,
344
+ return_diagnostics: bool = False,
345
+ ) -> Union[Dict[str, Any], ComplexityMetrics]:
346
+ """
347
+ Comprehensive model complexity quantification.
348
+
349
+ Args:
350
+ model: Trained GP model with kern and likelihood
351
+ X: Training data (n_samples, n_features)
352
+ strategy: Scoring strategy ('default', 'geometric', 'bayesian')
353
+ thresholds: Custom thresholds (uses defaults if None)
354
+ return_diagnostics: If True, return full ComplexityMetrics object
355
+
356
+ Returns:
357
+ Dictionary summary or ComplexityMetrics object
358
+
359
+ Raises:
360
+ ComplexityError: If analysis fails
361
+ """
362
+ try:
363
+ analyzer = ComplexityAnalyzer(
364
+ thresholds=thresholds or ComplexityThresholds(),
365
+ scoring_strategy=strategy,
366
+ )
367
+
368
+ metrics = analyzer.analyze(model, X)
369
+
370
+ if return_diagnostics:
371
+ return metrics
372
+
373
+ # Build summary dictionary
374
+ result = {
375
+ "score": metrics.total_score,
376
+ "log_score": metrics.log_score,
377
+ "category": metrics.category.name,
378
+ "interpretation": metrics.category.description,
379
+ "risk_level": metrics.category.risk_level,
380
+ "risk_factors": metrics.risk_factors,
381
+ "metrics": {
382
+ "n_parameters": metrics.n_parameters,
383
+ "n_kernel_components": metrics.n_kernel_components,
384
+ "roughness_score": metrics.roughness_score,
385
+ "signal_noise_ratio": metrics.signal_noise_ratio,
386
+ "effective_dof": metrics.effective_degrees_of_freedom,
387
+ "capacity_ratio": metrics.capacity_ratio,
388
+ },
389
+ "recommendations": _generate_recommendations(metrics),
390
+ }
391
+
392
+ return result
393
+
394
+ except ComplexityError:
395
+ raise
396
+ except Exception as e:
397
+ raise ComplexityError(f"Complexity analysis failed: {e}") from e
398
+
399
+
400
+ def compute_roughness_score(kern: Any) -> float:
401
+ """
402
+ Compute function roughness as inverse of characteristic lengthscale.
403
+
404
+ Higher roughness = more wiggly function = higher complexity.
405
+ """
406
+ roughness_values = []
407
+
408
+ def traverse(kernel: Any, path: str = "") -> None:
409
+ """Recursively collect lengthscale-based roughness."""
410
+ kernel_name = getattr(kernel, 'name', 'unknown')
411
+ current_path = f"{path}.{kernel_name}" if path else kernel_name
412
+
413
+ # Check for lengthscale attribute
414
+ if hasattr(kernel, 'lengthscale'):
415
+ try:
416
+ ls = kernel.lengthscale
417
+ if hasattr(ls, 'values'):
418
+ ls_val = ls.values
419
+ elif hasattr(ls, 'param_array'):
420
+ ls_val = ls.param_array
421
+ else:
422
+ ls_val = ls
423
+
424
+ arr = np.atleast_1d(ls_val)
425
+
426
+ # Use harmonic mean for ARD (penalizes small lengthscales more)
427
+ if len(arr) > 1:
428
+ # Filter out invalid values
429
+ valid = arr[np.isfinite(arr) & (arr > 0)]
430
+ if len(valid) > 0:
431
+ hmean = len(valid) / np.sum(1.0 / valid)
432
+ roughness_values.append(1.0 / hmean)
433
+ else:
434
+ ls_float = float(arr[0])
435
+ if np.isfinite(ls_float) and ls_float > 0:
436
+ roughness_values.append(1.0 / ls_float)
437
+
438
+ except Exception as e:
439
+ logger.debug(f"Could not extract lengthscale from {current_path}: {e}")
440
+
441
+ # Recurse into composite kernels
442
+ if hasattr(kernel, 'parts') and kernel.parts:
443
+ for i, part in enumerate(kernel.parts):
444
+ traverse(part, f"{current_path}[{i}]")
445
+
446
+ try:
447
+ traverse(kern)
448
+ except RecursionError:
449
+ raise ComplexityError("Kernel structure too deep (possible circular reference)")
450
+ except Exception as e:
451
+ raise ComplexityError(f"Roughness computation failed: {e}") from e
452
+
453
+ if not roughness_values:
454
+ logger.debug("No lengthscales found, returning unit roughness")
455
+ return 1.0
456
+
457
+ # Return geometric mean of roughness values
458
+ log_roughness = np.mean(np.log(roughness_values))
459
+ return float(np.exp(log_roughness))
460
+
461
+
462
+ def compute_noise_ratio(model: Any) -> float:
463
+ """
464
+ Compute signal-to-noise ratio (variance_signal / variance_noise).
465
+
466
+ Returns:
467
+ SNR value (>1 means signal dominates, <1 means noise dominates)
468
+ """
469
+ try:
470
+ # Extract signal variance from kernel
471
+ signal_var = _extract_signal_variance(model)
472
+
473
+ # Extract noise variance from likelihood
474
+ noise_var = _extract_noise_variance(model)
475
+
476
+ if not np.isfinite(signal_var) or not np.isfinite(noise_var):
477
+ logger.warning("Non-finite variance values detected")
478
+ return 1.0
479
+
480
+ if noise_var <= 0:
481
+ logger.warning("Zero or negative noise variance")
482
+ return 10.0 # Assume high SNR if no noise
483
+
484
+ snr = signal_var / noise_var
485
+
486
+ # Sanity bounds
487
+ return float(np.clip(snr, 1e-6, 1e6))
488
+
489
+ except Exception as e:
490
+ logger.debug(f"SNR computation failed: {e}")
491
+ return 1.0
492
+
493
+
494
+ def _extract_signal_variance(model: Any) -> float:
495
+ """Extract signal variance from model kernel."""
496
+ if not hasattr(model, 'kern'):
497
+ raise ComplexityError("Model has no kernel")
498
+
499
+ kern = model.kern
500
+
501
+ # Try to get variance from kernel
502
+ if hasattr(kern, 'variance'):
503
+ try:
504
+ return float(kern.variance)
505
+ except (TypeError, ValueError):
506
+ pass
507
+
508
+ # For composite kernels, sum variances
509
+ if hasattr(kern, 'parts') and kern.parts:
510
+ total_var = 0.0
511
+ for part in kern.parts:
512
+ if hasattr(part, 'variance'):
513
+ try:
514
+ total_var += float(part.variance)
515
+ except (TypeError, ValueError):
516
+ pass
517
+ if total_var > 0:
518
+ return total_var
519
+
520
+ # Fallback: estimate from kernel diagonal
521
+ try:
522
+ # Sample variance from kernel matrix diagonal
523
+ x_dummy = np.zeros((10, 1)) # Dummy input
524
+ K = kern.K(x_dummy, x_dummy)
525
+ return float(np.mean(np.diag(K)))
526
+ except Exception:
527
+ pass
528
+
529
+ logger.debug("Could not extract signal variance, using default")
530
+ return 1.0
531
+
532
+
533
+ def _extract_noise_variance(model: Any) -> float:
534
+ """Extract noise variance from model likelihood."""
535
+ # Try Gaussian_noise attribute (GPy style)
536
+ if hasattr(model, 'Gaussian_noise'):
537
+ if hasattr(model.Gaussian_noise, 'variance'):
538
+ try:
539
+ return float(model.Gaussian_noise.variance)
540
+ except (TypeError, ValueError):
541
+ pass
542
+
543
+ # Try likelihood attribute (general)
544
+ if hasattr(model, 'likelihood'):
545
+ lik = model.likelihood
546
+ if hasattr(lik, 'variance'):
547
+ try:
548
+ return float(lik.variance)
549
+ except (TypeError, ValueError):
550
+ pass
551
+
552
+ logger.debug("Could not extract noise variance, using default")
553
+ return 0.1
554
+
555
+
556
+ def _generate_recommendations(metrics: ComplexityMetrics) -> List[str]:
557
+ """Generate actionable recommendations based on metrics."""
558
+ recs = []
559
+
560
+ if metrics.category == ComplexityCategory.TOO_SIMPLE:
561
+ recs.extend([
562
+ "Add more kernel components (e.g., RBF + Linear)",
563
+ "Increase kernel flexibility (reduce lengthscale)",
564
+ "Check if model captures all data trends",
565
+ ])
566
+ elif metrics.category == ComplexityCategory.SIMPLE:
567
+ recs.extend([
568
+ "Consider more expressive kernel structure",
569
+ "Verify that lengthscales are appropriate for data",
570
+ ])
571
+ elif metrics.category == ComplexityCategory.COMPLEX:
572
+ recs.extend([
573
+ "Monitor validation performance for overfitting",
574
+ "Consider kernel simplification or regularization",
575
+ "Collect more training data if possible",
576
+ ])
577
+ elif metrics.category == ComplexityCategory.TOO_COMPLEX:
578
+ recs.extend([
579
+ "Simplify kernel structure (remove components)",
580
+ "Add strong priors on hyperparameters",
581
+ "Increase noise variance to regularize",
582
+ "Use sparse approximation methods",
583
+ ])
584
+
585
+ # SNR-specific recommendations
586
+ if metrics.signal_noise_ratio < 0.1:
587
+ recs.append("High noise level: consider denoising preprocessing")
588
+ elif metrics.signal_noise_ratio > 100:
589
+ recs.append("Very clean signal: can use simpler model")
590
+
591
+ # Capacity recommendations
592
+ if metrics.capacity_ratio > 0.9:
593
+ recs.append("Model has capacity to interpolate: risk of overfitting")
594
+
595
+ return recs
596
+
597
+
598
+ # Backwards compatibility wrappers
599
+ def check_variance_reasonable(variance: float, max_val: float = 1e6, min_val: float = 0.0) -> bool:
600
+ """Check if variance is within reasonable bounds."""
601
+ return min_val < variance < max_val
602
+
603
+
604
+ # Convenience function for quick assessment
605
+ def quick_complexity_check(model: Any, X: np.ndarray) -> str:
606
+ """
607
+ One-line complexity assessment.
608
+
609
+ Returns:
610
+ Human-readable complexity assessment string
611
+ """
612
+ try:
613
+ result = compute_complexity_score(model, X)
614
+ cat = result.get('category', 'UNKNOWN')
615
+ interp = result.get('interpretation', '')
616
+ score = result.get('log_score', 0)
617
+ return f"{cat}: {interp} (log-score={score:.2f})"
618
+ except Exception as e:
619
+ return f"Could not assess complexity: {e}"