gpclarity 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpclarity/__init__.py +190 -0
- gpclarity/_version.py +3 -0
- gpclarity/data_influence.py +501 -0
- gpclarity/exceptions.py +46 -0
- gpclarity/hyperparam_tracker.py +718 -0
- gpclarity/kernel_summary.py +285 -0
- gpclarity/model_complexity.py +619 -0
- gpclarity/plotting.py +337 -0
- gpclarity/uncertainty_analysis.py +647 -0
- gpclarity/utils.py +411 -0
- gpclarity-0.0.2.dist-info/METADATA +248 -0
- gpclarity-0.0.2.dist-info/RECORD +14 -0
- gpclarity-0.0.2.dist-info/WHEEL +4 -0
- gpclarity-0.0.2.dist-info/licenses/LICENSE +37 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Kernel interpretation and summarization tools for GPy models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import GPy
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LengthscaleThresholds:
|
|
14
|
+
"""Configurable thresholds for lengthscale interpretation."""
|
|
15
|
+
rapid_variation: float = 0.5
|
|
16
|
+
smooth_trend: float = 2.0
|
|
17
|
+
|
|
18
|
+
def validate(self):
|
|
19
|
+
"""Ensure thresholds are logically ordered."""
|
|
20
|
+
if not 0 < self.rapid_variation < self.smooth_trend:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Thresholds must satisfy 0 < rapid_variation ({self.rapid_variation}) "
|
|
23
|
+
f"< smooth_trend ({self.smooth_trend})"
|
|
24
|
+
)
|
|
25
|
+
return self
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class VarianceThresholds:
|
|
30
|
+
"""Configurable thresholds for variance interpretation."""
|
|
31
|
+
very_low: float = 0.01
|
|
32
|
+
high: float = 10.0
|
|
33
|
+
|
|
34
|
+
def validate(self):
|
|
35
|
+
"""Ensure thresholds are positive."""
|
|
36
|
+
if self.very_low <= 0 or self.high <= 0:
|
|
37
|
+
raise ValueError("Variance thresholds must be positive")
|
|
38
|
+
if self.very_low >= self.high:
|
|
39
|
+
raise ValueError(f"very_low ({self.very_low}) must be < high ({self.high})")
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class InterpretationConfig:
|
|
45
|
+
"""Complete configuration for kernel interpretation."""
|
|
46
|
+
lengthscale: LengthscaleThresholds = None
|
|
47
|
+
variance: VarianceThresholds = None
|
|
48
|
+
|
|
49
|
+
def __post_init__(self):
|
|
50
|
+
if self.lengthscale is None:
|
|
51
|
+
self.lengthscale = LengthscaleThresholds()
|
|
52
|
+
if self.variance is None:
|
|
53
|
+
self.variance = VarianceThresholds()
|
|
54
|
+
self.lengthscale.validate()
|
|
55
|
+
self.variance.validate()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_kernel_structure(kern: GPy.kern.Kern) -> Union[str, List[Any]]:
|
|
59
|
+
"""
|
|
60
|
+
Recursively parse composite kernel trees into nested structure.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
kern: GPy kernel object
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Kernel name string or nested list for composite kernels
|
|
67
|
+
"""
|
|
68
|
+
if hasattr(kern, "parts") and kern.parts:
|
|
69
|
+
return [get_kernel_structure(k) for k in kern.parts]
|
|
70
|
+
return kern.name
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def extract_kernel_params(kern: GPy.kern.Kern) -> Dict[str, float]:
|
|
74
|
+
"""
|
|
75
|
+
Extract all hyperparameters with proper handling of constraints and transformations.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
kern: GPy kernel object
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary mapping parameter names to values
|
|
82
|
+
"""
|
|
83
|
+
params = {}
|
|
84
|
+
for param in kern.parameters:
|
|
85
|
+
# Get raw value before transformation
|
|
86
|
+
if hasattr(param, "values"):
|
|
87
|
+
values = param.param_array
|
|
88
|
+
if hasattr(values, "__iter__") and not isinstance(values, str):
|
|
89
|
+
params[param.name] = [float(v) for v in values]
|
|
90
|
+
else:
|
|
91
|
+
params[param.name] = float(values)
|
|
92
|
+
return params
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def interpret_lengthscale(
|
|
96
|
+
lengthscale: Union[float, np.ndarray],
|
|
97
|
+
config: Optional[LengthscaleThresholds] = None
|
|
98
|
+
) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Interpret lengthscale magnitude with data-aware thresholds.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
lengthscale: Single float or array of lengthscales
|
|
104
|
+
config: Threshold configuration (uses defaults if None)
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Human-readable interpretation string
|
|
108
|
+
"""
|
|
109
|
+
cfg = (config or LengthscaleThresholds()).validate()
|
|
110
|
+
|
|
111
|
+
# Handle ARD lengthscales
|
|
112
|
+
if isinstance(lengthscale, (list, np.ndarray)):
|
|
113
|
+
ls_mean = np.mean(lengthscale)
|
|
114
|
+
ls_range = f"range[{np.min(lengthscale):.2f}, {np.max(lengthscale):.2f}]"
|
|
115
|
+
else:
|
|
116
|
+
ls_mean = lengthscale
|
|
117
|
+
ls_range = f"{ls_mean:.2f}"
|
|
118
|
+
|
|
119
|
+
if ls_mean < cfg.rapid_variation:
|
|
120
|
+
return f"Rapid variation (high frequency, {ls_range})"
|
|
121
|
+
elif ls_mean > cfg.smooth_trend:
|
|
122
|
+
return f"Smooth trends (low frequency, {ls_range})"
|
|
123
|
+
else:
|
|
124
|
+
return f"Moderate flexibility ({ls_range})"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def interpret_variance(
|
|
128
|
+
variance: float,
|
|
129
|
+
name: str = "Signal",
|
|
130
|
+
config: Optional[VarianceThresholds] = None
|
|
131
|
+
) -> str:
|
|
132
|
+
"""
|
|
133
|
+
Interpret variance magnitude with context-aware messaging.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
variance: Variance value
|
|
137
|
+
name: Type of variance ("Signal" or "Noise")
|
|
138
|
+
config: Threshold configuration (uses defaults if None)
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Human-readable interpretation string
|
|
142
|
+
"""
|
|
143
|
+
cfg = (config or VarianceThresholds()).validate()
|
|
144
|
+
|
|
145
|
+
if variance < cfg.very_low:
|
|
146
|
+
return f"Very low {name.lower()} (≈{variance:.3f})"
|
|
147
|
+
elif variance > cfg.high:
|
|
148
|
+
return f"High {name.lower()} (≈{variance:.1f})"
|
|
149
|
+
return f"Moderate {name.lower()} (≈{variance:.2f})"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def summarize_kernel(
|
|
153
|
+
model: GPy.models.GPRegression,
|
|
154
|
+
X: Optional[np.ndarray] = None,
|
|
155
|
+
verbose: bool = True,
|
|
156
|
+
config: Optional[InterpretationConfig] = None
|
|
157
|
+
) -> Dict[str, Any]:
|
|
158
|
+
"""
|
|
159
|
+
Generate comprehensive human-readable kernel interpretation.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
model: Trained GPy model
|
|
163
|
+
X: Training data (optional, for context-aware thresholds)
|
|
164
|
+
verbose: Whether to print formatted summary
|
|
165
|
+
config: Interpretation configuration (uses defaults if None)
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Dictionary with structured interpretation
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ValueError: If model lacks required attributes
|
|
172
|
+
"""
|
|
173
|
+
if not hasattr(model, "kern"):
|
|
174
|
+
raise ValueError("Model must have a 'kern' attribute")
|
|
175
|
+
|
|
176
|
+
cfg = config or InterpretationConfig()
|
|
177
|
+
kernel = model.kern
|
|
178
|
+
structure = get_kernel_structure(kernel)
|
|
179
|
+
params = extract_kernel_params(kernel)
|
|
180
|
+
|
|
181
|
+
# Build interpretation structure
|
|
182
|
+
interpretation = {
|
|
183
|
+
"kernel_structure": structure,
|
|
184
|
+
"components": [],
|
|
185
|
+
"composite": isinstance(structure, list),
|
|
186
|
+
"config": {
|
|
187
|
+
"lengthscale_thresholds": {
|
|
188
|
+
"rapid_variation": cfg.lengthscale.rapid_variation,
|
|
189
|
+
"smooth_trend": cfg.lengthscale.smooth_trend,
|
|
190
|
+
},
|
|
191
|
+
"variance_thresholds": {
|
|
192
|
+
"very_low": cfg.variance.very_low,
|
|
193
|
+
"high": cfg.variance.high,
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Parse each component recursively
|
|
199
|
+
def parse_component(kern: GPy.kern.Kern, path: str = ""):
|
|
200
|
+
if hasattr(kern, "parts") and kern.parts:
|
|
201
|
+
for i, part in enumerate(kern.parts):
|
|
202
|
+
parse_component(part, f"{path}.parts[{i}]" if path else f"parts[{i}]")
|
|
203
|
+
else:
|
|
204
|
+
comp = {"type": kern.name, "path": path, "params": {}, "interpretation": {}}
|
|
205
|
+
|
|
206
|
+
if hasattr(kern, "lengthscale"):
|
|
207
|
+
ls = kern.lengthscale
|
|
208
|
+
comp["params"]["lengthscale"] = (
|
|
209
|
+
ls.tolist() if hasattr(ls, "__iter__") else float(ls)
|
|
210
|
+
)
|
|
211
|
+
comp["interpretation"]["smoothness"] = interpret_lengthscale(ls, cfg.lengthscale)
|
|
212
|
+
|
|
213
|
+
if hasattr(kern, "variance"):
|
|
214
|
+
var = float(kern.variance)
|
|
215
|
+
comp["params"]["variance"] = var
|
|
216
|
+
is_noise = "White" in kern.name or "Noise" in kern.name
|
|
217
|
+
comp["interpretation"]["strength"] = interpret_variance(
|
|
218
|
+
var, "Noise" if is_noise else "Signal", cfg.variance
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if hasattr(kern, "periodicity"):
|
|
222
|
+
comp["params"]["periodicity"] = float(kern.periodicity)
|
|
223
|
+
comp["interpretation"][
|
|
224
|
+
"pattern"
|
|
225
|
+
] = f"Periodic with period {kern.periodicity.values:.2f}"
|
|
226
|
+
|
|
227
|
+
interpretation["components"].append(comp)
|
|
228
|
+
|
|
229
|
+
parse_component(kernel)
|
|
230
|
+
|
|
231
|
+
# Overall assessment
|
|
232
|
+
if interpretation["composite"]:
|
|
233
|
+
n_components = len(interpretation["components"])
|
|
234
|
+
interpretation["overall"] = f"Composite kernel with {n_components} components"
|
|
235
|
+
else:
|
|
236
|
+
interpretation["overall"] = "Single kernel"
|
|
237
|
+
|
|
238
|
+
# Print formatted summary if requested
|
|
239
|
+
if verbose:
|
|
240
|
+
_print_kernel_summary(interpretation)
|
|
241
|
+
|
|
242
|
+
return interpretation
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _print_kernel_summary(interpretation: Dict[str, Any]):
|
|
246
|
+
"""Pretty print kernel summary."""
|
|
247
|
+
print("\n KERNEL SUMMARY")
|
|
248
|
+
print("=" * 50)
|
|
249
|
+
print(f"Structure: {interpretation['kernel_structure']}")
|
|
250
|
+
|
|
251
|
+
# Print config info
|
|
252
|
+
cfg = interpretation.get("config", {})
|
|
253
|
+
if cfg:
|
|
254
|
+
print(f"\nThresholds:")
|
|
255
|
+
ls_cfg = cfg.get("lengthscale_thresholds", {})
|
|
256
|
+
print(f" Lengthscale: rapid<{ls_cfg.get('rapid_variation', 0.5)}, "
|
|
257
|
+
f"smooth>{ls_cfg.get('smooth_trend', 2.0)}")
|
|
258
|
+
var_cfg = cfg.get("variance_thresholds", {})
|
|
259
|
+
print(f" Variance: very_low<{var_cfg.get('very_low', 0.01)}, "
|
|
260
|
+
f"high>{var_cfg.get('high', 10.0)}")
|
|
261
|
+
print()
|
|
262
|
+
|
|
263
|
+
for comp in interpretation["components"]:
|
|
264
|
+
print(f" {comp['type']} ({comp['path']})")
|
|
265
|
+
for key, val in comp["params"].items():
|
|
266
|
+
print(f" └─ {key}: {val}")
|
|
267
|
+
for key, val in comp["interpretation"].items():
|
|
268
|
+
print(f" {val}")
|
|
269
|
+
print()
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def format_kernel_tree(model: GPy.models.GPRegression) -> str:
|
|
273
|
+
"""
|
|
274
|
+
Pretty-print kernel tree structure using the original kernel names.
|
|
275
|
+
"""
|
|
276
|
+
structure = get_kernel_structure(model.kern)
|
|
277
|
+
|
|
278
|
+
def format_node(node, indent=0):
|
|
279
|
+
if isinstance(node, list):
|
|
280
|
+
return "\n".join(format_node(n, indent + 2) for n in node)
|
|
281
|
+
return (
|
|
282
|
+
" " * indent + f"└─ {node}"
|
|
283
|
+
) # use the name as it comes in (do not update names)
|
|
284
|
+
|
|
285
|
+
return format_node(structure)
|