ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,1327 @@
1
+ """
2
+ Matplotlib + Seaborn Visualization Engine
3
+ Production-quality visualizations that work reliably with Gradio UI.
4
+
5
+ All functions return matplotlib Figure objects (not file paths).
6
+ Designed for publication-quality plots with professional styling.
7
+ """
8
+
9
+ import matplotlib
10
+ matplotlib.use('Agg') # Use non-interactive backend for Gradio compatibility
11
+
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import numpy as np
15
+ import pandas as pd
16
+ from typing import Dict, Any, List, Optional, Tuple, Union
17
+ from pathlib import Path
18
+ import warnings
19
+
20
+ warnings.filterwarnings('ignore')
21
+
22
+ # Set global style
23
+ sns.set_style('whitegrid')
24
+ plt.rcParams['figure.facecolor'] = 'white'
25
+ plt.rcParams['axes.facecolor'] = 'white'
26
+ plt.rcParams['font.size'] = 10
27
+ plt.rcParams['axes.labelsize'] = 12
28
+ plt.rcParams['axes.titlesize'] = 14
29
+ plt.rcParams['xtick.labelsize'] = 10
30
+ plt.rcParams['ytick.labelsize'] = 10
31
+ plt.rcParams['legend.fontsize'] = 10
32
+
33
+
34
+ # ============================================================================
35
+ # BASIC PLOTS
36
+ # ============================================================================
37
+
38
+ def create_scatter_plot(
39
+ x: Union[np.ndarray, pd.Series, list],
40
+ y: Union[np.ndarray, pd.Series, list],
41
+ hue: Optional[Union[np.ndarray, pd.Series, list]] = None,
42
+ size: Optional[Union[np.ndarray, pd.Series, list]] = None,
43
+ title: str = "Scatter Plot",
44
+ xlabel: str = "X",
45
+ ylabel: str = "Y",
46
+ figsize: Tuple[int, int] = (10, 6),
47
+ alpha: float = 0.6,
48
+ save_path: Optional[str] = None
49
+ ) -> plt.Figure:
50
+ """
51
+ Create a professional scatter plot with optional color coding and size variation.
52
+
53
+ Args:
54
+ x: X-axis data
55
+ y: Y-axis data
56
+ hue: Optional categorical data for color coding
57
+ size: Optional numeric data for size variation
58
+ title: Plot title
59
+ xlabel: X-axis label
60
+ ylabel: Y-axis label
61
+ figsize: Figure size (width, height)
62
+ alpha: Point transparency (0-1)
63
+ save_path: Optional path to save PNG file
64
+
65
+ Returns:
66
+ matplotlib Figure object
67
+
68
+ Example:
69
+ >>> fig = create_scatter_plot(df['feature1'], df['target'],
70
+ ... hue=df['category'], title='Feature vs Target')
71
+ >>> # Display in Gradio: gr.Plot(value=fig)
72
+ """
73
+ try:
74
+ fig, ax = plt.subplots(figsize=figsize)
75
+
76
+ # Convert inputs to arrays
77
+ x = np.array(x)
78
+ y = np.array(y)
79
+
80
+ if hue is not None:
81
+ hue = np.array(hue)
82
+ unique_hues = np.unique(hue)
83
+ colors = sns.color_palette('Set2', n_colors=len(unique_hues))
84
+
85
+ for i, hue_val in enumerate(unique_hues):
86
+ mask = hue == hue_val
87
+ scatter_size = 50 if size is None else np.array(size)[mask]
88
+ ax.scatter(x[mask], y[mask],
89
+ c=[colors[i]],
90
+ s=scatter_size,
91
+ alpha=alpha,
92
+ label=str(hue_val),
93
+ edgecolors='black',
94
+ linewidth=0.5)
95
+ ax.legend(title='Category', loc='best', framealpha=0.9)
96
+ else:
97
+ scatter_size = 50 if size is None else size
98
+ ax.scatter(x, y,
99
+ c='steelblue',
100
+ s=scatter_size,
101
+ alpha=alpha,
102
+ edgecolors='black',
103
+ linewidth=0.5)
104
+
105
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
106
+ ax.set_xlabel(xlabel, fontsize=12)
107
+ ax.set_ylabel(ylabel, fontsize=12)
108
+ ax.grid(True, alpha=0.3, linestyle='--')
109
+ plt.tight_layout()
110
+
111
+ if save_path:
112
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
113
+ print(f" ✓ Saved scatter plot to {save_path}")
114
+
115
+ return fig
116
+
117
+ except Exception as e:
118
+ print(f" ✗ Error creating scatter plot: {str(e)}")
119
+ return None
120
+
121
+
122
+ def create_line_plot(
123
+ x: Union[np.ndarray, pd.Series, list],
124
+ y: Union[Dict[str, np.ndarray], np.ndarray, pd.Series, list],
125
+ title: str = "Line Plot",
126
+ xlabel: str = "X",
127
+ ylabel: str = "Y",
128
+ figsize: Tuple[int, int] = (10, 6),
129
+ markers: bool = True,
130
+ save_path: Optional[str] = None
131
+ ) -> plt.Figure:
132
+ """
133
+ Create a line plot (supports multiple lines via dict).
134
+
135
+ Args:
136
+ x: X-axis data
137
+ y: Y-axis data (dict for multiple lines: {'label1': y1, 'label2': y2})
138
+ title: Plot title
139
+ xlabel: X-axis label
140
+ ylabel: Y-axis label
141
+ figsize: Figure size
142
+ markers: Show markers on lines
143
+ save_path: Optional save path
144
+
145
+ Returns:
146
+ matplotlib Figure object
147
+ """
148
+ try:
149
+ fig, ax = plt.subplots(figsize=figsize)
150
+
151
+ x = np.array(x)
152
+ marker_style = 'o' if markers else None
153
+
154
+ if isinstance(y, dict):
155
+ colors = sns.color_palette('husl', n_colors=len(y))
156
+ for i, (label, y_data) in enumerate(y.items()):
157
+ ax.plot(x, np.array(y_data),
158
+ marker=marker_style,
159
+ label=label,
160
+ linewidth=2,
161
+ markersize=6,
162
+ color=colors[i])
163
+ ax.legend(loc='best', framealpha=0.9)
164
+ else:
165
+ ax.plot(x, np.array(y),
166
+ marker=marker_style,
167
+ linewidth=2,
168
+ markersize=6,
169
+ color='steelblue')
170
+
171
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
172
+ ax.set_xlabel(xlabel, fontsize=12)
173
+ ax.set_ylabel(ylabel, fontsize=12)
174
+ ax.grid(True, alpha=0.3, linestyle='--')
175
+ plt.tight_layout()
176
+
177
+ if save_path:
178
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
179
+ print(f" ✓ Saved line plot to {save_path}")
180
+
181
+ return fig
182
+
183
+ except Exception as e:
184
+ print(f" ✗ Error creating line plot: {str(e)}")
185
+ return None
186
+
187
+
188
+ def create_bar_chart(
189
+ categories: Union[list, np.ndarray],
190
+ values: Union[np.ndarray, pd.Series, list],
191
+ title: str = "Bar Chart",
192
+ xlabel: str = "Category",
193
+ ylabel: str = "Value",
194
+ figsize: Tuple[int, int] = (10, 6),
195
+ horizontal: bool = False,
196
+ color: str = 'steelblue',
197
+ save_path: Optional[str] = None
198
+ ) -> plt.Figure:
199
+ """
200
+ Create a bar chart (vertical or horizontal).
201
+
202
+ Args:
203
+ categories: Category names
204
+ values: Values for each category
205
+ title: Plot title
206
+ xlabel: X-axis label
207
+ ylabel: Y-axis label
208
+ figsize: Figure size
209
+ horizontal: If True, create horizontal bars
210
+ color: Bar color
211
+ save_path: Optional save path
212
+
213
+ Returns:
214
+ matplotlib Figure object
215
+ """
216
+ try:
217
+ fig, ax = plt.subplots(figsize=figsize)
218
+
219
+ categories = list(categories)
220
+ values = np.array(values)
221
+
222
+ if horizontal:
223
+ ax.barh(categories, values, color=color, edgecolor='black', linewidth=0.7)
224
+ ax.set_xlabel(ylabel, fontsize=12)
225
+ ax.set_ylabel(xlabel, fontsize=12)
226
+ else:
227
+ ax.bar(categories, values, color=color, edgecolor='black', linewidth=0.7)
228
+ ax.set_xlabel(xlabel, fontsize=12)
229
+ ax.set_ylabel(ylabel, fontsize=12)
230
+
231
+ # Rotate labels if many categories
232
+ if len(categories) > 10:
233
+ plt.xticks(rotation=45, ha='right')
234
+
235
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
236
+ ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
237
+ plt.tight_layout()
238
+
239
+ if save_path:
240
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
241
+ print(f" ✓ Saved bar chart to {save_path}")
242
+
243
+ return fig
244
+
245
+ except Exception as e:
246
+ print(f" ✗ Error creating bar chart: {str(e)}")
247
+ return None
248
+
249
+
250
+ def create_histogram(
251
+ data: Union[np.ndarray, pd.Series, list],
252
+ title: str = "Histogram",
253
+ xlabel: str = "Value",
254
+ ylabel: str = "Frequency",
255
+ bins: int = 30,
256
+ figsize: Tuple[int, int] = (10, 6),
257
+ kde: bool = True,
258
+ save_path: Optional[str] = None
259
+ ) -> plt.Figure:
260
+ """
261
+ Create a histogram with optional KDE overlay.
262
+
263
+ Args:
264
+ data: Data to plot
265
+ title: Plot title
266
+ xlabel: X-axis label
267
+ ylabel: Y-axis label
268
+ bins: Number of bins
269
+ figsize: Figure size
270
+ kde: Show kernel density estimate
271
+ save_path: Optional save path
272
+
273
+ Returns:
274
+ matplotlib Figure object
275
+ """
276
+ try:
277
+ fig, ax = plt.subplots(figsize=figsize)
278
+
279
+ data = np.array(data)
280
+ data = data[~np.isnan(data)] # Remove NaN values
281
+
282
+ if len(data) == 0:
283
+ print(" ✗ No valid data for histogram")
284
+ return None
285
+
286
+ # Create histogram
287
+ ax.hist(data, bins=bins, color='steelblue',
288
+ edgecolor='black', alpha=0.7, density=kde)
289
+
290
+ # Add KDE if requested
291
+ if kde:
292
+ ax2 = ax.twinx()
293
+ sns.kdeplot(data, ax=ax2, color='darkred', linewidth=2, label='KDE')
294
+ ax2.set_ylabel('Density', fontsize=12)
295
+ ax2.legend(loc='upper right')
296
+
297
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
298
+ ax.set_xlabel(xlabel, fontsize=12)
299
+ ax.set_ylabel(ylabel, fontsize=12)
300
+ ax.grid(True, alpha=0.3, linestyle='--')
301
+ plt.tight_layout()
302
+
303
+ if save_path:
304
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
305
+ print(f" ✓ Saved histogram to {save_path}")
306
+
307
+ return fig
308
+
309
+ except Exception as e:
310
+ print(f" ✗ Error creating histogram: {str(e)}")
311
+ return None
312
+
313
+
314
+ def create_boxplot(
315
+ data: Union[Dict[str, np.ndarray], pd.DataFrame],
316
+ title: str = "Box Plot",
317
+ xlabel: str = "Category",
318
+ ylabel: str = "Value",
319
+ figsize: Tuple[int, int] = (10, 6),
320
+ horizontal: bool = False,
321
+ save_path: Optional[str] = None
322
+ ) -> plt.Figure:
323
+ """
324
+ Create box plots for multiple columns/categories.
325
+
326
+ Args:
327
+ data: Dictionary of {column_name: values} or DataFrame
328
+ title: Plot title
329
+ xlabel: X-axis label
330
+ ylabel: Y-axis label
331
+ figsize: Figure size
332
+ horizontal: If True, create horizontal boxplots
333
+ save_path: Optional save path
334
+
335
+ Returns:
336
+ matplotlib Figure object
337
+ """
338
+ try:
339
+ fig, ax = plt.subplots(figsize=figsize)
340
+
341
+ if isinstance(data, pd.DataFrame):
342
+ data_to_plot = [data[col].dropna() for col in data.columns]
343
+ labels = data.columns
344
+ elif isinstance(data, dict):
345
+ data_to_plot = [np.array(v)[~np.isnan(np.array(v))] for v in data.values()]
346
+ labels = list(data.keys())
347
+ else:
348
+ raise ValueError("Data must be DataFrame or dict")
349
+
350
+ bp = ax.boxplot(data_to_plot,
351
+ labels=labels,
352
+ vert=not horizontal,
353
+ patch_artist=True,
354
+ notch=True,
355
+ showmeans=True)
356
+
357
+ # Styling
358
+ for patch in bp['boxes']:
359
+ patch.set_facecolor('lightblue')
360
+ patch.set_alpha(0.7)
361
+
362
+ for whisker in bp['whiskers']:
363
+ whisker.set(linewidth=1.5, color='gray')
364
+
365
+ for cap in bp['caps']:
366
+ cap.set(linewidth=1.5, color='gray')
367
+
368
+ for median in bp['medians']:
369
+ median.set(linewidth=2, color='darkred')
370
+
371
+ for mean in bp['means']:
372
+ mean.set(marker='D', markerfacecolor='green', markersize=6)
373
+
374
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
375
+
376
+ if horizontal:
377
+ ax.set_xlabel(ylabel, fontsize=12)
378
+ ax.set_ylabel(xlabel, fontsize=12)
379
+ else:
380
+ ax.set_xlabel(xlabel, fontsize=12)
381
+ ax.set_ylabel(ylabel, fontsize=12)
382
+ if len(labels) > 8:
383
+ plt.xticks(rotation=45, ha='right')
384
+
385
+ ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
386
+ plt.tight_layout()
387
+
388
+ if save_path:
389
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
390
+ print(f" ✓ Saved boxplot to {save_path}")
391
+
392
+ return fig
393
+
394
+ except Exception as e:
395
+ print(f" ✗ Error creating boxplot: {str(e)}")
396
+ return None
397
+
398
+
399
+ # ============================================================================
400
+ # STATISTICAL PLOTS
401
+ # ============================================================================
402
+
403
+ def create_correlation_heatmap(
404
+ data: Union[pd.DataFrame, np.ndarray],
405
+ columns: Optional[List[str]] = None,
406
+ title: str = "Correlation Heatmap",
407
+ figsize: Tuple[int, int] = (12, 10),
408
+ annot: bool = True,
409
+ cmap: str = 'RdBu_r',
410
+ save_path: Optional[str] = None
411
+ ) -> plt.Figure:
412
+ """
413
+ Create a correlation heatmap with annotations.
414
+
415
+ Args:
416
+ data: DataFrame or correlation matrix
417
+ columns: Column names (if data is np.ndarray)
418
+ title: Plot title
419
+ figsize: Figure size
420
+ annot: Show correlation values as annotations
421
+ cmap: Colormap (diverging, centered at 0)
422
+ save_path: Optional save path
423
+
424
+ Returns:
425
+ matplotlib Figure object
426
+
427
+ Example:
428
+ >>> fig = create_correlation_heatmap(df[numeric_cols])
429
+ """
430
+ try:
431
+ fig, ax = plt.subplots(figsize=figsize)
432
+
433
+ # Calculate correlation if DataFrame
434
+ if isinstance(data, pd.DataFrame):
435
+ corr_matrix = data.corr()
436
+ else:
437
+ corr_matrix = pd.DataFrame(data, columns=columns, index=columns)
438
+
439
+ # Create heatmap
440
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) # Mask upper triangle
441
+
442
+ sns.heatmap(corr_matrix,
443
+ mask=mask,
444
+ annot=annot,
445
+ fmt='.2f',
446
+ cmap=cmap,
447
+ center=0,
448
+ square=True,
449
+ linewidths=0.5,
450
+ cbar_kws={'shrink': 0.8, 'label': 'Correlation'},
451
+ ax=ax,
452
+ vmin=-1,
453
+ vmax=1)
454
+
455
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
456
+ plt.tight_layout()
457
+
458
+ if save_path:
459
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
460
+ print(f" ✓ Saved correlation heatmap to {save_path}")
461
+
462
+ return fig
463
+
464
+ except Exception as e:
465
+ print(f" ✗ Error creating correlation heatmap: {str(e)}")
466
+ return None
467
+
468
+
469
+ def create_distribution_plot(
470
+ data: Union[np.ndarray, pd.Series, list],
471
+ title: str = "Distribution Plot",
472
+ xlabel: str = "Value",
473
+ figsize: Tuple[int, int] = (10, 6),
474
+ show_rug: bool = False,
475
+ save_path: Optional[str] = None
476
+ ) -> plt.Figure:
477
+ """
478
+ Create a distribution plot with histogram and KDE.
479
+
480
+ Args:
481
+ data: Data to plot
482
+ title: Plot title
483
+ xlabel: X-axis label
484
+ figsize: Figure size
485
+ show_rug: Show rug plot (data points on x-axis)
486
+ save_path: Optional save path
487
+
488
+ Returns:
489
+ matplotlib Figure object
490
+ """
491
+ try:
492
+ fig, ax = plt.subplots(figsize=figsize)
493
+
494
+ data = np.array(data)
495
+ data = data[~np.isnan(data)]
496
+
497
+ if len(data) == 0:
498
+ print(" ✗ No valid data for distribution plot")
499
+ return None
500
+
501
+ # Create distribution plot
502
+ sns.histplot(data, kde=True, ax=ax, color='steelblue',
503
+ edgecolor='black', alpha=0.6, bins=30)
504
+
505
+ if show_rug:
506
+ sns.rugplot(data, ax=ax, color='darkred', alpha=0.5, height=0.05)
507
+
508
+ # Add statistics text
509
+ mean_val = np.mean(data)
510
+ median_val = np.median(data)
511
+ std_val = np.std(data)
512
+
513
+ stats_text = f'Mean: {mean_val:.2f}\nMedian: {median_val:.2f}\nStd: {std_val:.2f}'
514
+ ax.text(0.98, 0.98, stats_text,
515
+ transform=ax.transAxes,
516
+ verticalalignment='top',
517
+ horizontalalignment='right',
518
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
519
+ fontsize=10)
520
+
521
+ # Add vertical lines for mean and median
522
+ ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label='Mean')
523
+ ax.axvline(median_val, color='green', linestyle='--', linewidth=2, label='Median')
524
+
525
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
526
+ ax.set_xlabel(xlabel, fontsize=12)
527
+ ax.set_ylabel('Frequency / Density', fontsize=12)
528
+ ax.legend(loc='upper left')
529
+ ax.grid(True, alpha=0.3, linestyle='--')
530
+ plt.tight_layout()
531
+
532
+ if save_path:
533
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
534
+ print(f" ✓ Saved distribution plot to {save_path}")
535
+
536
+ return fig
537
+
538
+ except Exception as e:
539
+ print(f" ✗ Error creating distribution plot: {str(e)}")
540
+ return None
541
+
542
+
543
+ def create_violin_plot(
544
+ data: Union[Dict[str, np.ndarray], pd.DataFrame],
545
+ title: str = "Violin Plot",
546
+ xlabel: str = "Category",
547
+ ylabel: str = "Value",
548
+ figsize: Tuple[int, int] = (10, 6),
549
+ save_path: Optional[str] = None
550
+ ) -> plt.Figure:
551
+ """
552
+ Create violin plots showing distribution for multiple categories.
553
+
554
+ Args:
555
+ data: Dictionary or DataFrame with categories
556
+ title: Plot title
557
+ xlabel: X-axis label
558
+ ylabel: Y-axis label
559
+ figsize: Figure size
560
+ save_path: Optional save path
561
+
562
+ Returns:
563
+ matplotlib Figure object
564
+ """
565
+ try:
566
+ fig, ax = plt.subplots(figsize=figsize)
567
+
568
+ if isinstance(data, dict):
569
+ # Convert dict to DataFrame for seaborn
570
+ df_list = []
571
+ for key, values in data.items():
572
+ df_list.append(pd.DataFrame({
573
+ 'Category': [key] * len(values),
574
+ 'Value': values
575
+ }))
576
+ plot_df = pd.concat(df_list, ignore_index=True)
577
+ else:
578
+ plot_df = data
579
+
580
+ # Create violin plot
581
+ sns.violinplot(data=plot_df, x='Category', y='Value', ax=ax,
582
+ palette='Set2', inner='box')
583
+
584
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
585
+ ax.set_xlabel(xlabel, fontsize=12)
586
+ ax.set_ylabel(ylabel, fontsize=12)
587
+
588
+ if len(plot_df['Category'].unique()) > 8:
589
+ plt.xticks(rotation=45, ha='right')
590
+
591
+ ax.grid(True, alpha=0.3, linestyle='--', axis='y')
592
+ plt.tight_layout()
593
+
594
+ if save_path:
595
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
596
+ print(f" ✓ Saved violin plot to {save_path}")
597
+
598
+ return fig
599
+
600
+ except Exception as e:
601
+ print(f" ✗ Error creating violin plot: {str(e)}")
602
+ return None
603
+
604
+
605
+ def create_pairplot(
606
+ data: pd.DataFrame,
607
+ hue: Optional[str] = None,
608
+ title: str = "Pair Plot",
609
+ figsize: Tuple[int, int] = (12, 12),
610
+ save_path: Optional[str] = None
611
+ ) -> plt.Figure:
612
+ """
613
+ Create a pairplot (scatterplot matrix) for multiple features.
614
+
615
+ Args:
616
+ data: DataFrame with features to plot
617
+ hue: Column name for color coding
618
+ title: Plot title
619
+ figsize: Figure size
620
+ save_path: Optional save path
621
+
622
+ Returns:
623
+ matplotlib Figure object
624
+ """
625
+ try:
626
+ # Seaborn pairplot returns a PairGrid, we need to extract the figure
627
+ if hue and hue in data.columns:
628
+ pair_grid = sns.pairplot(data, hue=hue, palette='Set2',
629
+ diag_kind='kde', corner=True)
630
+ else:
631
+ pair_grid = sns.pairplot(data, palette='Set2',
632
+ diag_kind='kde', corner=True)
633
+
634
+ fig = pair_grid.fig
635
+ fig.suptitle(title, fontsize=14, fontweight='bold', y=1.01)
636
+ plt.tight_layout()
637
+
638
+ if save_path:
639
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
640
+ print(f" ✓ Saved pairplot to {save_path}")
641
+
642
+ return fig
643
+
644
+ except Exception as e:
645
+ print(f" ✗ Error creating pairplot: {str(e)}")
646
+ return None
647
+
648
+
649
+ # ============================================================================
650
+ # MACHINE LEARNING PLOTS
651
+ # ============================================================================
652
+
653
+ def create_roc_curve(
654
+ models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
655
+ title: str = "ROC Curve Comparison",
656
+ figsize: Tuple[int, int] = (10, 8),
657
+ save_path: Optional[str] = None
658
+ ) -> plt.Figure:
659
+ """
660
+ Create ROC curves for multiple models on the same plot.
661
+
662
+ Args:
663
+ models_data: Dict of {model_name: (fpr, tpr, auc_score)}
664
+ title: Plot title
665
+ figsize: Figure size
666
+ save_path: Optional save path
667
+
668
+ Returns:
669
+ matplotlib Figure object
670
+
671
+ Example:
672
+ >>> from sklearn.metrics import roc_curve, auc
673
+ >>> fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
674
+ >>> auc_score = auc(fpr, tpr)
675
+ >>> models = {'Random Forest': (fpr, tpr, auc_score)}
676
+ >>> fig = create_roc_curve(models)
677
+ """
678
+ try:
679
+ fig, ax = plt.subplots(figsize=figsize)
680
+
681
+ colors = sns.color_palette('husl', n_colors=len(models_data))
682
+
683
+ for i, (model_name, (fpr, tpr, auc_score)) in enumerate(models_data.items()):
684
+ ax.plot(fpr, tpr,
685
+ linewidth=2.5,
686
+ label=f'{model_name} (AUC = {auc_score:.3f})',
687
+ color=colors[i])
688
+
689
+ # Add diagonal reference line (random classifier)
690
+ ax.plot([0, 1], [0, 1],
691
+ linestyle='--',
692
+ linewidth=2,
693
+ color='gray',
694
+ label='Random Classifier (AUC = 0.500)')
695
+
696
+ ax.set_xlim([0.0, 1.0])
697
+ ax.set_ylim([0.0, 1.05])
698
+ ax.set_xlabel('False Positive Rate', fontsize=12)
699
+ ax.set_ylabel('True Positive Rate', fontsize=12)
700
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
701
+ ax.legend(loc='lower right', fontsize=10, framealpha=0.9)
702
+ ax.grid(True, alpha=0.3, linestyle='--')
703
+ plt.tight_layout()
704
+
705
+ if save_path:
706
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
707
+ print(f" ✓ Saved ROC curve to {save_path}")
708
+
709
+ return fig
710
+
711
+ except Exception as e:
712
+ print(f" ✗ Error creating ROC curve: {str(e)}")
713
+ return None
714
+
715
+
716
+ def create_confusion_matrix(
717
+ cm: np.ndarray,
718
+ class_names: Optional[List[str]] = None,
719
+ title: str = "Confusion Matrix",
720
+ figsize: Tuple[int, int] = (10, 8),
721
+ show_percentages: bool = True,
722
+ save_path: Optional[str] = None
723
+ ) -> plt.Figure:
724
+ """
725
+ Create a confusion matrix heatmap with annotations.
726
+
727
+ Args:
728
+ cm: Confusion matrix (from sklearn.metrics.confusion_matrix)
729
+ class_names: Names of classes (optional)
730
+ title: Plot title
731
+ figsize: Figure size
732
+ show_percentages: Show percentages in addition to counts
733
+ save_path: Optional save path
734
+
735
+ Returns:
736
+ matplotlib Figure object
737
+
738
+ Example:
739
+ >>> from sklearn.metrics import confusion_matrix
740
+ >>> cm = confusion_matrix(y_true, y_pred)
741
+ >>> fig = create_confusion_matrix(cm, class_names=['Class 0', 'Class 1'])
742
+ """
743
+ try:
744
+ fig, ax = plt.subplots(figsize=figsize)
745
+
746
+ if class_names is None:
747
+ class_names = [f'Class {i}' for i in range(len(cm))]
748
+
749
+ # Normalize for percentages
750
+ cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
751
+
752
+ # Create annotations
753
+ if show_percentages:
754
+ annotations = np.array([[f'{count}\n({percent:.1f}%)'
755
+ for count, percent in zip(row_counts, row_percents)]
756
+ for row_counts, row_percents in zip(cm, cm_percent)])
757
+ else:
758
+ annotations = cm
759
+
760
+ # Create heatmap
761
+ sns.heatmap(cm,
762
+ annot=annotations,
763
+ fmt='',
764
+ cmap='Blues',
765
+ square=True,
766
+ linewidths=0.5,
767
+ cbar_kws={'label': 'Count'},
768
+ xticklabels=class_names,
769
+ yticklabels=class_names,
770
+ ax=ax)
771
+
772
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
773
+ ax.set_ylabel('Actual', fontsize=12)
774
+ ax.set_xlabel('Predicted', fontsize=12)
775
+ plt.tight_layout()
776
+
777
+ if save_path:
778
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
779
+ print(f" ✓ Saved confusion matrix to {save_path}")
780
+
781
+ return fig
782
+
783
+ except Exception as e:
784
+ print(f" ✗ Error creating confusion matrix: {str(e)}")
785
+ return None
786
+
787
+
788
+ def create_precision_recall_curve(
789
+ models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
790
+ title: str = "Precision-Recall Curve",
791
+ figsize: Tuple[int, int] = (10, 8),
792
+ save_path: Optional[str] = None
793
+ ) -> plt.Figure:
794
+ """
795
+ Create precision-recall curves for multiple models.
796
+
797
+ Args:
798
+ models_data: Dict of {model_name: (precision, recall, avg_precision)}
799
+ title: Plot title
800
+ figsize: Figure size
801
+ save_path: Optional save path
802
+
803
+ Returns:
804
+ matplotlib Figure object
805
+ """
806
+ try:
807
+ fig, ax = plt.subplots(figsize=figsize)
808
+
809
+ colors = sns.color_palette('husl', n_colors=len(models_data))
810
+
811
+ for i, (model_name, (precision, recall, avg_precision)) in enumerate(models_data.items()):
812
+ ax.plot(recall, precision,
813
+ linewidth=2.5,
814
+ label=f'{model_name} (AP = {avg_precision:.3f})',
815
+ color=colors[i])
816
+
817
+ ax.set_xlim([0.0, 1.0])
818
+ ax.set_ylim([0.0, 1.05])
819
+ ax.set_xlabel('Recall', fontsize=12)
820
+ ax.set_ylabel('Precision', fontsize=12)
821
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
822
+ ax.legend(loc='best', fontsize=10, framealpha=0.9)
823
+ ax.grid(True, alpha=0.3, linestyle='--')
824
+ plt.tight_layout()
825
+
826
+ if save_path:
827
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
828
+ print(f" ✓ Saved precision-recall curve to {save_path}")
829
+
830
+ return fig
831
+
832
+ except Exception as e:
833
+ print(f" ✗ Error creating precision-recall curve: {str(e)}")
834
+ return None
835
+
836
+
837
+ def create_feature_importance(
838
+ feature_names: List[str],
839
+ importances: np.ndarray,
840
+ title: str = "Feature Importance",
841
+ top_n: int = 20,
842
+ figsize: Tuple[int, int] = (10, 8),
843
+ save_path: Optional[str] = None
844
+ ) -> plt.Figure:
845
+ """
846
+ Create a horizontal bar chart of feature importances.
847
+
848
+ Args:
849
+ feature_names: List of feature names
850
+ importances: Array of importance values
851
+ title: Plot title
852
+ top_n: Number of top features to show
853
+ figsize: Figure size
854
+ save_path: Optional save path
855
+
856
+ Returns:
857
+ matplotlib Figure object
858
+
859
+ Example:
860
+ >>> importances = model.feature_importances_
861
+ >>> fig = create_feature_importance(feature_names, importances, top_n=15)
862
+ """
863
+ try:
864
+ # Sort by importance
865
+ indices = np.argsort(importances)[::-1][:top_n]
866
+ sorted_features = [feature_names[i] for i in indices]
867
+ sorted_importances = importances[indices]
868
+
869
+ # Create figure with appropriate height
870
+ height = max(8, top_n * 0.4)
871
+ fig, ax = plt.subplots(figsize=(figsize[0], height))
872
+
873
+ # Color bars by positive/negative (if any negative values)
874
+ colors = ['green' if x >= 0 else 'red' for x in sorted_importances]
875
+
876
+ # Create horizontal bar chart
877
+ y_pos = np.arange(len(sorted_features))
878
+ ax.barh(y_pos, sorted_importances, color=colors, edgecolor='black', linewidth=0.7)
879
+
880
+ ax.set_yticks(y_pos)
881
+ ax.set_yticklabels(sorted_features)
882
+ ax.invert_yaxis() # Top features at top
883
+ ax.set_xlabel('Importance Score', fontsize=12)
884
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
885
+ ax.grid(True, alpha=0.3, linestyle='--', axis='x')
886
+ plt.tight_layout()
887
+
888
+ if save_path:
889
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
890
+ print(f" ✓ Saved feature importance to {save_path}")
891
+
892
+ return fig
893
+
894
+ except Exception as e:
895
+ print(f" ✗ Error creating feature importance plot: {str(e)}")
896
+ return None
897
+
898
+
899
+ def create_residual_plot(
900
+ y_true: np.ndarray,
901
+ y_pred: np.ndarray,
902
+ title: str = "Residual Plot",
903
+ figsize: Tuple[int, int] = (10, 6),
904
+ save_path: Optional[str] = None
905
+ ) -> plt.Figure:
906
+ """
907
+ Create a residual plot (Predicted vs Actual) for regression models.
908
+
909
+ Args:
910
+ y_true: True target values
911
+ y_pred: Predicted values
912
+ title: Plot title
913
+ figsize: Figure size
914
+ save_path: Optional save path
915
+
916
+ Returns:
917
+ matplotlib Figure object
918
+ """
919
+ try:
920
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(figsize[0]*2, figsize[1]))
921
+
922
+ residuals = y_true - y_pred
923
+
924
+ # Plot 1: Predicted vs Actual
925
+ ax1.scatter(y_true, y_pred, alpha=0.5, s=50, edgecolors='black', linewidth=0.5)
926
+
927
+ # Add perfect prediction line
928
+ min_val = min(y_true.min(), y_pred.min())
929
+ max_val = max(y_true.max(), y_pred.max())
930
+ ax1.plot([min_val, max_val], [min_val, max_val],
931
+ 'r--', linewidth=2, label='Perfect Prediction')
932
+
933
+ ax1.set_xlabel('Actual Values', fontsize=12)
934
+ ax1.set_ylabel('Predicted Values', fontsize=12)
935
+ ax1.set_title('Predicted vs Actual', fontsize=13, fontweight='bold')
936
+ ax1.legend()
937
+ ax1.grid(True, alpha=0.3, linestyle='--')
938
+
939
+ # Plot 2: Residuals vs Predicted
940
+ ax2.scatter(y_pred, residuals, alpha=0.5, s=50,
941
+ color='steelblue', edgecolors='black', linewidth=0.5)
942
+ ax2.axhline(y=0, color='red', linestyle='--', linewidth=2)
943
+
944
+ ax2.set_xlabel('Predicted Values', fontsize=12)
945
+ ax2.set_ylabel('Residuals', fontsize=12)
946
+ ax2.set_title('Residuals vs Predicted', fontsize=13, fontweight='bold')
947
+ ax2.grid(True, alpha=0.3, linestyle='--')
948
+
949
+ fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
950
+ plt.tight_layout()
951
+
952
+ if save_path:
953
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
954
+ print(f" ✓ Saved residual plot to {save_path}")
955
+
956
+ return fig
957
+
958
+ except Exception as e:
959
+ print(f" ✗ Error creating residual plot: {str(e)}")
960
+ return None
961
+
962
+
963
+ def create_learning_curve(
964
+ train_sizes: np.ndarray,
965
+ train_scores_mean: np.ndarray,
966
+ train_scores_std: np.ndarray,
967
+ val_scores_mean: np.ndarray,
968
+ val_scores_std: np.ndarray,
969
+ title: str = "Learning Curve",
970
+ figsize: Tuple[int, int] = (10, 6),
971
+ save_path: Optional[str] = None
972
+ ) -> plt.Figure:
973
+ """
974
+ Create a learning curve showing training and validation scores.
975
+
976
+ Args:
977
+ train_sizes: Array of training set sizes
978
+ train_scores_mean: Mean training scores
979
+ train_scores_std: Std of training scores
980
+ val_scores_mean: Mean validation scores
981
+ val_scores_std: Std of validation scores
982
+ title: Plot title
983
+ figsize: Figure size
984
+ save_path: Optional save path
985
+
986
+ Returns:
987
+ matplotlib Figure object
988
+ """
989
+ try:
990
+ fig, ax = plt.subplots(figsize=figsize)
991
+
992
+ # Plot training scores
993
+ ax.plot(train_sizes, train_scores_mean, 'o-', color='blue',
994
+ linewidth=2, markersize=8, label='Training Score')
995
+ ax.fill_between(train_sizes,
996
+ train_scores_mean - train_scores_std,
997
+ train_scores_mean + train_scores_std,
998
+ alpha=0.2, color='blue')
999
+
1000
+ # Plot validation scores
1001
+ ax.plot(train_sizes, val_scores_mean, 'o-', color='orange',
1002
+ linewidth=2, markersize=8, label='Validation Score')
1003
+ ax.fill_between(train_sizes,
1004
+ val_scores_mean - val_scores_std,
1005
+ val_scores_mean + val_scores_std,
1006
+ alpha=0.2, color='orange')
1007
+
1008
+ ax.set_xlabel('Training Set Size', fontsize=12)
1009
+ ax.set_ylabel('Score', fontsize=12)
1010
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
1011
+ ax.legend(loc='best', fontsize=11, framealpha=0.9)
1012
+ ax.grid(True, alpha=0.3, linestyle='--')
1013
+ plt.tight_layout()
1014
+
1015
+ if save_path:
1016
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
1017
+ print(f" ✓ Saved learning curve to {save_path}")
1018
+
1019
+ return fig
1020
+
1021
+ except Exception as e:
1022
+ print(f" ✗ Error creating learning curve: {str(e)}")
1023
+ return None
1024
+
1025
+
1026
+ # ============================================================================
1027
+ # DATA QUALITY PLOTS
1028
+ # ============================================================================
1029
+
1030
+ def create_missing_values_heatmap(
1031
+ df: pd.DataFrame,
1032
+ title: str = "Missing Values Heatmap",
1033
+ figsize: Tuple[int, int] = (12, 8),
1034
+ save_path: Optional[str] = None
1035
+ ) -> plt.Figure:
1036
+ """
1037
+ Create a heatmap showing missing values pattern.
1038
+
1039
+ Args:
1040
+ df: DataFrame to analyze
1041
+ title: Plot title
1042
+ figsize: Figure size
1043
+ save_path: Optional save path
1044
+
1045
+ Returns:
1046
+ matplotlib Figure object
1047
+ """
1048
+ try:
1049
+ fig, ax = plt.subplots(figsize=figsize)
1050
+
1051
+ # Create binary matrix (1 = missing, 0 = present)
1052
+ missing_matrix = df.isnull().astype(int)
1053
+
1054
+ # Plot heatmap
1055
+ sns.heatmap(missing_matrix.T,
1056
+ cbar=False,
1057
+ cmap='RdYlGn_r',
1058
+ ax=ax,
1059
+ yticklabels=df.columns)
1060
+
1061
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
1062
+ ax.set_xlabel('Sample Index', fontsize=12)
1063
+ ax.set_ylabel('Features', fontsize=12)
1064
+ plt.tight_layout()
1065
+
1066
+ if save_path:
1067
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
1068
+ print(f" ✓ Saved missing values heatmap to {save_path}")
1069
+
1070
+ return fig
1071
+
1072
+ except Exception as e:
1073
+ print(f" ✗ Error creating missing values heatmap: {str(e)}")
1074
+ return None
1075
+
1076
+
1077
+ def create_missing_values_bar(
1078
+ df: pd.DataFrame,
1079
+ title: str = "Missing Values by Column",
1080
+ figsize: Tuple[int, int] = (10, 6),
1081
+ save_path: Optional[str] = None
1082
+ ) -> plt.Figure:
1083
+ """
1084
+ Create a bar chart showing percentage of missing values per column.
1085
+
1086
+ Args:
1087
+ df: DataFrame to analyze
1088
+ title: Plot title
1089
+ figsize: Figure size
1090
+ save_path: Optional save path
1091
+
1092
+ Returns:
1093
+ matplotlib Figure object
1094
+ """
1095
+ try:
1096
+ # Calculate missing percentages
1097
+ missing_pct = (df.isnull().sum() / len(df) * 100).sort_values(ascending=False)
1098
+ missing_pct = missing_pct[missing_pct > 0] # Only columns with missing values
1099
+
1100
+ if len(missing_pct) == 0:
1101
+ print(" ℹ No missing values found")
1102
+ return None
1103
+
1104
+ height = max(6, len(missing_pct) * 0.3)
1105
+ fig, ax = plt.subplots(figsize=(figsize[0], height))
1106
+
1107
+ # Create horizontal bar chart
1108
+ colors = plt.cm.Reds(missing_pct / 100)
1109
+ ax.barh(range(len(missing_pct)), missing_pct.values,
1110
+ color=colors, edgecolor='black', linewidth=0.7)
1111
+
1112
+ ax.set_yticks(range(len(missing_pct)))
1113
+ ax.set_yticklabels(missing_pct.index)
1114
+ ax.set_xlabel('Missing Values (%)', fontsize=12)
1115
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
1116
+ ax.grid(True, alpha=0.3, linestyle='--', axis='x')
1117
+
1118
+ # Add percentage labels
1119
+ for i, v in enumerate(missing_pct.values):
1120
+ ax.text(v + 1, i, f'{v:.1f}%', va='center', fontsize=10)
1121
+
1122
+ plt.tight_layout()
1123
+
1124
+ if save_path:
1125
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
1126
+ print(f" ✓ Saved missing values bar chart to {save_path}")
1127
+
1128
+ return fig
1129
+
1130
+ except Exception as e:
1131
+ print(f" ✗ Error creating missing values bar chart: {str(e)}")
1132
+ return None
1133
+
1134
+
1135
+ def create_outlier_detection_boxplot(
1136
+ df: pd.DataFrame,
1137
+ columns: Optional[List[str]] = None,
1138
+ title: str = "Outlier Detection",
1139
+ figsize: Tuple[int, int] = (12, 6),
1140
+ save_path: Optional[str] = None
1141
+ ) -> plt.Figure:
1142
+ """
1143
+ Create box plots for outlier detection across multiple columns.
1144
+
1145
+ Args:
1146
+ df: DataFrame with numeric columns
1147
+ columns: Columns to plot (None = all numeric)
1148
+ title: Plot title
1149
+ figsize: Figure size
1150
+ save_path: Optional save path
1151
+
1152
+ Returns:
1153
+ matplotlib Figure object
1154
+ """
1155
+ try:
1156
+ if columns is None:
1157
+ columns = df.select_dtypes(include=[np.number]).columns.tolist()[:10]
1158
+
1159
+ return create_boxplot(df[columns], title=title, figsize=figsize, save_path=save_path)
1160
+
1161
+ except Exception as e:
1162
+ print(f" ✗ Error creating outlier detection plot: {str(e)}")
1163
+ return None
1164
+
1165
+
1166
+ def create_skewness_plot(
1167
+ df: pd.DataFrame,
1168
+ title: str = "Feature Skewness Distribution",
1169
+ figsize: Tuple[int, int] = (10, 6),
1170
+ save_path: Optional[str] = None
1171
+ ) -> plt.Figure:
1172
+ """
1173
+ Create a bar chart showing skewness of numeric features.
1174
+
1175
+ Args:
1176
+ df: DataFrame with numeric columns
1177
+ title: Plot title
1178
+ figsize: Figure size
1179
+ save_path: Optional save path
1180
+
1181
+ Returns:
1182
+ matplotlib Figure object
1183
+ """
1184
+ try:
1185
+ # Calculate skewness for numeric columns
1186
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
1187
+ skewness = df[numeric_cols].skew().sort_values(ascending=False)
1188
+
1189
+ if len(skewness) == 0:
1190
+ print(" ℹ No numeric columns to analyze")
1191
+ return None
1192
+
1193
+ height = max(6, len(skewness) * 0.3)
1194
+ fig, ax = plt.subplots(figsize=(figsize[0], height))
1195
+
1196
+ # Color by skewness level
1197
+ colors = ['green' if abs(x) < 0.5 else 'orange' if abs(x) < 1 else 'red'
1198
+ for x in skewness.values]
1199
+
1200
+ ax.barh(range(len(skewness)), skewness.values,
1201
+ color=colors, edgecolor='black', linewidth=0.7)
1202
+
1203
+ ax.set_yticks(range(len(skewness)))
1204
+ ax.set_yticklabels(skewness.index)
1205
+ ax.set_xlabel('Skewness', fontsize=12)
1206
+ ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
1207
+ ax.axvline(x=0, color='black', linestyle='-', linewidth=1)
1208
+ ax.axvline(x=-0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
1209
+ ax.axvline(x=0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
1210
+ ax.grid(True, alpha=0.3, linestyle='--', axis='x')
1211
+
1212
+ # Add legend
1213
+ from matplotlib.patches import Patch
1214
+ legend_elements = [
1215
+ Patch(facecolor='green', label='Low (|skew| < 0.5)'),
1216
+ Patch(facecolor='orange', label='Moderate (0.5 ≤ |skew| < 1)'),
1217
+ Patch(facecolor='red', label='High (|skew| ≥ 1)')
1218
+ ]
1219
+ ax.legend(handles=legend_elements, loc='best')
1220
+
1221
+ plt.tight_layout()
1222
+
1223
+ if save_path:
1224
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
1225
+ print(f" ✓ Saved skewness plot to {save_path}")
1226
+
1227
+ return fig
1228
+
1229
+ except Exception as e:
1230
+ print(f" ✗ Error creating skewness plot: {str(e)}")
1231
+ return None
1232
+
1233
+
1234
+ # ============================================================================
1235
+ # UTILITY FUNCTIONS
1236
+ # ============================================================================
1237
+
1238
+ def save_figure(fig: plt.Figure, path: str, dpi: int = 300) -> None:
1239
+ """
1240
+ Save a matplotlib figure to file.
1241
+
1242
+ Args:
1243
+ fig: Matplotlib Figure object
1244
+ path: Output file path (supports .png, .jpg, .pdf, .svg)
1245
+ dpi: Resolution (dots per inch)
1246
+ """
1247
+ try:
1248
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
1249
+ fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white')
1250
+ print(f" ✓ Saved figure to {path}")
1251
+ except Exception as e:
1252
+ print(f" ✗ Error saving figure: {str(e)}")
1253
+
1254
+
1255
+ def close_figure(fig: plt.Figure) -> None:
1256
+ """
1257
+ Close a matplotlib figure to free memory.
1258
+
1259
+ Args:
1260
+ fig: Matplotlib Figure object
1261
+ """
1262
+ if fig is not None:
1263
+ plt.close(fig)
1264
+
1265
+
1266
+ def create_subplots_grid(
1267
+ plot_data: List[Dict[str, Any]],
1268
+ rows: int,
1269
+ cols: int,
1270
+ figsize: Tuple[int, int] = (15, 12),
1271
+ title: str = "Plot Grid",
1272
+ save_path: Optional[str] = None
1273
+ ) -> plt.Figure:
1274
+ """
1275
+ Create a grid of subplots.
1276
+
1277
+ Args:
1278
+ plot_data: List of dicts with plot specifications
1279
+ rows: Number of rows
1280
+ cols: Number of columns
1281
+ figsize: Figure size
1282
+ title: Overall title
1283
+ save_path: Optional save path
1284
+
1285
+ Returns:
1286
+ matplotlib Figure object
1287
+
1288
+ Example:
1289
+ >>> plots = [
1290
+ ... {'type': 'scatter', 'x': x1, 'y': y1, 'title': 'Plot 1'},
1291
+ ... {'type': 'hist', 'data': data1, 'title': 'Plot 2'}
1292
+ ... ]
1293
+ >>> fig = create_subplots_grid(plots, 2, 2)
1294
+ """
1295
+ try:
1296
+ fig, axes = plt.subplots(rows, cols, figsize=figsize)
1297
+ axes = axes.flatten() if rows * cols > 1 else [axes]
1298
+
1299
+ for i, (ax, plot_spec) in enumerate(zip(axes, plot_data)):
1300
+ plot_type = plot_spec.get('type', 'scatter')
1301
+
1302
+ if plot_type == 'scatter':
1303
+ ax.scatter(plot_spec['x'], plot_spec['y'], alpha=0.6)
1304
+ elif plot_type == 'hist':
1305
+ ax.hist(plot_spec['data'], bins=30, edgecolor='black')
1306
+ elif plot_type == 'line':
1307
+ ax.plot(plot_spec['x'], plot_spec['y'])
1308
+
1309
+ ax.set_title(plot_spec.get('title', f'Subplot {i+1}'), fontweight='bold')
1310
+ ax.grid(True, alpha=0.3)
1311
+
1312
+ # Hide unused subplots
1313
+ for i in range(len(plot_data), len(axes)):
1314
+ axes[i].axis('off')
1315
+
1316
+ fig.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
1317
+ plt.tight_layout()
1318
+
1319
+ if save_path:
1320
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
1321
+ print(f" ✓ Saved subplot grid to {save_path}")
1322
+
1323
+ return fig
1324
+
1325
+ except Exception as e:
1326
+ print(f" ✗ Error creating subplot grid: {str(e)}")
1327
+ return None