gpclarity 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpclarity/__init__.py +190 -0
- gpclarity/_version.py +3 -0
- gpclarity/data_influence.py +501 -0
- gpclarity/exceptions.py +46 -0
- gpclarity/hyperparam_tracker.py +718 -0
- gpclarity/kernel_summary.py +285 -0
- gpclarity/model_complexity.py +619 -0
- gpclarity/plotting.py +337 -0
- gpclarity/uncertainty_analysis.py +647 -0
- gpclarity/utils.py +411 -0
- gpclarity-0.0.2.dist-info/METADATA +248 -0
- gpclarity-0.0.2.dist-info/RECORD +14 -0
- gpclarity-0.0.2.dist-info/WHEEL +4 -0
- gpclarity-0.0.2.dist-info/licenses/LICENSE +37 -0
gpclarity/plotting.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization utilities for GP interpretability.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
def plot_influence_map(
|
|
10
|
+
X_train: np.ndarray,
|
|
11
|
+
influence_scores: np.ndarray,
|
|
12
|
+
ax: Optional["plt.Axes"] = None,
|
|
13
|
+
title: str = "Data Point Influence Map",
|
|
14
|
+
**scatter_kwargs,
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Visualize data point influence in input space.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
X_train: Training inputs (shape: [n, d] where d <= 2)
|
|
21
|
+
influence_scores: Influence scores per point
|
|
22
|
+
ax: Matplotlib axes
|
|
23
|
+
title: Plot title
|
|
24
|
+
**scatter_kwargs: Passed to ax.scatter
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Matplotlib axes
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
import matplotlib.pyplot as plt
|
|
31
|
+
from matplotlib.colors import Normalize
|
|
32
|
+
except ImportError as e:
|
|
33
|
+
raise ImportError(
|
|
34
|
+
"plotting requires matplotlib. Install: pip install matplotlib"
|
|
35
|
+
) from e
|
|
36
|
+
|
|
37
|
+
# Dimensionality check
|
|
38
|
+
if X_train.shape[1] > 2:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Cannot plot {X_train.shape[1]}D data directly. "
|
|
41
|
+
"Use PCA reduction or select 2 dimensions."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if ax is None:
|
|
45
|
+
_, ax = plt.subplots(figsize=(10, 6))
|
|
46
|
+
|
|
47
|
+
# Handle non-finite scores
|
|
48
|
+
safe_scores = np.where(np.isfinite(influence_scores), influence_scores, 0.0)
|
|
49
|
+
|
|
50
|
+
# Normalize sizes
|
|
51
|
+
max_score = np.max(safe_scores) if np.max(safe_scores) > 0 else 1.0
|
|
52
|
+
sizes = 50 + (safe_scores / (max_score + 1e-10)) * 500
|
|
53
|
+
|
|
54
|
+
# Color normalization
|
|
55
|
+
norm = Normalize(vmin=np.min(safe_scores), vmax=max_score)
|
|
56
|
+
|
|
57
|
+
scatter = ax.scatter(
|
|
58
|
+
X_train[:, 0],
|
|
59
|
+
X_train[:, 1] if X_train.shape[1] > 1 else np.zeros(X_train.shape[0]),
|
|
60
|
+
c=safe_scores,
|
|
61
|
+
s=sizes,
|
|
62
|
+
cmap=scatter_kwargs.get("cmap", "viridis"),
|
|
63
|
+
alpha=scatter_kwargs.get("alpha", 0.7),
|
|
64
|
+
norm=norm,
|
|
65
|
+
edgecolors="black",
|
|
66
|
+
linewidth=0.5,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
ax.set_xlabel("Dimension 1", fontsize=11)
|
|
70
|
+
if X_train.shape[1] > 1:
|
|
71
|
+
ax.set_ylabel("Dimension 2", fontsize=11)
|
|
72
|
+
else:
|
|
73
|
+
ax.set_ylabel("Zero baseline (1D projection)", fontsize=11)
|
|
74
|
+
|
|
75
|
+
ax.set_title(title + "\n(size ∝ influence)", fontsize=12, fontweight="bold")
|
|
76
|
+
|
|
77
|
+
cbar = plt.colorbar(scatter, ax=ax)
|
|
78
|
+
cbar.set_label("Influence Score", fontsize=10)
|
|
79
|
+
|
|
80
|
+
ax.grid(True, alpha=0.3)
|
|
81
|
+
|
|
82
|
+
return ax
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def plot_optimization_trajectory(
|
|
87
|
+
tracker: "HyperparameterTracker",
|
|
88
|
+
params: Optional[List[str]] = None,
|
|
89
|
+
figsize: Optional[Tuple[float, float]] = None,
|
|
90
|
+
show_convergence: bool = True,
|
|
91
|
+
show_ll: bool = True,
|
|
92
|
+
n_cols: int = 2,
|
|
93
|
+
) -> "plt.Figure":
|
|
94
|
+
"""
|
|
95
|
+
Plot parameter trajectories from optimization history.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
tracker: HyperparameterTracker with recorded history
|
|
99
|
+
params: Specific parameters to plot (all if None)
|
|
100
|
+
figsize: Figure dimensions
|
|
101
|
+
show_convergence: Show final value and convergence bands
|
|
102
|
+
show_ll: Include log-likelihood subplot
|
|
103
|
+
n_cols: Number of columns in subplot grid
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Matplotlib figure
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
import matplotlib.pyplot as plt
|
|
110
|
+
except ImportError as e:
|
|
111
|
+
raise ImportError(
|
|
112
|
+
"plotting requires matplotlib. Install: pip install matplotlib"
|
|
113
|
+
) from e
|
|
114
|
+
|
|
115
|
+
history = tracker.history
|
|
116
|
+
|
|
117
|
+
# Determine parameters to plot
|
|
118
|
+
if params is None:
|
|
119
|
+
params = list(history[0].parameters.keys())
|
|
120
|
+
|
|
121
|
+
# Add log-likelihood to params if requested
|
|
122
|
+
plot_items = list(params)
|
|
123
|
+
if show_ll:
|
|
124
|
+
has_ll = any(s.log_likelihood is not None for s in history)
|
|
125
|
+
if has_ll:
|
|
126
|
+
plot_items.append("__log_likelihood__")
|
|
127
|
+
|
|
128
|
+
# Calculate grid layout
|
|
129
|
+
n_plots = len(plot_items)
|
|
130
|
+
n_rows = (n_plots + n_cols - 1) // n_cols
|
|
131
|
+
|
|
132
|
+
if figsize is None:
|
|
133
|
+
figsize = (4 * n_cols, 3 * n_rows)
|
|
134
|
+
|
|
135
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
|
|
136
|
+
axes_flat = axes.flatten()
|
|
137
|
+
|
|
138
|
+
for idx, param_name in enumerate(plot_items):
|
|
139
|
+
ax = axes_flat[idx]
|
|
140
|
+
|
|
141
|
+
if param_name == "__log_likelihood__":
|
|
142
|
+
iterations = [s.iteration for s in history]
|
|
143
|
+
values = [s.log_likelihood for s in history]
|
|
144
|
+
ax.plot(iterations, values, linewidth=2, color="green")
|
|
145
|
+
ax.set_title("Log Likelihood", fontweight="bold")
|
|
146
|
+
ax.set_ylabel("LL")
|
|
147
|
+
else:
|
|
148
|
+
iterations, values = tracker.get_parameter_trajectory(param_name)
|
|
149
|
+
|
|
150
|
+
# Handle multi-dimensional
|
|
151
|
+
if values.ndim > 1:
|
|
152
|
+
for dim in range(values.shape[1]):
|
|
153
|
+
ax.plot(iterations, values[:, dim],
|
|
154
|
+
label=f"Dim {dim}", alpha=0.8)
|
|
155
|
+
ax.legend(fontsize=8)
|
|
156
|
+
else:
|
|
157
|
+
ax.plot(iterations, values, linewidth=2, color="#2E86AB")
|
|
158
|
+
|
|
159
|
+
if show_convergence and len(values) > 10:
|
|
160
|
+
final_val = values[-1]
|
|
161
|
+
ax.axhline(y=final_val, color="green", linestyle="--",
|
|
162
|
+
alpha=0.5, label=f"Final: {final_val:.3f}")
|
|
163
|
+
# Convergence band (±1 std of last 10%)
|
|
164
|
+
window = max(5, len(values) // 10)
|
|
165
|
+
recent_std = np.std(values[-window:])
|
|
166
|
+
ax.fill_between(iterations,
|
|
167
|
+
final_val - recent_std,
|
|
168
|
+
final_val + recent_std,
|
|
169
|
+
alpha=0.1, color="green")
|
|
170
|
+
ax.legend(fontsize=8)
|
|
171
|
+
|
|
172
|
+
ax.set_title(f"{param_name.replace('_', ' ').title()}",
|
|
173
|
+
fontweight="bold")
|
|
174
|
+
ax.set_ylabel("Value")
|
|
175
|
+
|
|
176
|
+
ax.set_xlabel("Iteration")
|
|
177
|
+
ax.grid(True, alpha=0.3)
|
|
178
|
+
|
|
179
|
+
# Hide unused subplots
|
|
180
|
+
for idx in range(len(plot_items), len(axes_flat)):
|
|
181
|
+
axes_flat[idx].set_visible(False)
|
|
182
|
+
|
|
183
|
+
plt.suptitle("Optimization Trajectories", fontsize=14, fontweight="bold")
|
|
184
|
+
plt.tight_layout()
|
|
185
|
+
|
|
186
|
+
return fig
|
|
187
|
+
|
|
188
|
+
def plot_uncertainty_profile(
|
|
189
|
+
profiler: "UncertaintyProfiler",
|
|
190
|
+
X_test: np.ndarray,
|
|
191
|
+
*,
|
|
192
|
+
X_train: Optional[np.ndarray] = None,
|
|
193
|
+
y_train: Optional[np.ndarray] = None,
|
|
194
|
+
y_test: Optional[np.ndarray] = None,
|
|
195
|
+
ax: Optional["plt.Axes"] = None,
|
|
196
|
+
confidence_levels: Tuple[float, ...] = (1.0, 2.0),
|
|
197
|
+
plot_std: bool = False,
|
|
198
|
+
fill_alpha: float = 0.2,
|
|
199
|
+
color_mean: str = "#1f77b4",
|
|
200
|
+
color_fill: str = "#1f77b4",
|
|
201
|
+
color_train: str = "red",
|
|
202
|
+
color_test: str = "green",
|
|
203
|
+
show_regions: bool = False,
|
|
204
|
+
**kwargs,
|
|
205
|
+
) -> "plt.Axes":
|
|
206
|
+
"""
|
|
207
|
+
Comprehensive uncertainty profile visualization.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
profiler: UncertaintyProfiler instance
|
|
211
|
+
X_test: Test locations
|
|
212
|
+
X_train: Training inputs
|
|
213
|
+
y_train: Training outputs
|
|
214
|
+
y_test: Test ground truth (optional)
|
|
215
|
+
ax: Matplotlib axes
|
|
216
|
+
confidence_levels: Sigma levels for confidence bands
|
|
217
|
+
plot_std: Overlay standard deviation as line
|
|
218
|
+
fill_alpha: Opacity of confidence bands
|
|
219
|
+
color_mean: Color for mean line
|
|
220
|
+
color_fill: Color for uncertainty bands
|
|
221
|
+
color_train: Color for training points
|
|
222
|
+
color_test: Color for test ground truth
|
|
223
|
+
show_regions: Highlight extrapolation regions
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Matplotlib axes
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
import matplotlib.pyplot as plt
|
|
230
|
+
from matplotlib.patches import Patch
|
|
231
|
+
except ImportError as e:
|
|
232
|
+
raise ImportError(
|
|
233
|
+
"plotting requires matplotlib. Install: pip install matplotlib"
|
|
234
|
+
) from e
|
|
235
|
+
|
|
236
|
+
if ax is None:
|
|
237
|
+
_, ax = plt.subplots(figsize=(12, 6))
|
|
238
|
+
|
|
239
|
+
# Get predictions
|
|
240
|
+
pred = profiler.predict(X_test)
|
|
241
|
+
mean, std = pred.mean, pred.std
|
|
242
|
+
|
|
243
|
+
X_flat = X_test.flatten()
|
|
244
|
+
mean_flat = mean.flatten()
|
|
245
|
+
std_flat = std.flatten()
|
|
246
|
+
|
|
247
|
+
# Sort for clean plotting (if 1D)
|
|
248
|
+
if X_test.shape[1] == 1:
|
|
249
|
+
sort_idx = np.argsort(X_flat)
|
|
250
|
+
X_flat = X_flat[sort_idx]
|
|
251
|
+
mean_flat = mean_flat[sort_idx]
|
|
252
|
+
std_flat = std_flat[sort_idx]
|
|
253
|
+
|
|
254
|
+
# Plot confidence bands (outer first)
|
|
255
|
+
for level in sorted(confidence_levels, reverse=True):
|
|
256
|
+
alpha = fill_alpha * (0.5 ** (list(confidence_levels).index(level)))
|
|
257
|
+
ax.fill_between(
|
|
258
|
+
X_flat,
|
|
259
|
+
mean_flat - level * std_flat,
|
|
260
|
+
mean_flat + level * std_flat,
|
|
261
|
+
alpha=alpha,
|
|
262
|
+
color=color_fill,
|
|
263
|
+
label=f"±{level}σ" if level == max(confidence_levels) else None,
|
|
264
|
+
zorder=1,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Mean prediction
|
|
268
|
+
ax.plot(
|
|
269
|
+
X_flat,
|
|
270
|
+
mean_flat,
|
|
271
|
+
color=color_mean,
|
|
272
|
+
linewidth=2.5,
|
|
273
|
+
label="GP Mean",
|
|
274
|
+
zorder=3,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Standard deviation line (optional)
|
|
278
|
+
if plot_std:
|
|
279
|
+
ax.plot(
|
|
280
|
+
X_flat,
|
|
281
|
+
std_flat,
|
|
282
|
+
color=color_fill,
|
|
283
|
+
linestyle="--",
|
|
284
|
+
alpha=0.7,
|
|
285
|
+
label="Std Dev",
|
|
286
|
+
zorder=2,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Training data
|
|
290
|
+
if X_train is not None and y_train is not None:
|
|
291
|
+
ax.scatter(
|
|
292
|
+
X_train.flatten(),
|
|
293
|
+
y_train.flatten(),
|
|
294
|
+
color=color_train,
|
|
295
|
+
s=60,
|
|
296
|
+
zorder=5,
|
|
297
|
+
label="Training Data",
|
|
298
|
+
edgecolors="white",
|
|
299
|
+
linewidth=1,
|
|
300
|
+
alpha=0.9,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Test ground truth
|
|
304
|
+
if y_test is not None:
|
|
305
|
+
ax.scatter(
|
|
306
|
+
X_test.flatten(),
|
|
307
|
+
y_test.flatten(),
|
|
308
|
+
color=color_test,
|
|
309
|
+
s=40,
|
|
310
|
+
marker="x",
|
|
311
|
+
zorder=4,
|
|
312
|
+
label="Ground Truth",
|
|
313
|
+
alpha=0.7,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Region highlighting
|
|
317
|
+
if show_regions and X_train is not None:
|
|
318
|
+
regions = profiler.classify_regions(X_test)
|
|
319
|
+
|
|
320
|
+
# Highlight extrapolation regions
|
|
321
|
+
ext_mask = regions == UncertaintyRegion.EXTRAPOLATION
|
|
322
|
+
if np.any(ext_mask):
|
|
323
|
+
ax.axvspan(
|
|
324
|
+
np.min(X_test[ext_mask].flatten()),
|
|
325
|
+
np.max(X_test[ext_mask].flatten()),
|
|
326
|
+
alpha=0.1,
|
|
327
|
+
color="red",
|
|
328
|
+
label="Extrapolation",
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
ax.set_xlabel("Input", fontsize=12)
|
|
332
|
+
ax.set_ylabel("Output", fontsize=12)
|
|
333
|
+
ax.set_title("Uncertainty Profile", fontsize=14, fontweight="bold")
|
|
334
|
+
ax.legend(loc="best", fontsize=9)
|
|
335
|
+
ax.grid(True, alpha=0.3, linestyle="--")
|
|
336
|
+
|
|
337
|
+
return ax
|