gpclarity 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,718 @@
1
+ """
2
+ Kernel interpretation and summarization tools for GPy models.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import logging
9
+ from dataclasses import asdict, dataclass, field
10
+ from enum import Enum, auto
11
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Union
12
+
13
+ import numpy as np
14
+
15
+ from gpclarity.exceptions import KernelError
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SmoothnessCategory(Enum):
21
+ """Categorization of lengthscale interpretations."""
22
+ RAPID_VARIATION = auto()
23
+ MODERATE = auto()
24
+ SMOOTH_TREND = auto()
25
+
26
+ def describe(self) -> str:
27
+ descriptions = {
28
+ SmoothnessCategory.RAPID_VARIATION: "Rapid variation (high frequency)",
29
+ SmoothnessCategory.MODERATE: "Moderate flexibility",
30
+ SmoothnessCategory.SMOOTH_TREND: "Smooth trends (low frequency)",
31
+ }
32
+ return descriptions[self]
33
+
34
+
35
+ class VarianceCategory(Enum):
36
+ """Categorization of variance interpretations."""
37
+ VERY_LOW = auto()
38
+ MODERATE = auto()
39
+ HIGH = auto()
40
+
41
+ def describe(self, name: str = "signal") -> str:
42
+ return f"{self.name.replace('_', ' ').title()} {name.lower()}"
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class LengthscaleThresholds:
47
+ """Configurable thresholds for lengthscale interpretation."""
48
+ rapid_variation: float = 0.5
49
+ smooth_trend: float = 2.0
50
+
51
+ def __post_init__(self):
52
+ if not 0 < self.rapid_variation < self.smooth_trend:
53
+ raise ValueError(
54
+ f"Thresholds must satisfy 0 < rapid_variation ({self.rapid_variation}) "
55
+ f"< smooth_trend ({self.smooth_trend})"
56
+ )
57
+
58
+ def categorize(self, lengthscale: float) -> SmoothnessCategory:
59
+ """Categorize a lengthscale value."""
60
+ if lengthscale < self.rapid_variation:
61
+ return SmoothnessCategory.RAPID_VARIATION
62
+ elif lengthscale > self.smooth_trend:
63
+ return SmoothnessCategory.SMOOTH_TREND
64
+ return SmoothnessCategory.MODERATE
65
+
66
+
67
+ @dataclass(frozen=True)
68
+ class VarianceThresholds:
69
+ """Configurable thresholds for variance interpretation."""
70
+ very_low: float = 0.01
71
+ high: float = 10.0
72
+
73
+ def __post_init__(self):
74
+ if self.very_low <= 0 or self.high <= 0:
75
+ raise ValueError("Variance thresholds must be positive")
76
+ if self.very_low >= self.high:
77
+ raise ValueError(
78
+ f"very_low ({self.very_low}) must be < high ({self.high})"
79
+ )
80
+
81
+ def categorize(self, variance: float) -> VarianceCategory:
82
+ """Categorize a variance value."""
83
+ if variance < self.very_low:
84
+ return VarianceCategory.VERY_LOW
85
+ elif variance > self.high:
86
+ return VarianceCategory.HIGH
87
+ return VarianceCategory.MODERATE
88
+
89
+
90
+ @dataclass
91
+ class InterpretationConfig:
92
+ """Complete configuration for kernel interpretation."""
93
+ lengthscale: LengthscaleThresholds = field(
94
+ default_factory=LengthscaleThresholds
95
+ )
96
+ variance: VarianceThresholds = field(default_factory=VarianceThresholds)
97
+
98
+ def __post_init__(self):
99
+ # Ensure proper types
100
+ if isinstance(self.lengthscale, dict):
101
+ self.lengthscale = LengthscaleThresholds(**self.lengthscale)
102
+ if isinstance(self.variance, dict):
103
+ self.variance = VarianceThresholds(**self.variance)
104
+
105
+
106
+ @dataclass
107
+ class KernelComponent:
108
+ """Represents a single kernel component with interpretation."""
109
+ name: str
110
+ kernel_type: str
111
+ path: str
112
+ parameters: Dict[str, Any] = field(default_factory=dict)
113
+ interpretations: Dict[str, str] = field(default_factory=dict)
114
+ is_composite: bool = False
115
+ children: List["KernelComponent"] = field(default_factory=list)
116
+
117
+ def to_dict(self) -> Dict[str, Any]:
118
+ """Convert to dictionary representation."""
119
+ result = {
120
+ "name": self.name,
121
+ "type": self.kernel_type,
122
+ "parameters": self.parameters,
123
+ "interpretations": self.interpretations,
124
+ }
125
+ if self.children:
126
+ result["children"] = [c.to_dict() for c in self.children]
127
+ return result
128
+
129
+
130
+ class KernelVisitor(Protocol):
131
+ """Protocol for kernel tree visitors."""
132
+ def visit(self, kernel: Any, path: str = "") -> Optional[KernelComponent]:
133
+ ...
134
+
135
+
136
+ class KernelInterpreter:
137
+ """
138
+ Interpretable kernel analysis with pluggable strategies.
139
+
140
+ This class provides the core interpretation logic, separated from
141
+ the high-level API functions.
142
+ """
143
+
144
+ # Registry of kernel-specific interpreters
145
+ _interpreters: Dict[str, Callable[[Any, InterpretationConfig], Dict[str, Any]]] = {}
146
+
147
+ def __init__(self, config: Optional[InterpretationConfig] = None):
148
+ self.config = config or InterpretationConfig()
149
+
150
+ @classmethod
151
+ def register_kernel(cls, kernel_type: str):
152
+ """Decorator to register interpreter for specific kernel type."""
153
+ def decorator(func: Callable[[Any, InterpretationConfig], Dict[str, Any]]):
154
+ cls._interpreters[kernel_type] = func
155
+ return func
156
+ return decorator
157
+
158
+ def interpret(self, kernel: Any, path: str = "") -> KernelComponent:
159
+ """
160
+ Interpret a kernel or kernel component.
161
+
162
+ Args:
163
+ kernel: GPy kernel object
164
+ path: Hierarchical path string
165
+
166
+ Returns:
167
+ KernelComponent with interpretations
168
+ """
169
+ kernel_name = getattr(kernel, 'name', 'unknown')
170
+ kernel_type = type(kernel).__name__
171
+
172
+ # Check for registered handler
173
+ if kernel_type in self._interpreters:
174
+ params = self._interpreters[kernel_type](kernel, self.config)
175
+ else:
176
+ params = self._interpret_generic(kernel)
177
+
178
+ # Build component
179
+ component = KernelComponent(
180
+ name=kernel_name,
181
+ kernel_type=kernel_type,
182
+ path=path,
183
+ parameters=params.get("parameters", {}),
184
+ interpretations=params.get("interpretations", {}),
185
+ is_composite=params.get("is_composite", False),
186
+ )
187
+
188
+ # Handle composite kernels recursively
189
+ if component.is_composite and hasattr(kernel, "parts"):
190
+ for i, part in enumerate(kernel.parts):
191
+ child_path = f"{path}.parts[{i}]" if path else f"parts[{i}]"
192
+ child = self.interpret(part, child_path)
193
+ component.children.append(child)
194
+
195
+ return component
196
+
197
+ def _interpret_generic(self, kernel: Any) -> Dict[str, Any]:
198
+ """Generic interpretation fallback."""
199
+ result = {"parameters": {}, "interpretations": {}, "is_composite": False}
200
+
201
+ # Extract common parameters
202
+ if hasattr(kernel, "lengthscale"):
203
+ ls = kernel.lengthscale
204
+ ls_values = self._extract_values(ls)
205
+ result["parameters"]["lengthscale"] = ls_values
206
+
207
+ # Interpret ARD lengthscales
208
+ if isinstance(ls_values, (list, np.ndarray)) and len(ls_values) > 1:
209
+ mean_ls = float(np.mean(ls_values))
210
+ range_str = f"[{np.min(ls_values):.2f}, {np.max(ls_values):.2f}]"
211
+ category = self.config.lengthscale.categorize(mean_ls)
212
+ result["interpretations"]["smoothness"] = (
213
+ f"{category.describe()} (ARD, range: {range_str})"
214
+ )
215
+ else:
216
+ ls_float = float(ls_values[0]) if isinstance(ls_values, list) else float(ls_values)
217
+ category = self.config.lengthscale.categorize(ls_float)
218
+ result["interpretations"]["smoothness"] = category.describe()
219
+
220
+ if hasattr(kernel, "variance"):
221
+ var = float(kernel.variance)
222
+ result["parameters"]["variance"] = var
223
+ category = self.config.variance.categorize(var)
224
+ name = "Noise" if self._is_noise_kernel(kernel) else "Signal"
225
+ result["interpretations"]["strength"] = category.describe(name)
226
+
227
+ if hasattr(kernel, "periodicity"):
228
+ per = float(kernel.periodicity)
229
+ result["parameters"]["periodicity"] = per
230
+ result["interpretations"]["pattern"] = f"Periodic (period={per:.2f})"
231
+
232
+ # Check if composite
233
+ if hasattr(kernel, "parts") and kernel.parts:
234
+ result["is_composite"] = True
235
+
236
+ return result
237
+
238
+ @staticmethod
239
+ def _extract_values(param: Any) -> Union[float, List[float], np.ndarray]:
240
+ """Safely extract values from GPy parameter."""
241
+ if param is None:
242
+ return 0.0
243
+
244
+ if hasattr(param, "values"):
245
+ val = param.values
246
+ elif hasattr(param, "param_array"):
247
+ val = param.param_array
248
+ else:
249
+ val = param
250
+
251
+ arr = np.atleast_1d(val)
252
+ if len(arr) == 1:
253
+ return float(arr[0])
254
+ return arr.tolist() if isinstance(arr, np.ndarray) else list(arr)
255
+
256
+ @staticmethod
257
+ def _is_noise_kernel(kernel: Any) -> bool:
258
+ """Determine if kernel represents noise."""
259
+ name = getattr(kernel, 'name', '').lower()
260
+ return any(n in name for n in ['white', 'noise', 'bias'])
261
+
262
+
263
+ class KernelSummaryFormatter:
264
+ """Formats kernel summaries for different output formats."""
265
+
266
+ def __init__(self, component: KernelComponent):
267
+ self.root = component
268
+
269
+ def to_text(self, verbose: bool = True) -> str:
270
+ """Generate human-readable text summary."""
271
+ lines = [
272
+ "\n╔" + "═" * 58 + "╗",
273
+ "║" + " KERNEL SUMMARY".center(58) + "║",
274
+ "╚" + "═" * 58 + "╝\n",
275
+ ]
276
+
277
+ # Configuration
278
+ lines.append("Configuration:")
279
+ lines.append(f" Lengthscale: rapid<{self.root.interpretations.get('lengthscale_rapid', 0.5)}, "
280
+ f"smooth>{self.root.interpretations.get('lengthscale_smooth', 2.0)}")
281
+ lines.append(f" Variance: very_low<{self.root.interpretations.get('variance_low', 0.01)}, "
282
+ f"high>{self.root.interpretations.get('variance_high', 10.0)}\n")
283
+
284
+ # Tree structure
285
+ lines.append("Structure:")
286
+ lines.append(self._format_tree(self.root))
287
+ lines.append("")
288
+
289
+ # Detailed components
290
+ lines.append("Components:")
291
+ lines.append(self._format_components(self.root))
292
+
293
+ return "\n".join(lines)
294
+
295
+ def _format_tree(self, component: KernelComponent, prefix: str = "", is_last: bool = True) -> str:
296
+ """Format tree structure with box-drawing characters."""
297
+ connector = "└── " if is_last else "├── "
298
+ line = prefix + connector + component.kernel_type
299
+ if component.name != component.kernel_type:
300
+ line += f" ({component.name})"
301
+
302
+ lines = [line]
303
+
304
+ if component.children:
305
+ new_prefix = prefix + (" " if is_last else "│ ")
306
+ for i, child in enumerate(component.children):
307
+ is_last_child = (i == len(component.children) - 1)
308
+ lines.append(self._format_tree(child, new_prefix, is_last_child))
309
+
310
+ return "\n".join(lines)
311
+
312
+ def _format_components(self, component: KernelComponent, depth: int = 0) -> str:
313
+ """Format component details."""
314
+ lines = []
315
+ indent = " " * depth
316
+
317
+ if not component.is_composite or depth == 0:
318
+ lines.append(f"{indent}【{component.kernel_type}】 {component.path}")
319
+ for key, val in component.parameters.items():
320
+ lines.append(f"{indent} ├─ {key}: {val}")
321
+ for key, val in component.interpretations.items():
322
+ lines.append(f"{indent} └─ {val}")
323
+ lines.append("")
324
+
325
+ for child in component.children:
326
+ lines.append(self._format_components(child, depth + 1))
327
+
328
+ return "\n".join(lines)
329
+
330
+ def to_markdown(self) -> str:
331
+ """Generate Markdown formatted summary."""
332
+ lines = ["# Kernel Summary\n"]
333
+
334
+ def add_component(comp: KernelComponent, level: int = 2):
335
+ header = "#" * level
336
+ lines.append(f"{header} {comp.kernel_type}\n")
337
+
338
+ if comp.parameters:
339
+ lines.append("| Parameter | Value |")
340
+ lines.append("|-----------|-------|")
341
+ for k, v in comp.parameters.items():
342
+ lines.append(f"| {k} | {v} |")
343
+ lines.append("")
344
+
345
+ if comp.interpretations:
346
+ lines.append("**Interpretations:**")
347
+ for k, v in comp.interpretations.items():
348
+ lines.append(f"- **{k}**: {v}")
349
+ lines.append("")
350
+
351
+ for child in comp.children:
352
+ add_component(child, level + 1)
353
+
354
+ add_component(self.root)
355
+ return "\n".join(lines)
356
+
357
+ def to_json(self, indent: int = 2) -> str:
358
+ """Generate JSON representation."""
359
+ return json.dumps(self.root.to_dict(), indent=indent)
360
+
361
+
362
+ # Register specific kernel interpreters
363
+ @KernelInterpreter.register_kernel("RBF")
364
+ def _interpret_rbf(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
365
+ """Specialized RBF kernel interpretation."""
366
+ result = {"parameters": {}, "interpretations": {}, "is_composite": False}
367
+
368
+ # RBF-specific: lengthscale is crucial
369
+ ls = float(kernel.lengthscale)
370
+ result["parameters"]["lengthscale"] = ls
371
+
372
+ category = config.lengthscale.categorize(ls)
373
+ if category == SmoothnessCategory.RAPID_VARIATION:
374
+ advice = "Model will fit noise - consider increasing"
375
+ elif category == SmoothnessCategory.SMOOTH_TREND:
376
+ advice = "Model may underfit - consider decreasing"
377
+ else:
378
+ advice = "Well-balanced flexibility"
379
+
380
+ result["interpretations"]["smoothness"] = f"{category.describe()}. {advice}"
381
+
382
+ # Variance
383
+ var = float(kernel.variance)
384
+ result["parameters"]["variance"] = var
385
+ var_cat = config.variance.categorize(var)
386
+ result["interpretations"]["signal_strength"] = var_cat.describe("Signal")
387
+
388
+ return result
389
+
390
+
391
+ @KernelInterpreter.register_kernel("Linear")
392
+ def _interpret_linear(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
393
+ """Specialized Linear kernel interpretation."""
394
+ result = {"parameters": {}, "interpretations": {}, "is_composite": False}
395
+
396
+ if hasattr(kernel, "variances"):
397
+ variances = KernelInterpreter._extract_values(kernel.variances)
398
+ result["parameters"]["ARD_variances"] = variances
399
+ active_dims = sum(1 for v in np.atleast_1d(variances) if v > 0.1)
400
+ result["interpretations"]["relevance"] = (
401
+ f"Linear trend in {active_dims}/{len(np.atleast_1d(variances))} dimensions"
402
+ )
403
+
404
+ return result
405
+
406
+
407
+ @KernelInterpreter.register_kernel("PeriodicExponential")
408
+ @KernelInterpreter.register_kernel("PeriodicMatern32")
409
+ @KernelInterpreter.register_kernel("PeriodicMatern52")
410
+ def _interpret_periodic(kernel: Any, config: InterpretationConfig) -> Dict[str, Any]:
411
+ """Specialized periodic kernel interpretation."""
412
+ result = {"parameters": {}, "interpretations": {}, "is_composite": False}
413
+
414
+ period = float(kernel.periodicity)
415
+ result["parameters"]["period"] = period
416
+ result["interpretations"]["pattern"] = f"Repeating pattern every {period:.2f} units"
417
+
418
+ ls = float(kernel.lengthscale)
419
+ result["parameters"]["decay_lengthscale"] = ls
420
+ if ls > period * 2:
421
+ result["interpretations"]["stability"] = "Long-range periodic correlations"
422
+ else:
423
+ result["interpretations"]["stability"] = "Local periodic patterns only"
424
+
425
+ return result
426
+
427
+
428
+ # High-level API functions
429
+ def summarize_kernel(
430
+ model: Any,
431
+ X: Optional[np.ndarray] = None,
432
+ verbose: bool = True,
433
+ config: Optional[InterpretationConfig] = None,
434
+ format: str = "text",
435
+ ) -> Union[str, Dict[str, Any]]:
436
+ """
437
+ Generate comprehensive kernel interpretation.
438
+
439
+ Args:
440
+ model: GPy model with 'kern' attribute
441
+ X: Training data (optional, for context-aware scaling)
442
+ verbose: Print summary if True
443
+ config: Interpretation configuration
444
+ format: Output format ('text', 'markdown', 'json', 'dict')
445
+
446
+ Returns:
447
+ Formatted string or dictionary depending on format
448
+
449
+ Raises:
450
+ KernelError: If model invalid or kernel uninterpretable
451
+ """
452
+ if not hasattr(model, "kern"):
453
+ raise KernelError("Model must have 'kern' attribute")
454
+
455
+ # Auto-adjust config based on data scale if provided
456
+ cfg = config or InterpretationConfig()
457
+ if X is not None:
458
+ cfg = _adapt_config_to_data(cfg, X)
459
+
460
+ # Build interpretation tree
461
+ interpreter = KernelInterpreter(cfg)
462
+ try:
463
+ root = interpreter.interpret(model.kern)
464
+ except Exception as e:
465
+ raise KernelError(f"Failed to interpret kernel: {e}") from e
466
+
467
+ # Format output
468
+ formatter = KernelSummaryFormatter(root)
469
+
470
+ if format == "dict":
471
+ return root.to_dict()
472
+ elif format == "json":
473
+ result = formatter.to_json()
474
+ elif format == "markdown":
475
+ result = formatter.to_markdown()
476
+ else: # text
477
+ result = formatter.to_text()
478
+
479
+ if verbose and format in ("text", "markdown"):
480
+ print(result)
481
+
482
+ return result
483
+
484
+
485
+ def interpret_lengthscale(
486
+ lengthscale: Union[float, np.ndarray, List[float]],
487
+ config: Optional[LengthscaleThresholds] = None,
488
+ return_category: bool = False,
489
+ ) -> Union[str, Tuple[str, SmoothnessCategory]]:
490
+ """
491
+ Interpret lengthscale magnitude with data-aware thresholds.
492
+
493
+ Args:
494
+ lengthscale: Single value or array of lengthscales
495
+ config: Threshold configuration
496
+ return_category: If True, return (description, category) tuple
497
+
498
+ Returns:
499
+ Interpretation string, or tuple if return_category=True
500
+ """
501
+ cfg = config or LengthscaleThresholds()
502
+
503
+ # Normalize input
504
+ if isinstance(lengthscale, (list, np.ndarray)):
505
+ arr = np.atleast_1d(lengthscale)
506
+ mean_ls = float(np.mean(arr))
507
+ range_str = f"[{np.min(arr):.2f}, {np.max(arr):.2f}]"
508
+ is_ard = len(arr) > 1
509
+ else:
510
+ mean_ls = float(lengthscale)
511
+ range_str = f"{mean_ls:.2f}"
512
+ is_ard = False
513
+
514
+ category = cfg.categorize(mean_ls)
515
+ description = category.describe()
516
+
517
+ if is_ard:
518
+ description += f" (ARD, range: {range_str})"
519
+ else:
520
+ description += f" ({range_str})"
521
+
522
+ if return_category:
523
+ return description, category
524
+ return description
525
+
526
+
527
+ def interpret_variance(
528
+ variance: float,
529
+ name: str = "Signal",
530
+ config: Optional[VarianceThresholds] = None,
531
+ return_category: bool = False,
532
+ ) -> Union[str, Tuple[str, VarianceCategory]]:
533
+ """
534
+ Interpret variance magnitude with context-aware messaging.
535
+
536
+ Args:
537
+ variance: Variance value
538
+ name: Type of variance ("Signal" or "Noise")
539
+ config: Threshold configuration
540
+ return_category: If True, return (description, category) tuple
541
+
542
+ Returns:
543
+ Interpretation string, or tuple if return_category=True
544
+ """
545
+ cfg = config or VarianceThresholds()
546
+ category = cfg.categorize(variance)
547
+ description = category.describe(name) + f" (≈{variance:.3f})"
548
+
549
+ if return_category:
550
+ return description, category
551
+ return description
552
+
553
+
554
+ def format_kernel_tree(model: Any, style: str = "unicode") -> str:
555
+ """
556
+ Pretty-print kernel tree structure.
557
+
558
+ Args:
559
+ model: GPy model
560
+ style: Output style ('unicode', 'ascii', 'minimal')
561
+
562
+ Returns:
563
+ Formatted tree string
564
+ """
565
+ if not hasattr(model, "kern"):
566
+ raise KernelError("Model must have 'kern' attribute")
567
+
568
+ interpreter = KernelInterpreter()
569
+ root = interpreter.interpret(model.kern)
570
+
571
+ if style == "minimal":
572
+ return root.kernel_type
573
+
574
+ # Use formatter's tree rendering
575
+ formatter = KernelSummaryFormatter(root)
576
+ # Extract just the tree portion
577
+ full_text = formatter.to_text(verbose=False)
578
+ # Find and return just the structure section
579
+ lines = full_text.split("\n")
580
+ start_idx = None
581
+ for i, line in enumerate(lines):
582
+ if "Structure:" in line:
583
+ start_idx = i + 1
584
+ elif start_idx and line.startswith("Components:"):
585
+ return "\n".join(lines[start_idx:i]).strip()
586
+
587
+ return root.kernel_type
588
+
589
+
590
+ def count_kernel_components(model: Any) -> int:
591
+ """Count total number of kernel components (leaf nodes)."""
592
+ if not hasattr(model, "kern"):
593
+ return 0
594
+
595
+ def count_leaves(kernel: Any) -> int:
596
+ if hasattr(kernel, "parts") and kernel.parts:
597
+ return sum(count_leaves(k) for k in kernel.parts)
598
+ return 1
599
+
600
+ return count_leaves(model.kern)
601
+
602
+
603
+ def extract_kernel_params_flat(model: Any) -> Dict[str, float]:
604
+ """
605
+ Extract all kernel parameters as flat dictionary with dotted paths.
606
+
607
+ Args:
608
+ model: GPy model
609
+
610
+ Returns:
611
+ Flat dictionary mapping "path.param" to value
612
+ """
613
+ if not hasattr(model, "kern"):
614
+ raise KernelError("Model must have 'kern' attribute")
615
+
616
+ params = {}
617
+
618
+ def extract(kernel: Any, path: str = ""):
619
+ current_path = f"{path}.{kernel.name}" if path else kernel.name
620
+
621
+ if hasattr(kernel, "parameters"):
622
+ for param in kernel.parameters:
623
+ param_path = f"{current_path}.{param.name}"
624
+ val = KernelInterpreter._extract_values(param)
625
+ if isinstance(val, list):
626
+ for i, v in enumerate(val):
627
+ params[f"{param_path}[{i}]"] = float(v)
628
+ else:
629
+ params[param_path] = float(val)
630
+
631
+ if hasattr(kernel, "parts") and kernel.parts:
632
+ for i, part in enumerate(kernel.parts):
633
+ extract(part, current_path)
634
+
635
+ extract(model.kern)
636
+ return params
637
+
638
+
639
+ def get_lengthscale(model: Any, as_dict: bool = False) -> Union[float, Dict[str, float]]:
640
+ """
641
+ Extract lengthscale(s) from model kernel.
642
+
643
+ Args:
644
+ model: GPy model
645
+ as_dict: Return dictionary with component paths as keys
646
+
647
+ Returns:
648
+ Single float, array, or dictionary of lengthscales
649
+ """
650
+ if not hasattr(model, "kern"):
651
+ raise KernelError("Model must have 'kern' attribute")
652
+
653
+ if as_dict:
654
+ result = {}
655
+ def find_lengthscales(kernel: Any, path: str = ""):
656
+ current = f"{path}.{kernel.name}" if path else kernel.name
657
+ if hasattr(kernel, "lengthscale"):
658
+ result[current] = KernelInterpreter._extract_values(kernel.lengthscale)
659
+ if hasattr(kernel, "parts") and kernel.parts:
660
+ for part in kernel.parts:
661
+ find_lengthscales(part, current)
662
+ find_lengthscales(model.kern)
663
+ return result
664
+ else:
665
+ # Return first found lengthscale
666
+ if hasattr(model.kern, "lengthscale"):
667
+ val = KernelInterpreter._extract_values(model.kern.lengthscale)
668
+ return val[0] if isinstance(val, list) else val
669
+ raise KernelError("Model kernel has no lengthscale attribute")
670
+
671
+
672
+ def get_noise_variance(model: Any) -> float:
673
+ """Extract noise variance from GP model."""
674
+ if not hasattr(model, "likelihood"):
675
+ raise KernelError("Model must have 'likelihood' attribute")
676
+ try:
677
+ return float(model.likelihood.variance)
678
+ except Exception as e:
679
+ raise KernelError(f"Could not extract noise variance: {e}") from e
680
+
681
+
682
+ # Private utilities
683
+ def _adapt_config_to_data(
684
+ config: InterpretationConfig,
685
+ X: np.ndarray,
686
+ ) -> InterpretationConfig:
687
+ """
688
+ Auto-scale thresholds based on data characteristics.
689
+
690
+ Args:
691
+ config: Base configuration
692
+ X: Training data (n_samples, n_dims)
693
+
694
+ Returns:
695
+ Adapted configuration
696
+ """
697
+ if X.ndim != 2 or X.shape[0] < 2:
698
+ return config
699
+
700
+ # Compute data scale
701
+ ranges = np.ptp(X, axis=0) # Peak-to-peak (max - min)
702
+ median_range = float(np.median(ranges[ranges > 0]))
703
+
704
+ if median_range <= 0:
705
+ return config
706
+
707
+ # Scale thresholds proportionally to data range
708
+ scale_factor = median_range / 2.0 # Assuming standardized data ~2 range
709
+
710
+ new_ls = LengthscaleThresholds(
711
+ rapid_variation=config.lengthscale.rapid_variation * scale_factor,
712
+ smooth_trend=config.lengthscale.smooth_trend * scale_factor,
713
+ )
714
+
715
+ return InterpretationConfig(
716
+ lengthscale=new_ls,
717
+ variance=config.variance,
718
+ )