datatypical 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
datatypical_viz.py ADDED
@@ -0,0 +1,912 @@
1
+ """
2
+ DataTypical v0.7 - Visualization Module
3
+ ========================================
4
+
5
+ Publication-quality visualizations for dual-perspective analysis:
6
+ 1. significance_plot: Dual-perspective scatter (hero visualization)
7
+ - Automatic discrete/continuous color detection
8
+ - Binary: purple (low) and green (high)
9
+ - Discrete (3-12 categories): viridis palette + legend
10
+ - Continuous (>12 values): viridis colormap + colorbar
11
+ 2. heatmap: Feature attribution heatmaps
12
+ 3. profile_plot: Feature importance profiles for individual samples
13
+
14
+ Design specifications:
15
+ - Viridis colormap (default)
16
+ - figsize (6,5) for scatter and heatmaps
17
+ - figsize (12,5) for profile plots
18
+ - Font size 12 for ticks, 14 for axis labels
19
+ - Clean, professional style matching existing software
20
+
21
+ Author: Amanda S. Barnard
22
+ Date: January 2026
23
+ """
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import matplotlib.pyplot as plt
28
+ import matplotlib.patches as mpatches
29
+ from matplotlib.colors import LinearSegmentedColormap
30
+ import seaborn as sns
31
+ from typing import Optional, Dict, List, Tuple, Union
32
+
33
+ # Configure plotting defaults
34
+ plt.rcParams['font.size'] = 12
35
+ plt.rcParams['axes.labelsize'] = 14
36
+ plt.rcParams['xtick.labelsize'] = 12
37
+ plt.rcParams['ytick.labelsize'] = 12
38
+ plt.rcParams['legend.fontsize'] = 12
39
+
40
+
41
+ # ============================================================================
42
+ # HELPER function
43
+ # ============================================================================
44
+
45
+ def get_top_sample(
46
+ results: pd.DataFrame,
47
+ rank_column: str,
48
+ n: int = 1,
49
+ mode: str = 'max'
50
+ ) -> Union[int, List[int], None]:
51
+ """
52
+ Safely get top sample(s) from results, handling missing formative data.
53
+
54
+ Parameters
55
+ ----------
56
+ results : pd.DataFrame
57
+ Results from DataTypical.fit_transform()
58
+ rank_column : str
59
+ Column to rank by (e.g., 'archetypal_rank', 'archetypal_shapley_rank')
60
+ n : int
61
+ Number of top samples to return (default: 1)
62
+ mode : str
63
+ 'max' for highest values, 'min' for lowest (default: 'max')
64
+
65
+ Returns
66
+ -------
67
+ sample_idx : int, list, or None
68
+ Top sample index/indices, or None if data not available
69
+ Returns single int if n=1, list if n>1
70
+
71
+ Examples
72
+ --------
73
+ >>> # Get top archetypal sample
74
+ >>> top_idx = get_top_sample(results, 'archetypal_rank')
75
+ >>>
76
+ >>> # Get top formative sample (handles NaN gracefully)
77
+ >>> top_formative = get_top_sample(results, 'archetypal_shapley_rank')
78
+ >>> if top_formative is not None:
79
+ ... ax = profile_plot(dt, top_formative, order='global')
80
+ >>>
81
+ >>> # Get top 5 samples
82
+ >>> top_5 = get_top_sample(results, 'prototypical_rank', n=5)
83
+ """
84
+ if rank_column not in results.columns:
85
+ print(f"âš Column '{rank_column}' not found in results")
86
+ return None
87
+
88
+ # Check if column is all NaN (formative data not available)
89
+ if results[rank_column].isna().all():
90
+ print(f"âš Column '{rank_column}' has no data (likely fast_mode=True)")
91
+
92
+ # Provide helpful message based on column name
93
+ if 'shapley_rank' in rank_column:
94
+ print(f" Formative data requires: DataTypical(shapley_mode=True, fast_mode=False)")
95
+ if 'stereotypical' in rank_column:
96
+ print(f" Also requires: stereotype_column='<your_column>'")
97
+
98
+ return None
99
+
100
+ # Get top sample(s)
101
+ if mode == 'max':
102
+ if n == 1:
103
+ return results[rank_column].idxmax()
104
+ else:
105
+ return results.nlargest(n, rank_column).index.tolist()
106
+ else: # min
107
+ if n == 1:
108
+ return results[rank_column].idxmin()
109
+ else:
110
+ return results.nsmallest(n, rank_column).index.tolist()
111
+
112
+
113
+ # ============================================================================
114
+ # HERO VISUALIZATION: Dual-Perspective Scatter
115
+ # ============================================================================
116
+
117
+ def significance_plot(
118
+ results: pd.DataFrame,
119
+ significance: str = 'archetypal',
120
+ color_by: Optional[str] = None,
121
+ size_by: Optional[str] = None,
122
+ labels: Optional[pd.Series] = None,
123
+ label_top: int = 0,
124
+ quadrant_lines: bool = True,
125
+ quadrant_threshold: Tuple[float, float] = (0.5, 0.5),
126
+ figsize: Tuple[int, int] = (6, 5),
127
+ cmap: str = 'viridis',
128
+ title: Optional[str] = None,
129
+ ax: Optional[plt.Axes] = None,
130
+ **kwargs
131
+ ) -> plt.Axes:
132
+ """
133
+ Create dual-perspective significance scatter plot (hero visualization).
134
+
135
+ This plot reveals the relationship between actual significance (samples that
136
+ ARE archetypal/prototypical/stereotypical) and formative significance
137
+ (samples that CREATE the structure).
138
+
139
+ Parameters
140
+ ----------
141
+ results : pd.DataFrame
142
+ Results from DataTypical.fit_transform() with shapley_mode=True
143
+ significance : str
144
+ Which significance to plot: 'archetypal', 'prototypical', or 'stereotypical'
145
+ Automatically uses {significance}_rank (x-axis) and {significance}_shapley_rank (y-axis)
146
+ color_by : str, optional
147
+ Column name to color points by. Automatically detects:
148
+ - Continuous variables (>12 unique values): colored with viridis colormap + colorbar
149
+ - Binary variables (2 unique values): purple (low) and green (high) + legend
150
+ - Discrete variables (3-5 unique): discrete viridis colors + legend
151
+ - Discrete variables (6-12 unique): discrete colors + different markers + legend
152
+ - Non-numeric variables: discrete viridis colors + legend
153
+ size_by : str, optional
154
+ Column name to size points by
155
+ labels : pd.Series, optional
156
+ Labels for points (e.g., compound IDs)
157
+ label_top : int
158
+ Number of top points to label (by actual+formative)
159
+ quadrant_lines : bool
160
+ Draw quadrant division lines
161
+ quadrant_threshold : tuple of float
162
+ (x, y) thresholds for quadrant lines
163
+ figsize : tuple of int
164
+ Figure size (width, height)
165
+ cmap : str
166
+ Colormap name (default: 'viridis')
167
+ title : str, optional
168
+ Plot title
169
+ ax : plt.Axes, optional
170
+ Existing axes to plot on
171
+ **kwargs
172
+ Additional arguments passed to scatter()
173
+
174
+ Returns
175
+ -------
176
+ ax : plt.Axes
177
+ Matplotlib axes object
178
+
179
+ Examples
180
+ --------
181
+ >>> dt = DataTypical(shapley_mode=True)
182
+ >>> results = dt.fit_transform(data)
183
+ >>>
184
+ >>> # Basic archetypal plot
185
+ >>> ax = significance_plot(results, significance='archetypal')
186
+ >>>
187
+ >>> # Prototypical with continuous color
188
+ >>> ax = significance_plot(
189
+ ... results,
190
+ ... significance='prototypical',
191
+ ... color_by='solubility', # continuous variable
192
+ ... label_top=5
193
+ ... )
194
+ >>>
195
+ >>> # Archetypal with binary color (auto-detected)
196
+ >>> ax = significance_plot(
197
+ ... results,
198
+ ... significance='archetypal',
199
+ ... color_by='treatment' # binary: 0/1 or control/treatment
200
+ ... )
201
+ >>>
202
+ >>> # Stereotypical with discrete multi-class color
203
+ >>> ax = significance_plot(
204
+ ... results,
205
+ ... significance='stereotypical',
206
+ ... color_by='cell_type' # 4 categories with different colors
207
+ ... )
208
+ """
209
+
210
+ # Validate significance parameter
211
+ valid_significance = ['archetypal', 'prototypical', 'stereotypical']
212
+ if significance not in valid_significance:
213
+ raise ValueError(f"significance must be one of {valid_significance}, got '{significance}'")
214
+
215
+ # Auto-determine column names
216
+ actual_col = f'{significance}_rank'
217
+ formative_col = f'{significance}_shapley_rank'
218
+
219
+ if actual_col not in results.columns:
220
+ raise ValueError(f"Column '{actual_col}' not in results")
221
+ if formative_col not in results.columns:
222
+ raise ValueError(f"Column '{formative_col}' not in results")
223
+
224
+ # Check if formative data is available (could be None in fast_mode)
225
+ if results[formative_col].isna().all():
226
+ # Print informative message and skip plot
227
+ print(f"\nâš Skipping significance plot:")
228
+ print(f" Formative data ('{formative_col}') not available (fast_mode=True)")
229
+ print(f" This plot requires fast_mode=False to compute formative Shapley values")
230
+
231
+ # Return None or empty axes depending on whether axes was provided
232
+ if ax is None:
233
+ # Create empty figure with message
234
+ fig, ax = plt.subplots(figsize=figsize)
235
+ ax.text(0.5, 0.5, 'Formative data not available\n(fast_mode=True)',
236
+ ha='center', va='center', fontsize=14, color='gray')
237
+ ax.set_xlim(0, 1)
238
+ ax.set_ylim(0, 1)
239
+ ax.set_xlabel(f'{significance.capitalize()} Rank (Actual)', fontsize=14)
240
+ ax.set_ylabel(f'{significance.capitalize()} Rank (Formative)', fontsize=14)
241
+ if title:
242
+ ax.set_title(title, fontsize=14)
243
+ return ax
244
+ return ax
245
+
246
+ # Create figure if needed
247
+ if ax is None:
248
+ fig, ax = plt.subplots(figsize=figsize)
249
+
250
+ # Prepare size values (needed for all plot types)
251
+ if size_by is not None:
252
+ if size_by not in results.columns:
253
+ raise ValueError(f"size_by column '{size_by}' not found in results")
254
+ s_values = results[size_by]
255
+ # Normalize to reasonable range (10-200)
256
+ s_min, s_max = s_values.min(), s_values.max()
257
+ if s_max > s_min:
258
+ s_normalized = 10 + 190 * (s_values - s_min) / (s_max - s_min)
259
+ else:
260
+ s_normalized = 50
261
+ else:
262
+ s_normalized = 50
263
+
264
+ # Handle coloring logic with automatic discrete/continuous detection
265
+ if color_by is not None:
266
+ if color_by not in results.columns:
267
+ raise ValueError(f"color_by column '{color_by}' not found in results")
268
+
269
+ c_values = results[color_by]
270
+
271
+ # Detect variable type: discrete vs continuous
272
+ is_numeric = pd.api.types.is_numeric_dtype(c_values)
273
+ n_unique = c_values.nunique()
274
+
275
+ # Determine if discrete or continuous
276
+ if not is_numeric:
277
+ # Non-numeric → discrete
278
+ is_discrete = True
279
+ use_markers = False
280
+ elif n_unique < 6:
281
+ # Numeric with <6 values → discrete
282
+ is_discrete = True
283
+ use_markers = False
284
+ elif 6 <= n_unique <= 12:
285
+ # Numeric with 6-12 values → discrete with markers
286
+ is_discrete = True
287
+ use_markers = True
288
+ else:
289
+ # Numeric with >12 values → continuous
290
+ is_discrete = False
291
+ use_markers = False
292
+
293
+ if is_discrete:
294
+ # DISCRETE COLORING: Use legend instead of colorbar
295
+
296
+ # Get unique values in sorted order for consistent mapping
297
+ if is_numeric:
298
+ unique_values = sorted(c_values.dropna().unique())
299
+ else:
300
+ unique_values = sorted(c_values.dropna().unique(), key=str)
301
+
302
+ n_categories = len(unique_values)
303
+
304
+ if n_categories == 0:
305
+ raise ValueError(f"color_by column '{color_by}' has no valid values")
306
+
307
+ # Create color mapping
308
+ if n_categories == 2:
309
+ # BINARY: purple (viridis[0.0]) for lower, green (viridis[0.6]) for higher
310
+ viridis_cmap = plt.cm.get_cmap('viridis')
311
+ color_map = {
312
+ unique_values[0]: viridis_cmap(0.0), # purple for lower value
313
+ unique_values[1]: viridis_cmap(0.66) # green for higher value
314
+ }
315
+ else:
316
+ # MULTI-CLASS: discrete viridis palette
317
+ viridis_cmap = plt.cm.get_cmap('viridis', n_categories)
318
+ color_map = {val: viridis_cmap(i) for i, val in enumerate(unique_values)}
319
+
320
+ # Create marker mapping if needed
321
+ if use_markers:
322
+ # 6-12 categories: use different markers for each category
323
+ marker_list = ['o', 's', '^', 'v', 'D', 'p', '*', 'h', 'X', 'P', '<', '>']
324
+ marker_map = {val: marker_list[i % len(marker_list)]
325
+ for i, val in enumerate(unique_values)}
326
+
327
+ # Plot each category separately for legend
328
+ for category in unique_values:
329
+ mask = c_values == category
330
+
331
+ # Get marker for this category
332
+ if use_markers:
333
+ marker = marker_map[category]
334
+ else:
335
+ marker = 'o'
336
+
337
+ # Get size values for this subset
338
+ if isinstance(s_normalized, (int, float)):
339
+ s_subset = s_normalized
340
+ else:
341
+ s_subset = s_normalized[mask]
342
+
343
+ # Plot this category
344
+ ax.scatter(
345
+ results.loc[mask, actual_col],
346
+ results.loc[mask, formative_col],
347
+ c=[color_map[category]],
348
+ s=s_subset,
349
+ marker=marker,
350
+ edgecolors='black',
351
+ linewidth=0.5,
352
+ label=str(category),
353
+ **kwargs
354
+ )
355
+
356
+ # Add legend
357
+ ax.legend(
358
+ title=color_by,
359
+ fontsize=12,
360
+ title_fontsize=12,
361
+ loc='center left',
362
+ bbox_to_anchor=(1.0, 0.5),
363
+ frameon=False
364
+ )
365
+
366
+ else:
367
+ # CONTINUOUS COLORING: Use colorbar (existing behavior)
368
+ scatter = ax.scatter(
369
+ results[actual_col],
370
+ results[formative_col],
371
+ c=c_values,
372
+ s=s_normalized,
373
+ cmap=cmap,
374
+ edgecolors='black',
375
+ linewidth=0.5,
376
+ **kwargs
377
+ )
378
+ # Add colorbar
379
+ cbar = plt.colorbar(scatter, ax=ax)
380
+ cbar.set_label(color_by, fontsize=14)
381
+ cbar.ax.tick_params(labelsize=12)
382
+
383
+ else:
384
+ # NO COLORING: Use default steelblue
385
+ ax.scatter(
386
+ results[actual_col],
387
+ results[formative_col],
388
+ c='steelblue',
389
+ s=s_normalized,
390
+ alpha=0.7,
391
+ edgecolors='black',
392
+ linewidth=0.5,
393
+ **kwargs
394
+ )
395
+
396
+ # Add quadrant lines if requested
397
+ if quadrant_lines:
398
+ x_thresh, y_thresh = quadrant_threshold
399
+ ax.axvline(x_thresh, color='gray', linestyle='--', alpha=0.5, linewidth=1)
400
+ ax.axhline(y_thresh, color='gray', linestyle='--', alpha=0.5, linewidth=1)
401
+
402
+ # Label top points if requested
403
+ if label_top > 0 and labels is not None:
404
+ # Rank by sum of actual + formative
405
+ combined_rank = results[actual_col] + results[formative_col]
406
+ top_indices = combined_rank.nlargest(label_top).index
407
+
408
+ for idx in top_indices:
409
+ ax.annotate(
410
+ labels[idx],
411
+ xy=(results.loc[idx, actual_col], results.loc[idx, formative_col]),
412
+ xytext=(5, 5),
413
+ textcoords='offset points',
414
+ fontsize=10,
415
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='gray'),
416
+ arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0', color='gray', lw=0.5)
417
+ )
418
+
419
+ # Labels and title
420
+ ax.set_xlabel(f'{significance.capitalize()} Rank (Actual)', fontsize=14)
421
+ ax.set_ylabel(f'{significance.capitalize()} Rank (Formative)', fontsize=14)
422
+
423
+ if title:
424
+ ax.set_title(title, fontsize=14)
425
+ else:
426
+ ax.set_title(f'{significance.capitalize()} Dual-Perspective Analysis', fontsize=14)
427
+
428
+ # Grid for readability
429
+ ax.grid(alpha=0.3, linestyle='--')
430
+ ax.set_axisbelow(True)
431
+
432
+ # Ensure limits include full range
433
+ ax.set_xlim(-0.05, 1.05)
434
+ ax.set_ylim(-0.05, 1.05)
435
+
436
+ plt.tight_layout()
437
+
438
+ return ax
439
+
440
+
441
+ # ============================================================================
442
+ # FEATURE ATTRIBUTION HEATMAP
443
+ # ============================================================================
444
+
445
+ def heatmap(
446
+ dt_fitted,
447
+ results: Optional[pd.DataFrame] = None,
448
+ samples: Optional[Union[List[int], np.ndarray]] = None,
449
+ significance: str = 'archetypal',
450
+ order: str = 'actual',
451
+ top_n: Optional[int] = None,
452
+ top_features: int = None,
453
+ figsize: Tuple[int, int] = (6, 5),
454
+ cmap: str = 'viridis',
455
+ center: Optional[float] = None,
456
+ title: Optional[str] = None,
457
+ ax: Optional[plt.Axes] = None,
458
+ **kwargs
459
+ ) -> plt.Axes:
460
+ """
461
+ Create feature attribution heatmap for Shapley explanations.
462
+
463
+ Always shows explanations (why samples ARE significant). Features are always
464
+ ordered by global importance (average across all samples). Instances can be
465
+ ordered by actual or formative ranks.
466
+
467
+ Parameters
468
+ ----------
469
+ dt_fitted : DataTypical
470
+ Fitted DataTypical instance with shapley_mode=True
471
+ results : pd.DataFrame, optional
472
+ Results dataframe from fit_transform(). Required if samples not specified.
473
+ samples : list or array, optional
474
+ Indices of samples to show. If None, uses top_n samples from results.
475
+ significance : str
476
+ Which significance to visualize: 'archetypal', 'prototypical', 'stereotypical'
477
+ order : str
478
+ How to order instances (rows): 'actual' or 'formative'
479
+ 'actual': Order by {significance}_rank (samples that ARE significant)
480
+ 'formative': Order by {significance}_shapley_rank (samples that CREATE structure)
481
+ top_n : int, optional
482
+ Number of top samples to show (if samples not specified).
483
+ Defaults to shapley_top_n from fitted model, or 20 if not set.
484
+ top_features : int, optional
485
+ Number of top features to show (by absolute attribution)
486
+ figsize : tuple of int
487
+ Figure size (width, height)
488
+ cmap : str
489
+ Colormap name (default: 'viridis')
490
+ center : float, optional
491
+ Value to center colormap at (for diverging colormaps)
492
+ title : str, optional
493
+ Plot title
494
+ ax : plt.Axes, optional
495
+ Existing axes to plot on
496
+ **kwargs
497
+ Additional arguments passed to seaborn.heatmap()
498
+
499
+ Returns
500
+ -------
501
+ ax : plt.Axes
502
+ Matplotlib axes object
503
+
504
+ Examples
505
+ --------
506
+ >>> dt = DataTypical(shapley_mode=True)
507
+ >>> results = dt.fit_transform(data)
508
+ >>>
509
+ >>> # Show explanations for top 10 archetypal samples
510
+ >>> ax = heatmap(dt, results, top_n=10, significance='archetypal')
511
+ >>>
512
+ >>> # Order by formative ranks
513
+ >>> ax = heatmap(dt, results, order='formative', significance='archetypal')
514
+ >>>
515
+ >>> # Show top 20 prototypical with only top 15 features
516
+ >>> ax = heatmap(dt, results, significance='prototypical', top_n=20, top_features=15)
517
+ """
518
+
519
+ if not dt_fitted.shapley_mode:
520
+ raise RuntimeError("Shapley mode not enabled. Set shapley_mode=True when fitting.")
521
+
522
+ # Default top_n to shapley_top_n from fitted model
523
+ if top_n is None:
524
+ if hasattr(dt_fitted, 'shapley_top_n') and dt_fitted.shapley_top_n is not None:
525
+ # Convert fraction to absolute count if needed
526
+ if isinstance(dt_fitted.shapley_top_n, float) and 0 < dt_fitted.shapley_top_n < 1:
527
+ n_samples = len(dt_fitted.train_index_) if dt_fitted.train_index_ is not None else 100
528
+ top_n = max(1, int(dt_fitted.shapley_top_n * n_samples))
529
+ else:
530
+ top_n = int(dt_fitted.shapley_top_n)
531
+ else:
532
+ top_n = 20 # Fallback default
533
+
534
+ # Get explanations matrix (always use explanations, not formative)
535
+ if significance == 'archetypal':
536
+ Phi = dt_fitted.Phi_archetypal_explanations_
537
+ elif significance == 'prototypical':
538
+ Phi = dt_fitted.Phi_prototypical_explanations_
539
+ elif significance == 'stereotypical':
540
+ Phi = dt_fitted.Phi_stereotypical_explanations_
541
+ else:
542
+ raise ValueError(f"Unknown significance: {significance}")
543
+
544
+ # Check if data is available
545
+ if Phi is None:
546
+ print(f"\n⚠ Skipping {significance} explanations heatmap:")
547
+ print(f" Explanations data not available")
548
+
549
+ if significance == 'stereotypical':
550
+ print(f" Note: Stereotypical also requires stereotype_column to be set")
551
+
552
+ print(f"\n To enable this plot, refit with:")
553
+ if significance == 'stereotypical':
554
+ print(f" DataTypical(shapley_mode=True, stereotype_column='<column>')")
555
+ else:
556
+ print(f" DataTypical(shapley_mode=True)")
557
+
558
+ # Return empty axes with message
559
+ if ax is None:
560
+ fig, ax = plt.subplots(figsize=figsize)
561
+
562
+ ax.text(0.5, 0.5, f'{significance.capitalize()} explanations\nnot available',
563
+ ha='center', va='center', fontsize=14, color='gray')
564
+ ax.set_xlim(0, 1)
565
+ ax.set_ylim(0, 1)
566
+ ax.axis('off')
567
+
568
+ return ax
569
+
570
+ # Select samples
571
+ if samples is None:
572
+ if results is None:
573
+ raise ValueError("Must provide either 'samples' or 'results' dataframe")
574
+
575
+ # Determine rank column based on order
576
+ if order == 'actual':
577
+ rank_col = f"{significance}_rank"
578
+ elif order == 'formative':
579
+ rank_col = f"{significance}_shapley_rank"
580
+
581
+ # Check if formative data is available
582
+ if results[rank_col].isna().all():
583
+ print(f"\n⚠ Warning: order='formative' requested but formative data not available")
584
+ print(f" Falling back to ordes='actual'")
585
+ rank_col = f"{significance}_rank"
586
+ order = 'actual'
587
+ else:
588
+ raise ValueError(f"order must be 'actual' or 'formative', got '{order}'")
589
+
590
+ if rank_col not in results.columns:
591
+ raise RuntimeError(f"Cannot find {rank_col} in results")
592
+
593
+ # Get DataFrame indices of top samples (already ordered by rank)
594
+ top_samples_df_indices = results.nlargest(top_n, rank_col).index
595
+
596
+ # FIX: Use get_shapley_explanations to handle shapley_top_n subsampling
597
+ Phi_subset_list = []
598
+ sample_labels = []
599
+
600
+ for df_idx in top_samples_df_indices:
601
+ try:
602
+ # Get explanations using the API (handles index mapping internally)
603
+ explanations = dt_fitted.get_shapley_explanations(df_idx)
604
+ shapley_values = explanations[significance]
605
+
606
+ # Only include if we have valid data
607
+ if np.any(shapley_values != 0) or not hasattr(dt_fitted, 'shapley_top_n') or dt_fitted.shapley_top_n is None:
608
+ Phi_subset_list.append(shapley_values)
609
+ sample_labels.append(str(df_idx))
610
+ except (IndexError, KeyError):
611
+ # This sample doesn't have explanations computed
612
+ pass
613
+
614
+ if len(Phi_subset_list) == 0:
615
+ print(f"\n⚠ Error: None of the top {top_n} {significance} instances have explanations")
616
+ print(f" This can happen when shapley_top_n is too small")
617
+ if ax is None:
618
+ fig, ax = plt.subplots(figsize=figsize)
619
+ ax.text(0.5, 0.5, f'No explanations available\nfor top {top_n} instances',
620
+ ha='center', va='center', fontsize=14, color='gray')
621
+ ax.axis('off')
622
+ return ax
623
+
624
+ Phi_subset = np.array(Phi_subset_list)
625
+
626
+ # WARNING: Check for zero Shapley values when using formative ordering
627
+ if order == 'formative':
628
+ zero_count = (Phi_subset == 0).all(axis=1).sum()
629
+ if zero_count > 0:
630
+ print(f"\n⚠ Warning: {zero_count}/{len(sample_labels)} top formative {significance} instances have zero Shapley values")
631
+ print(f" This occurs when a formative instance is not in the top {significance} instances")
632
+ print(f" (determined by shapley_top_n parameter)")
633
+ print(f" These instances CREATE structure but are not themselves highly {significance}")
634
+
635
+ else:
636
+ # Custom sample list provided - convert to positional indices
637
+ samples = np.asarray(samples)
638
+
639
+ # Extract Shapley values for selected samples
640
+ # CRITICAL: Keep rank order - DO NOT re-sort!
641
+ # Top-ranked sample should appear at TOP of heatmap
642
+ Phi_subset = Phi[samples, :]
643
+
644
+ # Create DataFrame for heatmap with actual sample IDs
645
+ if dt_fitted.train_index_ is not None:
646
+ # Use actual DataFrame indices from training
647
+ sample_labels = [str(dt_fitted.train_index_[s]) for s in samples]
648
+ else:
649
+ sample_labels = [f"Sample {s}" for s in samples]
650
+
651
+ # Get feature names
652
+ feature_names = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
653
+
654
+ # ALWAYS order features by global importance (average across ALL samples)
655
+ global_importance = np.abs(Phi).mean(axis=0)
656
+ feature_order = np.argsort(global_importance)[::-1]
657
+ Phi_subset = Phi_subset[:, feature_order]
658
+ feature_names = [feature_names[i] for i in feature_order]
659
+
660
+ # Select top features if requested
661
+ if top_features is not None:
662
+ Phi_subset = Phi_subset[:, :top_features]
663
+ feature_names = feature_names[:top_features]
664
+
665
+ # Create DataFrame for heatmap
666
+ df_heatmap = pd.DataFrame(
667
+ Phi_subset,
668
+ index=sample_labels,
669
+ columns=feature_names
670
+ )
671
+
672
+ # Create figure if needed
673
+ if ax is None:
674
+ fig, ax = plt.subplots(figsize=figsize)
675
+
676
+ # Create heatmap
677
+ heatmap_kwargs = {
678
+ 'cmap': cmap,
679
+ 'center': center,
680
+ 'cbar_kws': {'label': 'Shapley Value', 'shrink': 0.5},
681
+ 'linewidths': 0.5,
682
+ 'annot': False,
683
+ 'fmt': '.3f',
684
+ 'square': False
685
+ }
686
+ heatmap_kwargs.update(kwargs)
687
+
688
+ sns.heatmap(df_heatmap, ax=ax, **heatmap_kwargs)
689
+
690
+ for spine in ax.spines.values():
691
+ spine.set_visible(True)
692
+ spine.set_linewidth(1)
693
+ spine.set_edgecolor('black')
694
+ cbar = ax.collections[0].colorbar
695
+ cbar.outline.set_linewidth(1)
696
+ cbar.outline.set_edgecolor('black')
697
+
698
+ # Labels
699
+ ax.set_xlabel('Features (Ordered by Global Importance)', fontsize=12)
700
+ if order == 'actual':
701
+ ylabel = f'Samples (Ordered by {significance.title()} Rank)'
702
+ else:
703
+ ylabel = f'Samples (Ordered by Formative {significance.title()} Rank)'
704
+ ax.set_ylabel(ylabel, fontsize=12)
705
+ ax.tick_params(axis='x', labelrotation=90, labelsize=12)
706
+ ax.tick_params(axis='y', labelrotation=0, labelsize=12)
707
+
708
+ if title:
709
+ ax.set_title(title, fontsize=14)
710
+ else:
711
+ ax.set_title(f'{significance.title()} Explanations', fontsize=14)
712
+
713
+ plt.tight_layout()
714
+
715
+ return ax
716
+
717
+ # ============================================================================
718
+ # INDIVIDUAL SAMPLE PROFILE
719
+ # ============================================================================
720
+
721
+ def profile_plot(
722
+ dt_fitted,
723
+ sample_idx: int,
724
+ significance: str = 'archetypal',
725
+ order: str = 'local',
726
+ figsize: Tuple[int, int] = (12, 5),
727
+ cmap: str = 'viridis',
728
+ title: Optional[str] = None,
729
+ ax: Optional[plt.Axes] = None,
730
+ **kwargs
731
+ ) -> plt.Axes:
732
+ """
733
+ Create feature importance profile for a single sample.
734
+
735
+ Shows Shapley explanations (y-axis) for each feature (x-axis), with bars colored by
736
+ the normalized feature value for this sample. Features can be ordered locally (by
737
+ this sample's importance) or globally (by average importance across all samples).
738
+
739
+ Parameters
740
+ ----------
741
+ dt_fitted : DataTypical
742
+ Fitted DataTypical instance with shapley_mode=True
743
+ sample_idx : int
744
+ Index of sample to profile
745
+ significance : str
746
+ Which significance to visualize: 'archetypal', 'prototypical', 'stereotypical'
747
+ order : str
748
+ Feature ordering method: 'local' or 'global'
749
+ 'local': Order by this sample's Shapley values
750
+ 'global': Order by average importance across all samples (uses explanations)
751
+ figsize : tuple of int
752
+ Figure size (width, height)
753
+ cmap : str
754
+ Colormap name (default: 'viridis')
755
+ title : str, optional
756
+ Plot title
757
+ ax : plt.Axes, optional
758
+ Existing axes to plot on
759
+ **kwargs
760
+ Additional arguments passed to bar()
761
+
762
+ Returns
763
+ -------
764
+ ax : plt.Axes
765
+ Matplotlib axes object
766
+
767
+ Examples
768
+ --------
769
+ >>> dt = DataTypical(shapley_mode=True)
770
+ >>> dt.fit(data)
771
+ >>>
772
+ >>> # Profile top archetypal sample (local ordering)
773
+ >>> top_idx = results['archetypal_rank'].idxmax()
774
+ >>> ax = profile_plot(dt, top_idx, significance='archetypal', order='local')
775
+ >>>
776
+ >>> # Profile top formative sample (global ordering)
777
+ >>> top_formative = results['archetypal_shapley_rank'].idxmax()
778
+ >>> ax = profile_plot(dt, top_formative, significance='archetypal', order='global')
779
+ """
780
+
781
+ if not dt_fitted.shapley_mode:
782
+ raise RuntimeError("Shapley mode not enabled. Set shapley_mode=True when fitting.")
783
+
784
+ # Validate parameters
785
+ valid_significance = ['archetypal', 'prototypical', 'stereotypical']
786
+ if significance not in valid_significance:
787
+ raise ValueError(f"significance must be one of {valid_significance}, got '{significance}'")
788
+
789
+ valid_order = ['local', 'global']
790
+ if order not in valid_order:
791
+ raise ValueError(f"order must be one of {valid_order}, got '{order}'")
792
+
793
+ # Get explanations for this sample
794
+ explanations = dt_fitted.get_shapley_explanations(sample_idx)
795
+ shapley_values = explanations[significance]
796
+
797
+ # Get feature names
798
+ feature_names = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
799
+
800
+ # Determine feature ordering
801
+ if order == 'local':
802
+ # Order by THIS sample's Shapley values
803
+ importance = np.abs(shapley_values)
804
+ ordering_type = "local"
805
+ else: # global
806
+ # Order by average importance across ALL samples (using explanations)
807
+ if significance == 'archetypal':
808
+ Phi_explanations = dt_fitted.Phi_archetypal_explanations_
809
+ elif significance == 'prototypical':
810
+ Phi_explanations = dt_fitted.Phi_prototypical_explanations_
811
+ else: # stereotypical
812
+ Phi_explanations = dt_fitted.Phi_stereotypical_explanations_
813
+
814
+ # Calculate global importance (average across all samples)
815
+ importance = np.mean(np.abs(Phi_explanations), axis=0)
816
+ ordering_type = "global"
817
+
818
+ # Sort features by importance (descending by absolute value)
819
+ sorted_idx = np.argsort(importance)[::-1]
820
+
821
+ # Get sorted Shapley values (keep original signs for plotting)
822
+ shapley_sorted = shapley_values[sorted_idx]
823
+ features_sorted = [feature_names[i] for i in sorted_idx]
824
+
825
+ # Get normalized feature values for coloring
826
+ # Need to get the actual feature values for this sample
827
+ original_data = dt_fitted._df_original_fit
828
+ numeric_cols = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
829
+
830
+ # Get sample's feature values
831
+ sample_feature_values = original_data.loc[sample_idx, numeric_cols].values.astype(np.float64)
832
+
833
+ # Get dataset min/max for normalization
834
+ dataset_values = original_data[numeric_cols].values.astype(np.float64)
835
+ feat_min = dataset_values.min(axis=0)
836
+ feat_max = dataset_values.max(axis=0)
837
+
838
+ # Normalize to [0, 1]
839
+ feat_range = feat_max - feat_min
840
+ feat_range[feat_range == 0] = 1.0 # Avoid division by zero
841
+ normalized_values = (sample_feature_values - feat_min) / feat_range
842
+
843
+ # Sort by same order as Shapley values
844
+ normalized_sorted = normalized_values[sorted_idx]
845
+
846
+ # Create figure if needed
847
+ if ax is None:
848
+ fig, ax = plt.subplots(figsize=figsize)
849
+
850
+ # Create color array from normalized feature values
851
+ colors = plt.cm.get_cmap(cmap)(normalized_sorted)
852
+
853
+ # Create bar plot - use signed Shapley values (can be negative)
854
+ x_pos = np.arange(len(features_sorted))
855
+ bars = ax.bar(x_pos, shapley_sorted, color=colors, edgecolor='black', linewidth=0.5, **kwargs)
856
+
857
+ # Add zero reference line
858
+ ax.axhline(0, color='black', linewidth=0.8, linestyle='-', alpha=0.5)
859
+
860
+ # Add colorbar
861
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
862
+ sm.set_array([])
863
+ cbar = plt.colorbar(sm, ax=ax)
864
+ cbar.set_label('Normalized Feature Value', fontsize=14)
865
+ cbar.ax.tick_params(labelsize=12)
866
+
867
+ # Labels
868
+ ax.set_xticks(x_pos)
869
+ ax.set_xticklabels(features_sorted, rotation=90, ha='center', fontsize=12)
870
+
871
+ # Update xlabel based on ordering type
872
+ if ordering_type == "local":
873
+ ax.set_xlabel('Features (Ordered by Local Importance)', fontsize=14)
874
+ else: # global
875
+ ax.set_xlabel('Features (Ordered by Global Importance)', fontsize=14)
876
+
877
+ ax.set_ylabel('Shapley Value', fontsize=14)
878
+
879
+ if title:
880
+ ax.set_title(title, fontsize=14)
881
+ else:
882
+ # Get sample label if available
883
+ sample_label = f"Sample {sample_idx}"
884
+ if hasattr(dt_fitted, 'label_df_') and dt_fitted.label_df_ is not None:
885
+ if len(dt_fitted.label_df_.columns) > 0:
886
+ first_label_col = dt_fitted.label_df_.columns[0]
887
+ if sample_idx in dt_fitted.label_df_.index:
888
+ sample_label = dt_fitted.label_df_.loc[sample_idx, first_label_col]
889
+
890
+ # Add significance to title
891
+ sig_name = significance.capitalize()
892
+ ax.set_title(f'{sig_name} Explanations: {sample_idx}', fontsize=14)
893
+
894
+ # Grid for readability
895
+ ax.grid(axis='y', alpha=0.3, linestyle='--')
896
+ ax.set_axisbelow(True)
897
+
898
+ plt.tight_layout()
899
+
900
+ return ax
901
+
902
+
903
+ # ============================================================================
904
+ # Export
905
+ # ============================================================================
906
+
907
+ __all__ = [
908
+ 'significance_plot',
909
+ 'heatmap',
910
+ 'profile_plot',
911
+ 'get_top_sample'
912
+ ]