ai-metacognition-toolkit 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_metacognition/__init__.py +123 -0
- ai_metacognition/analyzers/__init__.py +24 -0
- ai_metacognition/analyzers/base.py +39 -0
- ai_metacognition/analyzers/counterfactual_cot.py +579 -0
- ai_metacognition/analyzers/model_api.py +39 -0
- ai_metacognition/detectors/__init__.py +40 -0
- ai_metacognition/detectors/base.py +42 -0
- ai_metacognition/detectors/observer_effect.py +651 -0
- ai_metacognition/detectors/sandbagging_detector.py +1438 -0
- ai_metacognition/detectors/situational_awareness.py +526 -0
- ai_metacognition/integrations/__init__.py +16 -0
- ai_metacognition/integrations/anthropic_api.py +230 -0
- ai_metacognition/integrations/base.py +113 -0
- ai_metacognition/integrations/openai_api.py +300 -0
- ai_metacognition/probing/__init__.py +24 -0
- ai_metacognition/probing/extraction.py +176 -0
- ai_metacognition/probing/hooks.py +200 -0
- ai_metacognition/probing/probes.py +186 -0
- ai_metacognition/probing/vectors.py +133 -0
- ai_metacognition/utils/__init__.py +48 -0
- ai_metacognition/utils/feature_extraction.py +534 -0
- ai_metacognition/utils/statistical_tests.py +317 -0
- ai_metacognition/utils/text_processing.py +98 -0
- ai_metacognition/visualizations/__init__.py +22 -0
- ai_metacognition/visualizations/plotting.py +523 -0
- ai_metacognition_toolkit-0.3.0.dist-info/METADATA +621 -0
- ai_metacognition_toolkit-0.3.0.dist-info/RECORD +30 -0
- ai_metacognition_toolkit-0.3.0.dist-info/WHEEL +5 -0
- ai_metacognition_toolkit-0.3.0.dist-info/licenses/LICENSE +21 -0
- ai_metacognition_toolkit-0.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
"""Publication-ready plotting functions for metacognition analysis.
|
|
2
|
+
|
|
3
|
+
This module provides high-quality visualization functions for:
|
|
4
|
+
- Awareness probability time series with confidence bands
|
|
5
|
+
- Causal attribution bar charts
|
|
6
|
+
- Feature divergence heatmaps
|
|
7
|
+
- Distribution comparison plots
|
|
8
|
+
|
|
9
|
+
All plots are customizable and can be saved at publication quality.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
import numpy as np
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
import matplotlib.dates as mdates
|
|
17
|
+
from matplotlib.figure import Figure
|
|
18
|
+
from matplotlib.axes import Axes
|
|
19
|
+
import seaborn as sns
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Publication-ready default settings
|
|
23
|
+
DEFAULT_DPI = 300
|
|
24
|
+
DEFAULT_FIGSIZE = (10, 6)
|
|
25
|
+
DEFAULT_STYLE = "seaborn-v0_8-darkgrid"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _setup_plot_style(style: Optional[str] = None) -> None:
|
|
29
|
+
"""Setup matplotlib style for publication-ready plots.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
style: Matplotlib style name. If None, uses default.
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
plt.style.use(style or DEFAULT_STYLE)
|
|
36
|
+
except OSError:
|
|
37
|
+
# Fallback if style not available
|
|
38
|
+
plt.style.use("default")
|
|
39
|
+
plt.rcParams.update({
|
|
40
|
+
'figure.facecolor': 'white',
|
|
41
|
+
'axes.facecolor': 'white',
|
|
42
|
+
'axes.edgecolor': 'black',
|
|
43
|
+
'grid.alpha': 0.3,
|
|
44
|
+
'font.size': 10,
|
|
45
|
+
})
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def plot_awareness_over_time(
|
|
49
|
+
timestamps: List[Union[datetime, float]],
|
|
50
|
+
probabilities: List[float],
|
|
51
|
+
confidence_intervals: Optional[List[Tuple[float, float]]] = None,
|
|
52
|
+
title: str = "Situational Awareness Over Time",
|
|
53
|
+
xlabel: str = "Time",
|
|
54
|
+
ylabel: str = "Awareness Probability",
|
|
55
|
+
threshold: Optional[float] = None,
|
|
56
|
+
figsize: Tuple[float, float] = DEFAULT_FIGSIZE,
|
|
57
|
+
save_path: Optional[str] = None,
|
|
58
|
+
dpi: int = DEFAULT_DPI,
|
|
59
|
+
style: Optional[str] = None,
|
|
60
|
+
show: bool = True,
|
|
61
|
+
) -> Figure:
|
|
62
|
+
"""Plot time series of situational awareness probabilities.
|
|
63
|
+
|
|
64
|
+
Creates a line plot showing awareness probability over time with optional
|
|
65
|
+
confidence intervals shown as shaded bands. Optionally displays a threshold
|
|
66
|
+
line for decision-making.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
timestamps: List of timestamps (datetime objects or numeric values)
|
|
70
|
+
probabilities: Awareness probability at each timestamp
|
|
71
|
+
confidence_intervals: Optional list of (lower, upper) confidence bounds
|
|
72
|
+
title: Plot title
|
|
73
|
+
xlabel: X-axis label
|
|
74
|
+
ylabel: Y-axis label
|
|
75
|
+
threshold: Optional decision threshold to display as horizontal line
|
|
76
|
+
figsize: Figure size as (width, height) in inches
|
|
77
|
+
save_path: If provided, save figure to this path
|
|
78
|
+
dpi: Resolution for saved figure
|
|
79
|
+
style: Matplotlib style to use
|
|
80
|
+
show: Whether to display the plot
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Matplotlib Figure object
|
|
84
|
+
|
|
85
|
+
Examples:
|
|
86
|
+
>>> from datetime import datetime, timedelta
|
|
87
|
+
>>> base = datetime.now()
|
|
88
|
+
>>> timestamps = [base + timedelta(hours=i) for i in range(24)]
|
|
89
|
+
>>> probabilities = [0.1 + 0.03*i for i in range(24)]
|
|
90
|
+
>>> confidence_intervals = [(p-0.05, p+0.05) for p in probabilities]
|
|
91
|
+
>>> fig = plot_awareness_over_time(
|
|
92
|
+
... timestamps, probabilities, confidence_intervals,
|
|
93
|
+
... threshold=0.5, save_path="awareness_trend.png"
|
|
94
|
+
... )
|
|
95
|
+
"""
|
|
96
|
+
_setup_plot_style(style)
|
|
97
|
+
|
|
98
|
+
if len(timestamps) != len(probabilities):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Timestamps ({len(timestamps)}) and probabilities ({len(probabilities)}) "
|
|
101
|
+
"must have the same length"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if confidence_intervals and len(confidence_intervals) != len(probabilities):
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Confidence intervals ({len(confidence_intervals)}) must match "
|
|
107
|
+
f"probabilities ({len(probabilities)})"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
111
|
+
|
|
112
|
+
# Convert timestamps if needed
|
|
113
|
+
x_values = timestamps
|
|
114
|
+
is_datetime = isinstance(timestamps[0], datetime)
|
|
115
|
+
|
|
116
|
+
# Main probability line
|
|
117
|
+
ax.plot(
|
|
118
|
+
x_values, probabilities,
|
|
119
|
+
linewidth=2, color='#2E86AB', label='Awareness Probability',
|
|
120
|
+
marker='o', markersize=4, markerfacecolor='white', markeredgewidth=1.5
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Confidence intervals
|
|
124
|
+
if confidence_intervals:
|
|
125
|
+
lower_bounds = [ci[0] for ci in confidence_intervals]
|
|
126
|
+
upper_bounds = [ci[1] for ci in confidence_intervals]
|
|
127
|
+
ax.fill_between(
|
|
128
|
+
x_values, lower_bounds, upper_bounds,
|
|
129
|
+
alpha=0.3, color='#2E86AB', label='95% Confidence Interval'
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Threshold line
|
|
133
|
+
if threshold is not None:
|
|
134
|
+
ax.axhline(
|
|
135
|
+
y=threshold, color='#E63946', linestyle='--',
|
|
136
|
+
linewidth=2, alpha=0.7, label=f'Threshold ({threshold:.2f})'
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Formatting
|
|
140
|
+
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
|
141
|
+
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
|
142
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
143
|
+
ax.set_ylim(-0.05, 1.05)
|
|
144
|
+
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
|
145
|
+
ax.legend(loc='best', framealpha=0.9, fontsize=10)
|
|
146
|
+
|
|
147
|
+
# Format x-axis for datetime
|
|
148
|
+
if is_datetime:
|
|
149
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
|
|
150
|
+
plt.xticks(rotation=45, ha='right')
|
|
151
|
+
|
|
152
|
+
plt.tight_layout()
|
|
153
|
+
|
|
154
|
+
if save_path:
|
|
155
|
+
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
156
|
+
|
|
157
|
+
if show:
|
|
158
|
+
plt.show()
|
|
159
|
+
|
|
160
|
+
return fig
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def plot_causal_attribution(
|
|
164
|
+
intervention_results: List[Dict],
|
|
165
|
+
title: str = "Causal Attribution Analysis",
|
|
166
|
+
xlabel: str = "Divergence Score",
|
|
167
|
+
ylabel: str = "Intervention Type",
|
|
168
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
169
|
+
save_path: Optional[str] = None,
|
|
170
|
+
dpi: int = DEFAULT_DPI,
|
|
171
|
+
style: Optional[str] = None,
|
|
172
|
+
show: bool = True,
|
|
173
|
+
highlight_threshold: Optional[float] = None,
|
|
174
|
+
) -> Figure:
|
|
175
|
+
"""Plot causal attribution from counterfactual interventions.
|
|
176
|
+
|
|
177
|
+
Creates a horizontal bar chart showing the causal importance of different
|
|
178
|
+
interventions, sorted by divergence score. Optionally highlights interventions
|
|
179
|
+
above a significance threshold.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
intervention_results: List of dicts with 'intervention_type' and 'divergence' keys
|
|
183
|
+
title: Plot title
|
|
184
|
+
xlabel: X-axis label
|
|
185
|
+
ylabel: Y-axis label
|
|
186
|
+
figsize: Figure size as (width, height) in inches
|
|
187
|
+
save_path: If provided, save figure to this path
|
|
188
|
+
dpi: Resolution for saved figure
|
|
189
|
+
style: Matplotlib style to use
|
|
190
|
+
show: Whether to display the plot
|
|
191
|
+
highlight_threshold: If provided, highlight bars above this value
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Matplotlib Figure object
|
|
195
|
+
|
|
196
|
+
Examples:
|
|
197
|
+
>>> results = [
|
|
198
|
+
... {"intervention_type": "remove_eval_awareness", "divergence": 0.342},
|
|
199
|
+
... {"intervention_type": "amplify_eval_awareness", "divergence": 0.156},
|
|
200
|
+
... {"intervention_type": "remove_goal_reasoning", "divergence": 0.089},
|
|
201
|
+
... ]
|
|
202
|
+
>>> fig = plot_causal_attribution(
|
|
203
|
+
... results, highlight_threshold=0.2, save_path="causal_attribution.png"
|
|
204
|
+
... )
|
|
205
|
+
"""
|
|
206
|
+
_setup_plot_style(style)
|
|
207
|
+
|
|
208
|
+
if not intervention_results:
|
|
209
|
+
raise ValueError("intervention_results cannot be empty")
|
|
210
|
+
|
|
211
|
+
# Sort by divergence (descending)
|
|
212
|
+
sorted_results = sorted(
|
|
213
|
+
intervention_results,
|
|
214
|
+
key=lambda x: x.get('divergence', 0),
|
|
215
|
+
reverse=True
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
intervention_types = [
|
|
219
|
+
r['intervention_type'].replace('_', ' ').title()
|
|
220
|
+
for r in sorted_results
|
|
221
|
+
]
|
|
222
|
+
divergences = [r['divergence'] for r in sorted_results]
|
|
223
|
+
|
|
224
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
225
|
+
|
|
226
|
+
# Color bars based on threshold
|
|
227
|
+
if highlight_threshold is not None:
|
|
228
|
+
colors = [
|
|
229
|
+
'#E63946' if d >= highlight_threshold else '#2E86AB'
|
|
230
|
+
for d in divergences
|
|
231
|
+
]
|
|
232
|
+
else:
|
|
233
|
+
colors = ['#2E86AB'] * len(divergences)
|
|
234
|
+
|
|
235
|
+
# Create horizontal bar chart
|
|
236
|
+
y_pos = np.arange(len(intervention_types))
|
|
237
|
+
bars = ax.barh(y_pos, divergences, color=colors, alpha=0.8, edgecolor='black', linewidth=1.2)
|
|
238
|
+
|
|
239
|
+
# Add value labels on bars
|
|
240
|
+
for i, (bar, div) in enumerate(zip(bars, divergences)):
|
|
241
|
+
width = bar.get_width()
|
|
242
|
+
ax.text(
|
|
243
|
+
width + max(divergences) * 0.02,
|
|
244
|
+
bar.get_y() + bar.get_height() / 2,
|
|
245
|
+
f'{div:.3f}',
|
|
246
|
+
ha='left', va='center', fontsize=10, fontweight='bold'
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Threshold line
|
|
250
|
+
if highlight_threshold is not None:
|
|
251
|
+
ax.axvline(
|
|
252
|
+
x=highlight_threshold, color='#E63946', linestyle='--',
|
|
253
|
+
linewidth=2, alpha=0.7, label=f'Significance Threshold ({highlight_threshold:.2f})'
|
|
254
|
+
)
|
|
255
|
+
ax.legend(loc='best', framealpha=0.9, fontsize=10)
|
|
256
|
+
|
|
257
|
+
# Formatting
|
|
258
|
+
ax.set_yticks(y_pos)
|
|
259
|
+
ax.set_yticklabels(intervention_types, fontsize=10)
|
|
260
|
+
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
|
261
|
+
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
|
262
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
263
|
+
ax.grid(True, alpha=0.3, axis='x', linestyle='-', linewidth=0.5)
|
|
264
|
+
|
|
265
|
+
plt.tight_layout()
|
|
266
|
+
|
|
267
|
+
if save_path:
|
|
268
|
+
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
269
|
+
|
|
270
|
+
if show:
|
|
271
|
+
plt.show()
|
|
272
|
+
|
|
273
|
+
return fig
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def plot_divergence_heatmap(
|
|
277
|
+
features: List[str],
|
|
278
|
+
timepoints: List[Union[datetime, str, float]],
|
|
279
|
+
z_scores: np.ndarray,
|
|
280
|
+
title: str = "Feature Divergence Over Time",
|
|
281
|
+
xlabel: str = "Time",
|
|
282
|
+
ylabel: str = "Feature",
|
|
283
|
+
figsize: Tuple[float, float] = (12, 8),
|
|
284
|
+
save_path: Optional[str] = None,
|
|
285
|
+
dpi: int = DEFAULT_DPI,
|
|
286
|
+
style: Optional[str] = None,
|
|
287
|
+
show: bool = True,
|
|
288
|
+
cmap: str = "RdYlGn_r",
|
|
289
|
+
vmin: float = -5.0,
|
|
290
|
+
vmax: float = 5.0,
|
|
291
|
+
) -> Figure:
|
|
292
|
+
"""Plot heatmap of feature divergence over time.
|
|
293
|
+
|
|
294
|
+
Creates a color-coded heatmap showing z-scores for multiple features across
|
|
295
|
+
timepoints. Uses diverging colormap with red indicating high divergence
|
|
296
|
+
(potential issues) and green indicating normal behavior.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
features: List of feature names
|
|
300
|
+
timepoints: List of timepoint labels (datetime, strings, or numeric)
|
|
301
|
+
z_scores: 2D array of z-scores with shape (n_features, n_timepoints)
|
|
302
|
+
title: Plot title
|
|
303
|
+
xlabel: X-axis label
|
|
304
|
+
ylabel: Y-axis label
|
|
305
|
+
figsize: Figure size as (width, height) in inches
|
|
306
|
+
save_path: If provided, save figure to this path
|
|
307
|
+
dpi: Resolution for saved figure
|
|
308
|
+
style: Matplotlib style to use
|
|
309
|
+
show: Whether to display the plot
|
|
310
|
+
cmap: Colormap name (diverging recommended)
|
|
311
|
+
vmin: Minimum z-score for colormap
|
|
312
|
+
vmax: Maximum z-score for colormap
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Matplotlib Figure object
|
|
316
|
+
|
|
317
|
+
Examples:
|
|
318
|
+
>>> features = ["response_length", "refusal_rate", "hedging_ratio"]
|
|
319
|
+
>>> timepoints = ["00:00", "01:00", "02:00", "03:00"]
|
|
320
|
+
>>> z_scores = np.array([
|
|
321
|
+
... [0.5, 1.2, 2.8, 3.5], # response_length
|
|
322
|
+
... [-0.2, 0.1, 0.3, 4.2], # refusal_rate
|
|
323
|
+
... [0.1, 0.3, 0.2, 0.4], # hedging_ratio
|
|
324
|
+
... ])
|
|
325
|
+
>>> fig = plot_divergence_heatmap(
|
|
326
|
+
... features, timepoints, z_scores, save_path="divergence_heatmap.png"
|
|
327
|
+
... )
|
|
328
|
+
"""
|
|
329
|
+
_setup_plot_style(style)
|
|
330
|
+
|
|
331
|
+
if z_scores.shape[0] != len(features):
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"z_scores rows ({z_scores.shape[0]}) must match features ({len(features)})"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if z_scores.shape[1] != len(timepoints):
|
|
337
|
+
raise ValueError(
|
|
338
|
+
f"z_scores columns ({z_scores.shape[1]}) must match timepoints ({len(timepoints)})"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
342
|
+
|
|
343
|
+
# Create heatmap
|
|
344
|
+
im = ax.imshow(
|
|
345
|
+
z_scores, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax,
|
|
346
|
+
interpolation='nearest'
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Add colorbar
|
|
350
|
+
cbar = plt.colorbar(im, ax=ax, pad=0.02)
|
|
351
|
+
cbar.set_label('Z-Score (σ)', rotation=270, labelpad=20, fontsize=12, fontweight='bold')
|
|
352
|
+
|
|
353
|
+
# Add severity threshold lines on colorbar
|
|
354
|
+
cbar.ax.axhline(y=2.0, color='yellow', linestyle='--', linewidth=1.5, alpha=0.7)
|
|
355
|
+
cbar.ax.axhline(y=3.0, color='orange', linestyle='--', linewidth=1.5, alpha=0.7)
|
|
356
|
+
cbar.ax.axhline(y=4.0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
|
|
357
|
+
|
|
358
|
+
# Set ticks
|
|
359
|
+
ax.set_xticks(np.arange(len(timepoints)))
|
|
360
|
+
ax.set_yticks(np.arange(len(features)))
|
|
361
|
+
|
|
362
|
+
# Format timepoint labels
|
|
363
|
+
timepoint_labels = []
|
|
364
|
+
for tp in timepoints:
|
|
365
|
+
if isinstance(tp, datetime):
|
|
366
|
+
timepoint_labels.append(tp.strftime('%H:%M'))
|
|
367
|
+
else:
|
|
368
|
+
timepoint_labels.append(str(tp))
|
|
369
|
+
|
|
370
|
+
ax.set_xticklabels(timepoint_labels, rotation=45, ha='right', fontsize=9)
|
|
371
|
+
ax.set_yticklabels(
|
|
372
|
+
[f.replace('_', ' ').title() for f in features],
|
|
373
|
+
fontsize=10
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Add text annotations for high divergence
|
|
377
|
+
for i in range(len(features)):
|
|
378
|
+
for j in range(len(timepoints)):
|
|
379
|
+
z = z_scores[i, j]
|
|
380
|
+
if abs(z) >= 3.0: # Annotate significant divergences
|
|
381
|
+
text_color = 'white' if abs(z) >= 4.0 else 'black'
|
|
382
|
+
ax.text(
|
|
383
|
+
j, i, f'{z:.1f}',
|
|
384
|
+
ha='center', va='center',
|
|
385
|
+
color=text_color, fontsize=8, fontweight='bold'
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Formatting
|
|
389
|
+
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
|
390
|
+
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
|
391
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
392
|
+
|
|
393
|
+
plt.tight_layout()
|
|
394
|
+
|
|
395
|
+
if save_path:
|
|
396
|
+
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
397
|
+
|
|
398
|
+
if show:
|
|
399
|
+
plt.show()
|
|
400
|
+
|
|
401
|
+
return fig
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def plot_feature_distributions(
|
|
405
|
+
eval_dist: Union[np.ndarray, List[float]],
|
|
406
|
+
prod_dist: Union[np.ndarray, List[float]],
|
|
407
|
+
feature_name: str,
|
|
408
|
+
title: Optional[str] = None,
|
|
409
|
+
xlabel: Optional[str] = None,
|
|
410
|
+
ylabel: str = "Density",
|
|
411
|
+
figsize: Tuple[float, float] = (10, 6),
|
|
412
|
+
save_path: Optional[str] = None,
|
|
413
|
+
dpi: int = DEFAULT_DPI,
|
|
414
|
+
style: Optional[str] = None,
|
|
415
|
+
show: bool = True,
|
|
416
|
+
bins: int = 30,
|
|
417
|
+
alpha: float = 0.6,
|
|
418
|
+
) -> Figure:
|
|
419
|
+
"""Plot overlaid distributions for evaluation vs production.
|
|
420
|
+
|
|
421
|
+
Creates overlaid histograms and KDE plots comparing a feature's distribution
|
|
422
|
+
in evaluation and production contexts. Helps visualize behavioral divergence.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
eval_dist: Feature values from evaluation context
|
|
426
|
+
prod_dist: Feature values from production context
|
|
427
|
+
feature_name: Name of the feature being plotted
|
|
428
|
+
title: Plot title (auto-generated if None)
|
|
429
|
+
xlabel: X-axis label (auto-generated if None)
|
|
430
|
+
ylabel: Y-axis label
|
|
431
|
+
figsize: Figure size as (width, height) in inches
|
|
432
|
+
save_path: If provided, save figure to this path
|
|
433
|
+
dpi: Resolution for saved figure
|
|
434
|
+
style: Matplotlib style to use
|
|
435
|
+
show: Whether to display the plot
|
|
436
|
+
bins: Number of histogram bins
|
|
437
|
+
alpha: Transparency for histogram bars
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
Matplotlib Figure object
|
|
441
|
+
|
|
442
|
+
Examples:
|
|
443
|
+
>>> eval_dist = np.random.normal(100, 15, 200)
|
|
444
|
+
>>> prod_dist = np.random.normal(120, 20, 200)
|
|
445
|
+
>>> fig = plot_feature_distributions(
|
|
446
|
+
... eval_dist, prod_dist, "response_length",
|
|
447
|
+
... save_path="response_length_comparison.png"
|
|
448
|
+
... )
|
|
449
|
+
"""
|
|
450
|
+
_setup_plot_style(style)
|
|
451
|
+
|
|
452
|
+
eval_dist = np.array(eval_dist)
|
|
453
|
+
prod_dist = np.array(prod_dist)
|
|
454
|
+
|
|
455
|
+
if len(eval_dist) == 0 or len(prod_dist) == 0:
|
|
456
|
+
raise ValueError("Both distributions must contain data")
|
|
457
|
+
|
|
458
|
+
# Auto-generate labels if not provided
|
|
459
|
+
if title is None:
|
|
460
|
+
title = f"{feature_name.replace('_', ' ').title()} Distribution Comparison"
|
|
461
|
+
if xlabel is None:
|
|
462
|
+
xlabel = feature_name.replace('_', ' ').title()
|
|
463
|
+
|
|
464
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
465
|
+
|
|
466
|
+
# Compute statistics
|
|
467
|
+
eval_mean, eval_std = np.mean(eval_dist), np.std(eval_dist)
|
|
468
|
+
prod_mean, prod_std = np.mean(prod_dist), np.std(prod_dist)
|
|
469
|
+
|
|
470
|
+
# Histograms
|
|
471
|
+
ax.hist(
|
|
472
|
+
eval_dist, bins=bins, alpha=alpha, color='#2E86AB',
|
|
473
|
+
label=f'Evaluation (μ={eval_mean:.2f}, σ={eval_std:.2f})',
|
|
474
|
+
density=True, edgecolor='black', linewidth=0.5
|
|
475
|
+
)
|
|
476
|
+
ax.hist(
|
|
477
|
+
prod_dist, bins=bins, alpha=alpha, color='#E63946',
|
|
478
|
+
label=f'Production (μ={prod_mean:.2f}, σ={prod_std:.2f})',
|
|
479
|
+
density=True, edgecolor='black', linewidth=0.5
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# KDE plots for smooth overlay
|
|
483
|
+
try:
|
|
484
|
+
from scipy import stats
|
|
485
|
+
|
|
486
|
+
# Evaluation KDE
|
|
487
|
+
kde_eval = stats.gaussian_kde(eval_dist)
|
|
488
|
+
x_eval = np.linspace(eval_dist.min(), eval_dist.max(), 100)
|
|
489
|
+
ax.plot(x_eval, kde_eval(x_eval), color='#2E86AB', linewidth=2.5, alpha=0.8)
|
|
490
|
+
|
|
491
|
+
# Production KDE
|
|
492
|
+
kde_prod = stats.gaussian_kde(prod_dist)
|
|
493
|
+
x_prod = np.linspace(prod_dist.min(), prod_dist.max(), 100)
|
|
494
|
+
ax.plot(x_prod, kde_prod(x_prod), color='#E63946', linewidth=2.5, alpha=0.8)
|
|
495
|
+
except ImportError:
|
|
496
|
+
pass # Skip KDE if scipy not available
|
|
497
|
+
|
|
498
|
+
# Mean lines
|
|
499
|
+
ax.axvline(
|
|
500
|
+
eval_mean, color='#2E86AB', linestyle='--',
|
|
501
|
+
linewidth=2, alpha=0.7, label=f'Eval Mean ({eval_mean:.2f})'
|
|
502
|
+
)
|
|
503
|
+
ax.axvline(
|
|
504
|
+
prod_mean, color='#E63946', linestyle='--',
|
|
505
|
+
linewidth=2, alpha=0.7, label=f'Prod Mean ({prod_mean:.2f})'
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# Formatting
|
|
509
|
+
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
|
510
|
+
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
|
511
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
512
|
+
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
|
513
|
+
ax.legend(loc='best', framealpha=0.9, fontsize=9)
|
|
514
|
+
|
|
515
|
+
plt.tight_layout()
|
|
516
|
+
|
|
517
|
+
if save_path:
|
|
518
|
+
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
519
|
+
|
|
520
|
+
if show:
|
|
521
|
+
plt.show()
|
|
522
|
+
|
|
523
|
+
return fig
|