ai-metacognition-toolkit 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ai_metacognition/__init__.py +123 -0
  2. ai_metacognition/analyzers/__init__.py +24 -0
  3. ai_metacognition/analyzers/base.py +39 -0
  4. ai_metacognition/analyzers/counterfactual_cot.py +579 -0
  5. ai_metacognition/analyzers/model_api.py +39 -0
  6. ai_metacognition/detectors/__init__.py +40 -0
  7. ai_metacognition/detectors/base.py +42 -0
  8. ai_metacognition/detectors/observer_effect.py +651 -0
  9. ai_metacognition/detectors/sandbagging_detector.py +1438 -0
  10. ai_metacognition/detectors/situational_awareness.py +526 -0
  11. ai_metacognition/integrations/__init__.py +16 -0
  12. ai_metacognition/integrations/anthropic_api.py +230 -0
  13. ai_metacognition/integrations/base.py +113 -0
  14. ai_metacognition/integrations/openai_api.py +300 -0
  15. ai_metacognition/probing/__init__.py +24 -0
  16. ai_metacognition/probing/extraction.py +176 -0
  17. ai_metacognition/probing/hooks.py +200 -0
  18. ai_metacognition/probing/probes.py +186 -0
  19. ai_metacognition/probing/vectors.py +133 -0
  20. ai_metacognition/utils/__init__.py +48 -0
  21. ai_metacognition/utils/feature_extraction.py +534 -0
  22. ai_metacognition/utils/statistical_tests.py +317 -0
  23. ai_metacognition/utils/text_processing.py +98 -0
  24. ai_metacognition/visualizations/__init__.py +22 -0
  25. ai_metacognition/visualizations/plotting.py +523 -0
  26. ai_metacognition_toolkit-0.3.0.dist-info/METADATA +621 -0
  27. ai_metacognition_toolkit-0.3.0.dist-info/RECORD +30 -0
  28. ai_metacognition_toolkit-0.3.0.dist-info/WHEEL +5 -0
  29. ai_metacognition_toolkit-0.3.0.dist-info/licenses/LICENSE +21 -0
  30. ai_metacognition_toolkit-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,523 @@
1
+ """Publication-ready plotting functions for metacognition analysis.
2
+
3
+ This module provides high-quality visualization functions for:
4
+ - Awareness probability time series with confidence bands
5
+ - Causal attribution bar charts
6
+ - Feature divergence heatmaps
7
+ - Distribution comparison plots
8
+
9
+ All plots are customizable and can be saved at publication quality.
10
+ """
11
+
12
+ from typing import Dict, List, Optional, Tuple, Union
13
+ from datetime import datetime
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import matplotlib.dates as mdates
17
+ from matplotlib.figure import Figure
18
+ from matplotlib.axes import Axes
19
+ import seaborn as sns
20
+
21
+
22
+ # Publication-ready default settings
23
+ DEFAULT_DPI = 300
24
+ DEFAULT_FIGSIZE = (10, 6)
25
+ DEFAULT_STYLE = "seaborn-v0_8-darkgrid"
26
+
27
+
28
+ def _setup_plot_style(style: Optional[str] = None) -> None:
29
+ """Setup matplotlib style for publication-ready plots.
30
+
31
+ Args:
32
+ style: Matplotlib style name. If None, uses default.
33
+ """
34
+ try:
35
+ plt.style.use(style or DEFAULT_STYLE)
36
+ except OSError:
37
+ # Fallback if style not available
38
+ plt.style.use("default")
39
+ plt.rcParams.update({
40
+ 'figure.facecolor': 'white',
41
+ 'axes.facecolor': 'white',
42
+ 'axes.edgecolor': 'black',
43
+ 'grid.alpha': 0.3,
44
+ 'font.size': 10,
45
+ })
46
+
47
+
48
+ def plot_awareness_over_time(
49
+ timestamps: List[Union[datetime, float]],
50
+ probabilities: List[float],
51
+ confidence_intervals: Optional[List[Tuple[float, float]]] = None,
52
+ title: str = "Situational Awareness Over Time",
53
+ xlabel: str = "Time",
54
+ ylabel: str = "Awareness Probability",
55
+ threshold: Optional[float] = None,
56
+ figsize: Tuple[float, float] = DEFAULT_FIGSIZE,
57
+ save_path: Optional[str] = None,
58
+ dpi: int = DEFAULT_DPI,
59
+ style: Optional[str] = None,
60
+ show: bool = True,
61
+ ) -> Figure:
62
+ """Plot time series of situational awareness probabilities.
63
+
64
+ Creates a line plot showing awareness probability over time with optional
65
+ confidence intervals shown as shaded bands. Optionally displays a threshold
66
+ line for decision-making.
67
+
68
+ Args:
69
+ timestamps: List of timestamps (datetime objects or numeric values)
70
+ probabilities: Awareness probability at each timestamp
71
+ confidence_intervals: Optional list of (lower, upper) confidence bounds
72
+ title: Plot title
73
+ xlabel: X-axis label
74
+ ylabel: Y-axis label
75
+ threshold: Optional decision threshold to display as horizontal line
76
+ figsize: Figure size as (width, height) in inches
77
+ save_path: If provided, save figure to this path
78
+ dpi: Resolution for saved figure
79
+ style: Matplotlib style to use
80
+ show: Whether to display the plot
81
+
82
+ Returns:
83
+ Matplotlib Figure object
84
+
85
+ Examples:
86
+ >>> from datetime import datetime, timedelta
87
+ >>> base = datetime.now()
88
+ >>> timestamps = [base + timedelta(hours=i) for i in range(24)]
89
+ >>> probabilities = [0.1 + 0.03*i for i in range(24)]
90
+ >>> confidence_intervals = [(p-0.05, p+0.05) for p in probabilities]
91
+ >>> fig = plot_awareness_over_time(
92
+ ... timestamps, probabilities, confidence_intervals,
93
+ ... threshold=0.5, save_path="awareness_trend.png"
94
+ ... )
95
+ """
96
+ _setup_plot_style(style)
97
+
98
+ if len(timestamps) != len(probabilities):
99
+ raise ValueError(
100
+ f"Timestamps ({len(timestamps)}) and probabilities ({len(probabilities)}) "
101
+ "must have the same length"
102
+ )
103
+
104
+ if confidence_intervals and len(confidence_intervals) != len(probabilities):
105
+ raise ValueError(
106
+ f"Confidence intervals ({len(confidence_intervals)}) must match "
107
+ f"probabilities ({len(probabilities)})"
108
+ )
109
+
110
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
111
+
112
+ # Convert timestamps if needed
113
+ x_values = timestamps
114
+ is_datetime = isinstance(timestamps[0], datetime)
115
+
116
+ # Main probability line
117
+ ax.plot(
118
+ x_values, probabilities,
119
+ linewidth=2, color='#2E86AB', label='Awareness Probability',
120
+ marker='o', markersize=4, markerfacecolor='white', markeredgewidth=1.5
121
+ )
122
+
123
+ # Confidence intervals
124
+ if confidence_intervals:
125
+ lower_bounds = [ci[0] for ci in confidence_intervals]
126
+ upper_bounds = [ci[1] for ci in confidence_intervals]
127
+ ax.fill_between(
128
+ x_values, lower_bounds, upper_bounds,
129
+ alpha=0.3, color='#2E86AB', label='95% Confidence Interval'
130
+ )
131
+
132
+ # Threshold line
133
+ if threshold is not None:
134
+ ax.axhline(
135
+ y=threshold, color='#E63946', linestyle='--',
136
+ linewidth=2, alpha=0.7, label=f'Threshold ({threshold:.2f})'
137
+ )
138
+
139
+ # Formatting
140
+ ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
141
+ ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
142
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
143
+ ax.set_ylim(-0.05, 1.05)
144
+ ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
145
+ ax.legend(loc='best', framealpha=0.9, fontsize=10)
146
+
147
+ # Format x-axis for datetime
148
+ if is_datetime:
149
+ ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
150
+ plt.xticks(rotation=45, ha='right')
151
+
152
+ plt.tight_layout()
153
+
154
+ if save_path:
155
+ fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
156
+
157
+ if show:
158
+ plt.show()
159
+
160
+ return fig
161
+
162
+
163
+ def plot_causal_attribution(
164
+ intervention_results: List[Dict],
165
+ title: str = "Causal Attribution Analysis",
166
+ xlabel: str = "Divergence Score",
167
+ ylabel: str = "Intervention Type",
168
+ figsize: Tuple[float, float] = (10, 6),
169
+ save_path: Optional[str] = None,
170
+ dpi: int = DEFAULT_DPI,
171
+ style: Optional[str] = None,
172
+ show: bool = True,
173
+ highlight_threshold: Optional[float] = None,
174
+ ) -> Figure:
175
+ """Plot causal attribution from counterfactual interventions.
176
+
177
+ Creates a horizontal bar chart showing the causal importance of different
178
+ interventions, sorted by divergence score. Optionally highlights interventions
179
+ above a significance threshold.
180
+
181
+ Args:
182
+ intervention_results: List of dicts with 'intervention_type' and 'divergence' keys
183
+ title: Plot title
184
+ xlabel: X-axis label
185
+ ylabel: Y-axis label
186
+ figsize: Figure size as (width, height) in inches
187
+ save_path: If provided, save figure to this path
188
+ dpi: Resolution for saved figure
189
+ style: Matplotlib style to use
190
+ show: Whether to display the plot
191
+ highlight_threshold: If provided, highlight bars above this value
192
+
193
+ Returns:
194
+ Matplotlib Figure object
195
+
196
+ Examples:
197
+ >>> results = [
198
+ ... {"intervention_type": "remove_eval_awareness", "divergence": 0.342},
199
+ ... {"intervention_type": "amplify_eval_awareness", "divergence": 0.156},
200
+ ... {"intervention_type": "remove_goal_reasoning", "divergence": 0.089},
201
+ ... ]
202
+ >>> fig = plot_causal_attribution(
203
+ ... results, highlight_threshold=0.2, save_path="causal_attribution.png"
204
+ ... )
205
+ """
206
+ _setup_plot_style(style)
207
+
208
+ if not intervention_results:
209
+ raise ValueError("intervention_results cannot be empty")
210
+
211
+ # Sort by divergence (descending)
212
+ sorted_results = sorted(
213
+ intervention_results,
214
+ key=lambda x: x.get('divergence', 0),
215
+ reverse=True
216
+ )
217
+
218
+ intervention_types = [
219
+ r['intervention_type'].replace('_', ' ').title()
220
+ for r in sorted_results
221
+ ]
222
+ divergences = [r['divergence'] for r in sorted_results]
223
+
224
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
225
+
226
+ # Color bars based on threshold
227
+ if highlight_threshold is not None:
228
+ colors = [
229
+ '#E63946' if d >= highlight_threshold else '#2E86AB'
230
+ for d in divergences
231
+ ]
232
+ else:
233
+ colors = ['#2E86AB'] * len(divergences)
234
+
235
+ # Create horizontal bar chart
236
+ y_pos = np.arange(len(intervention_types))
237
+ bars = ax.barh(y_pos, divergences, color=colors, alpha=0.8, edgecolor='black', linewidth=1.2)
238
+
239
+ # Add value labels on bars
240
+ for i, (bar, div) in enumerate(zip(bars, divergences)):
241
+ width = bar.get_width()
242
+ ax.text(
243
+ width + max(divergences) * 0.02,
244
+ bar.get_y() + bar.get_height() / 2,
245
+ f'{div:.3f}',
246
+ ha='left', va='center', fontsize=10, fontweight='bold'
247
+ )
248
+
249
+ # Threshold line
250
+ if highlight_threshold is not None:
251
+ ax.axvline(
252
+ x=highlight_threshold, color='#E63946', linestyle='--',
253
+ linewidth=2, alpha=0.7, label=f'Significance Threshold ({highlight_threshold:.2f})'
254
+ )
255
+ ax.legend(loc='best', framealpha=0.9, fontsize=10)
256
+
257
+ # Formatting
258
+ ax.set_yticks(y_pos)
259
+ ax.set_yticklabels(intervention_types, fontsize=10)
260
+ ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
261
+ ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
262
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
263
+ ax.grid(True, alpha=0.3, axis='x', linestyle='-', linewidth=0.5)
264
+
265
+ plt.tight_layout()
266
+
267
+ if save_path:
268
+ fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
269
+
270
+ if show:
271
+ plt.show()
272
+
273
+ return fig
274
+
275
+
276
+ def plot_divergence_heatmap(
277
+ features: List[str],
278
+ timepoints: List[Union[datetime, str, float]],
279
+ z_scores: np.ndarray,
280
+ title: str = "Feature Divergence Over Time",
281
+ xlabel: str = "Time",
282
+ ylabel: str = "Feature",
283
+ figsize: Tuple[float, float] = (12, 8),
284
+ save_path: Optional[str] = None,
285
+ dpi: int = DEFAULT_DPI,
286
+ style: Optional[str] = None,
287
+ show: bool = True,
288
+ cmap: str = "RdYlGn_r",
289
+ vmin: float = -5.0,
290
+ vmax: float = 5.0,
291
+ ) -> Figure:
292
+ """Plot heatmap of feature divergence over time.
293
+
294
+ Creates a color-coded heatmap showing z-scores for multiple features across
295
+ timepoints. Uses diverging colormap with red indicating high divergence
296
+ (potential issues) and green indicating normal behavior.
297
+
298
+ Args:
299
+ features: List of feature names
300
+ timepoints: List of timepoint labels (datetime, strings, or numeric)
301
+ z_scores: 2D array of z-scores with shape (n_features, n_timepoints)
302
+ title: Plot title
303
+ xlabel: X-axis label
304
+ ylabel: Y-axis label
305
+ figsize: Figure size as (width, height) in inches
306
+ save_path: If provided, save figure to this path
307
+ dpi: Resolution for saved figure
308
+ style: Matplotlib style to use
309
+ show: Whether to display the plot
310
+ cmap: Colormap name (diverging recommended)
311
+ vmin: Minimum z-score for colormap
312
+ vmax: Maximum z-score for colormap
313
+
314
+ Returns:
315
+ Matplotlib Figure object
316
+
317
+ Examples:
318
+ >>> features = ["response_length", "refusal_rate", "hedging_ratio"]
319
+ >>> timepoints = ["00:00", "01:00", "02:00", "03:00"]
320
+ >>> z_scores = np.array([
321
+ ... [0.5, 1.2, 2.8, 3.5], # response_length
322
+ ... [-0.2, 0.1, 0.3, 4.2], # refusal_rate
323
+ ... [0.1, 0.3, 0.2, 0.4], # hedging_ratio
324
+ ... ])
325
+ >>> fig = plot_divergence_heatmap(
326
+ ... features, timepoints, z_scores, save_path="divergence_heatmap.png"
327
+ ... )
328
+ """
329
+ _setup_plot_style(style)
330
+
331
+ if z_scores.shape[0] != len(features):
332
+ raise ValueError(
333
+ f"z_scores rows ({z_scores.shape[0]}) must match features ({len(features)})"
334
+ )
335
+
336
+ if z_scores.shape[1] != len(timepoints):
337
+ raise ValueError(
338
+ f"z_scores columns ({z_scores.shape[1]}) must match timepoints ({len(timepoints)})"
339
+ )
340
+
341
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
342
+
343
+ # Create heatmap
344
+ im = ax.imshow(
345
+ z_scores, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax,
346
+ interpolation='nearest'
347
+ )
348
+
349
+ # Add colorbar
350
+ cbar = plt.colorbar(im, ax=ax, pad=0.02)
351
+ cbar.set_label('Z-Score (σ)', rotation=270, labelpad=20, fontsize=12, fontweight='bold')
352
+
353
+ # Add severity threshold lines on colorbar
354
+ cbar.ax.axhline(y=2.0, color='yellow', linestyle='--', linewidth=1.5, alpha=0.7)
355
+ cbar.ax.axhline(y=3.0, color='orange', linestyle='--', linewidth=1.5, alpha=0.7)
356
+ cbar.ax.axhline(y=4.0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
357
+
358
+ # Set ticks
359
+ ax.set_xticks(np.arange(len(timepoints)))
360
+ ax.set_yticks(np.arange(len(features)))
361
+
362
+ # Format timepoint labels
363
+ timepoint_labels = []
364
+ for tp in timepoints:
365
+ if isinstance(tp, datetime):
366
+ timepoint_labels.append(tp.strftime('%H:%M'))
367
+ else:
368
+ timepoint_labels.append(str(tp))
369
+
370
+ ax.set_xticklabels(timepoint_labels, rotation=45, ha='right', fontsize=9)
371
+ ax.set_yticklabels(
372
+ [f.replace('_', ' ').title() for f in features],
373
+ fontsize=10
374
+ )
375
+
376
+ # Add text annotations for high divergence
377
+ for i in range(len(features)):
378
+ for j in range(len(timepoints)):
379
+ z = z_scores[i, j]
380
+ if abs(z) >= 3.0: # Annotate significant divergences
381
+ text_color = 'white' if abs(z) >= 4.0 else 'black'
382
+ ax.text(
383
+ j, i, f'{z:.1f}',
384
+ ha='center', va='center',
385
+ color=text_color, fontsize=8, fontweight='bold'
386
+ )
387
+
388
+ # Formatting
389
+ ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
390
+ ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
391
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
392
+
393
+ plt.tight_layout()
394
+
395
+ if save_path:
396
+ fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
397
+
398
+ if show:
399
+ plt.show()
400
+
401
+ return fig
402
+
403
+
404
+ def plot_feature_distributions(
405
+ eval_dist: Union[np.ndarray, List[float]],
406
+ prod_dist: Union[np.ndarray, List[float]],
407
+ feature_name: str,
408
+ title: Optional[str] = None,
409
+ xlabel: Optional[str] = None,
410
+ ylabel: str = "Density",
411
+ figsize: Tuple[float, float] = (10, 6),
412
+ save_path: Optional[str] = None,
413
+ dpi: int = DEFAULT_DPI,
414
+ style: Optional[str] = None,
415
+ show: bool = True,
416
+ bins: int = 30,
417
+ alpha: float = 0.6,
418
+ ) -> Figure:
419
+ """Plot overlaid distributions for evaluation vs production.
420
+
421
+ Creates overlaid histograms and KDE plots comparing a feature's distribution
422
+ in evaluation and production contexts. Helps visualize behavioral divergence.
423
+
424
+ Args:
425
+ eval_dist: Feature values from evaluation context
426
+ prod_dist: Feature values from production context
427
+ feature_name: Name of the feature being plotted
428
+ title: Plot title (auto-generated if None)
429
+ xlabel: X-axis label (auto-generated if None)
430
+ ylabel: Y-axis label
431
+ figsize: Figure size as (width, height) in inches
432
+ save_path: If provided, save figure to this path
433
+ dpi: Resolution for saved figure
434
+ style: Matplotlib style to use
435
+ show: Whether to display the plot
436
+ bins: Number of histogram bins
437
+ alpha: Transparency for histogram bars
438
+
439
+ Returns:
440
+ Matplotlib Figure object
441
+
442
+ Examples:
443
+ >>> eval_dist = np.random.normal(100, 15, 200)
444
+ >>> prod_dist = np.random.normal(120, 20, 200)
445
+ >>> fig = plot_feature_distributions(
446
+ ... eval_dist, prod_dist, "response_length",
447
+ ... save_path="response_length_comparison.png"
448
+ ... )
449
+ """
450
+ _setup_plot_style(style)
451
+
452
+ eval_dist = np.array(eval_dist)
453
+ prod_dist = np.array(prod_dist)
454
+
455
+ if len(eval_dist) == 0 or len(prod_dist) == 0:
456
+ raise ValueError("Both distributions must contain data")
457
+
458
+ # Auto-generate labels if not provided
459
+ if title is None:
460
+ title = f"{feature_name.replace('_', ' ').title()} Distribution Comparison"
461
+ if xlabel is None:
462
+ xlabel = feature_name.replace('_', ' ').title()
463
+
464
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
465
+
466
+ # Compute statistics
467
+ eval_mean, eval_std = np.mean(eval_dist), np.std(eval_dist)
468
+ prod_mean, prod_std = np.mean(prod_dist), np.std(prod_dist)
469
+
470
+ # Histograms
471
+ ax.hist(
472
+ eval_dist, bins=bins, alpha=alpha, color='#2E86AB',
473
+ label=f'Evaluation (μ={eval_mean:.2f}, σ={eval_std:.2f})',
474
+ density=True, edgecolor='black', linewidth=0.5
475
+ )
476
+ ax.hist(
477
+ prod_dist, bins=bins, alpha=alpha, color='#E63946',
478
+ label=f'Production (μ={prod_mean:.2f}, σ={prod_std:.2f})',
479
+ density=True, edgecolor='black', linewidth=0.5
480
+ )
481
+
482
+ # KDE plots for smooth overlay
483
+ try:
484
+ from scipy import stats
485
+
486
+ # Evaluation KDE
487
+ kde_eval = stats.gaussian_kde(eval_dist)
488
+ x_eval = np.linspace(eval_dist.min(), eval_dist.max(), 100)
489
+ ax.plot(x_eval, kde_eval(x_eval), color='#2E86AB', linewidth=2.5, alpha=0.8)
490
+
491
+ # Production KDE
492
+ kde_prod = stats.gaussian_kde(prod_dist)
493
+ x_prod = np.linspace(prod_dist.min(), prod_dist.max(), 100)
494
+ ax.plot(x_prod, kde_prod(x_prod), color='#E63946', linewidth=2.5, alpha=0.8)
495
+ except ImportError:
496
+ pass # Skip KDE if scipy not available
497
+
498
+ # Mean lines
499
+ ax.axvline(
500
+ eval_mean, color='#2E86AB', linestyle='--',
501
+ linewidth=2, alpha=0.7, label=f'Eval Mean ({eval_mean:.2f})'
502
+ )
503
+ ax.axvline(
504
+ prod_mean, color='#E63946', linestyle='--',
505
+ linewidth=2, alpha=0.7, label=f'Prod Mean ({prod_mean:.2f})'
506
+ )
507
+
508
+ # Formatting
509
+ ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
510
+ ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
511
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
512
+ ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
513
+ ax.legend(loc='best', framealpha=0.9, fontsize=9)
514
+
515
+ plt.tight_layout()
516
+
517
+ if save_path:
518
+ fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
519
+
520
+ if show:
521
+ plt.show()
522
+
523
+ return fig