gpclarity 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpclarity/plotting.py ADDED
@@ -0,0 +1,337 @@
1
+ """
2
+ Visualization utilities for GP interpretability.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+
9
+ def plot_influence_map(
10
+ X_train: np.ndarray,
11
+ influence_scores: np.ndarray,
12
+ ax: Optional["plt.Axes"] = None,
13
+ title: str = "Data Point Influence Map",
14
+ **scatter_kwargs,
15
+ ):
16
+ """
17
+ Visualize data point influence in input space.
18
+
19
+ Args:
20
+ X_train: Training inputs (shape: [n, d] where d <= 2)
21
+ influence_scores: Influence scores per point
22
+ ax: Matplotlib axes
23
+ title: Plot title
24
+ **scatter_kwargs: Passed to ax.scatter
25
+
26
+ Returns:
27
+ Matplotlib axes
28
+ """
29
+ try:
30
+ import matplotlib.pyplot as plt
31
+ from matplotlib.colors import Normalize
32
+ except ImportError as e:
33
+ raise ImportError(
34
+ "plotting requires matplotlib. Install: pip install matplotlib"
35
+ ) from e
36
+
37
+ # Dimensionality check
38
+ if X_train.shape[1] > 2:
39
+ raise ValueError(
40
+ f"Cannot plot {X_train.shape[1]}D data directly. "
41
+ "Use PCA reduction or select 2 dimensions."
42
+ )
43
+
44
+ if ax is None:
45
+ _, ax = plt.subplots(figsize=(10, 6))
46
+
47
+ # Handle non-finite scores
48
+ safe_scores = np.where(np.isfinite(influence_scores), influence_scores, 0.0)
49
+
50
+ # Normalize sizes
51
+ max_score = np.max(safe_scores) if np.max(safe_scores) > 0 else 1.0
52
+ sizes = 50 + (safe_scores / (max_score + 1e-10)) * 500
53
+
54
+ # Color normalization
55
+ norm = Normalize(vmin=np.min(safe_scores), vmax=max_score)
56
+
57
+ scatter = ax.scatter(
58
+ X_train[:, 0],
59
+ X_train[:, 1] if X_train.shape[1] > 1 else np.zeros(X_train.shape[0]),
60
+ c=safe_scores,
61
+ s=sizes,
62
+ cmap=scatter_kwargs.get("cmap", "viridis"),
63
+ alpha=scatter_kwargs.get("alpha", 0.7),
64
+ norm=norm,
65
+ edgecolors="black",
66
+ linewidth=0.5,
67
+ )
68
+
69
+ ax.set_xlabel("Dimension 1", fontsize=11)
70
+ if X_train.shape[1] > 1:
71
+ ax.set_ylabel("Dimension 2", fontsize=11)
72
+ else:
73
+ ax.set_ylabel("Zero baseline (1D projection)", fontsize=11)
74
+
75
+ ax.set_title(title + "\n(size ∝ influence)", fontsize=12, fontweight="bold")
76
+
77
+ cbar = plt.colorbar(scatter, ax=ax)
78
+ cbar.set_label("Influence Score", fontsize=10)
79
+
80
+ ax.grid(True, alpha=0.3)
81
+
82
+ return ax
83
+
84
+
85
+
86
+ def plot_optimization_trajectory(
87
+ tracker: "HyperparameterTracker",
88
+ params: Optional[List[str]] = None,
89
+ figsize: Optional[Tuple[float, float]] = None,
90
+ show_convergence: bool = True,
91
+ show_ll: bool = True,
92
+ n_cols: int = 2,
93
+ ) -> "plt.Figure":
94
+ """
95
+ Plot parameter trajectories from optimization history.
96
+
97
+ Args:
98
+ tracker: HyperparameterTracker with recorded history
99
+ params: Specific parameters to plot (all if None)
100
+ figsize: Figure dimensions
101
+ show_convergence: Show final value and convergence bands
102
+ show_ll: Include log-likelihood subplot
103
+ n_cols: Number of columns in subplot grid
104
+
105
+ Returns:
106
+ Matplotlib figure
107
+ """
108
+ try:
109
+ import matplotlib.pyplot as plt
110
+ except ImportError as e:
111
+ raise ImportError(
112
+ "plotting requires matplotlib. Install: pip install matplotlib"
113
+ ) from e
114
+
115
+ history = tracker.history
116
+
117
+ # Determine parameters to plot
118
+ if params is None:
119
+ params = list(history[0].parameters.keys())
120
+
121
+ # Add log-likelihood to params if requested
122
+ plot_items = list(params)
123
+ if show_ll:
124
+ has_ll = any(s.log_likelihood is not None for s in history)
125
+ if has_ll:
126
+ plot_items.append("__log_likelihood__")
127
+
128
+ # Calculate grid layout
129
+ n_plots = len(plot_items)
130
+ n_rows = (n_plots + n_cols - 1) // n_cols
131
+
132
+ if figsize is None:
133
+ figsize = (4 * n_cols, 3 * n_rows)
134
+
135
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
136
+ axes_flat = axes.flatten()
137
+
138
+ for idx, param_name in enumerate(plot_items):
139
+ ax = axes_flat[idx]
140
+
141
+ if param_name == "__log_likelihood__":
142
+ iterations = [s.iteration for s in history]
143
+ values = [s.log_likelihood for s in history]
144
+ ax.plot(iterations, values, linewidth=2, color="green")
145
+ ax.set_title("Log Likelihood", fontweight="bold")
146
+ ax.set_ylabel("LL")
147
+ else:
148
+ iterations, values = tracker.get_parameter_trajectory(param_name)
149
+
150
+ # Handle multi-dimensional
151
+ if values.ndim > 1:
152
+ for dim in range(values.shape[1]):
153
+ ax.plot(iterations, values[:, dim],
154
+ label=f"Dim {dim}", alpha=0.8)
155
+ ax.legend(fontsize=8)
156
+ else:
157
+ ax.plot(iterations, values, linewidth=2, color="#2E86AB")
158
+
159
+ if show_convergence and len(values) > 10:
160
+ final_val = values[-1]
161
+ ax.axhline(y=final_val, color="green", linestyle="--",
162
+ alpha=0.5, label=f"Final: {final_val:.3f}")
163
+ # Convergence band (±1 std of last 10%)
164
+ window = max(5, len(values) // 10)
165
+ recent_std = np.std(values[-window:])
166
+ ax.fill_between(iterations,
167
+ final_val - recent_std,
168
+ final_val + recent_std,
169
+ alpha=0.1, color="green")
170
+ ax.legend(fontsize=8)
171
+
172
+ ax.set_title(f"{param_name.replace('_', ' ').title()}",
173
+ fontweight="bold")
174
+ ax.set_ylabel("Value")
175
+
176
+ ax.set_xlabel("Iteration")
177
+ ax.grid(True, alpha=0.3)
178
+
179
+ # Hide unused subplots
180
+ for idx in range(len(plot_items), len(axes_flat)):
181
+ axes_flat[idx].set_visible(False)
182
+
183
+ plt.suptitle("Optimization Trajectories", fontsize=14, fontweight="bold")
184
+ plt.tight_layout()
185
+
186
+ return fig
187
+
188
+ def plot_uncertainty_profile(
189
+ profiler: "UncertaintyProfiler",
190
+ X_test: np.ndarray,
191
+ *,
192
+ X_train: Optional[np.ndarray] = None,
193
+ y_train: Optional[np.ndarray] = None,
194
+ y_test: Optional[np.ndarray] = None,
195
+ ax: Optional["plt.Axes"] = None,
196
+ confidence_levels: Tuple[float, ...] = (1.0, 2.0),
197
+ plot_std: bool = False,
198
+ fill_alpha: float = 0.2,
199
+ color_mean: str = "#1f77b4",
200
+ color_fill: str = "#1f77b4",
201
+ color_train: str = "red",
202
+ color_test: str = "green",
203
+ show_regions: bool = False,
204
+ **kwargs,
205
+ ) -> "plt.Axes":
206
+ """
207
+ Comprehensive uncertainty profile visualization.
208
+
209
+ Args:
210
+ profiler: UncertaintyProfiler instance
211
+ X_test: Test locations
212
+ X_train: Training inputs
213
+ y_train: Training outputs
214
+ y_test: Test ground truth (optional)
215
+ ax: Matplotlib axes
216
+ confidence_levels: Sigma levels for confidence bands
217
+ plot_std: Overlay standard deviation as line
218
+ fill_alpha: Opacity of confidence bands
219
+ color_mean: Color for mean line
220
+ color_fill: Color for uncertainty bands
221
+ color_train: Color for training points
222
+ color_test: Color for test ground truth
223
+ show_regions: Highlight extrapolation regions
224
+
225
+ Returns:
226
+ Matplotlib axes
227
+ """
228
+ try:
229
+ import matplotlib.pyplot as plt
230
+ from matplotlib.patches import Patch
231
+ except ImportError as e:
232
+ raise ImportError(
233
+ "plotting requires matplotlib. Install: pip install matplotlib"
234
+ ) from e
235
+
236
+ if ax is None:
237
+ _, ax = plt.subplots(figsize=(12, 6))
238
+
239
+ # Get predictions
240
+ pred = profiler.predict(X_test)
241
+ mean, std = pred.mean, pred.std
242
+
243
+ X_flat = X_test.flatten()
244
+ mean_flat = mean.flatten()
245
+ std_flat = std.flatten()
246
+
247
+ # Sort for clean plotting (if 1D)
248
+ if X_test.shape[1] == 1:
249
+ sort_idx = np.argsort(X_flat)
250
+ X_flat = X_flat[sort_idx]
251
+ mean_flat = mean_flat[sort_idx]
252
+ std_flat = std_flat[sort_idx]
253
+
254
+ # Plot confidence bands (outer first)
255
+ for level in sorted(confidence_levels, reverse=True):
256
+ alpha = fill_alpha * (0.5 ** (list(confidence_levels).index(level)))
257
+ ax.fill_between(
258
+ X_flat,
259
+ mean_flat - level * std_flat,
260
+ mean_flat + level * std_flat,
261
+ alpha=alpha,
262
+ color=color_fill,
263
+ label=f"±{level}σ" if level == max(confidence_levels) else None,
264
+ zorder=1,
265
+ )
266
+
267
+ # Mean prediction
268
+ ax.plot(
269
+ X_flat,
270
+ mean_flat,
271
+ color=color_mean,
272
+ linewidth=2.5,
273
+ label="GP Mean",
274
+ zorder=3,
275
+ )
276
+
277
+ # Standard deviation line (optional)
278
+ if plot_std:
279
+ ax.plot(
280
+ X_flat,
281
+ std_flat,
282
+ color=color_fill,
283
+ linestyle="--",
284
+ alpha=0.7,
285
+ label="Std Dev",
286
+ zorder=2,
287
+ )
288
+
289
+ # Training data
290
+ if X_train is not None and y_train is not None:
291
+ ax.scatter(
292
+ X_train.flatten(),
293
+ y_train.flatten(),
294
+ color=color_train,
295
+ s=60,
296
+ zorder=5,
297
+ label="Training Data",
298
+ edgecolors="white",
299
+ linewidth=1,
300
+ alpha=0.9,
301
+ )
302
+
303
+ # Test ground truth
304
+ if y_test is not None:
305
+ ax.scatter(
306
+ X_test.flatten(),
307
+ y_test.flatten(),
308
+ color=color_test,
309
+ s=40,
310
+ marker="x",
311
+ zorder=4,
312
+ label="Ground Truth",
313
+ alpha=0.7,
314
+ )
315
+
316
+ # Region highlighting
317
+ if show_regions and X_train is not None:
318
+ regions = profiler.classify_regions(X_test)
319
+
320
+ # Highlight extrapolation regions
321
+ ext_mask = regions == UncertaintyRegion.EXTRAPOLATION
322
+ if np.any(ext_mask):
323
+ ax.axvspan(
324
+ np.min(X_test[ext_mask].flatten()),
325
+ np.max(X_test[ext_mask].flatten()),
326
+ alpha=0.1,
327
+ color="red",
328
+ label="Extrapolation",
329
+ )
330
+
331
+ ax.set_xlabel("Input", fontsize=12)
332
+ ax.set_ylabel("Output", fontsize=12)
333
+ ax.set_title("Uncertainty Profile", fontsize=14, fontweight="bold")
334
+ ax.legend(loc="best", fontsize=9)
335
+ ax.grid(True, alpha=0.3, linestyle="--")
336
+
337
+ return ax