gpclarity 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,285 @@
1
+ """
2
+ Kernel interpretation and summarization tools for GPy models.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import GPy
9
+ import numpy as np
10
+
11
+
12
+ @dataclass
13
+ class LengthscaleThresholds:
14
+ """Configurable thresholds for lengthscale interpretation."""
15
+ rapid_variation: float = 0.5
16
+ smooth_trend: float = 2.0
17
+
18
+ def validate(self):
19
+ """Ensure thresholds are logically ordered."""
20
+ if not 0 < self.rapid_variation < self.smooth_trend:
21
+ raise ValueError(
22
+ f"Thresholds must satisfy 0 < rapid_variation ({self.rapid_variation}) "
23
+ f"< smooth_trend ({self.smooth_trend})"
24
+ )
25
+ return self
26
+
27
+
28
+ @dataclass
29
+ class VarianceThresholds:
30
+ """Configurable thresholds for variance interpretation."""
31
+ very_low: float = 0.01
32
+ high: float = 10.0
33
+
34
+ def validate(self):
35
+ """Ensure thresholds are positive."""
36
+ if self.very_low <= 0 or self.high <= 0:
37
+ raise ValueError("Variance thresholds must be positive")
38
+ if self.very_low >= self.high:
39
+ raise ValueError(f"very_low ({self.very_low}) must be < high ({self.high})")
40
+ return self
41
+
42
+
43
+ @dataclass
44
+ class InterpretationConfig:
45
+ """Complete configuration for kernel interpretation."""
46
+ lengthscale: LengthscaleThresholds = None
47
+ variance: VarianceThresholds = None
48
+
49
+ def __post_init__(self):
50
+ if self.lengthscale is None:
51
+ self.lengthscale = LengthscaleThresholds()
52
+ if self.variance is None:
53
+ self.variance = VarianceThresholds()
54
+ self.lengthscale.validate()
55
+ self.variance.validate()
56
+
57
+
58
+ def get_kernel_structure(kern: GPy.kern.Kern) -> Union[str, List[Any]]:
59
+ """
60
+ Recursively parse composite kernel trees into nested structure.
61
+
62
+ Args:
63
+ kern: GPy kernel object
64
+
65
+ Returns:
66
+ Kernel name string or nested list for composite kernels
67
+ """
68
+ if hasattr(kern, "parts") and kern.parts:
69
+ return [get_kernel_structure(k) for k in kern.parts]
70
+ return kern.name
71
+
72
+
73
+ def extract_kernel_params(kern: GPy.kern.Kern) -> Dict[str, float]:
74
+ """
75
+ Extract all hyperparameters with proper handling of constraints and transformations.
76
+
77
+ Args:
78
+ kern: GPy kernel object
79
+
80
+ Returns:
81
+ Dictionary mapping parameter names to values
82
+ """
83
+ params = {}
84
+ for param in kern.parameters:
85
+ # Get raw value before transformation
86
+ if hasattr(param, "values"):
87
+ values = param.param_array
88
+ if hasattr(values, "__iter__") and not isinstance(values, str):
89
+ params[param.name] = [float(v) for v in values]
90
+ else:
91
+ params[param.name] = float(values)
92
+ return params
93
+
94
+
95
+ def interpret_lengthscale(
96
+ lengthscale: Union[float, np.ndarray],
97
+ config: Optional[LengthscaleThresholds] = None
98
+ ) -> str:
99
+ """
100
+ Interpret lengthscale magnitude with data-aware thresholds.
101
+
102
+ Args:
103
+ lengthscale: Single float or array of lengthscales
104
+ config: Threshold configuration (uses defaults if None)
105
+
106
+ Returns:
107
+ Human-readable interpretation string
108
+ """
109
+ cfg = (config or LengthscaleThresholds()).validate()
110
+
111
+ # Handle ARD lengthscales
112
+ if isinstance(lengthscale, (list, np.ndarray)):
113
+ ls_mean = np.mean(lengthscale)
114
+ ls_range = f"range[{np.min(lengthscale):.2f}, {np.max(lengthscale):.2f}]"
115
+ else:
116
+ ls_mean = lengthscale
117
+ ls_range = f"{ls_mean:.2f}"
118
+
119
+ if ls_mean < cfg.rapid_variation:
120
+ return f"Rapid variation (high frequency, {ls_range})"
121
+ elif ls_mean > cfg.smooth_trend:
122
+ return f"Smooth trends (low frequency, {ls_range})"
123
+ else:
124
+ return f"Moderate flexibility ({ls_range})"
125
+
126
+
127
+ def interpret_variance(
128
+ variance: float,
129
+ name: str = "Signal",
130
+ config: Optional[VarianceThresholds] = None
131
+ ) -> str:
132
+ """
133
+ Interpret variance magnitude with context-aware messaging.
134
+
135
+ Args:
136
+ variance: Variance value
137
+ name: Type of variance ("Signal" or "Noise")
138
+ config: Threshold configuration (uses defaults if None)
139
+
140
+ Returns:
141
+ Human-readable interpretation string
142
+ """
143
+ cfg = (config or VarianceThresholds()).validate()
144
+
145
+ if variance < cfg.very_low:
146
+ return f"Very low {name.lower()} (≈{variance:.3f})"
147
+ elif variance > cfg.high:
148
+ return f"High {name.lower()} (≈{variance:.1f})"
149
+ return f"Moderate {name.lower()} (≈{variance:.2f})"
150
+
151
+
152
+ def summarize_kernel(
153
+ model: GPy.models.GPRegression,
154
+ X: Optional[np.ndarray] = None,
155
+ verbose: bool = True,
156
+ config: Optional[InterpretationConfig] = None
157
+ ) -> Dict[str, Any]:
158
+ """
159
+ Generate comprehensive human-readable kernel interpretation.
160
+
161
+ Args:
162
+ model: Trained GPy model
163
+ X: Training data (optional, for context-aware thresholds)
164
+ verbose: Whether to print formatted summary
165
+ config: Interpretation configuration (uses defaults if None)
166
+
167
+ Returns:
168
+ Dictionary with structured interpretation
169
+
170
+ Raises:
171
+ ValueError: If model lacks required attributes
172
+ """
173
+ if not hasattr(model, "kern"):
174
+ raise ValueError("Model must have a 'kern' attribute")
175
+
176
+ cfg = config or InterpretationConfig()
177
+ kernel = model.kern
178
+ structure = get_kernel_structure(kernel)
179
+ params = extract_kernel_params(kernel)
180
+
181
+ # Build interpretation structure
182
+ interpretation = {
183
+ "kernel_structure": structure,
184
+ "components": [],
185
+ "composite": isinstance(structure, list),
186
+ "config": {
187
+ "lengthscale_thresholds": {
188
+ "rapid_variation": cfg.lengthscale.rapid_variation,
189
+ "smooth_trend": cfg.lengthscale.smooth_trend,
190
+ },
191
+ "variance_thresholds": {
192
+ "very_low": cfg.variance.very_low,
193
+ "high": cfg.variance.high,
194
+ }
195
+ }
196
+ }
197
+
198
+ # Parse each component recursively
199
+ def parse_component(kern: GPy.kern.Kern, path: str = ""):
200
+ if hasattr(kern, "parts") and kern.parts:
201
+ for i, part in enumerate(kern.parts):
202
+ parse_component(part, f"{path}.parts[{i}]" if path else f"parts[{i}]")
203
+ else:
204
+ comp = {"type": kern.name, "path": path, "params": {}, "interpretation": {}}
205
+
206
+ if hasattr(kern, "lengthscale"):
207
+ ls = kern.lengthscale
208
+ comp["params"]["lengthscale"] = (
209
+ ls.tolist() if hasattr(ls, "__iter__") else float(ls)
210
+ )
211
+ comp["interpretation"]["smoothness"] = interpret_lengthscale(ls, cfg.lengthscale)
212
+
213
+ if hasattr(kern, "variance"):
214
+ var = float(kern.variance)
215
+ comp["params"]["variance"] = var
216
+ is_noise = "White" in kern.name or "Noise" in kern.name
217
+ comp["interpretation"]["strength"] = interpret_variance(
218
+ var, "Noise" if is_noise else "Signal", cfg.variance
219
+ )
220
+
221
+ if hasattr(kern, "periodicity"):
222
+ comp["params"]["periodicity"] = float(kern.periodicity)
223
+ comp["interpretation"][
224
+ "pattern"
225
+ ] = f"Periodic with period {kern.periodicity.values:.2f}"
226
+
227
+ interpretation["components"].append(comp)
228
+
229
+ parse_component(kernel)
230
+
231
+ # Overall assessment
232
+ if interpretation["composite"]:
233
+ n_components = len(interpretation["components"])
234
+ interpretation["overall"] = f"Composite kernel with {n_components} components"
235
+ else:
236
+ interpretation["overall"] = "Single kernel"
237
+
238
+ # Print formatted summary if requested
239
+ if verbose:
240
+ _print_kernel_summary(interpretation)
241
+
242
+ return interpretation
243
+
244
+
245
+ def _print_kernel_summary(interpretation: Dict[str, Any]):
246
+ """Pretty print kernel summary."""
247
+ print("\n KERNEL SUMMARY")
248
+ print("=" * 50)
249
+ print(f"Structure: {interpretation['kernel_structure']}")
250
+
251
+ # Print config info
252
+ cfg = interpretation.get("config", {})
253
+ if cfg:
254
+ print(f"\nThresholds:")
255
+ ls_cfg = cfg.get("lengthscale_thresholds", {})
256
+ print(f" Lengthscale: rapid<{ls_cfg.get('rapid_variation', 0.5)}, "
257
+ f"smooth>{ls_cfg.get('smooth_trend', 2.0)}")
258
+ var_cfg = cfg.get("variance_thresholds", {})
259
+ print(f" Variance: very_low<{var_cfg.get('very_low', 0.01)}, "
260
+ f"high>{var_cfg.get('high', 10.0)}")
261
+ print()
262
+
263
+ for comp in interpretation["components"]:
264
+ print(f" {comp['type']} ({comp['path']})")
265
+ for key, val in comp["params"].items():
266
+ print(f" └─ {key}: {val}")
267
+ for key, val in comp["interpretation"].items():
268
+ print(f" {val}")
269
+ print()
270
+
271
+
272
+ def format_kernel_tree(model: GPy.models.GPRegression) -> str:
273
+ """
274
+ Pretty-print kernel tree structure using the original kernel names.
275
+ """
276
+ structure = get_kernel_structure(model.kern)
277
+
278
+ def format_node(node, indent=0):
279
+ if isinstance(node, list):
280
+ return "\n".join(format_node(n, indent + 2) for n in node)
281
+ return (
282
+ " " * indent + f"└─ {node}"
283
+ ) # use the name as it comes in (do not update names)
284
+
285
+ return format_node(structure)