gpclarity 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpclarity/__init__.py +190 -0
- gpclarity/_version.py +3 -0
- gpclarity/data_influence.py +501 -0
- gpclarity/exceptions.py +46 -0
- gpclarity/hyperparam_tracker.py +718 -0
- gpclarity/kernel_summary.py +285 -0
- gpclarity/model_complexity.py +619 -0
- gpclarity/plotting.py +337 -0
- gpclarity/uncertainty_analysis.py +647 -0
- gpclarity/utils.py +411 -0
- gpclarity-0.0.2.dist-info/METADATA +248 -0
- gpclarity-0.0.2.dist-info/RECORD +14 -0
- gpclarity-0.0.2.dist-info/WHEEL +4 -0
- gpclarity-0.0.2.dist-info/licenses/LICENSE +37 -0
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Kernel interpretation and summarization tools for GPy models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import asdict, dataclass, field
|
|
10
|
+
from enum import Enum, auto
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from gpclarity.exceptions import KernelError
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SmoothnessCategory(Enum):
|
|
21
|
+
"""Categorization of lengthscale interpretations."""
|
|
22
|
+
RAPID_VARIATION = auto()
|
|
23
|
+
MODERATE = auto()
|
|
24
|
+
SMOOTH_TREND = auto()
|
|
25
|
+
|
|
26
|
+
def describe(self) -> str:
|
|
27
|
+
descriptions = {
|
|
28
|
+
SmoothnessCategory.RAPID_VARIATION: "Rapid variation (high frequency)",
|
|
29
|
+
SmoothnessCategory.MODERATE: "Moderate flexibility",
|
|
30
|
+
SmoothnessCategory.SMOOTH_TREND: "Smooth trends (low frequency)",
|
|
31
|
+
}
|
|
32
|
+
return descriptions[self]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class VarianceCategory(Enum):
|
|
36
|
+
"""Categorization of variance interpretations."""
|
|
37
|
+
VERY_LOW = auto()
|
|
38
|
+
MODERATE = auto()
|
|
39
|
+
HIGH = auto()
|
|
40
|
+
|
|
41
|
+
def describe(self, name: str = "signal") -> str:
|
|
42
|
+
return f"{self.name.replace('_', ' ').title()} {name.lower()}"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(frozen=True)
|
|
46
|
+
class LengthscaleThresholds:
|
|
47
|
+
"""Configurable thresholds for lengthscale interpretation."""
|
|
48
|
+
rapid_variation: float = 0.5
|
|
49
|
+
smooth_trend: float = 2.0
|
|
50
|
+
|
|
51
|
+
def __post_init__(self):
|
|
52
|
+
if not 0 < self.rapid_variation < self.smooth_trend:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Thresholds must satisfy 0 < rapid_variation ({self.rapid_variation}) "
|
|
55
|
+
f"< smooth_trend ({self.smooth_trend})"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def categorize(self, lengthscale: float) -> SmoothnessCategory:
|
|
59
|
+
"""Categorize a lengthscale value."""
|
|
60
|
+
if lengthscale < self.rapid_variation:
|
|
61
|
+
return SmoothnessCategory.RAPID_VARIATION
|
|
62
|
+
elif lengthscale > self.smooth_trend:
|
|
63
|
+
return SmoothnessCategory.SMOOTH_TREND
|
|
64
|
+
return SmoothnessCategory.MODERATE
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass(frozen=True)
|
|
68
|
+
class VarianceThresholds:
|
|
69
|
+
"""Configurable thresholds for variance interpretation."""
|
|
70
|
+
very_low: float = 0.01
|
|
71
|
+
high: float = 10.0
|
|
72
|
+
|
|
73
|
+
def __post_init__(self):
|
|
74
|
+
if self.very_low <= 0 or self.high <= 0:
|
|
75
|
+
raise ValueError("Variance thresholds must be positive")
|
|
76
|
+
if self.very_low >= self.high:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"very_low ({self.very_low}) must be < high ({self.high})"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def categorize(self, variance: float) -> VarianceCategory:
|
|
82
|
+
"""Categorize a variance value."""
|
|
83
|
+
if variance < self.very_low:
|
|
84
|
+
return VarianceCategory.VERY_LOW
|
|
85
|
+
elif variance > self.high:
|
|
86
|
+
return VarianceCategory.HIGH
|
|
87
|
+
return VarianceCategory.MODERATE
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class InterpretationConfig:
|
|
92
|
+
"""Complete configuration for kernel interpretation."""
|
|
93
|
+
lengthscale: LengthscaleThresholds = field(
|
|
94
|
+
default_factory=LengthscaleThresholds
|
|
95
|
+
)
|
|
96
|
+
variance: VarianceThresholds = field(default_factory=VarianceThresholds)
|
|
97
|
+
|
|
98
|
+
def __post_init__(self):
|
|
99
|
+
# Ensure proper types
|
|
100
|
+
if isinstance(self.lengthscale, dict):
|
|
101
|
+
self.lengthscale = LengthscaleThresholds(**self.lengthscale)
|
|
102
|
+
if isinstance(self.variance, dict):
|
|
103
|
+
self.variance = VarianceThresholds(**self.variance)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class KernelComponent:
|
|
108
|
+
"""Represents a single kernel component with interpretation."""
|
|
109
|
+
name: str
|
|
110
|
+
kernel_type: str
|
|
111
|
+
path: str
|
|
112
|
+
parameters: Dict[str, Any] = field(default_factory=dict)
|
|
113
|
+
interpretations: Dict[str, str] = field(default_factory=dict)
|
|
114
|
+
is_composite: bool = False
|
|
115
|
+
children: List["KernelComponent"] = field(default_factory=list)
|
|
116
|
+
|
|
117
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
118
|
+
"""Convert to dictionary representation."""
|
|
119
|
+
result = {
|
|
120
|
+
"name": self.name,
|
|
121
|
+
"type": self.kernel_type,
|
|
122
|
+
"parameters": self.parameters,
|
|
123
|
+
"interpretations": self.interpretations,
|
|
124
|
+
}
|
|
125
|
+
if self.children:
|
|
126
|
+
result["children"] = [c.to_dict() for c in self.children]
|
|
127
|
+
return result
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class KernelVisitor(Protocol):
|
|
131
|
+
"""Protocol for kernel tree visitors."""
|
|
132
|
+
def visit(self, kernel: Any, path: str = "") -> Optional[KernelComponent]:
|
|
133
|
+
...
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class KernelInterpreter:
|
|
137
|
+
"""
|
|
138
|
+
Interpretable kernel analysis with pluggable strategies.
|
|
139
|
+
|
|
140
|
+
This class provides the core interpretation logic, separated from
|
|
141
|
+
the high-level API functions.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
# Registry of kernel-specific interpreters
|
|
145
|
+
_interpreters: Dict[str, Callable[[Any, InterpretationConfig], Dict[str, Any]]] = {}
|
|
146
|
+
|
|
147
|
+
def __init__(self, config: Optional[InterpretationConfig] = None):
|
|
148
|
+
self.config = config or InterpretationConfig()
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def register_kernel(cls, kernel_type: str):
|
|
152
|
+
"""Decorator to register interpreter for specific kernel type."""
|
|
153
|
+
def decorator(func: Callable[[Any, InterpretationConfig], Dict[str, Any]]):
|
|
154
|
+
cls._interpreters[kernel_type] = func
|
|
155
|
+
return func
|
|
156
|
+
return decorator
|
|
157
|
+
|
|
158
|
+
def interpret(self, kernel: Any, path: str = "") -> KernelComponent:
|
|
159
|
+
"""
|
|
160
|
+
Interpret a kernel or kernel component.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
kernel: GPy kernel object
|
|
164
|
+
path: Hierarchical path string
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
KernelComponent with interpretations
|
|
168
|
+
"""
|
|
169
|
+
kernel_name = getattr(kernel, 'name', 'unknown')
|
|
170
|
+
kernel_type = type(kernel).__name__
|
|
171
|
+
|
|
172
|
+
# Check for registered handler
|
|
173
|
+
if kernel_type in self._interpreters:
|
|
174
|
+
params = self._interpreters[kernel_type](kernel, self.config)
|
|
175
|
+
else:
|
|
176
|
+
params = self._interpret_generic(kernel)
|
|
177
|
+
|
|
178
|
+
# Build component
|
|
179
|
+
component = KernelComponent(
|
|
180
|
+
name=kernel_name,
|
|
181
|
+
kernel_type=kernel_type,
|
|
182
|
+
path=path,
|
|
183
|
+
parameters=params.get("parameters", {}),
|
|
184
|
+
interpretations=params.get("interpretations", {}),
|
|
185
|
+
is_composite=params.get("is_composite", False),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Handle composite kernels recursively
|
|
189
|
+
if component.is_composite and hasattr(kernel, "parts"):
|
|
190
|
+
for i, part in enumerate(kernel.parts):
|
|
191
|
+
child_path = f"{path}.parts[{i}]" if path else f"parts[{i}]"
|
|
192
|
+
child = self.interpret(part, child_path)
|
|
193
|
+
component.children.append(child)
|
|
194
|
+
|
|
195
|
+
return component
|
|
196
|
+
|
|
197
|
+
def _interpret_generic(self, kernel: Any) -> Dict[str, Any]:
|
|
198
|
+
"""Generic interpretation fallback."""
|
|
199
|
+
result = {"parameters": {}, "interpretations": {}, "is_composite": False}
|
|
200
|
+
|
|
201
|
+
# Extract common parameters
|
|
202
|
+
if hasattr(kernel, "lengthscale"):
|
|
203
|
+
ls = kernel.lengthscale
|
|
204
|
+
ls_values = self._extract_values(ls)
|
|
205
|
+
result["parameters"]["lengthscale"] = ls_values
|
|
206
|
+
|
|
207
|
+
# Interpret ARD lengthscales
|
|
208
|
+
if isinstance(ls_values, (list, np.ndarray)) and len(ls_values) > 1:
|
|
209
|
+
mean_ls = float(np.mean(ls_values))
|
|
210
|
+
range_str = f"[{np.min(ls_values):.2f}, {np.max(ls_values):.2f}]"
|
|
211
|
+
category = self.config.lengthscale.categorize(mean_ls)
|
|
212
|
+
result["interpretations"]["smoothness"] = (
|
|
213
|
+
f"{category.describe()} (ARD, range: {range_str})"
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
ls_float = float(ls_values[0]) if isinstance(ls_values, list) else float(ls_values)
|
|
217
|
+
category = self.config.lengthscale.categorize(ls_float)
|
|
218
|
+
result["interpretations"]["smoothness"] = category.describe()
|
|
219
|
+
|
|
220
|
+
if hasattr(kernel, "variance"):
|
|
221
|
+
var = float(kernel.variance)
|
|
222
|
+
result["parameters"]["variance"] = var
|
|
223
|
+
category = self.config.variance.categorize(var)
|
|
224
|
+
name = "Noise" if self._is_noise_kernel(kernel) else "Signal"
|
|
225
|
+
result["interpretations"]["strength"] = category.describe(name)
|
|
226
|
+
|
|
227
|
+
if hasattr(kernel, "periodicity"):
|
|
228
|
+
per = float(kernel.periodicity)
|
|
229
|
+
result["parameters"]["periodicity"] = per
|
|
230
|
+
result["interpretations"]["pattern"] = f"Periodic (period={per:.2f})"
|
|
231
|
+
|
|
232
|
+
# Check if composite
|
|
233
|
+
if hasattr(kernel, "parts") and kernel.parts:
|
|
234
|
+
result["is_composite"] = True
|
|
235
|
+
|
|
236
|
+
return result
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _extract_values(param: Any) -> Union[float, List[float], np.ndarray]:
|
|
240
|
+
"""Safely extract values from GPy parameter."""
|
|
241
|
+
if param is None:
|
|
242
|
+
return 0.0
|
|
243
|
+
|
|
244
|
+
if hasattr(param, "values"):
|
|
245
|
+
val = param.values
|
|
246
|
+
elif hasattr(param, "param_array"):
|
|
247
|
+
val = param.param_array
|
|
248
|
+
else:
|
|
249
|
+
val = param
|
|
250
|
+
|
|
251
|
+
arr = np.atleast_1d(val)
|
|
252
|
+
if len(arr) == 1:
|
|
253
|
+
return float(arr[0])
|
|
254
|
+
return arr.tolist() if isinstance(arr, np.ndarray) else list(arr)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def _is_noise_kernel(kernel: Any) -> bool:
|
|
258
|
+
"""Determine if kernel represents noise."""
|
|
259
|
+
name = getattr(kernel, 'name', '').lower()
|
|
260
|
+
return any(n in name for n in ['white', 'noise', 'bias'])
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class KernelSummaryFormatter:
|
|
264
|
+
"""Formats kernel summaries for different output formats."""
|
|
265
|
+
|
|
266
|
+
def __init__(self, component: KernelComponent):
|
|
267
|
+
self.root = component
|
|
268
|
+
|
|
269
|
+
def to_text(self, verbose: bool = True) -> str:
|
|
270
|
+
"""Generate human-readable text summary."""
|
|
271
|
+
lines = [
|
|
272
|
+
"\n╔" + "═" * 58 + "╗",
|
|
273
|
+
"║" + " KERNEL SUMMARY".center(58) + "║",
|
|
274
|
+
"╚" + "═" * 58 + "╝\n",
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
# Configuration
|
|
278
|
+
lines.append("Configuration:")
|
|
279
|
+
lines.append(f" Lengthscale: rapid<{self.root.interpretations.get('lengthscale_rapid', 0.5)}, "
|
|
280
|
+
f"smooth>{self.root.interpretations.get('lengthscale_smooth', 2.0)}")
|
|
281
|
+
lines.append(f" Variance: very_low<{self.root.interpretations.get('variance_low', 0.01)}, "
|
|
282
|
+
f"high>{self.root.interpretations.get('variance_high', 10.0)}\n")
|
|
283
|
+
|
|
284
|
+
# Tree structure
|
|
285
|
+
lines.append("Structure:")
|
|
286
|
+
lines.append(self._format_tree(self.root))
|
|
287
|
+
lines.append("")
|
|
288
|
+
|
|
289
|
+
# Detailed components
|
|
290
|
+
lines.append("Components:")
|
|
291
|
+
lines.append(self._format_components(self.root))
|
|
292
|
+
|
|
293
|
+
return "\n".join(lines)
|
|
294
|
+
|
|
295
|
+
def _format_tree(self, component: KernelComponent, prefix: str = "", is_last: bool = True) -> str:
|
|
296
|
+
"""Format tree structure with box-drawing characters."""
|
|
297
|
+
connector = "└── " if is_last else "├── "
|
|
298
|
+
line = prefix + connector + component.kernel_type
|
|
299
|
+
if component.name != component.kernel_type:
|
|
300
|
+
line += f" ({component.name})"
|
|
301
|
+
|
|
302
|
+
lines = [line]
|
|
303
|
+
|
|
304
|
+
if component.children:
|
|
305
|
+
new_prefix = prefix + (" " if is_last else "│ ")
|
|
306
|
+
for i, child in enumerate(component.children):
|
|
307
|
+
is_last_child = (i == len(component.children) - 1)
|
|
308
|
+
lines.append(self._format_tree(child, new_prefix, is_last_child))
|
|
309
|
+
|
|
310
|
+
return "\n".join(lines)
|
|
311
|
+
|
|
312
|
+
def _format_components(self, component: KernelComponent, depth: int = 0) -> str:
|
|
313
|
+
"""Format component details."""
|
|
314
|
+
lines = []
|
|
315
|
+
indent = " " * depth
|
|
316
|
+
|
|
317
|
+
if not component.is_composite or depth == 0:
|
|
318
|
+
lines.append(f"{indent}【{component.kernel_type}】 {component.path}")
|
|
319
|
+
for key, val in component.parameters.items():
|
|
320
|
+
lines.append(f"{indent} ├─ {key}: {val}")
|
|
321
|
+
for key, val in component.interpretations.items():
|
|
322
|
+
lines.append(f"{indent} └─ {val}")
|
|
323
|
+
lines.append("")
|
|
324
|
+
|
|
325
|
+
for child in component.children:
|
|
326
|
+
lines.append(self._format_components(child, depth + 1))
|
|
327
|
+
|
|
328
|
+
return "\n".join(lines)
|
|
329
|
+
|
|
330
|
+
def to_markdown(self) -> str:
|
|
331
|
+
"""Generate Markdown formatted summary."""
|
|
332
|
+
lines = ["# Kernel Summary\n"]
|
|
333
|
+
|
|
334
|
+
def add_component(comp: KernelComponent, level: int = 2):
|
|
335
|
+
header = "#" * level
|
|
336
|
+
lines.append(f"{header} {comp.kernel_type}\n")
|
|
337
|
+
|
|
338
|
+
if comp.parameters:
|
|
339
|
+
lines.append("| Parameter | Value |")
|
|
340
|
+
lines.append("|-----------|-------|")
|
|
341
|
+
for k, v in comp.parameters.items():
|
|
342
|
+
lines.append(f"| {k} | {v} |")
|
|
343
|
+
lines.append("")
|
|
344
|
+
|
|
345
|
+
if comp.interpretations:
|
|
346
|
+
lines.append("**Interpretations:**")
|
|
347
|
+
for k, v in comp.interpretations.items():
|
|
348
|
+
lines.append(f"- **{k}**: {v}")
|
|
349
|
+
lines.append("")
|
|
350
|
+
|
|
351
|
+
for child in comp.children:
|
|
352
|
+
add_component(child, level + 1)
|
|
353
|
+
|
|
354
|
+
add_component(self.root)
|
|
355
|
+
return "\n".join(lines)
|
|
356
|
+
|
|
357
|
+
def to_json(self, indent: int = 2) -> str:
|
|
358
|
+
"""Generate JSON representation."""
|
|
359
|
+
return json.dumps(self.root.to_dict(), indent=indent)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# Register specific kernel interpreters
|
|
363
|
+
@KernelInterpreter.register_kernel("RBF")
|
|
364
|
+
def _interpret_rbf(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
|
|
365
|
+
"""Specialized RBF kernel interpretation."""
|
|
366
|
+
result = {"parameters": {}, "interpretations": {}, "is_composite": False}
|
|
367
|
+
|
|
368
|
+
# RBF-specific: lengthscale is crucial
|
|
369
|
+
ls = float(kernel.lengthscale)
|
|
370
|
+
result["parameters"]["lengthscale"] = ls
|
|
371
|
+
|
|
372
|
+
category = config.lengthscale.categorize(ls)
|
|
373
|
+
if category == SmoothnessCategory.RAPID_VARIATION:
|
|
374
|
+
advice = "Model will fit noise - consider increasing"
|
|
375
|
+
elif category == SmoothnessCategory.SMOOTH_TREND:
|
|
376
|
+
advice = "Model may underfit - consider decreasing"
|
|
377
|
+
else:
|
|
378
|
+
advice = "Well-balanced flexibility"
|
|
379
|
+
|
|
380
|
+
result["interpretations"]["smoothness"] = f"{category.describe()}. {advice}"
|
|
381
|
+
|
|
382
|
+
# Variance
|
|
383
|
+
var = float(kernel.variance)
|
|
384
|
+
result["parameters"]["variance"] = var
|
|
385
|
+
var_cat = config.variance.categorize(var)
|
|
386
|
+
result["interpretations"]["signal_strength"] = var_cat.describe("Signal")
|
|
387
|
+
|
|
388
|
+
return result
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@KernelInterpreter.register_kernel("Linear")
|
|
392
|
+
def _interpret_linear(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
|
|
393
|
+
"""Specialized Linear kernel interpretation."""
|
|
394
|
+
result = {"parameters": {}, "interpretations": {}, "is_composite": False}
|
|
395
|
+
|
|
396
|
+
if hasattr(kernel, "variances"):
|
|
397
|
+
variances = KernelInterpreter._extract_values(kernel.variances)
|
|
398
|
+
result["parameters"]["ARD_variances"] = variances
|
|
399
|
+
active_dims = sum(1 for v in np.atleast_1d(variances) if v > 0.1)
|
|
400
|
+
result["interpretations"]["relevance"] = (
|
|
401
|
+
f"Linear trend in {active_dims}/{len(np.atleast_1d(variances))} dimensions"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
return result
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@KernelInterpreter.register_kernel("PeriodicExponential")
|
|
408
|
+
@KernelInterpreter.register_kernel("PeriodicMatern32")
|
|
409
|
+
@KernelInterpreter.register_kernel("PeriodicMatern52")
|
|
410
|
+
def _interpret_periodic(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
|
|
411
|
+
"""Specialized periodic kernel interpretation."""
|
|
412
|
+
result = {"parameters": {}, "interpretations": {}, "is_composite": False}
|
|
413
|
+
|
|
414
|
+
period = float(kernel.periodicity)
|
|
415
|
+
result["parameters"]["period"] = period
|
|
416
|
+
result["interpretations"]["pattern"] = f"Repeating pattern every {period:.2f} units"
|
|
417
|
+
|
|
418
|
+
ls = float(kernel.lengthscale)
|
|
419
|
+
result["parameters"]["decay_lengthscale"] = ls
|
|
420
|
+
if ls > period * 2:
|
|
421
|
+
result["interpretations"]["stability"] = "Long-range periodic correlations"
|
|
422
|
+
else:
|
|
423
|
+
result["interpretations"]["stability"] = "Local periodic patterns only"
|
|
424
|
+
|
|
425
|
+
return result
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# High-level API functions
|
|
429
|
+
def summarize_kernel(
|
|
430
|
+
model: Any,
|
|
431
|
+
X: Optional[np.ndarray] = None,
|
|
432
|
+
verbose: bool = True,
|
|
433
|
+
config: Optional[InterpretationConfig] = None,
|
|
434
|
+
format: str = "text",
|
|
435
|
+
) -> Union[str, Dict[str, Any]]:
|
|
436
|
+
"""
|
|
437
|
+
Generate comprehensive kernel interpretation.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
model: GPy model with 'kern' attribute
|
|
441
|
+
X: Training data (optional, for context-aware scaling)
|
|
442
|
+
verbose: Print summary if True
|
|
443
|
+
config: Interpretation configuration
|
|
444
|
+
format: Output format ('text', 'markdown', 'json', 'dict')
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Formatted string or dictionary depending on format
|
|
448
|
+
|
|
449
|
+
Raises:
|
|
450
|
+
KernelError: If model invalid or kernel uninterpretable
|
|
451
|
+
"""
|
|
452
|
+
if not hasattr(model, "kern"):
|
|
453
|
+
raise KernelError("Model must have 'kern' attribute")
|
|
454
|
+
|
|
455
|
+
# Auto-adjust config based on data scale if provided
|
|
456
|
+
cfg = config or InterpretationConfig()
|
|
457
|
+
if X is not None:
|
|
458
|
+
cfg = _adapt_config_to_data(cfg, X)
|
|
459
|
+
|
|
460
|
+
# Build interpretation tree
|
|
461
|
+
interpreter = KernelInterpreter(cfg)
|
|
462
|
+
try:
|
|
463
|
+
root = interpreter.interpret(model.kern)
|
|
464
|
+
except Exception as e:
|
|
465
|
+
raise KernelError(f"Failed to interpret kernel: {e}") from e
|
|
466
|
+
|
|
467
|
+
# Format output
|
|
468
|
+
formatter = KernelSummaryFormatter(root)
|
|
469
|
+
|
|
470
|
+
if format == "dict":
|
|
471
|
+
return root.to_dict()
|
|
472
|
+
elif format == "json":
|
|
473
|
+
result = formatter.to_json()
|
|
474
|
+
elif format == "markdown":
|
|
475
|
+
result = formatter.to_markdown()
|
|
476
|
+
else: # text
|
|
477
|
+
result = formatter.to_text()
|
|
478
|
+
|
|
479
|
+
if verbose and format in ("text", "markdown"):
|
|
480
|
+
print(result)
|
|
481
|
+
|
|
482
|
+
return result
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def interpret_lengthscale(
|
|
486
|
+
lengthscale: Union[float, np.ndarray, List[float]],
|
|
487
|
+
config: Optional[LengthscaleThresholds] = None,
|
|
488
|
+
return_category: bool = False,
|
|
489
|
+
) -> Union[str, Tuple[str, SmoothnessCategory]]:
|
|
490
|
+
"""
|
|
491
|
+
Interpret lengthscale magnitude with data-aware thresholds.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
lengthscale: Single value or array of lengthscales
|
|
495
|
+
config: Threshold configuration
|
|
496
|
+
return_category: If True, return (description, category) tuple
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Interpretation string, or tuple if return_category=True
|
|
500
|
+
"""
|
|
501
|
+
cfg = config or LengthscaleThresholds()
|
|
502
|
+
|
|
503
|
+
# Normalize input
|
|
504
|
+
if isinstance(lengthscale, (list, np.ndarray)):
|
|
505
|
+
arr = np.atleast_1d(lengthscale)
|
|
506
|
+
mean_ls = float(np.mean(arr))
|
|
507
|
+
range_str = f"[{np.min(arr):.2f}, {np.max(arr):.2f}]"
|
|
508
|
+
is_ard = len(arr) > 1
|
|
509
|
+
else:
|
|
510
|
+
mean_ls = float(lengthscale)
|
|
511
|
+
range_str = f"{mean_ls:.2f}"
|
|
512
|
+
is_ard = False
|
|
513
|
+
|
|
514
|
+
category = cfg.categorize(mean_ls)
|
|
515
|
+
description = category.describe()
|
|
516
|
+
|
|
517
|
+
if is_ard:
|
|
518
|
+
description += f" (ARD, range: {range_str})"
|
|
519
|
+
else:
|
|
520
|
+
description += f" ({range_str})"
|
|
521
|
+
|
|
522
|
+
if return_category:
|
|
523
|
+
return description, category
|
|
524
|
+
return description
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def interpret_variance(
|
|
528
|
+
variance: float,
|
|
529
|
+
name: str = "Signal",
|
|
530
|
+
config: Optional[VarianceThresholds] = None,
|
|
531
|
+
return_category: bool = False,
|
|
532
|
+
) -> Union[str, Tuple[str, VarianceCategory]]:
|
|
533
|
+
"""
|
|
534
|
+
Interpret variance magnitude with context-aware messaging.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
variance: Variance value
|
|
538
|
+
name: Type of variance ("Signal" or "Noise")
|
|
539
|
+
config: Threshold configuration
|
|
540
|
+
return_category: If True, return (description, category) tuple
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
Interpretation string, or tuple if return_category=True
|
|
544
|
+
"""
|
|
545
|
+
cfg = config or VarianceThresholds()
|
|
546
|
+
category = cfg.categorize(variance)
|
|
547
|
+
description = category.describe(name) + f" (≈{variance:.3f})"
|
|
548
|
+
|
|
549
|
+
if return_category:
|
|
550
|
+
return description, category
|
|
551
|
+
return description
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def format_kernel_tree(model: Any, style: str = "unicode") -> str:
|
|
555
|
+
"""
|
|
556
|
+
Pretty-print kernel tree structure.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
model: GPy model
|
|
560
|
+
style: Output style ('unicode', 'ascii', 'minimal')
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Formatted tree string
|
|
564
|
+
"""
|
|
565
|
+
if not hasattr(model, "kern"):
|
|
566
|
+
raise KernelError("Model must have 'kern' attribute")
|
|
567
|
+
|
|
568
|
+
interpreter = KernelInterpreter()
|
|
569
|
+
root = interpreter.interpret(model.kern)
|
|
570
|
+
|
|
571
|
+
if style == "minimal":
|
|
572
|
+
return root.kernel_type
|
|
573
|
+
|
|
574
|
+
# Use formatter's tree rendering
|
|
575
|
+
formatter = KernelSummaryFormatter(root)
|
|
576
|
+
# Extract just the tree portion
|
|
577
|
+
full_text = formatter.to_text(verbose=False)
|
|
578
|
+
# Find and return just the structure section
|
|
579
|
+
lines = full_text.split("\n")
|
|
580
|
+
start_idx = None
|
|
581
|
+
for i, line in enumerate(lines):
|
|
582
|
+
if "Structure:" in line:
|
|
583
|
+
start_idx = i + 1
|
|
584
|
+
elif start_idx and line.startswith("Components:"):
|
|
585
|
+
return "\n".join(lines[start_idx:i]).strip()
|
|
586
|
+
|
|
587
|
+
return root.kernel_type
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def count_kernel_components(model: Any) -> int:
|
|
591
|
+
"""Count total number of kernel components (leaf nodes)."""
|
|
592
|
+
if not hasattr(model, "kern"):
|
|
593
|
+
return 0
|
|
594
|
+
|
|
595
|
+
def count_leaves(kernel: Any) -> int:
|
|
596
|
+
if hasattr(kernel, "parts") and kernel.parts:
|
|
597
|
+
return sum(count_leaves(k) for k in kernel.parts)
|
|
598
|
+
return 1
|
|
599
|
+
|
|
600
|
+
return count_leaves(model.kern)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def extract_kernel_params_flat(model: Any) -> Dict[str, float]:
|
|
604
|
+
"""
|
|
605
|
+
Extract all kernel parameters as flat dictionary with dotted paths.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
model: GPy model
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
Flat dictionary mapping "path.param" to value
|
|
612
|
+
"""
|
|
613
|
+
if not hasattr(model, "kern"):
|
|
614
|
+
raise KernelError("Model must have 'kern' attribute")
|
|
615
|
+
|
|
616
|
+
params = {}
|
|
617
|
+
|
|
618
|
+
def extract(kernel: Any, path: str = ""):
|
|
619
|
+
current_path = f"{path}.{kernel.name}" if path else kernel.name
|
|
620
|
+
|
|
621
|
+
if hasattr(kernel, "parameters"):
|
|
622
|
+
for param in kernel.parameters:
|
|
623
|
+
param_path = f"{current_path}.{param.name}"
|
|
624
|
+
val = KernelInterpreter._extract_values(param)
|
|
625
|
+
if isinstance(val, list):
|
|
626
|
+
for i, v in enumerate(val):
|
|
627
|
+
params[f"{param_path}[{i}]"] = float(v)
|
|
628
|
+
else:
|
|
629
|
+
params[param_path] = float(val)
|
|
630
|
+
|
|
631
|
+
if hasattr(kernel, "parts") and kernel.parts:
|
|
632
|
+
for i, part in enumerate(kernel.parts):
|
|
633
|
+
extract(part, current_path)
|
|
634
|
+
|
|
635
|
+
extract(model.kern)
|
|
636
|
+
return params
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def get_lengthscale(model: Any, as_dict: bool = False) -> Union[float, Dict[str, float]]:
|
|
640
|
+
"""
|
|
641
|
+
Extract lengthscale(s) from model kernel.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
model: GPy model
|
|
645
|
+
as_dict: Return dictionary with component paths as keys
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Single float, array, or dictionary of lengthscales
|
|
649
|
+
"""
|
|
650
|
+
if not hasattr(model, "kern"):
|
|
651
|
+
raise KernelError("Model must have 'kern' attribute")
|
|
652
|
+
|
|
653
|
+
if as_dict:
|
|
654
|
+
result = {}
|
|
655
|
+
def find_lengthscales(kernel: Any, path: str = ""):
|
|
656
|
+
current = f"{path}.{kernel.name}" if path else kernel.name
|
|
657
|
+
if hasattr(kernel, "lengthscale"):
|
|
658
|
+
result[current] = KernelInterpreter._extract_values(kernel.lengthscale)
|
|
659
|
+
if hasattr(kernel, "parts") and kernel.parts:
|
|
660
|
+
for part in kernel.parts:
|
|
661
|
+
find_lengthscales(part, current)
|
|
662
|
+
find_lengthscales(model.kern)
|
|
663
|
+
return result
|
|
664
|
+
else:
|
|
665
|
+
# Return first found lengthscale
|
|
666
|
+
if hasattr(model.kern, "lengthscale"):
|
|
667
|
+
val = KernelInterpreter._extract_values(model.kern.lengthscale)
|
|
668
|
+
return val[0] if isinstance(val, list) else val
|
|
669
|
+
raise KernelError("Model kernel has no lengthscale attribute")
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def get_noise_variance(model: Any) -> float:
|
|
673
|
+
"""Extract noise variance from GP model."""
|
|
674
|
+
if not hasattr(model, "likelihood"):
|
|
675
|
+
raise KernelError("Model must have 'likelihood' attribute")
|
|
676
|
+
try:
|
|
677
|
+
return float(model.likelihood.variance)
|
|
678
|
+
except Exception as e:
|
|
679
|
+
raise KernelError(f"Could not extract noise variance: {e}") from e
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
# Private utilities
|
|
683
|
+
def _adapt_config_to_data(
|
|
684
|
+
config: InterpretationConfig,
|
|
685
|
+
X: np.ndarray,
|
|
686
|
+
) -> InterpretationConfig:
|
|
687
|
+
"""
|
|
688
|
+
Auto-scale thresholds based on data characteristics.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
config: Base configuration
|
|
692
|
+
X: Training data (n_samples, n_dims)
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Adapted configuration
|
|
696
|
+
"""
|
|
697
|
+
if X.ndim != 2 or X.shape[0] < 2:
|
|
698
|
+
return config
|
|
699
|
+
|
|
700
|
+
# Compute data scale
|
|
701
|
+
ranges = np.ptp(X, axis=0) # Peak-to-peak (max - min)
|
|
702
|
+
median_range = float(np.median(ranges[ranges > 0]))
|
|
703
|
+
|
|
704
|
+
if median_range <= 0:
|
|
705
|
+
return config
|
|
706
|
+
|
|
707
|
+
# Scale thresholds proportionally to data range
|
|
708
|
+
scale_factor = median_range / 2.0 # Assuming standardized data ~2 range
|
|
709
|
+
|
|
710
|
+
new_ls = LengthscaleThresholds(
|
|
711
|
+
rapid_variation=config.lengthscale.rapid_variation * scale_factor,
|
|
712
|
+
smooth_trend=config.lengthscale.smooth_trend * scale_factor,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
return InterpretationConfig(
|
|
716
|
+
lengthscale=new_ls,
|
|
717
|
+
variance=config.variance,
|
|
718
|
+
)
|