ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
|
@@ -0,0 +1,1327 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Matplotlib + Seaborn Visualization Engine
|
|
3
|
+
Production-quality visualizations that work reliably with Gradio UI.
|
|
4
|
+
|
|
5
|
+
All functions return matplotlib Figure objects (not file paths).
|
|
6
|
+
Designed for publication-quality plots with professional styling.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import matplotlib
|
|
10
|
+
matplotlib.use('Agg') # Use non-interactive backend for Gradio compatibility
|
|
11
|
+
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import seaborn as sns
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import warnings
|
|
19
|
+
|
|
20
|
+
warnings.filterwarnings('ignore')
|
|
21
|
+
|
|
22
|
+
# Set global style
|
|
23
|
+
sns.set_style('whitegrid')
|
|
24
|
+
plt.rcParams['figure.facecolor'] = 'white'
|
|
25
|
+
plt.rcParams['axes.facecolor'] = 'white'
|
|
26
|
+
plt.rcParams['font.size'] = 10
|
|
27
|
+
plt.rcParams['axes.labelsize'] = 12
|
|
28
|
+
plt.rcParams['axes.titlesize'] = 14
|
|
29
|
+
plt.rcParams['xtick.labelsize'] = 10
|
|
30
|
+
plt.rcParams['ytick.labelsize'] = 10
|
|
31
|
+
plt.rcParams['legend.fontsize'] = 10
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ============================================================================
|
|
35
|
+
# BASIC PLOTS
|
|
36
|
+
# ============================================================================
|
|
37
|
+
|
|
38
|
+
def create_scatter_plot(
|
|
39
|
+
x: Union[np.ndarray, pd.Series, list],
|
|
40
|
+
y: Union[np.ndarray, pd.Series, list],
|
|
41
|
+
hue: Optional[Union[np.ndarray, pd.Series, list]] = None,
|
|
42
|
+
size: Optional[Union[np.ndarray, pd.Series, list]] = None,
|
|
43
|
+
title: str = "Scatter Plot",
|
|
44
|
+
xlabel: str = "X",
|
|
45
|
+
ylabel: str = "Y",
|
|
46
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
47
|
+
alpha: float = 0.6,
|
|
48
|
+
save_path: Optional[str] = None
|
|
49
|
+
) -> plt.Figure:
|
|
50
|
+
"""
|
|
51
|
+
Create a professional scatter plot with optional color coding and size variation.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x: X-axis data
|
|
55
|
+
y: Y-axis data
|
|
56
|
+
hue: Optional categorical data for color coding
|
|
57
|
+
size: Optional numeric data for size variation
|
|
58
|
+
title: Plot title
|
|
59
|
+
xlabel: X-axis label
|
|
60
|
+
ylabel: Y-axis label
|
|
61
|
+
figsize: Figure size (width, height)
|
|
62
|
+
alpha: Point transparency (0-1)
|
|
63
|
+
save_path: Optional path to save PNG file
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
matplotlib Figure object
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
>>> fig = create_scatter_plot(df['feature1'], df['target'],
|
|
70
|
+
... hue=df['category'], title='Feature vs Target')
|
|
71
|
+
>>> # Display in Gradio: gr.Plot(value=fig)
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
75
|
+
|
|
76
|
+
# Convert inputs to arrays
|
|
77
|
+
x = np.array(x)
|
|
78
|
+
y = np.array(y)
|
|
79
|
+
|
|
80
|
+
if hue is not None:
|
|
81
|
+
hue = np.array(hue)
|
|
82
|
+
unique_hues = np.unique(hue)
|
|
83
|
+
colors = sns.color_palette('Set2', n_colors=len(unique_hues))
|
|
84
|
+
|
|
85
|
+
for i, hue_val in enumerate(unique_hues):
|
|
86
|
+
mask = hue == hue_val
|
|
87
|
+
scatter_size = 50 if size is None else np.array(size)[mask]
|
|
88
|
+
ax.scatter(x[mask], y[mask],
|
|
89
|
+
c=[colors[i]],
|
|
90
|
+
s=scatter_size,
|
|
91
|
+
alpha=alpha,
|
|
92
|
+
label=str(hue_val),
|
|
93
|
+
edgecolors='black',
|
|
94
|
+
linewidth=0.5)
|
|
95
|
+
ax.legend(title='Category', loc='best', framealpha=0.9)
|
|
96
|
+
else:
|
|
97
|
+
scatter_size = 50 if size is None else size
|
|
98
|
+
ax.scatter(x, y,
|
|
99
|
+
c='steelblue',
|
|
100
|
+
s=scatter_size,
|
|
101
|
+
alpha=alpha,
|
|
102
|
+
edgecolors='black',
|
|
103
|
+
linewidth=0.5)
|
|
104
|
+
|
|
105
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
106
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
107
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
108
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
109
|
+
plt.tight_layout()
|
|
110
|
+
|
|
111
|
+
if save_path:
|
|
112
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
113
|
+
print(f" ✓ Saved scatter plot to {save_path}")
|
|
114
|
+
|
|
115
|
+
return fig
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
print(f" ✗ Error creating scatter plot: {str(e)}")
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def create_line_plot(
|
|
123
|
+
x: Union[np.ndarray, pd.Series, list],
|
|
124
|
+
y: Union[Dict[str, np.ndarray], np.ndarray, pd.Series, list],
|
|
125
|
+
title: str = "Line Plot",
|
|
126
|
+
xlabel: str = "X",
|
|
127
|
+
ylabel: str = "Y",
|
|
128
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
129
|
+
markers: bool = True,
|
|
130
|
+
save_path: Optional[str] = None
|
|
131
|
+
) -> plt.Figure:
|
|
132
|
+
"""
|
|
133
|
+
Create a line plot (supports multiple lines via dict).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
x: X-axis data
|
|
137
|
+
y: Y-axis data (dict for multiple lines: {'label1': y1, 'label2': y2})
|
|
138
|
+
title: Plot title
|
|
139
|
+
xlabel: X-axis label
|
|
140
|
+
ylabel: Y-axis label
|
|
141
|
+
figsize: Figure size
|
|
142
|
+
markers: Show markers on lines
|
|
143
|
+
save_path: Optional save path
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
matplotlib Figure object
|
|
147
|
+
"""
|
|
148
|
+
try:
|
|
149
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
150
|
+
|
|
151
|
+
x = np.array(x)
|
|
152
|
+
marker_style = 'o' if markers else None
|
|
153
|
+
|
|
154
|
+
if isinstance(y, dict):
|
|
155
|
+
colors = sns.color_palette('husl', n_colors=len(y))
|
|
156
|
+
for i, (label, y_data) in enumerate(y.items()):
|
|
157
|
+
ax.plot(x, np.array(y_data),
|
|
158
|
+
marker=marker_style,
|
|
159
|
+
label=label,
|
|
160
|
+
linewidth=2,
|
|
161
|
+
markersize=6,
|
|
162
|
+
color=colors[i])
|
|
163
|
+
ax.legend(loc='best', framealpha=0.9)
|
|
164
|
+
else:
|
|
165
|
+
ax.plot(x, np.array(y),
|
|
166
|
+
marker=marker_style,
|
|
167
|
+
linewidth=2,
|
|
168
|
+
markersize=6,
|
|
169
|
+
color='steelblue')
|
|
170
|
+
|
|
171
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
172
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
173
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
174
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
175
|
+
plt.tight_layout()
|
|
176
|
+
|
|
177
|
+
if save_path:
|
|
178
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
179
|
+
print(f" ✓ Saved line plot to {save_path}")
|
|
180
|
+
|
|
181
|
+
return fig
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
print(f" ✗ Error creating line plot: {str(e)}")
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def create_bar_chart(
|
|
189
|
+
categories: Union[list, np.ndarray],
|
|
190
|
+
values: Union[np.ndarray, pd.Series, list],
|
|
191
|
+
title: str = "Bar Chart",
|
|
192
|
+
xlabel: str = "Category",
|
|
193
|
+
ylabel: str = "Value",
|
|
194
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
195
|
+
horizontal: bool = False,
|
|
196
|
+
color: str = 'steelblue',
|
|
197
|
+
save_path: Optional[str] = None
|
|
198
|
+
) -> plt.Figure:
|
|
199
|
+
"""
|
|
200
|
+
Create a bar chart (vertical or horizontal).
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
categories: Category names
|
|
204
|
+
values: Values for each category
|
|
205
|
+
title: Plot title
|
|
206
|
+
xlabel: X-axis label
|
|
207
|
+
ylabel: Y-axis label
|
|
208
|
+
figsize: Figure size
|
|
209
|
+
horizontal: If True, create horizontal bars
|
|
210
|
+
color: Bar color
|
|
211
|
+
save_path: Optional save path
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
matplotlib Figure object
|
|
215
|
+
"""
|
|
216
|
+
try:
|
|
217
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
218
|
+
|
|
219
|
+
categories = list(categories)
|
|
220
|
+
values = np.array(values)
|
|
221
|
+
|
|
222
|
+
if horizontal:
|
|
223
|
+
ax.barh(categories, values, color=color, edgecolor='black', linewidth=0.7)
|
|
224
|
+
ax.set_xlabel(ylabel, fontsize=12)
|
|
225
|
+
ax.set_ylabel(xlabel, fontsize=12)
|
|
226
|
+
else:
|
|
227
|
+
ax.bar(categories, values, color=color, edgecolor='black', linewidth=0.7)
|
|
228
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
229
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
230
|
+
|
|
231
|
+
# Rotate labels if many categories
|
|
232
|
+
if len(categories) > 10:
|
|
233
|
+
plt.xticks(rotation=45, ha='right')
|
|
234
|
+
|
|
235
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
236
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
|
|
239
|
+
if save_path:
|
|
240
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
241
|
+
print(f" ✓ Saved bar chart to {save_path}")
|
|
242
|
+
|
|
243
|
+
return fig
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
print(f" ✗ Error creating bar chart: {str(e)}")
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def create_histogram(
|
|
251
|
+
data: Union[np.ndarray, pd.Series, list],
|
|
252
|
+
title: str = "Histogram",
|
|
253
|
+
xlabel: str = "Value",
|
|
254
|
+
ylabel: str = "Frequency",
|
|
255
|
+
bins: int = 30,
|
|
256
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
257
|
+
kde: bool = True,
|
|
258
|
+
save_path: Optional[str] = None
|
|
259
|
+
) -> plt.Figure:
|
|
260
|
+
"""
|
|
261
|
+
Create a histogram with optional KDE overlay.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
data: Data to plot
|
|
265
|
+
title: Plot title
|
|
266
|
+
xlabel: X-axis label
|
|
267
|
+
ylabel: Y-axis label
|
|
268
|
+
bins: Number of bins
|
|
269
|
+
figsize: Figure size
|
|
270
|
+
kde: Show kernel density estimate
|
|
271
|
+
save_path: Optional save path
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
matplotlib Figure object
|
|
275
|
+
"""
|
|
276
|
+
try:
|
|
277
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
278
|
+
|
|
279
|
+
data = np.array(data)
|
|
280
|
+
data = data[~np.isnan(data)] # Remove NaN values
|
|
281
|
+
|
|
282
|
+
if len(data) == 0:
|
|
283
|
+
print(" ✗ No valid data for histogram")
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
# Create histogram
|
|
287
|
+
ax.hist(data, bins=bins, color='steelblue',
|
|
288
|
+
edgecolor='black', alpha=0.7, density=kde)
|
|
289
|
+
|
|
290
|
+
# Add KDE if requested
|
|
291
|
+
if kde:
|
|
292
|
+
ax2 = ax.twinx()
|
|
293
|
+
sns.kdeplot(data, ax=ax2, color='darkred', linewidth=2, label='KDE')
|
|
294
|
+
ax2.set_ylabel('Density', fontsize=12)
|
|
295
|
+
ax2.legend(loc='upper right')
|
|
296
|
+
|
|
297
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
298
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
299
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
300
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
301
|
+
plt.tight_layout()
|
|
302
|
+
|
|
303
|
+
if save_path:
|
|
304
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
305
|
+
print(f" ✓ Saved histogram to {save_path}")
|
|
306
|
+
|
|
307
|
+
return fig
|
|
308
|
+
|
|
309
|
+
except Exception as e:
|
|
310
|
+
print(f" ✗ Error creating histogram: {str(e)}")
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def create_boxplot(
|
|
315
|
+
data: Union[Dict[str, np.ndarray], pd.DataFrame],
|
|
316
|
+
title: str = "Box Plot",
|
|
317
|
+
xlabel: str = "Category",
|
|
318
|
+
ylabel: str = "Value",
|
|
319
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
320
|
+
horizontal: bool = False,
|
|
321
|
+
save_path: Optional[str] = None
|
|
322
|
+
) -> plt.Figure:
|
|
323
|
+
"""
|
|
324
|
+
Create box plots for multiple columns/categories.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
data: Dictionary of {column_name: values} or DataFrame
|
|
328
|
+
title: Plot title
|
|
329
|
+
xlabel: X-axis label
|
|
330
|
+
ylabel: Y-axis label
|
|
331
|
+
figsize: Figure size
|
|
332
|
+
horizontal: If True, create horizontal boxplots
|
|
333
|
+
save_path: Optional save path
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
matplotlib Figure object
|
|
337
|
+
"""
|
|
338
|
+
try:
|
|
339
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
340
|
+
|
|
341
|
+
if isinstance(data, pd.DataFrame):
|
|
342
|
+
data_to_plot = [data[col].dropna() for col in data.columns]
|
|
343
|
+
labels = data.columns
|
|
344
|
+
elif isinstance(data, dict):
|
|
345
|
+
data_to_plot = [np.array(v)[~np.isnan(np.array(v))] for v in data.values()]
|
|
346
|
+
labels = list(data.keys())
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError("Data must be DataFrame or dict")
|
|
349
|
+
|
|
350
|
+
bp = ax.boxplot(data_to_plot,
|
|
351
|
+
labels=labels,
|
|
352
|
+
vert=not horizontal,
|
|
353
|
+
patch_artist=True,
|
|
354
|
+
notch=True,
|
|
355
|
+
showmeans=True)
|
|
356
|
+
|
|
357
|
+
# Styling
|
|
358
|
+
for patch in bp['boxes']:
|
|
359
|
+
patch.set_facecolor('lightblue')
|
|
360
|
+
patch.set_alpha(0.7)
|
|
361
|
+
|
|
362
|
+
for whisker in bp['whiskers']:
|
|
363
|
+
whisker.set(linewidth=1.5, color='gray')
|
|
364
|
+
|
|
365
|
+
for cap in bp['caps']:
|
|
366
|
+
cap.set(linewidth=1.5, color='gray')
|
|
367
|
+
|
|
368
|
+
for median in bp['medians']:
|
|
369
|
+
median.set(linewidth=2, color='darkred')
|
|
370
|
+
|
|
371
|
+
for mean in bp['means']:
|
|
372
|
+
mean.set(marker='D', markerfacecolor='green', markersize=6)
|
|
373
|
+
|
|
374
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
375
|
+
|
|
376
|
+
if horizontal:
|
|
377
|
+
ax.set_xlabel(ylabel, fontsize=12)
|
|
378
|
+
ax.set_ylabel(xlabel, fontsize=12)
|
|
379
|
+
else:
|
|
380
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
381
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
382
|
+
if len(labels) > 8:
|
|
383
|
+
plt.xticks(rotation=45, ha='right')
|
|
384
|
+
|
|
385
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
|
|
386
|
+
plt.tight_layout()
|
|
387
|
+
|
|
388
|
+
if save_path:
|
|
389
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
390
|
+
print(f" ✓ Saved boxplot to {save_path}")
|
|
391
|
+
|
|
392
|
+
return fig
|
|
393
|
+
|
|
394
|
+
except Exception as e:
|
|
395
|
+
print(f" ✗ Error creating boxplot: {str(e)}")
|
|
396
|
+
return None
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
# ============================================================================
|
|
400
|
+
# STATISTICAL PLOTS
|
|
401
|
+
# ============================================================================
|
|
402
|
+
|
|
403
|
+
def create_correlation_heatmap(
|
|
404
|
+
data: Union[pd.DataFrame, np.ndarray],
|
|
405
|
+
columns: Optional[List[str]] = None,
|
|
406
|
+
title: str = "Correlation Heatmap",
|
|
407
|
+
figsize: Tuple[int, int] = (12, 10),
|
|
408
|
+
annot: bool = True,
|
|
409
|
+
cmap: str = 'RdBu_r',
|
|
410
|
+
save_path: Optional[str] = None
|
|
411
|
+
) -> plt.Figure:
|
|
412
|
+
"""
|
|
413
|
+
Create a correlation heatmap with annotations.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
data: DataFrame or correlation matrix
|
|
417
|
+
columns: Column names (if data is np.ndarray)
|
|
418
|
+
title: Plot title
|
|
419
|
+
figsize: Figure size
|
|
420
|
+
annot: Show correlation values as annotations
|
|
421
|
+
cmap: Colormap (diverging, centered at 0)
|
|
422
|
+
save_path: Optional save path
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
matplotlib Figure object
|
|
426
|
+
|
|
427
|
+
Example:
|
|
428
|
+
>>> fig = create_correlation_heatmap(df[numeric_cols])
|
|
429
|
+
"""
|
|
430
|
+
try:
|
|
431
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
432
|
+
|
|
433
|
+
# Calculate correlation if DataFrame
|
|
434
|
+
if isinstance(data, pd.DataFrame):
|
|
435
|
+
corr_matrix = data.corr()
|
|
436
|
+
else:
|
|
437
|
+
corr_matrix = pd.DataFrame(data, columns=columns, index=columns)
|
|
438
|
+
|
|
439
|
+
# Create heatmap
|
|
440
|
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) # Mask upper triangle
|
|
441
|
+
|
|
442
|
+
sns.heatmap(corr_matrix,
|
|
443
|
+
mask=mask,
|
|
444
|
+
annot=annot,
|
|
445
|
+
fmt='.2f',
|
|
446
|
+
cmap=cmap,
|
|
447
|
+
center=0,
|
|
448
|
+
square=True,
|
|
449
|
+
linewidths=0.5,
|
|
450
|
+
cbar_kws={'shrink': 0.8, 'label': 'Correlation'},
|
|
451
|
+
ax=ax,
|
|
452
|
+
vmin=-1,
|
|
453
|
+
vmax=1)
|
|
454
|
+
|
|
455
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
456
|
+
plt.tight_layout()
|
|
457
|
+
|
|
458
|
+
if save_path:
|
|
459
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
460
|
+
print(f" ✓ Saved correlation heatmap to {save_path}")
|
|
461
|
+
|
|
462
|
+
return fig
|
|
463
|
+
|
|
464
|
+
except Exception as e:
|
|
465
|
+
print(f" ✗ Error creating correlation heatmap: {str(e)}")
|
|
466
|
+
return None
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def create_distribution_plot(
|
|
470
|
+
data: Union[np.ndarray, pd.Series, list],
|
|
471
|
+
title: str = "Distribution Plot",
|
|
472
|
+
xlabel: str = "Value",
|
|
473
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
474
|
+
show_rug: bool = False,
|
|
475
|
+
save_path: Optional[str] = None
|
|
476
|
+
) -> plt.Figure:
|
|
477
|
+
"""
|
|
478
|
+
Create a distribution plot with histogram and KDE.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
data: Data to plot
|
|
482
|
+
title: Plot title
|
|
483
|
+
xlabel: X-axis label
|
|
484
|
+
figsize: Figure size
|
|
485
|
+
show_rug: Show rug plot (data points on x-axis)
|
|
486
|
+
save_path: Optional save path
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
matplotlib Figure object
|
|
490
|
+
"""
|
|
491
|
+
try:
|
|
492
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
493
|
+
|
|
494
|
+
data = np.array(data)
|
|
495
|
+
data = data[~np.isnan(data)]
|
|
496
|
+
|
|
497
|
+
if len(data) == 0:
|
|
498
|
+
print(" ✗ No valid data for distribution plot")
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
# Create distribution plot
|
|
502
|
+
sns.histplot(data, kde=True, ax=ax, color='steelblue',
|
|
503
|
+
edgecolor='black', alpha=0.6, bins=30)
|
|
504
|
+
|
|
505
|
+
if show_rug:
|
|
506
|
+
sns.rugplot(data, ax=ax, color='darkred', alpha=0.5, height=0.05)
|
|
507
|
+
|
|
508
|
+
# Add statistics text
|
|
509
|
+
mean_val = np.mean(data)
|
|
510
|
+
median_val = np.median(data)
|
|
511
|
+
std_val = np.std(data)
|
|
512
|
+
|
|
513
|
+
stats_text = f'Mean: {mean_val:.2f}\nMedian: {median_val:.2f}\nStd: {std_val:.2f}'
|
|
514
|
+
ax.text(0.98, 0.98, stats_text,
|
|
515
|
+
transform=ax.transAxes,
|
|
516
|
+
verticalalignment='top',
|
|
517
|
+
horizontalalignment='right',
|
|
518
|
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
|
|
519
|
+
fontsize=10)
|
|
520
|
+
|
|
521
|
+
# Add vertical lines for mean and median
|
|
522
|
+
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label='Mean')
|
|
523
|
+
ax.axvline(median_val, color='green', linestyle='--', linewidth=2, label='Median')
|
|
524
|
+
|
|
525
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
526
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
527
|
+
ax.set_ylabel('Frequency / Density', fontsize=12)
|
|
528
|
+
ax.legend(loc='upper left')
|
|
529
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
530
|
+
plt.tight_layout()
|
|
531
|
+
|
|
532
|
+
if save_path:
|
|
533
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
534
|
+
print(f" ✓ Saved distribution plot to {save_path}")
|
|
535
|
+
|
|
536
|
+
return fig
|
|
537
|
+
|
|
538
|
+
except Exception as e:
|
|
539
|
+
print(f" ✗ Error creating distribution plot: {str(e)}")
|
|
540
|
+
return None
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def create_violin_plot(
|
|
544
|
+
data: Union[Dict[str, np.ndarray], pd.DataFrame],
|
|
545
|
+
title: str = "Violin Plot",
|
|
546
|
+
xlabel: str = "Category",
|
|
547
|
+
ylabel: str = "Value",
|
|
548
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
549
|
+
save_path: Optional[str] = None
|
|
550
|
+
) -> plt.Figure:
|
|
551
|
+
"""
|
|
552
|
+
Create violin plots showing distribution for multiple categories.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
data: Dictionary or DataFrame with categories
|
|
556
|
+
title: Plot title
|
|
557
|
+
xlabel: X-axis label
|
|
558
|
+
ylabel: Y-axis label
|
|
559
|
+
figsize: Figure size
|
|
560
|
+
save_path: Optional save path
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
matplotlib Figure object
|
|
564
|
+
"""
|
|
565
|
+
try:
|
|
566
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
567
|
+
|
|
568
|
+
if isinstance(data, dict):
|
|
569
|
+
# Convert dict to DataFrame for seaborn
|
|
570
|
+
df_list = []
|
|
571
|
+
for key, values in data.items():
|
|
572
|
+
df_list.append(pd.DataFrame({
|
|
573
|
+
'Category': [key] * len(values),
|
|
574
|
+
'Value': values
|
|
575
|
+
}))
|
|
576
|
+
plot_df = pd.concat(df_list, ignore_index=True)
|
|
577
|
+
else:
|
|
578
|
+
plot_df = data
|
|
579
|
+
|
|
580
|
+
# Create violin plot
|
|
581
|
+
sns.violinplot(data=plot_df, x='Category', y='Value', ax=ax,
|
|
582
|
+
palette='Set2', inner='box')
|
|
583
|
+
|
|
584
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
585
|
+
ax.set_xlabel(xlabel, fontsize=12)
|
|
586
|
+
ax.set_ylabel(ylabel, fontsize=12)
|
|
587
|
+
|
|
588
|
+
if len(plot_df['Category'].unique()) > 8:
|
|
589
|
+
plt.xticks(rotation=45, ha='right')
|
|
590
|
+
|
|
591
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='y')
|
|
592
|
+
plt.tight_layout()
|
|
593
|
+
|
|
594
|
+
if save_path:
|
|
595
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
596
|
+
print(f" ✓ Saved violin plot to {save_path}")
|
|
597
|
+
|
|
598
|
+
return fig
|
|
599
|
+
|
|
600
|
+
except Exception as e:
|
|
601
|
+
print(f" ✗ Error creating violin plot: {str(e)}")
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def create_pairplot(
|
|
606
|
+
data: pd.DataFrame,
|
|
607
|
+
hue: Optional[str] = None,
|
|
608
|
+
title: str = "Pair Plot",
|
|
609
|
+
figsize: Tuple[int, int] = (12, 12),
|
|
610
|
+
save_path: Optional[str] = None
|
|
611
|
+
) -> plt.Figure:
|
|
612
|
+
"""
|
|
613
|
+
Create a pairplot (scatterplot matrix) for multiple features.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
data: DataFrame with features to plot
|
|
617
|
+
hue: Column name for color coding
|
|
618
|
+
title: Plot title
|
|
619
|
+
figsize: Figure size
|
|
620
|
+
save_path: Optional save path
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
matplotlib Figure object
|
|
624
|
+
"""
|
|
625
|
+
try:
|
|
626
|
+
# Seaborn pairplot returns a PairGrid, we need to extract the figure
|
|
627
|
+
if hue and hue in data.columns:
|
|
628
|
+
pair_grid = sns.pairplot(data, hue=hue, palette='Set2',
|
|
629
|
+
diag_kind='kde', corner=True)
|
|
630
|
+
else:
|
|
631
|
+
pair_grid = sns.pairplot(data, palette='Set2',
|
|
632
|
+
diag_kind='kde', corner=True)
|
|
633
|
+
|
|
634
|
+
fig = pair_grid.fig
|
|
635
|
+
fig.suptitle(title, fontsize=14, fontweight='bold', y=1.01)
|
|
636
|
+
plt.tight_layout()
|
|
637
|
+
|
|
638
|
+
if save_path:
|
|
639
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
640
|
+
print(f" ✓ Saved pairplot to {save_path}")
|
|
641
|
+
|
|
642
|
+
return fig
|
|
643
|
+
|
|
644
|
+
except Exception as e:
|
|
645
|
+
print(f" ✗ Error creating pairplot: {str(e)}")
|
|
646
|
+
return None
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
# ============================================================================
|
|
650
|
+
# MACHINE LEARNING PLOTS
|
|
651
|
+
# ============================================================================
|
|
652
|
+
|
|
653
|
+
def create_roc_curve(
|
|
654
|
+
models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
|
|
655
|
+
title: str = "ROC Curve Comparison",
|
|
656
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
657
|
+
save_path: Optional[str] = None
|
|
658
|
+
) -> plt.Figure:
|
|
659
|
+
"""
|
|
660
|
+
Create ROC curves for multiple models on the same plot.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
models_data: Dict of {model_name: (fpr, tpr, auc_score)}
|
|
664
|
+
title: Plot title
|
|
665
|
+
figsize: Figure size
|
|
666
|
+
save_path: Optional save path
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
matplotlib Figure object
|
|
670
|
+
|
|
671
|
+
Example:
|
|
672
|
+
>>> from sklearn.metrics import roc_curve, auc
|
|
673
|
+
>>> fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
|
674
|
+
>>> auc_score = auc(fpr, tpr)
|
|
675
|
+
>>> models = {'Random Forest': (fpr, tpr, auc_score)}
|
|
676
|
+
>>> fig = create_roc_curve(models)
|
|
677
|
+
"""
|
|
678
|
+
try:
|
|
679
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
680
|
+
|
|
681
|
+
colors = sns.color_palette('husl', n_colors=len(models_data))
|
|
682
|
+
|
|
683
|
+
for i, (model_name, (fpr, tpr, auc_score)) in enumerate(models_data.items()):
|
|
684
|
+
ax.plot(fpr, tpr,
|
|
685
|
+
linewidth=2.5,
|
|
686
|
+
label=f'{model_name} (AUC = {auc_score:.3f})',
|
|
687
|
+
color=colors[i])
|
|
688
|
+
|
|
689
|
+
# Add diagonal reference line (random classifier)
|
|
690
|
+
ax.plot([0, 1], [0, 1],
|
|
691
|
+
linestyle='--',
|
|
692
|
+
linewidth=2,
|
|
693
|
+
color='gray',
|
|
694
|
+
label='Random Classifier (AUC = 0.500)')
|
|
695
|
+
|
|
696
|
+
ax.set_xlim([0.0, 1.0])
|
|
697
|
+
ax.set_ylim([0.0, 1.05])
|
|
698
|
+
ax.set_xlabel('False Positive Rate', fontsize=12)
|
|
699
|
+
ax.set_ylabel('True Positive Rate', fontsize=12)
|
|
700
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
701
|
+
ax.legend(loc='lower right', fontsize=10, framealpha=0.9)
|
|
702
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
703
|
+
plt.tight_layout()
|
|
704
|
+
|
|
705
|
+
if save_path:
|
|
706
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
707
|
+
print(f" ✓ Saved ROC curve to {save_path}")
|
|
708
|
+
|
|
709
|
+
return fig
|
|
710
|
+
|
|
711
|
+
except Exception as e:
|
|
712
|
+
print(f" ✗ Error creating ROC curve: {str(e)}")
|
|
713
|
+
return None
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def create_confusion_matrix(
|
|
717
|
+
cm: np.ndarray,
|
|
718
|
+
class_names: Optional[List[str]] = None,
|
|
719
|
+
title: str = "Confusion Matrix",
|
|
720
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
721
|
+
show_percentages: bool = True,
|
|
722
|
+
save_path: Optional[str] = None
|
|
723
|
+
) -> plt.Figure:
|
|
724
|
+
"""
|
|
725
|
+
Create a confusion matrix heatmap with annotations.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
cm: Confusion matrix (from sklearn.metrics.confusion_matrix)
|
|
729
|
+
class_names: Names of classes (optional)
|
|
730
|
+
title: Plot title
|
|
731
|
+
figsize: Figure size
|
|
732
|
+
show_percentages: Show percentages in addition to counts
|
|
733
|
+
save_path: Optional save path
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
matplotlib Figure object
|
|
737
|
+
|
|
738
|
+
Example:
|
|
739
|
+
>>> from sklearn.metrics import confusion_matrix
|
|
740
|
+
>>> cm = confusion_matrix(y_true, y_pred)
|
|
741
|
+
>>> fig = create_confusion_matrix(cm, class_names=['Class 0', 'Class 1'])
|
|
742
|
+
"""
|
|
743
|
+
try:
|
|
744
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
745
|
+
|
|
746
|
+
if class_names is None:
|
|
747
|
+
class_names = [f'Class {i}' for i in range(len(cm))]
|
|
748
|
+
|
|
749
|
+
# Normalize for percentages
|
|
750
|
+
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
|
751
|
+
|
|
752
|
+
# Create annotations
|
|
753
|
+
if show_percentages:
|
|
754
|
+
annotations = np.array([[f'{count}\n({percent:.1f}%)'
|
|
755
|
+
for count, percent in zip(row_counts, row_percents)]
|
|
756
|
+
for row_counts, row_percents in zip(cm, cm_percent)])
|
|
757
|
+
else:
|
|
758
|
+
annotations = cm
|
|
759
|
+
|
|
760
|
+
# Create heatmap
|
|
761
|
+
sns.heatmap(cm,
|
|
762
|
+
annot=annotations,
|
|
763
|
+
fmt='',
|
|
764
|
+
cmap='Blues',
|
|
765
|
+
square=True,
|
|
766
|
+
linewidths=0.5,
|
|
767
|
+
cbar_kws={'label': 'Count'},
|
|
768
|
+
xticklabels=class_names,
|
|
769
|
+
yticklabels=class_names,
|
|
770
|
+
ax=ax)
|
|
771
|
+
|
|
772
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
773
|
+
ax.set_ylabel('Actual', fontsize=12)
|
|
774
|
+
ax.set_xlabel('Predicted', fontsize=12)
|
|
775
|
+
plt.tight_layout()
|
|
776
|
+
|
|
777
|
+
if save_path:
|
|
778
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
779
|
+
print(f" ✓ Saved confusion matrix to {save_path}")
|
|
780
|
+
|
|
781
|
+
return fig
|
|
782
|
+
|
|
783
|
+
except Exception as e:
|
|
784
|
+
print(f" ✗ Error creating confusion matrix: {str(e)}")
|
|
785
|
+
return None
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def create_precision_recall_curve(
|
|
789
|
+
models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
|
|
790
|
+
title: str = "Precision-Recall Curve",
|
|
791
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
792
|
+
save_path: Optional[str] = None
|
|
793
|
+
) -> plt.Figure:
|
|
794
|
+
"""
|
|
795
|
+
Create precision-recall curves for multiple models.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
models_data: Dict of {model_name: (precision, recall, avg_precision)}
|
|
799
|
+
title: Plot title
|
|
800
|
+
figsize: Figure size
|
|
801
|
+
save_path: Optional save path
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
matplotlib Figure object
|
|
805
|
+
"""
|
|
806
|
+
try:
|
|
807
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
808
|
+
|
|
809
|
+
colors = sns.color_palette('husl', n_colors=len(models_data))
|
|
810
|
+
|
|
811
|
+
for i, (model_name, (precision, recall, avg_precision)) in enumerate(models_data.items()):
|
|
812
|
+
ax.plot(recall, precision,
|
|
813
|
+
linewidth=2.5,
|
|
814
|
+
label=f'{model_name} (AP = {avg_precision:.3f})',
|
|
815
|
+
color=colors[i])
|
|
816
|
+
|
|
817
|
+
ax.set_xlim([0.0, 1.0])
|
|
818
|
+
ax.set_ylim([0.0, 1.05])
|
|
819
|
+
ax.set_xlabel('Recall', fontsize=12)
|
|
820
|
+
ax.set_ylabel('Precision', fontsize=12)
|
|
821
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
822
|
+
ax.legend(loc='best', fontsize=10, framealpha=0.9)
|
|
823
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
824
|
+
plt.tight_layout()
|
|
825
|
+
|
|
826
|
+
if save_path:
|
|
827
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
828
|
+
print(f" ✓ Saved precision-recall curve to {save_path}")
|
|
829
|
+
|
|
830
|
+
return fig
|
|
831
|
+
|
|
832
|
+
except Exception as e:
|
|
833
|
+
print(f" ✗ Error creating precision-recall curve: {str(e)}")
|
|
834
|
+
return None
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def create_feature_importance(
|
|
838
|
+
feature_names: List[str],
|
|
839
|
+
importances: np.ndarray,
|
|
840
|
+
title: str = "Feature Importance",
|
|
841
|
+
top_n: int = 20,
|
|
842
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
843
|
+
save_path: Optional[str] = None
|
|
844
|
+
) -> plt.Figure:
|
|
845
|
+
"""
|
|
846
|
+
Create a horizontal bar chart of feature importances.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
feature_names: List of feature names
|
|
850
|
+
importances: Array of importance values
|
|
851
|
+
title: Plot title
|
|
852
|
+
top_n: Number of top features to show
|
|
853
|
+
figsize: Figure size
|
|
854
|
+
save_path: Optional save path
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
matplotlib Figure object
|
|
858
|
+
|
|
859
|
+
Example:
|
|
860
|
+
>>> importances = model.feature_importances_
|
|
861
|
+
>>> fig = create_feature_importance(feature_names, importances, top_n=15)
|
|
862
|
+
"""
|
|
863
|
+
try:
|
|
864
|
+
# Sort by importance
|
|
865
|
+
indices = np.argsort(importances)[::-1][:top_n]
|
|
866
|
+
sorted_features = [feature_names[i] for i in indices]
|
|
867
|
+
sorted_importances = importances[indices]
|
|
868
|
+
|
|
869
|
+
# Create figure with appropriate height
|
|
870
|
+
height = max(8, top_n * 0.4)
|
|
871
|
+
fig, ax = plt.subplots(figsize=(figsize[0], height))
|
|
872
|
+
|
|
873
|
+
# Color bars by positive/negative (if any negative values)
|
|
874
|
+
colors = ['green' if x >= 0 else 'red' for x in sorted_importances]
|
|
875
|
+
|
|
876
|
+
# Create horizontal bar chart
|
|
877
|
+
y_pos = np.arange(len(sorted_features))
|
|
878
|
+
ax.barh(y_pos, sorted_importances, color=colors, edgecolor='black', linewidth=0.7)
|
|
879
|
+
|
|
880
|
+
ax.set_yticks(y_pos)
|
|
881
|
+
ax.set_yticklabels(sorted_features)
|
|
882
|
+
ax.invert_yaxis() # Top features at top
|
|
883
|
+
ax.set_xlabel('Importance Score', fontsize=12)
|
|
884
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
885
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
|
|
886
|
+
plt.tight_layout()
|
|
887
|
+
|
|
888
|
+
if save_path:
|
|
889
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
890
|
+
print(f" ✓ Saved feature importance to {save_path}")
|
|
891
|
+
|
|
892
|
+
return fig
|
|
893
|
+
|
|
894
|
+
except Exception as e:
|
|
895
|
+
print(f" ✗ Error creating feature importance plot: {str(e)}")
|
|
896
|
+
return None
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def create_residual_plot(
|
|
900
|
+
y_true: np.ndarray,
|
|
901
|
+
y_pred: np.ndarray,
|
|
902
|
+
title: str = "Residual Plot",
|
|
903
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
904
|
+
save_path: Optional[str] = None
|
|
905
|
+
) -> plt.Figure:
|
|
906
|
+
"""
|
|
907
|
+
Create a residual plot (Predicted vs Actual) for regression models.
|
|
908
|
+
|
|
909
|
+
Args:
|
|
910
|
+
y_true: True target values
|
|
911
|
+
y_pred: Predicted values
|
|
912
|
+
title: Plot title
|
|
913
|
+
figsize: Figure size
|
|
914
|
+
save_path: Optional save path
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
matplotlib Figure object
|
|
918
|
+
"""
|
|
919
|
+
try:
|
|
920
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(figsize[0]*2, figsize[1]))
|
|
921
|
+
|
|
922
|
+
residuals = y_true - y_pred
|
|
923
|
+
|
|
924
|
+
# Plot 1: Predicted vs Actual
|
|
925
|
+
ax1.scatter(y_true, y_pred, alpha=0.5, s=50, edgecolors='black', linewidth=0.5)
|
|
926
|
+
|
|
927
|
+
# Add perfect prediction line
|
|
928
|
+
min_val = min(y_true.min(), y_pred.min())
|
|
929
|
+
max_val = max(y_true.max(), y_pred.max())
|
|
930
|
+
ax1.plot([min_val, max_val], [min_val, max_val],
|
|
931
|
+
'r--', linewidth=2, label='Perfect Prediction')
|
|
932
|
+
|
|
933
|
+
ax1.set_xlabel('Actual Values', fontsize=12)
|
|
934
|
+
ax1.set_ylabel('Predicted Values', fontsize=12)
|
|
935
|
+
ax1.set_title('Predicted vs Actual', fontsize=13, fontweight='bold')
|
|
936
|
+
ax1.legend()
|
|
937
|
+
ax1.grid(True, alpha=0.3, linestyle='--')
|
|
938
|
+
|
|
939
|
+
# Plot 2: Residuals vs Predicted
|
|
940
|
+
ax2.scatter(y_pred, residuals, alpha=0.5, s=50,
|
|
941
|
+
color='steelblue', edgecolors='black', linewidth=0.5)
|
|
942
|
+
ax2.axhline(y=0, color='red', linestyle='--', linewidth=2)
|
|
943
|
+
|
|
944
|
+
ax2.set_xlabel('Predicted Values', fontsize=12)
|
|
945
|
+
ax2.set_ylabel('Residuals', fontsize=12)
|
|
946
|
+
ax2.set_title('Residuals vs Predicted', fontsize=13, fontweight='bold')
|
|
947
|
+
ax2.grid(True, alpha=0.3, linestyle='--')
|
|
948
|
+
|
|
949
|
+
fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
|
|
950
|
+
plt.tight_layout()
|
|
951
|
+
|
|
952
|
+
if save_path:
|
|
953
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
954
|
+
print(f" ✓ Saved residual plot to {save_path}")
|
|
955
|
+
|
|
956
|
+
return fig
|
|
957
|
+
|
|
958
|
+
except Exception as e:
|
|
959
|
+
print(f" ✗ Error creating residual plot: {str(e)}")
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def create_learning_curve(
|
|
964
|
+
train_sizes: np.ndarray,
|
|
965
|
+
train_scores_mean: np.ndarray,
|
|
966
|
+
train_scores_std: np.ndarray,
|
|
967
|
+
val_scores_mean: np.ndarray,
|
|
968
|
+
val_scores_std: np.ndarray,
|
|
969
|
+
title: str = "Learning Curve",
|
|
970
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
971
|
+
save_path: Optional[str] = None
|
|
972
|
+
) -> plt.Figure:
|
|
973
|
+
"""
|
|
974
|
+
Create a learning curve showing training and validation scores.
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
train_sizes: Array of training set sizes
|
|
978
|
+
train_scores_mean: Mean training scores
|
|
979
|
+
train_scores_std: Std of training scores
|
|
980
|
+
val_scores_mean: Mean validation scores
|
|
981
|
+
val_scores_std: Std of validation scores
|
|
982
|
+
title: Plot title
|
|
983
|
+
figsize: Figure size
|
|
984
|
+
save_path: Optional save path
|
|
985
|
+
|
|
986
|
+
Returns:
|
|
987
|
+
matplotlib Figure object
|
|
988
|
+
"""
|
|
989
|
+
try:
|
|
990
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
991
|
+
|
|
992
|
+
# Plot training scores
|
|
993
|
+
ax.plot(train_sizes, train_scores_mean, 'o-', color='blue',
|
|
994
|
+
linewidth=2, markersize=8, label='Training Score')
|
|
995
|
+
ax.fill_between(train_sizes,
|
|
996
|
+
train_scores_mean - train_scores_std,
|
|
997
|
+
train_scores_mean + train_scores_std,
|
|
998
|
+
alpha=0.2, color='blue')
|
|
999
|
+
|
|
1000
|
+
# Plot validation scores
|
|
1001
|
+
ax.plot(train_sizes, val_scores_mean, 'o-', color='orange',
|
|
1002
|
+
linewidth=2, markersize=8, label='Validation Score')
|
|
1003
|
+
ax.fill_between(train_sizes,
|
|
1004
|
+
val_scores_mean - val_scores_std,
|
|
1005
|
+
val_scores_mean + val_scores_std,
|
|
1006
|
+
alpha=0.2, color='orange')
|
|
1007
|
+
|
|
1008
|
+
ax.set_xlabel('Training Set Size', fontsize=12)
|
|
1009
|
+
ax.set_ylabel('Score', fontsize=12)
|
|
1010
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
1011
|
+
ax.legend(loc='best', fontsize=11, framealpha=0.9)
|
|
1012
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
1013
|
+
plt.tight_layout()
|
|
1014
|
+
|
|
1015
|
+
if save_path:
|
|
1016
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
1017
|
+
print(f" ✓ Saved learning curve to {save_path}")
|
|
1018
|
+
|
|
1019
|
+
return fig
|
|
1020
|
+
|
|
1021
|
+
except Exception as e:
|
|
1022
|
+
print(f" ✗ Error creating learning curve: {str(e)}")
|
|
1023
|
+
return None
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
# ============================================================================
|
|
1027
|
+
# DATA QUALITY PLOTS
|
|
1028
|
+
# ============================================================================
|
|
1029
|
+
|
|
1030
|
+
def create_missing_values_heatmap(
|
|
1031
|
+
df: pd.DataFrame,
|
|
1032
|
+
title: str = "Missing Values Heatmap",
|
|
1033
|
+
figsize: Tuple[int, int] = (12, 8),
|
|
1034
|
+
save_path: Optional[str] = None
|
|
1035
|
+
) -> plt.Figure:
|
|
1036
|
+
"""
|
|
1037
|
+
Create a heatmap showing missing values pattern.
|
|
1038
|
+
|
|
1039
|
+
Args:
|
|
1040
|
+
df: DataFrame to analyze
|
|
1041
|
+
title: Plot title
|
|
1042
|
+
figsize: Figure size
|
|
1043
|
+
save_path: Optional save path
|
|
1044
|
+
|
|
1045
|
+
Returns:
|
|
1046
|
+
matplotlib Figure object
|
|
1047
|
+
"""
|
|
1048
|
+
try:
|
|
1049
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1050
|
+
|
|
1051
|
+
# Create binary matrix (1 = missing, 0 = present)
|
|
1052
|
+
missing_matrix = df.isnull().astype(int)
|
|
1053
|
+
|
|
1054
|
+
# Plot heatmap
|
|
1055
|
+
sns.heatmap(missing_matrix.T,
|
|
1056
|
+
cbar=False,
|
|
1057
|
+
cmap='RdYlGn_r',
|
|
1058
|
+
ax=ax,
|
|
1059
|
+
yticklabels=df.columns)
|
|
1060
|
+
|
|
1061
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
1062
|
+
ax.set_xlabel('Sample Index', fontsize=12)
|
|
1063
|
+
ax.set_ylabel('Features', fontsize=12)
|
|
1064
|
+
plt.tight_layout()
|
|
1065
|
+
|
|
1066
|
+
if save_path:
|
|
1067
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
1068
|
+
print(f" ✓ Saved missing values heatmap to {save_path}")
|
|
1069
|
+
|
|
1070
|
+
return fig
|
|
1071
|
+
|
|
1072
|
+
except Exception as e:
|
|
1073
|
+
print(f" ✗ Error creating missing values heatmap: {str(e)}")
|
|
1074
|
+
return None
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def create_missing_values_bar(
|
|
1078
|
+
df: pd.DataFrame,
|
|
1079
|
+
title: str = "Missing Values by Column",
|
|
1080
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
1081
|
+
save_path: Optional[str] = None
|
|
1082
|
+
) -> plt.Figure:
|
|
1083
|
+
"""
|
|
1084
|
+
Create a bar chart showing percentage of missing values per column.
|
|
1085
|
+
|
|
1086
|
+
Args:
|
|
1087
|
+
df: DataFrame to analyze
|
|
1088
|
+
title: Plot title
|
|
1089
|
+
figsize: Figure size
|
|
1090
|
+
save_path: Optional save path
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
matplotlib Figure object
|
|
1094
|
+
"""
|
|
1095
|
+
try:
|
|
1096
|
+
# Calculate missing percentages
|
|
1097
|
+
missing_pct = (df.isnull().sum() / len(df) * 100).sort_values(ascending=False)
|
|
1098
|
+
missing_pct = missing_pct[missing_pct > 0] # Only columns with missing values
|
|
1099
|
+
|
|
1100
|
+
if len(missing_pct) == 0:
|
|
1101
|
+
print(" ℹ No missing values found")
|
|
1102
|
+
return None
|
|
1103
|
+
|
|
1104
|
+
height = max(6, len(missing_pct) * 0.3)
|
|
1105
|
+
fig, ax = plt.subplots(figsize=(figsize[0], height))
|
|
1106
|
+
|
|
1107
|
+
# Create horizontal bar chart
|
|
1108
|
+
colors = plt.cm.Reds(missing_pct / 100)
|
|
1109
|
+
ax.barh(range(len(missing_pct)), missing_pct.values,
|
|
1110
|
+
color=colors, edgecolor='black', linewidth=0.7)
|
|
1111
|
+
|
|
1112
|
+
ax.set_yticks(range(len(missing_pct)))
|
|
1113
|
+
ax.set_yticklabels(missing_pct.index)
|
|
1114
|
+
ax.set_xlabel('Missing Values (%)', fontsize=12)
|
|
1115
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
1116
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
|
|
1117
|
+
|
|
1118
|
+
# Add percentage labels
|
|
1119
|
+
for i, v in enumerate(missing_pct.values):
|
|
1120
|
+
ax.text(v + 1, i, f'{v:.1f}%', va='center', fontsize=10)
|
|
1121
|
+
|
|
1122
|
+
plt.tight_layout()
|
|
1123
|
+
|
|
1124
|
+
if save_path:
|
|
1125
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
1126
|
+
print(f" ✓ Saved missing values bar chart to {save_path}")
|
|
1127
|
+
|
|
1128
|
+
return fig
|
|
1129
|
+
|
|
1130
|
+
except Exception as e:
|
|
1131
|
+
print(f" ✗ Error creating missing values bar chart: {str(e)}")
|
|
1132
|
+
return None
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
def create_outlier_detection_boxplot(
|
|
1136
|
+
df: pd.DataFrame,
|
|
1137
|
+
columns: Optional[List[str]] = None,
|
|
1138
|
+
title: str = "Outlier Detection",
|
|
1139
|
+
figsize: Tuple[int, int] = (12, 6),
|
|
1140
|
+
save_path: Optional[str] = None
|
|
1141
|
+
) -> plt.Figure:
|
|
1142
|
+
"""
|
|
1143
|
+
Create box plots for outlier detection across multiple columns.
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
df: DataFrame with numeric columns
|
|
1147
|
+
columns: Columns to plot (None = all numeric)
|
|
1148
|
+
title: Plot title
|
|
1149
|
+
figsize: Figure size
|
|
1150
|
+
save_path: Optional save path
|
|
1151
|
+
|
|
1152
|
+
Returns:
|
|
1153
|
+
matplotlib Figure object
|
|
1154
|
+
"""
|
|
1155
|
+
try:
|
|
1156
|
+
if columns is None:
|
|
1157
|
+
columns = df.select_dtypes(include=[np.number]).columns.tolist()[:10]
|
|
1158
|
+
|
|
1159
|
+
return create_boxplot(df[columns], title=title, figsize=figsize, save_path=save_path)
|
|
1160
|
+
|
|
1161
|
+
except Exception as e:
|
|
1162
|
+
print(f" ✗ Error creating outlier detection plot: {str(e)}")
|
|
1163
|
+
return None
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
def create_skewness_plot(
|
|
1167
|
+
df: pd.DataFrame,
|
|
1168
|
+
title: str = "Feature Skewness Distribution",
|
|
1169
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
1170
|
+
save_path: Optional[str] = None
|
|
1171
|
+
) -> plt.Figure:
|
|
1172
|
+
"""
|
|
1173
|
+
Create a bar chart showing skewness of numeric features.
|
|
1174
|
+
|
|
1175
|
+
Args:
|
|
1176
|
+
df: DataFrame with numeric columns
|
|
1177
|
+
title: Plot title
|
|
1178
|
+
figsize: Figure size
|
|
1179
|
+
save_path: Optional save path
|
|
1180
|
+
|
|
1181
|
+
Returns:
|
|
1182
|
+
matplotlib Figure object
|
|
1183
|
+
"""
|
|
1184
|
+
try:
|
|
1185
|
+
# Calculate skewness for numeric columns
|
|
1186
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
1187
|
+
skewness = df[numeric_cols].skew().sort_values(ascending=False)
|
|
1188
|
+
|
|
1189
|
+
if len(skewness) == 0:
|
|
1190
|
+
print(" ℹ No numeric columns to analyze")
|
|
1191
|
+
return None
|
|
1192
|
+
|
|
1193
|
+
height = max(6, len(skewness) * 0.3)
|
|
1194
|
+
fig, ax = plt.subplots(figsize=(figsize[0], height))
|
|
1195
|
+
|
|
1196
|
+
# Color by skewness level
|
|
1197
|
+
colors = ['green' if abs(x) < 0.5 else 'orange' if abs(x) < 1 else 'red'
|
|
1198
|
+
for x in skewness.values]
|
|
1199
|
+
|
|
1200
|
+
ax.barh(range(len(skewness)), skewness.values,
|
|
1201
|
+
color=colors, edgecolor='black', linewidth=0.7)
|
|
1202
|
+
|
|
1203
|
+
ax.set_yticks(range(len(skewness)))
|
|
1204
|
+
ax.set_yticklabels(skewness.index)
|
|
1205
|
+
ax.set_xlabel('Skewness', fontsize=12)
|
|
1206
|
+
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
|
|
1207
|
+
ax.axvline(x=0, color='black', linestyle='-', linewidth=1)
|
|
1208
|
+
ax.axvline(x=-0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
|
|
1209
|
+
ax.axvline(x=0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
|
|
1210
|
+
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
|
|
1211
|
+
|
|
1212
|
+
# Add legend
|
|
1213
|
+
from matplotlib.patches import Patch
|
|
1214
|
+
legend_elements = [
|
|
1215
|
+
Patch(facecolor='green', label='Low (|skew| < 0.5)'),
|
|
1216
|
+
Patch(facecolor='orange', label='Moderate (0.5 ≤ |skew| < 1)'),
|
|
1217
|
+
Patch(facecolor='red', label='High (|skew| ≥ 1)')
|
|
1218
|
+
]
|
|
1219
|
+
ax.legend(handles=legend_elements, loc='best')
|
|
1220
|
+
|
|
1221
|
+
plt.tight_layout()
|
|
1222
|
+
|
|
1223
|
+
if save_path:
|
|
1224
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
1225
|
+
print(f" ✓ Saved skewness plot to {save_path}")
|
|
1226
|
+
|
|
1227
|
+
return fig
|
|
1228
|
+
|
|
1229
|
+
except Exception as e:
|
|
1230
|
+
print(f" ✗ Error creating skewness plot: {str(e)}")
|
|
1231
|
+
return None
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
# ============================================================================
|
|
1235
|
+
# UTILITY FUNCTIONS
|
|
1236
|
+
# ============================================================================
|
|
1237
|
+
|
|
1238
|
+
def save_figure(fig: plt.Figure, path: str, dpi: int = 300) -> None:
|
|
1239
|
+
"""
|
|
1240
|
+
Save a matplotlib figure to file.
|
|
1241
|
+
|
|
1242
|
+
Args:
|
|
1243
|
+
fig: Matplotlib Figure object
|
|
1244
|
+
path: Output file path (supports .png, .jpg, .pdf, .svg)
|
|
1245
|
+
dpi: Resolution (dots per inch)
|
|
1246
|
+
"""
|
|
1247
|
+
try:
|
|
1248
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
1249
|
+
fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white')
|
|
1250
|
+
print(f" ✓ Saved figure to {path}")
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
print(f" ✗ Error saving figure: {str(e)}")
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
def close_figure(fig: plt.Figure) -> None:
|
|
1256
|
+
"""
|
|
1257
|
+
Close a matplotlib figure to free memory.
|
|
1258
|
+
|
|
1259
|
+
Args:
|
|
1260
|
+
fig: Matplotlib Figure object
|
|
1261
|
+
"""
|
|
1262
|
+
if fig is not None:
|
|
1263
|
+
plt.close(fig)
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
def create_subplots_grid(
|
|
1267
|
+
plot_data: List[Dict[str, Any]],
|
|
1268
|
+
rows: int,
|
|
1269
|
+
cols: int,
|
|
1270
|
+
figsize: Tuple[int, int] = (15, 12),
|
|
1271
|
+
title: str = "Plot Grid",
|
|
1272
|
+
save_path: Optional[str] = None
|
|
1273
|
+
) -> plt.Figure:
|
|
1274
|
+
"""
|
|
1275
|
+
Create a grid of subplots.
|
|
1276
|
+
|
|
1277
|
+
Args:
|
|
1278
|
+
plot_data: List of dicts with plot specifications
|
|
1279
|
+
rows: Number of rows
|
|
1280
|
+
cols: Number of columns
|
|
1281
|
+
figsize: Figure size
|
|
1282
|
+
title: Overall title
|
|
1283
|
+
save_path: Optional save path
|
|
1284
|
+
|
|
1285
|
+
Returns:
|
|
1286
|
+
matplotlib Figure object
|
|
1287
|
+
|
|
1288
|
+
Example:
|
|
1289
|
+
>>> plots = [
|
|
1290
|
+
... {'type': 'scatter', 'x': x1, 'y': y1, 'title': 'Plot 1'},
|
|
1291
|
+
... {'type': 'hist', 'data': data1, 'title': 'Plot 2'}
|
|
1292
|
+
... ]
|
|
1293
|
+
>>> fig = create_subplots_grid(plots, 2, 2)
|
|
1294
|
+
"""
|
|
1295
|
+
try:
|
|
1296
|
+
fig, axes = plt.subplots(rows, cols, figsize=figsize)
|
|
1297
|
+
axes = axes.flatten() if rows * cols > 1 else [axes]
|
|
1298
|
+
|
|
1299
|
+
for i, (ax, plot_spec) in enumerate(zip(axes, plot_data)):
|
|
1300
|
+
plot_type = plot_spec.get('type', 'scatter')
|
|
1301
|
+
|
|
1302
|
+
if plot_type == 'scatter':
|
|
1303
|
+
ax.scatter(plot_spec['x'], plot_spec['y'], alpha=0.6)
|
|
1304
|
+
elif plot_type == 'hist':
|
|
1305
|
+
ax.hist(plot_spec['data'], bins=30, edgecolor='black')
|
|
1306
|
+
elif plot_type == 'line':
|
|
1307
|
+
ax.plot(plot_spec['x'], plot_spec['y'])
|
|
1308
|
+
|
|
1309
|
+
ax.set_title(plot_spec.get('title', f'Subplot {i+1}'), fontweight='bold')
|
|
1310
|
+
ax.grid(True, alpha=0.3)
|
|
1311
|
+
|
|
1312
|
+
# Hide unused subplots
|
|
1313
|
+
for i in range(len(plot_data), len(axes)):
|
|
1314
|
+
axes[i].axis('off')
|
|
1315
|
+
|
|
1316
|
+
fig.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
|
|
1317
|
+
plt.tight_layout()
|
|
1318
|
+
|
|
1319
|
+
if save_path:
|
|
1320
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
1321
|
+
print(f" ✓ Saved subplot grid to {save_path}")
|
|
1322
|
+
|
|
1323
|
+
return fig
|
|
1324
|
+
|
|
1325
|
+
except Exception as e:
|
|
1326
|
+
print(f" ✗ Error creating subplot grid: {str(e)}")
|
|
1327
|
+
return None
|