datatypical 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datatypical-0.7.0.dist-info/METADATA +302 -0
- datatypical-0.7.0.dist-info/RECORD +7 -0
- datatypical-0.7.0.dist-info/WHEEL +5 -0
- datatypical-0.7.0.dist-info/licenses/LICENSE +21 -0
- datatypical-0.7.0.dist-info/top_level.txt +2 -0
- datatypical.py +3417 -0
- datatypical_viz.py +912 -0
datatypical_viz.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DataTypical v0.7 - Visualization Module
|
|
3
|
+
========================================
|
|
4
|
+
|
|
5
|
+
Publication-quality visualizations for dual-perspective analysis:
|
|
6
|
+
1. significance_plot: Dual-perspective scatter (hero visualization)
|
|
7
|
+
- Automatic discrete/continuous color detection
|
|
8
|
+
- Binary: purple (low) and green (high)
|
|
9
|
+
- Discrete (3-12 categories): viridis palette + legend
|
|
10
|
+
- Continuous (>12 values): viridis colormap + colorbar
|
|
11
|
+
2. heatmap: Feature attribution heatmaps
|
|
12
|
+
3. profile_plot: Feature importance profiles for individual samples
|
|
13
|
+
|
|
14
|
+
Design specifications:
|
|
15
|
+
- Viridis colormap (default)
|
|
16
|
+
- figsize (6,5) for scatter and heatmaps
|
|
17
|
+
- figsize (12,5) for profile plots
|
|
18
|
+
- Font size 12 for ticks, 14 for axis labels
|
|
19
|
+
- Clean, professional style matching existing software
|
|
20
|
+
|
|
21
|
+
Author: Amanda S. Barnard
|
|
22
|
+
Date: January 2026
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import pandas as pd
|
|
27
|
+
import matplotlib.pyplot as plt
|
|
28
|
+
import matplotlib.patches as mpatches
|
|
29
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
30
|
+
import seaborn as sns
|
|
31
|
+
from typing import Optional, Dict, List, Tuple, Union
|
|
32
|
+
|
|
33
|
+
# Configure plotting defaults
|
|
34
|
+
plt.rcParams['font.size'] = 12
|
|
35
|
+
plt.rcParams['axes.labelsize'] = 14
|
|
36
|
+
plt.rcParams['xtick.labelsize'] = 12
|
|
37
|
+
plt.rcParams['ytick.labelsize'] = 12
|
|
38
|
+
plt.rcParams['legend.fontsize'] = 12
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ============================================================================
|
|
42
|
+
# HELPER function
|
|
43
|
+
# ============================================================================
|
|
44
|
+
|
|
45
|
+
def get_top_sample(
|
|
46
|
+
results: pd.DataFrame,
|
|
47
|
+
rank_column: str,
|
|
48
|
+
n: int = 1,
|
|
49
|
+
mode: str = 'max'
|
|
50
|
+
) -> Union[int, List[int], None]:
|
|
51
|
+
"""
|
|
52
|
+
Safely get top sample(s) from results, handling missing formative data.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
results : pd.DataFrame
|
|
57
|
+
Results from DataTypical.fit_transform()
|
|
58
|
+
rank_column : str
|
|
59
|
+
Column to rank by (e.g., 'archetypal_rank', 'archetypal_shapley_rank')
|
|
60
|
+
n : int
|
|
61
|
+
Number of top samples to return (default: 1)
|
|
62
|
+
mode : str
|
|
63
|
+
'max' for highest values, 'min' for lowest (default: 'max')
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
sample_idx : int, list, or None
|
|
68
|
+
Top sample index/indices, or None if data not available
|
|
69
|
+
Returns single int if n=1, list if n>1
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> # Get top archetypal sample
|
|
74
|
+
>>> top_idx = get_top_sample(results, 'archetypal_rank')
|
|
75
|
+
>>>
|
|
76
|
+
>>> # Get top formative sample (handles NaN gracefully)
|
|
77
|
+
>>> top_formative = get_top_sample(results, 'archetypal_shapley_rank')
|
|
78
|
+
>>> if top_formative is not None:
|
|
79
|
+
... ax = profile_plot(dt, top_formative, order='global')
|
|
80
|
+
>>>
|
|
81
|
+
>>> # Get top 5 samples
|
|
82
|
+
>>> top_5 = get_top_sample(results, 'prototypical_rank', n=5)
|
|
83
|
+
"""
|
|
84
|
+
if rank_column not in results.columns:
|
|
85
|
+
print(f"âš Column '{rank_column}' not found in results")
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
# Check if column is all NaN (formative data not available)
|
|
89
|
+
if results[rank_column].isna().all():
|
|
90
|
+
print(f"âš Column '{rank_column}' has no data (likely fast_mode=True)")
|
|
91
|
+
|
|
92
|
+
# Provide helpful message based on column name
|
|
93
|
+
if 'shapley_rank' in rank_column:
|
|
94
|
+
print(f" Formative data requires: DataTypical(shapley_mode=True, fast_mode=False)")
|
|
95
|
+
if 'stereotypical' in rank_column:
|
|
96
|
+
print(f" Also requires: stereotype_column='<your_column>'")
|
|
97
|
+
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
# Get top sample(s)
|
|
101
|
+
if mode == 'max':
|
|
102
|
+
if n == 1:
|
|
103
|
+
return results[rank_column].idxmax()
|
|
104
|
+
else:
|
|
105
|
+
return results.nlargest(n, rank_column).index.tolist()
|
|
106
|
+
else: # min
|
|
107
|
+
if n == 1:
|
|
108
|
+
return results[rank_column].idxmin()
|
|
109
|
+
else:
|
|
110
|
+
return results.nsmallest(n, rank_column).index.tolist()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ============================================================================
|
|
114
|
+
# HERO VISUALIZATION: Dual-Perspective Scatter
|
|
115
|
+
# ============================================================================
|
|
116
|
+
|
|
117
|
+
def significance_plot(
|
|
118
|
+
results: pd.DataFrame,
|
|
119
|
+
significance: str = 'archetypal',
|
|
120
|
+
color_by: Optional[str] = None,
|
|
121
|
+
size_by: Optional[str] = None,
|
|
122
|
+
labels: Optional[pd.Series] = None,
|
|
123
|
+
label_top: int = 0,
|
|
124
|
+
quadrant_lines: bool = True,
|
|
125
|
+
quadrant_threshold: Tuple[float, float] = (0.5, 0.5),
|
|
126
|
+
figsize: Tuple[int, int] = (6, 5),
|
|
127
|
+
cmap: str = 'viridis',
|
|
128
|
+
title: Optional[str] = None,
|
|
129
|
+
ax: Optional[plt.Axes] = None,
|
|
130
|
+
**kwargs
|
|
131
|
+
) -> plt.Axes:
|
|
132
|
+
"""
|
|
133
|
+
Create dual-perspective significance scatter plot (hero visualization).
|
|
134
|
+
|
|
135
|
+
This plot reveals the relationship between actual significance (samples that
|
|
136
|
+
ARE archetypal/prototypical/stereotypical) and formative significance
|
|
137
|
+
(samples that CREATE the structure).
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
results : pd.DataFrame
|
|
142
|
+
Results from DataTypical.fit_transform() with shapley_mode=True
|
|
143
|
+
significance : str
|
|
144
|
+
Which significance to plot: 'archetypal', 'prototypical', or 'stereotypical'
|
|
145
|
+
Automatically uses {significance}_rank (x-axis) and {significance}_shapley_rank (y-axis)
|
|
146
|
+
color_by : str, optional
|
|
147
|
+
Column name to color points by. Automatically detects:
|
|
148
|
+
- Continuous variables (>12 unique values): colored with viridis colormap + colorbar
|
|
149
|
+
- Binary variables (2 unique values): purple (low) and green (high) + legend
|
|
150
|
+
- Discrete variables (3-5 unique): discrete viridis colors + legend
|
|
151
|
+
- Discrete variables (6-12 unique): discrete colors + different markers + legend
|
|
152
|
+
- Non-numeric variables: discrete viridis colors + legend
|
|
153
|
+
size_by : str, optional
|
|
154
|
+
Column name to size points by
|
|
155
|
+
labels : pd.Series, optional
|
|
156
|
+
Labels for points (e.g., compound IDs)
|
|
157
|
+
label_top : int
|
|
158
|
+
Number of top points to label (by actual+formative)
|
|
159
|
+
quadrant_lines : bool
|
|
160
|
+
Draw quadrant division lines
|
|
161
|
+
quadrant_threshold : tuple of float
|
|
162
|
+
(x, y) thresholds for quadrant lines
|
|
163
|
+
figsize : tuple of int
|
|
164
|
+
Figure size (width, height)
|
|
165
|
+
cmap : str
|
|
166
|
+
Colormap name (default: 'viridis')
|
|
167
|
+
title : str, optional
|
|
168
|
+
Plot title
|
|
169
|
+
ax : plt.Axes, optional
|
|
170
|
+
Existing axes to plot on
|
|
171
|
+
**kwargs
|
|
172
|
+
Additional arguments passed to scatter()
|
|
173
|
+
|
|
174
|
+
Returns
|
|
175
|
+
-------
|
|
176
|
+
ax : plt.Axes
|
|
177
|
+
Matplotlib axes object
|
|
178
|
+
|
|
179
|
+
Examples
|
|
180
|
+
--------
|
|
181
|
+
>>> dt = DataTypical(shapley_mode=True)
|
|
182
|
+
>>> results = dt.fit_transform(data)
|
|
183
|
+
>>>
|
|
184
|
+
>>> # Basic archetypal plot
|
|
185
|
+
>>> ax = significance_plot(results, significance='archetypal')
|
|
186
|
+
>>>
|
|
187
|
+
>>> # Prototypical with continuous color
|
|
188
|
+
>>> ax = significance_plot(
|
|
189
|
+
... results,
|
|
190
|
+
... significance='prototypical',
|
|
191
|
+
... color_by='solubility', # continuous variable
|
|
192
|
+
... label_top=5
|
|
193
|
+
... )
|
|
194
|
+
>>>
|
|
195
|
+
>>> # Archetypal with binary color (auto-detected)
|
|
196
|
+
>>> ax = significance_plot(
|
|
197
|
+
... results,
|
|
198
|
+
... significance='archetypal',
|
|
199
|
+
... color_by='treatment' # binary: 0/1 or control/treatment
|
|
200
|
+
... )
|
|
201
|
+
>>>
|
|
202
|
+
>>> # Stereotypical with discrete multi-class color
|
|
203
|
+
>>> ax = significance_plot(
|
|
204
|
+
... results,
|
|
205
|
+
... significance='stereotypical',
|
|
206
|
+
... color_by='cell_type' # 4 categories with different colors
|
|
207
|
+
... )
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
# Validate significance parameter
|
|
211
|
+
valid_significance = ['archetypal', 'prototypical', 'stereotypical']
|
|
212
|
+
if significance not in valid_significance:
|
|
213
|
+
raise ValueError(f"significance must be one of {valid_significance}, got '{significance}'")
|
|
214
|
+
|
|
215
|
+
# Auto-determine column names
|
|
216
|
+
actual_col = f'{significance}_rank'
|
|
217
|
+
formative_col = f'{significance}_shapley_rank'
|
|
218
|
+
|
|
219
|
+
if actual_col not in results.columns:
|
|
220
|
+
raise ValueError(f"Column '{actual_col}' not in results")
|
|
221
|
+
if formative_col not in results.columns:
|
|
222
|
+
raise ValueError(f"Column '{formative_col}' not in results")
|
|
223
|
+
|
|
224
|
+
# Check if formative data is available (could be None in fast_mode)
|
|
225
|
+
if results[formative_col].isna().all():
|
|
226
|
+
# Print informative message and skip plot
|
|
227
|
+
print(f"\nâš Skipping significance plot:")
|
|
228
|
+
print(f" Formative data ('{formative_col}') not available (fast_mode=True)")
|
|
229
|
+
print(f" This plot requires fast_mode=False to compute formative Shapley values")
|
|
230
|
+
|
|
231
|
+
# Return None or empty axes depending on whether axes was provided
|
|
232
|
+
if ax is None:
|
|
233
|
+
# Create empty figure with message
|
|
234
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
235
|
+
ax.text(0.5, 0.5, 'Formative data not available\n(fast_mode=True)',
|
|
236
|
+
ha='center', va='center', fontsize=14, color='gray')
|
|
237
|
+
ax.set_xlim(0, 1)
|
|
238
|
+
ax.set_ylim(0, 1)
|
|
239
|
+
ax.set_xlabel(f'{significance.capitalize()} Rank (Actual)', fontsize=14)
|
|
240
|
+
ax.set_ylabel(f'{significance.capitalize()} Rank (Formative)', fontsize=14)
|
|
241
|
+
if title:
|
|
242
|
+
ax.set_title(title, fontsize=14)
|
|
243
|
+
return ax
|
|
244
|
+
return ax
|
|
245
|
+
|
|
246
|
+
# Create figure if needed
|
|
247
|
+
if ax is None:
|
|
248
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
249
|
+
|
|
250
|
+
# Prepare size values (needed for all plot types)
|
|
251
|
+
if size_by is not None:
|
|
252
|
+
if size_by not in results.columns:
|
|
253
|
+
raise ValueError(f"size_by column '{size_by}' not found in results")
|
|
254
|
+
s_values = results[size_by]
|
|
255
|
+
# Normalize to reasonable range (10-200)
|
|
256
|
+
s_min, s_max = s_values.min(), s_values.max()
|
|
257
|
+
if s_max > s_min:
|
|
258
|
+
s_normalized = 10 + 190 * (s_values - s_min) / (s_max - s_min)
|
|
259
|
+
else:
|
|
260
|
+
s_normalized = 50
|
|
261
|
+
else:
|
|
262
|
+
s_normalized = 50
|
|
263
|
+
|
|
264
|
+
# Handle coloring logic with automatic discrete/continuous detection
|
|
265
|
+
if color_by is not None:
|
|
266
|
+
if color_by not in results.columns:
|
|
267
|
+
raise ValueError(f"color_by column '{color_by}' not found in results")
|
|
268
|
+
|
|
269
|
+
c_values = results[color_by]
|
|
270
|
+
|
|
271
|
+
# Detect variable type: discrete vs continuous
|
|
272
|
+
is_numeric = pd.api.types.is_numeric_dtype(c_values)
|
|
273
|
+
n_unique = c_values.nunique()
|
|
274
|
+
|
|
275
|
+
# Determine if discrete or continuous
|
|
276
|
+
if not is_numeric:
|
|
277
|
+
# Non-numeric → discrete
|
|
278
|
+
is_discrete = True
|
|
279
|
+
use_markers = False
|
|
280
|
+
elif n_unique < 6:
|
|
281
|
+
# Numeric with <6 values → discrete
|
|
282
|
+
is_discrete = True
|
|
283
|
+
use_markers = False
|
|
284
|
+
elif 6 <= n_unique <= 12:
|
|
285
|
+
# Numeric with 6-12 values → discrete with markers
|
|
286
|
+
is_discrete = True
|
|
287
|
+
use_markers = True
|
|
288
|
+
else:
|
|
289
|
+
# Numeric with >12 values → continuous
|
|
290
|
+
is_discrete = False
|
|
291
|
+
use_markers = False
|
|
292
|
+
|
|
293
|
+
if is_discrete:
|
|
294
|
+
# DISCRETE COLORING: Use legend instead of colorbar
|
|
295
|
+
|
|
296
|
+
# Get unique values in sorted order for consistent mapping
|
|
297
|
+
if is_numeric:
|
|
298
|
+
unique_values = sorted(c_values.dropna().unique())
|
|
299
|
+
else:
|
|
300
|
+
unique_values = sorted(c_values.dropna().unique(), key=str)
|
|
301
|
+
|
|
302
|
+
n_categories = len(unique_values)
|
|
303
|
+
|
|
304
|
+
if n_categories == 0:
|
|
305
|
+
raise ValueError(f"color_by column '{color_by}' has no valid values")
|
|
306
|
+
|
|
307
|
+
# Create color mapping
|
|
308
|
+
if n_categories == 2:
|
|
309
|
+
# BINARY: purple (viridis[0.0]) for lower, green (viridis[0.6]) for higher
|
|
310
|
+
viridis_cmap = plt.cm.get_cmap('viridis')
|
|
311
|
+
color_map = {
|
|
312
|
+
unique_values[0]: viridis_cmap(0.0), # purple for lower value
|
|
313
|
+
unique_values[1]: viridis_cmap(0.66) # green for higher value
|
|
314
|
+
}
|
|
315
|
+
else:
|
|
316
|
+
# MULTI-CLASS: discrete viridis palette
|
|
317
|
+
viridis_cmap = plt.cm.get_cmap('viridis', n_categories)
|
|
318
|
+
color_map = {val: viridis_cmap(i) for i, val in enumerate(unique_values)}
|
|
319
|
+
|
|
320
|
+
# Create marker mapping if needed
|
|
321
|
+
if use_markers:
|
|
322
|
+
# 6-12 categories: use different markers for each category
|
|
323
|
+
marker_list = ['o', 's', '^', 'v', 'D', 'p', '*', 'h', 'X', 'P', '<', '>']
|
|
324
|
+
marker_map = {val: marker_list[i % len(marker_list)]
|
|
325
|
+
for i, val in enumerate(unique_values)}
|
|
326
|
+
|
|
327
|
+
# Plot each category separately for legend
|
|
328
|
+
for category in unique_values:
|
|
329
|
+
mask = c_values == category
|
|
330
|
+
|
|
331
|
+
# Get marker for this category
|
|
332
|
+
if use_markers:
|
|
333
|
+
marker = marker_map[category]
|
|
334
|
+
else:
|
|
335
|
+
marker = 'o'
|
|
336
|
+
|
|
337
|
+
# Get size values for this subset
|
|
338
|
+
if isinstance(s_normalized, (int, float)):
|
|
339
|
+
s_subset = s_normalized
|
|
340
|
+
else:
|
|
341
|
+
s_subset = s_normalized[mask]
|
|
342
|
+
|
|
343
|
+
# Plot this category
|
|
344
|
+
ax.scatter(
|
|
345
|
+
results.loc[mask, actual_col],
|
|
346
|
+
results.loc[mask, formative_col],
|
|
347
|
+
c=[color_map[category]],
|
|
348
|
+
s=s_subset,
|
|
349
|
+
marker=marker,
|
|
350
|
+
edgecolors='black',
|
|
351
|
+
linewidth=0.5,
|
|
352
|
+
label=str(category),
|
|
353
|
+
**kwargs
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Add legend
|
|
357
|
+
ax.legend(
|
|
358
|
+
title=color_by,
|
|
359
|
+
fontsize=12,
|
|
360
|
+
title_fontsize=12,
|
|
361
|
+
loc='center left',
|
|
362
|
+
bbox_to_anchor=(1.0, 0.5),
|
|
363
|
+
frameon=False
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
else:
|
|
367
|
+
# CONTINUOUS COLORING: Use colorbar (existing behavior)
|
|
368
|
+
scatter = ax.scatter(
|
|
369
|
+
results[actual_col],
|
|
370
|
+
results[formative_col],
|
|
371
|
+
c=c_values,
|
|
372
|
+
s=s_normalized,
|
|
373
|
+
cmap=cmap,
|
|
374
|
+
edgecolors='black',
|
|
375
|
+
linewidth=0.5,
|
|
376
|
+
**kwargs
|
|
377
|
+
)
|
|
378
|
+
# Add colorbar
|
|
379
|
+
cbar = plt.colorbar(scatter, ax=ax)
|
|
380
|
+
cbar.set_label(color_by, fontsize=14)
|
|
381
|
+
cbar.ax.tick_params(labelsize=12)
|
|
382
|
+
|
|
383
|
+
else:
|
|
384
|
+
# NO COLORING: Use default steelblue
|
|
385
|
+
ax.scatter(
|
|
386
|
+
results[actual_col],
|
|
387
|
+
results[formative_col],
|
|
388
|
+
c='steelblue',
|
|
389
|
+
s=s_normalized,
|
|
390
|
+
alpha=0.7,
|
|
391
|
+
edgecolors='black',
|
|
392
|
+
linewidth=0.5,
|
|
393
|
+
**kwargs
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Add quadrant lines if requested
|
|
397
|
+
if quadrant_lines:
|
|
398
|
+
x_thresh, y_thresh = quadrant_threshold
|
|
399
|
+
ax.axvline(x_thresh, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
|
400
|
+
ax.axhline(y_thresh, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
|
401
|
+
|
|
402
|
+
# Label top points if requested
|
|
403
|
+
if label_top > 0 and labels is not None:
|
|
404
|
+
# Rank by sum of actual + formative
|
|
405
|
+
combined_rank = results[actual_col] + results[formative_col]
|
|
406
|
+
top_indices = combined_rank.nlargest(label_top).index
|
|
407
|
+
|
|
408
|
+
for idx in top_indices:
|
|
409
|
+
ax.annotate(
|
|
410
|
+
labels[idx],
|
|
411
|
+
xy=(results.loc[idx, actual_col], results.loc[idx, formative_col]),
|
|
412
|
+
xytext=(5, 5),
|
|
413
|
+
textcoords='offset points',
|
|
414
|
+
fontsize=10,
|
|
415
|
+
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7, edgecolor='gray'),
|
|
416
|
+
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0', color='gray', lw=0.5)
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Labels and title
|
|
420
|
+
ax.set_xlabel(f'{significance.capitalize()} Rank (Actual)', fontsize=14)
|
|
421
|
+
ax.set_ylabel(f'{significance.capitalize()} Rank (Formative)', fontsize=14)
|
|
422
|
+
|
|
423
|
+
if title:
|
|
424
|
+
ax.set_title(title, fontsize=14)
|
|
425
|
+
else:
|
|
426
|
+
ax.set_title(f'{significance.capitalize()} Dual-Perspective Analysis', fontsize=14)
|
|
427
|
+
|
|
428
|
+
# Grid for readability
|
|
429
|
+
ax.grid(alpha=0.3, linestyle='--')
|
|
430
|
+
ax.set_axisbelow(True)
|
|
431
|
+
|
|
432
|
+
# Ensure limits include full range
|
|
433
|
+
ax.set_xlim(-0.05, 1.05)
|
|
434
|
+
ax.set_ylim(-0.05, 1.05)
|
|
435
|
+
|
|
436
|
+
plt.tight_layout()
|
|
437
|
+
|
|
438
|
+
return ax
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# ============================================================================
|
|
442
|
+
# FEATURE ATTRIBUTION HEATMAP
|
|
443
|
+
# ============================================================================
|
|
444
|
+
|
|
445
|
+
def heatmap(
|
|
446
|
+
dt_fitted,
|
|
447
|
+
results: Optional[pd.DataFrame] = None,
|
|
448
|
+
samples: Optional[Union[List[int], np.ndarray]] = None,
|
|
449
|
+
significance: str = 'archetypal',
|
|
450
|
+
order: str = 'actual',
|
|
451
|
+
top_n: Optional[int] = None,
|
|
452
|
+
top_features: int = None,
|
|
453
|
+
figsize: Tuple[int, int] = (6, 5),
|
|
454
|
+
cmap: str = 'viridis',
|
|
455
|
+
center: Optional[float] = None,
|
|
456
|
+
title: Optional[str] = None,
|
|
457
|
+
ax: Optional[plt.Axes] = None,
|
|
458
|
+
**kwargs
|
|
459
|
+
) -> plt.Axes:
|
|
460
|
+
"""
|
|
461
|
+
Create feature attribution heatmap for Shapley explanations.
|
|
462
|
+
|
|
463
|
+
Always shows explanations (why samples ARE significant). Features are always
|
|
464
|
+
ordered by global importance (average across all samples). Instances can be
|
|
465
|
+
ordered by actual or formative ranks.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
dt_fitted : DataTypical
|
|
470
|
+
Fitted DataTypical instance with shapley_mode=True
|
|
471
|
+
results : pd.DataFrame, optional
|
|
472
|
+
Results dataframe from fit_transform(). Required if samples not specified.
|
|
473
|
+
samples : list or array, optional
|
|
474
|
+
Indices of samples to show. If None, uses top_n samples from results.
|
|
475
|
+
significance : str
|
|
476
|
+
Which significance to visualize: 'archetypal', 'prototypical', 'stereotypical'
|
|
477
|
+
order : str
|
|
478
|
+
How to order instances (rows): 'actual' or 'formative'
|
|
479
|
+
'actual': Order by {significance}_rank (samples that ARE significant)
|
|
480
|
+
'formative': Order by {significance}_shapley_rank (samples that CREATE structure)
|
|
481
|
+
top_n : int, optional
|
|
482
|
+
Number of top samples to show (if samples not specified).
|
|
483
|
+
Defaults to shapley_top_n from fitted model, or 20 if not set.
|
|
484
|
+
top_features : int, optional
|
|
485
|
+
Number of top features to show (by absolute attribution)
|
|
486
|
+
figsize : tuple of int
|
|
487
|
+
Figure size (width, height)
|
|
488
|
+
cmap : str
|
|
489
|
+
Colormap name (default: 'viridis')
|
|
490
|
+
center : float, optional
|
|
491
|
+
Value to center colormap at (for diverging colormaps)
|
|
492
|
+
title : str, optional
|
|
493
|
+
Plot title
|
|
494
|
+
ax : plt.Axes, optional
|
|
495
|
+
Existing axes to plot on
|
|
496
|
+
**kwargs
|
|
497
|
+
Additional arguments passed to seaborn.heatmap()
|
|
498
|
+
|
|
499
|
+
Returns
|
|
500
|
+
-------
|
|
501
|
+
ax : plt.Axes
|
|
502
|
+
Matplotlib axes object
|
|
503
|
+
|
|
504
|
+
Examples
|
|
505
|
+
--------
|
|
506
|
+
>>> dt = DataTypical(shapley_mode=True)
|
|
507
|
+
>>> results = dt.fit_transform(data)
|
|
508
|
+
>>>
|
|
509
|
+
>>> # Show explanations for top 10 archetypal samples
|
|
510
|
+
>>> ax = heatmap(dt, results, top_n=10, significance='archetypal')
|
|
511
|
+
>>>
|
|
512
|
+
>>> # Order by formative ranks
|
|
513
|
+
>>> ax = heatmap(dt, results, order='formative', significance='archetypal')
|
|
514
|
+
>>>
|
|
515
|
+
>>> # Show top 20 prototypical with only top 15 features
|
|
516
|
+
>>> ax = heatmap(dt, results, significance='prototypical', top_n=20, top_features=15)
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
if not dt_fitted.shapley_mode:
|
|
520
|
+
raise RuntimeError("Shapley mode not enabled. Set shapley_mode=True when fitting.")
|
|
521
|
+
|
|
522
|
+
# Default top_n to shapley_top_n from fitted model
|
|
523
|
+
if top_n is None:
|
|
524
|
+
if hasattr(dt_fitted, 'shapley_top_n') and dt_fitted.shapley_top_n is not None:
|
|
525
|
+
# Convert fraction to absolute count if needed
|
|
526
|
+
if isinstance(dt_fitted.shapley_top_n, float) and 0 < dt_fitted.shapley_top_n < 1:
|
|
527
|
+
n_samples = len(dt_fitted.train_index_) if dt_fitted.train_index_ is not None else 100
|
|
528
|
+
top_n = max(1, int(dt_fitted.shapley_top_n * n_samples))
|
|
529
|
+
else:
|
|
530
|
+
top_n = int(dt_fitted.shapley_top_n)
|
|
531
|
+
else:
|
|
532
|
+
top_n = 20 # Fallback default
|
|
533
|
+
|
|
534
|
+
# Get explanations matrix (always use explanations, not formative)
|
|
535
|
+
if significance == 'archetypal':
|
|
536
|
+
Phi = dt_fitted.Phi_archetypal_explanations_
|
|
537
|
+
elif significance == 'prototypical':
|
|
538
|
+
Phi = dt_fitted.Phi_prototypical_explanations_
|
|
539
|
+
elif significance == 'stereotypical':
|
|
540
|
+
Phi = dt_fitted.Phi_stereotypical_explanations_
|
|
541
|
+
else:
|
|
542
|
+
raise ValueError(f"Unknown significance: {significance}")
|
|
543
|
+
|
|
544
|
+
# Check if data is available
|
|
545
|
+
if Phi is None:
|
|
546
|
+
print(f"\n⚠ Skipping {significance} explanations heatmap:")
|
|
547
|
+
print(f" Explanations data not available")
|
|
548
|
+
|
|
549
|
+
if significance == 'stereotypical':
|
|
550
|
+
print(f" Note: Stereotypical also requires stereotype_column to be set")
|
|
551
|
+
|
|
552
|
+
print(f"\n To enable this plot, refit with:")
|
|
553
|
+
if significance == 'stereotypical':
|
|
554
|
+
print(f" DataTypical(shapley_mode=True, stereotype_column='<column>')")
|
|
555
|
+
else:
|
|
556
|
+
print(f" DataTypical(shapley_mode=True)")
|
|
557
|
+
|
|
558
|
+
# Return empty axes with message
|
|
559
|
+
if ax is None:
|
|
560
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
561
|
+
|
|
562
|
+
ax.text(0.5, 0.5, f'{significance.capitalize()} explanations\nnot available',
|
|
563
|
+
ha='center', va='center', fontsize=14, color='gray')
|
|
564
|
+
ax.set_xlim(0, 1)
|
|
565
|
+
ax.set_ylim(0, 1)
|
|
566
|
+
ax.axis('off')
|
|
567
|
+
|
|
568
|
+
return ax
|
|
569
|
+
|
|
570
|
+
# Select samples
|
|
571
|
+
if samples is None:
|
|
572
|
+
if results is None:
|
|
573
|
+
raise ValueError("Must provide either 'samples' or 'results' dataframe")
|
|
574
|
+
|
|
575
|
+
# Determine rank column based on order
|
|
576
|
+
if order == 'actual':
|
|
577
|
+
rank_col = f"{significance}_rank"
|
|
578
|
+
elif order == 'formative':
|
|
579
|
+
rank_col = f"{significance}_shapley_rank"
|
|
580
|
+
|
|
581
|
+
# Check if formative data is available
|
|
582
|
+
if results[rank_col].isna().all():
|
|
583
|
+
print(f"\n⚠ Warning: order='formative' requested but formative data not available")
|
|
584
|
+
print(f" Falling back to ordes='actual'")
|
|
585
|
+
rank_col = f"{significance}_rank"
|
|
586
|
+
order = 'actual'
|
|
587
|
+
else:
|
|
588
|
+
raise ValueError(f"order must be 'actual' or 'formative', got '{order}'")
|
|
589
|
+
|
|
590
|
+
if rank_col not in results.columns:
|
|
591
|
+
raise RuntimeError(f"Cannot find {rank_col} in results")
|
|
592
|
+
|
|
593
|
+
# Get DataFrame indices of top samples (already ordered by rank)
|
|
594
|
+
top_samples_df_indices = results.nlargest(top_n, rank_col).index
|
|
595
|
+
|
|
596
|
+
# FIX: Use get_shapley_explanations to handle shapley_top_n subsampling
|
|
597
|
+
Phi_subset_list = []
|
|
598
|
+
sample_labels = []
|
|
599
|
+
|
|
600
|
+
for df_idx in top_samples_df_indices:
|
|
601
|
+
try:
|
|
602
|
+
# Get explanations using the API (handles index mapping internally)
|
|
603
|
+
explanations = dt_fitted.get_shapley_explanations(df_idx)
|
|
604
|
+
shapley_values = explanations[significance]
|
|
605
|
+
|
|
606
|
+
# Only include if we have valid data
|
|
607
|
+
if np.any(shapley_values != 0) or not hasattr(dt_fitted, 'shapley_top_n') or dt_fitted.shapley_top_n is None:
|
|
608
|
+
Phi_subset_list.append(shapley_values)
|
|
609
|
+
sample_labels.append(str(df_idx))
|
|
610
|
+
except (IndexError, KeyError):
|
|
611
|
+
# This sample doesn't have explanations computed
|
|
612
|
+
pass
|
|
613
|
+
|
|
614
|
+
if len(Phi_subset_list) == 0:
|
|
615
|
+
print(f"\n⚠ Error: None of the top {top_n} {significance} instances have explanations")
|
|
616
|
+
print(f" This can happen when shapley_top_n is too small")
|
|
617
|
+
if ax is None:
|
|
618
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
619
|
+
ax.text(0.5, 0.5, f'No explanations available\nfor top {top_n} instances',
|
|
620
|
+
ha='center', va='center', fontsize=14, color='gray')
|
|
621
|
+
ax.axis('off')
|
|
622
|
+
return ax
|
|
623
|
+
|
|
624
|
+
Phi_subset = np.array(Phi_subset_list)
|
|
625
|
+
|
|
626
|
+
# WARNING: Check for zero Shapley values when using formative ordering
|
|
627
|
+
if order == 'formative':
|
|
628
|
+
zero_count = (Phi_subset == 0).all(axis=1).sum()
|
|
629
|
+
if zero_count > 0:
|
|
630
|
+
print(f"\n⚠ Warning: {zero_count}/{len(sample_labels)} top formative {significance} instances have zero Shapley values")
|
|
631
|
+
print(f" This occurs when a formative instance is not in the top {significance} instances")
|
|
632
|
+
print(f" (determined by shapley_top_n parameter)")
|
|
633
|
+
print(f" These instances CREATE structure but are not themselves highly {significance}")
|
|
634
|
+
|
|
635
|
+
else:
|
|
636
|
+
# Custom sample list provided - convert to positional indices
|
|
637
|
+
samples = np.asarray(samples)
|
|
638
|
+
|
|
639
|
+
# Extract Shapley values for selected samples
|
|
640
|
+
# CRITICAL: Keep rank order - DO NOT re-sort!
|
|
641
|
+
# Top-ranked sample should appear at TOP of heatmap
|
|
642
|
+
Phi_subset = Phi[samples, :]
|
|
643
|
+
|
|
644
|
+
# Create DataFrame for heatmap with actual sample IDs
|
|
645
|
+
if dt_fitted.train_index_ is not None:
|
|
646
|
+
# Use actual DataFrame indices from training
|
|
647
|
+
sample_labels = [str(dt_fitted.train_index_[s]) for s in samples]
|
|
648
|
+
else:
|
|
649
|
+
sample_labels = [f"Sample {s}" for s in samples]
|
|
650
|
+
|
|
651
|
+
# Get feature names
|
|
652
|
+
feature_names = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
|
|
653
|
+
|
|
654
|
+
# ALWAYS order features by global importance (average across ALL samples)
|
|
655
|
+
global_importance = np.abs(Phi).mean(axis=0)
|
|
656
|
+
feature_order = np.argsort(global_importance)[::-1]
|
|
657
|
+
Phi_subset = Phi_subset[:, feature_order]
|
|
658
|
+
feature_names = [feature_names[i] for i in feature_order]
|
|
659
|
+
|
|
660
|
+
# Select top features if requested
|
|
661
|
+
if top_features is not None:
|
|
662
|
+
Phi_subset = Phi_subset[:, :top_features]
|
|
663
|
+
feature_names = feature_names[:top_features]
|
|
664
|
+
|
|
665
|
+
# Create DataFrame for heatmap
|
|
666
|
+
df_heatmap = pd.DataFrame(
|
|
667
|
+
Phi_subset,
|
|
668
|
+
index=sample_labels,
|
|
669
|
+
columns=feature_names
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# Create figure if needed
|
|
673
|
+
if ax is None:
|
|
674
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
675
|
+
|
|
676
|
+
# Create heatmap
|
|
677
|
+
heatmap_kwargs = {
|
|
678
|
+
'cmap': cmap,
|
|
679
|
+
'center': center,
|
|
680
|
+
'cbar_kws': {'label': 'Shapley Value', 'shrink': 0.5},
|
|
681
|
+
'linewidths': 0.5,
|
|
682
|
+
'annot': False,
|
|
683
|
+
'fmt': '.3f',
|
|
684
|
+
'square': False
|
|
685
|
+
}
|
|
686
|
+
heatmap_kwargs.update(kwargs)
|
|
687
|
+
|
|
688
|
+
sns.heatmap(df_heatmap, ax=ax, **heatmap_kwargs)
|
|
689
|
+
|
|
690
|
+
for spine in ax.spines.values():
|
|
691
|
+
spine.set_visible(True)
|
|
692
|
+
spine.set_linewidth(1)
|
|
693
|
+
spine.set_edgecolor('black')
|
|
694
|
+
cbar = ax.collections[0].colorbar
|
|
695
|
+
cbar.outline.set_linewidth(1)
|
|
696
|
+
cbar.outline.set_edgecolor('black')
|
|
697
|
+
|
|
698
|
+
# Labels
|
|
699
|
+
ax.set_xlabel('Features (Ordered by Global Importance)', fontsize=12)
|
|
700
|
+
if order == 'actual':
|
|
701
|
+
ylabel = f'Samples (Ordered by {significance.title()} Rank)'
|
|
702
|
+
else:
|
|
703
|
+
ylabel = f'Samples (Ordered by Formative {significance.title()} Rank)'
|
|
704
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
705
|
+
ax.tick_params(axis='x', labelrotation=90, labelsize=12)
|
|
706
|
+
ax.tick_params(axis='y', labelrotation=0, labelsize=12)
|
|
707
|
+
|
|
708
|
+
if title:
|
|
709
|
+
ax.set_title(title, fontsize=14)
|
|
710
|
+
else:
|
|
711
|
+
ax.set_title(f'{significance.title()} Explanations', fontsize=14)
|
|
712
|
+
|
|
713
|
+
plt.tight_layout()
|
|
714
|
+
|
|
715
|
+
return ax
|
|
716
|
+
|
|
717
|
+
# ============================================================================
|
|
718
|
+
# INDIVIDUAL SAMPLE PROFILE
|
|
719
|
+
# ============================================================================
|
|
720
|
+
|
|
721
|
+
def profile_plot(
|
|
722
|
+
dt_fitted,
|
|
723
|
+
sample_idx: int,
|
|
724
|
+
significance: str = 'archetypal',
|
|
725
|
+
order: str = 'local',
|
|
726
|
+
figsize: Tuple[int, int] = (12, 5),
|
|
727
|
+
cmap: str = 'viridis',
|
|
728
|
+
title: Optional[str] = None,
|
|
729
|
+
ax: Optional[plt.Axes] = None,
|
|
730
|
+
**kwargs
|
|
731
|
+
) -> plt.Axes:
|
|
732
|
+
"""
|
|
733
|
+
Create feature importance profile for a single sample.
|
|
734
|
+
|
|
735
|
+
Shows Shapley explanations (y-axis) for each feature (x-axis), with bars colored by
|
|
736
|
+
the normalized feature value for this sample. Features can be ordered locally (by
|
|
737
|
+
this sample's importance) or globally (by average importance across all samples).
|
|
738
|
+
|
|
739
|
+
Parameters
|
|
740
|
+
----------
|
|
741
|
+
dt_fitted : DataTypical
|
|
742
|
+
Fitted DataTypical instance with shapley_mode=True
|
|
743
|
+
sample_idx : int
|
|
744
|
+
Index of sample to profile
|
|
745
|
+
significance : str
|
|
746
|
+
Which significance to visualize: 'archetypal', 'prototypical', 'stereotypical'
|
|
747
|
+
order : str
|
|
748
|
+
Feature ordering method: 'local' or 'global'
|
|
749
|
+
'local': Order by this sample's Shapley values
|
|
750
|
+
'global': Order by average importance across all samples (uses explanations)
|
|
751
|
+
figsize : tuple of int
|
|
752
|
+
Figure size (width, height)
|
|
753
|
+
cmap : str
|
|
754
|
+
Colormap name (default: 'viridis')
|
|
755
|
+
title : str, optional
|
|
756
|
+
Plot title
|
|
757
|
+
ax : plt.Axes, optional
|
|
758
|
+
Existing axes to plot on
|
|
759
|
+
**kwargs
|
|
760
|
+
Additional arguments passed to bar()
|
|
761
|
+
|
|
762
|
+
Returns
|
|
763
|
+
-------
|
|
764
|
+
ax : plt.Axes
|
|
765
|
+
Matplotlib axes object
|
|
766
|
+
|
|
767
|
+
Examples
|
|
768
|
+
--------
|
|
769
|
+
>>> dt = DataTypical(shapley_mode=True)
|
|
770
|
+
>>> dt.fit(data)
|
|
771
|
+
>>>
|
|
772
|
+
>>> # Profile top archetypal sample (local ordering)
|
|
773
|
+
>>> top_idx = results['archetypal_rank'].idxmax()
|
|
774
|
+
>>> ax = profile_plot(dt, top_idx, significance='archetypal', order='local')
|
|
775
|
+
>>>
|
|
776
|
+
>>> # Profile top formative sample (global ordering)
|
|
777
|
+
>>> top_formative = results['archetypal_shapley_rank'].idxmax()
|
|
778
|
+
>>> ax = profile_plot(dt, top_formative, significance='archetypal', order='global')
|
|
779
|
+
"""
|
|
780
|
+
|
|
781
|
+
if not dt_fitted.shapley_mode:
|
|
782
|
+
raise RuntimeError("Shapley mode not enabled. Set shapley_mode=True when fitting.")
|
|
783
|
+
|
|
784
|
+
# Validate parameters
|
|
785
|
+
valid_significance = ['archetypal', 'prototypical', 'stereotypical']
|
|
786
|
+
if significance not in valid_significance:
|
|
787
|
+
raise ValueError(f"significance must be one of {valid_significance}, got '{significance}'")
|
|
788
|
+
|
|
789
|
+
valid_order = ['local', 'global']
|
|
790
|
+
if order not in valid_order:
|
|
791
|
+
raise ValueError(f"order must be one of {valid_order}, got '{order}'")
|
|
792
|
+
|
|
793
|
+
# Get explanations for this sample
|
|
794
|
+
explanations = dt_fitted.get_shapley_explanations(sample_idx)
|
|
795
|
+
shapley_values = explanations[significance]
|
|
796
|
+
|
|
797
|
+
# Get feature names
|
|
798
|
+
feature_names = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
|
|
799
|
+
|
|
800
|
+
# Determine feature ordering
|
|
801
|
+
if order == 'local':
|
|
802
|
+
# Order by THIS sample's Shapley values
|
|
803
|
+
importance = np.abs(shapley_values)
|
|
804
|
+
ordering_type = "local"
|
|
805
|
+
else: # global
|
|
806
|
+
# Order by average importance across ALL samples (using explanations)
|
|
807
|
+
if significance == 'archetypal':
|
|
808
|
+
Phi_explanations = dt_fitted.Phi_archetypal_explanations_
|
|
809
|
+
elif significance == 'prototypical':
|
|
810
|
+
Phi_explanations = dt_fitted.Phi_prototypical_explanations_
|
|
811
|
+
else: # stereotypical
|
|
812
|
+
Phi_explanations = dt_fitted.Phi_stereotypical_explanations_
|
|
813
|
+
|
|
814
|
+
# Calculate global importance (average across all samples)
|
|
815
|
+
importance = np.mean(np.abs(Phi_explanations), axis=0)
|
|
816
|
+
ordering_type = "global"
|
|
817
|
+
|
|
818
|
+
# Sort features by importance (descending by absolute value)
|
|
819
|
+
sorted_idx = np.argsort(importance)[::-1]
|
|
820
|
+
|
|
821
|
+
# Get sorted Shapley values (keep original signs for plotting)
|
|
822
|
+
shapley_sorted = shapley_values[sorted_idx]
|
|
823
|
+
features_sorted = [feature_names[i] for i in sorted_idx]
|
|
824
|
+
|
|
825
|
+
# Get normalized feature values for coloring
|
|
826
|
+
# Need to get the actual feature values for this sample
|
|
827
|
+
original_data = dt_fitted._df_original_fit
|
|
828
|
+
numeric_cols = [dt_fitted.feature_columns_[i] for i, keep in enumerate(dt_fitted.keep_mask_) if keep]
|
|
829
|
+
|
|
830
|
+
# Get sample's feature values
|
|
831
|
+
sample_feature_values = original_data.loc[sample_idx, numeric_cols].values.astype(np.float64)
|
|
832
|
+
|
|
833
|
+
# Get dataset min/max for normalization
|
|
834
|
+
dataset_values = original_data[numeric_cols].values.astype(np.float64)
|
|
835
|
+
feat_min = dataset_values.min(axis=0)
|
|
836
|
+
feat_max = dataset_values.max(axis=0)
|
|
837
|
+
|
|
838
|
+
# Normalize to [0, 1]
|
|
839
|
+
feat_range = feat_max - feat_min
|
|
840
|
+
feat_range[feat_range == 0] = 1.0 # Avoid division by zero
|
|
841
|
+
normalized_values = (sample_feature_values - feat_min) / feat_range
|
|
842
|
+
|
|
843
|
+
# Sort by same order as Shapley values
|
|
844
|
+
normalized_sorted = normalized_values[sorted_idx]
|
|
845
|
+
|
|
846
|
+
# Create figure if needed
|
|
847
|
+
if ax is None:
|
|
848
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
849
|
+
|
|
850
|
+
# Create color array from normalized feature values
|
|
851
|
+
colors = plt.cm.get_cmap(cmap)(normalized_sorted)
|
|
852
|
+
|
|
853
|
+
# Create bar plot - use signed Shapley values (can be negative)
|
|
854
|
+
x_pos = np.arange(len(features_sorted))
|
|
855
|
+
bars = ax.bar(x_pos, shapley_sorted, color=colors, edgecolor='black', linewidth=0.5, **kwargs)
|
|
856
|
+
|
|
857
|
+
# Add zero reference line
|
|
858
|
+
ax.axhline(0, color='black', linewidth=0.8, linestyle='-', alpha=0.5)
|
|
859
|
+
|
|
860
|
+
# Add colorbar
|
|
861
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
|
|
862
|
+
sm.set_array([])
|
|
863
|
+
cbar = plt.colorbar(sm, ax=ax)
|
|
864
|
+
cbar.set_label('Normalized Feature Value', fontsize=14)
|
|
865
|
+
cbar.ax.tick_params(labelsize=12)
|
|
866
|
+
|
|
867
|
+
# Labels
|
|
868
|
+
ax.set_xticks(x_pos)
|
|
869
|
+
ax.set_xticklabels(features_sorted, rotation=90, ha='center', fontsize=12)
|
|
870
|
+
|
|
871
|
+
# Update xlabel based on ordering type
|
|
872
|
+
if ordering_type == "local":
|
|
873
|
+
ax.set_xlabel('Features (Ordered by Local Importance)', fontsize=14)
|
|
874
|
+
else: # global
|
|
875
|
+
ax.set_xlabel('Features (Ordered by Global Importance)', fontsize=14)
|
|
876
|
+
|
|
877
|
+
ax.set_ylabel('Shapley Value', fontsize=14)
|
|
878
|
+
|
|
879
|
+
if title:
|
|
880
|
+
ax.set_title(title, fontsize=14)
|
|
881
|
+
else:
|
|
882
|
+
# Get sample label if available
|
|
883
|
+
sample_label = f"Sample {sample_idx}"
|
|
884
|
+
if hasattr(dt_fitted, 'label_df_') and dt_fitted.label_df_ is not None:
|
|
885
|
+
if len(dt_fitted.label_df_.columns) > 0:
|
|
886
|
+
first_label_col = dt_fitted.label_df_.columns[0]
|
|
887
|
+
if sample_idx in dt_fitted.label_df_.index:
|
|
888
|
+
sample_label = dt_fitted.label_df_.loc[sample_idx, first_label_col]
|
|
889
|
+
|
|
890
|
+
# Add significance to title
|
|
891
|
+
sig_name = significance.capitalize()
|
|
892
|
+
ax.set_title(f'{sig_name} Explanations: {sample_idx}', fontsize=14)
|
|
893
|
+
|
|
894
|
+
# Grid for readability
|
|
895
|
+
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
|
896
|
+
ax.set_axisbelow(True)
|
|
897
|
+
|
|
898
|
+
plt.tight_layout()
|
|
899
|
+
|
|
900
|
+
return ax
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
# ============================================================================
|
|
904
|
+
# Export
|
|
905
|
+
# ============================================================================
|
|
906
|
+
|
|
907
|
+
__all__ = [
|
|
908
|
+
'significance_plot',
|
|
909
|
+
'heatmap',
|
|
910
|
+
'profile_plot',
|
|
911
|
+
'get_top_sample'
|
|
912
|
+
]
|