ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
|
@@ -0,0 +1,559 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive Visualization Engine (Matplotlib + Seaborn)
|
|
3
|
+
Automatically generate all relevant plots for data analysis and model evaluation.
|
|
4
|
+
|
|
5
|
+
All functions now return matplotlib Figure objects for Gradio compatibility.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import matplotlib
|
|
9
|
+
matplotlib.use('Agg') # Non-interactive backend for Gradio
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import seaborn as sns
|
|
13
|
+
import polars as pl
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import sys
|
|
19
|
+
import os
|
|
20
|
+
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
|
|
21
|
+
|
|
22
|
+
# Add parent directory to path
|
|
23
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
24
|
+
|
|
25
|
+
from ds_agent.utils.polars_helpers import load_dataframe
|
|
26
|
+
from ds_agent.utils.validation import validate_file_exists
|
|
27
|
+
|
|
28
|
+
# Import matplotlib visualization functions
|
|
29
|
+
try:
|
|
30
|
+
from .matplotlib_visualizations import (
|
|
31
|
+
create_scatter_plot,
|
|
32
|
+
create_bar_chart,
|
|
33
|
+
create_histogram,
|
|
34
|
+
create_boxplot,
|
|
35
|
+
create_correlation_heatmap,
|
|
36
|
+
create_distribution_plot,
|
|
37
|
+
create_roc_curve,
|
|
38
|
+
create_confusion_matrix,
|
|
39
|
+
create_feature_importance,
|
|
40
|
+
create_residual_plot,
|
|
41
|
+
create_missing_values_heatmap,
|
|
42
|
+
create_missing_values_bar,
|
|
43
|
+
create_outlier_detection_boxplot,
|
|
44
|
+
save_figure,
|
|
45
|
+
close_figure
|
|
46
|
+
)
|
|
47
|
+
except ImportError:
|
|
48
|
+
# Fallback for direct execution
|
|
49
|
+
from matplotlib_visualizations import (
|
|
50
|
+
create_scatter_plot,
|
|
51
|
+
create_bar_chart,
|
|
52
|
+
create_histogram,
|
|
53
|
+
create_boxplot,
|
|
54
|
+
create_correlation_heatmap,
|
|
55
|
+
create_distribution_plot,
|
|
56
|
+
create_roc_curve,
|
|
57
|
+
create_confusion_matrix,
|
|
58
|
+
create_feature_importance,
|
|
59
|
+
create_residual_plot,
|
|
60
|
+
create_missing_values_heatmap,
|
|
61
|
+
create_missing_values_bar,
|
|
62
|
+
create_outlier_detection_boxplot,
|
|
63
|
+
save_figure,
|
|
64
|
+
close_figure
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Set global style
|
|
68
|
+
sns.set_style('whitegrid')
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def generate_all_plots(file_path: str,
|
|
72
|
+
target_col: Optional[str] = None,
|
|
73
|
+
output_dir: str = "./outputs/plots") -> Dict[str, Any]:
|
|
74
|
+
"""
|
|
75
|
+
Generate ALL plots for a dataset automatically.
|
|
76
|
+
|
|
77
|
+
Generates:
|
|
78
|
+
- Data quality plots
|
|
79
|
+
- EDA plots
|
|
80
|
+
- Distribution plots
|
|
81
|
+
- Correlation plots
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
file_path: Path to dataset
|
|
85
|
+
target_col: Optional target column
|
|
86
|
+
output_dir: Directory to save plots
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Dictionary with Figure objects and saved file paths
|
|
90
|
+
"""
|
|
91
|
+
validate_file_exists(file_path)
|
|
92
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
93
|
+
|
|
94
|
+
results = {
|
|
95
|
+
"output_directory": output_dir,
|
|
96
|
+
"plots_generated": [],
|
|
97
|
+
"figure_objects": [], # Store Figure objects
|
|
98
|
+
"plot_categories": {}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
print(f"🎨 Generating comprehensive visualizations...")
|
|
102
|
+
|
|
103
|
+
# 1. Data Quality Plots
|
|
104
|
+
quality_plots = generate_data_quality_plots(file_path, output_dir)
|
|
105
|
+
results["plot_categories"]["data_quality"] = quality_plots
|
|
106
|
+
results["plots_generated"].extend(quality_plots.get("plot_paths", []))
|
|
107
|
+
results["figure_objects"].extend(quality_plots.get("figures", []))
|
|
108
|
+
|
|
109
|
+
# 2. EDA Plots
|
|
110
|
+
eda_plots = generate_eda_plots(file_path, target_col, output_dir)
|
|
111
|
+
results["plot_categories"]["eda"] = eda_plots
|
|
112
|
+
results["plots_generated"].extend(eda_plots.get("plot_paths", []))
|
|
113
|
+
results["figure_objects"].extend(eda_plots.get("figures", []))
|
|
114
|
+
|
|
115
|
+
# 3. Distribution Plots
|
|
116
|
+
dist_plots = generate_distribution_plots(file_path, output_dir)
|
|
117
|
+
results["plot_categories"]["distributions"] = dist_plots
|
|
118
|
+
results["plots_generated"].extend(dist_plots.get("plot_paths", []))
|
|
119
|
+
results["figure_objects"].extend(dist_plots.get("figures", []))
|
|
120
|
+
|
|
121
|
+
results["total_plots"] = len(results["plots_generated"])
|
|
122
|
+
print(f"✅ Generated {results['total_plots']} plots in {output_dir}")
|
|
123
|
+
|
|
124
|
+
return results
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def generate_data_quality_plots(file_path: str, output_dir: str) -> Dict[str, Any]:
|
|
128
|
+
"""Generate plots related to data quality using Matplotlib."""
|
|
129
|
+
df = load_dataframe(file_path).to_pandas()
|
|
130
|
+
plots = []
|
|
131
|
+
figures = []
|
|
132
|
+
|
|
133
|
+
# 1. Missing values bar chart
|
|
134
|
+
missing_data = df.isnull().sum()
|
|
135
|
+
if missing_data.sum() > 0:
|
|
136
|
+
fig = create_missing_values_bar(
|
|
137
|
+
df=df,
|
|
138
|
+
title="Missing Values by Column",
|
|
139
|
+
figsize=(10, 6)
|
|
140
|
+
)
|
|
141
|
+
if fig is not None:
|
|
142
|
+
path = f"{output_dir}/missing_values.png"
|
|
143
|
+
save_figure(fig, path)
|
|
144
|
+
plots.append(path)
|
|
145
|
+
figures.append(fig)
|
|
146
|
+
print(f" ✓ Missing values plot")
|
|
147
|
+
|
|
148
|
+
# 2. Data types distribution (pie chart alternative - bar chart)
|
|
149
|
+
dtype_counts = df.dtypes.astype(str).value_counts()
|
|
150
|
+
fig = create_bar_chart(
|
|
151
|
+
categories=dtype_counts.index.tolist(),
|
|
152
|
+
values=dtype_counts.values,
|
|
153
|
+
title="Data Types Distribution",
|
|
154
|
+
xlabel="Data Type",
|
|
155
|
+
ylabel="Count",
|
|
156
|
+
figsize=(8, 6),
|
|
157
|
+
color='steelblue'
|
|
158
|
+
)
|
|
159
|
+
if fig is not None:
|
|
160
|
+
path = f"{output_dir}/data_types.png"
|
|
161
|
+
save_figure(fig, path)
|
|
162
|
+
plots.append(path)
|
|
163
|
+
figures.append(fig)
|
|
164
|
+
print(f" ✓ Data types plot")
|
|
165
|
+
|
|
166
|
+
# 3. Outlier detection (box plots)
|
|
167
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns[:6] # Limit to 6
|
|
168
|
+
if len(numeric_cols) > 0:
|
|
169
|
+
fig = create_boxplot(
|
|
170
|
+
data=df[numeric_cols],
|
|
171
|
+
title="Outlier Detection (Box Plots)",
|
|
172
|
+
figsize=(12, 6)
|
|
173
|
+
)
|
|
174
|
+
if fig is not None:
|
|
175
|
+
path = f"{output_dir}/outliers_boxplot.png"
|
|
176
|
+
save_figure(fig, path)
|
|
177
|
+
plots.append(path)
|
|
178
|
+
figures.append(fig)
|
|
179
|
+
print(f" ✓ Outlier detection plot")
|
|
180
|
+
|
|
181
|
+
return {"plot_paths": plots, "figures": figures, "n_plots": len(plots)}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def generate_eda_plots(file_path: str, target_col: Optional[str] = None, output_dir: str = "./outputs/plots/eda") -> Dict[str, Any]:
|
|
185
|
+
"""Generate exploratory data analysis plots using Matplotlib."""
|
|
186
|
+
df = load_dataframe(file_path).to_pandas()
|
|
187
|
+
plots = []
|
|
188
|
+
figures = []
|
|
189
|
+
|
|
190
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
191
|
+
if target_col and target_col in numeric_cols:
|
|
192
|
+
numeric_cols.remove(target_col)
|
|
193
|
+
|
|
194
|
+
# 1. Correlation heatmap
|
|
195
|
+
if len(numeric_cols) > 1:
|
|
196
|
+
fig = create_correlation_heatmap(
|
|
197
|
+
data=df[numeric_cols[:15]], # Limit to 15 features
|
|
198
|
+
title="Feature Correlation Matrix",
|
|
199
|
+
figsize=(12, 10),
|
|
200
|
+
annot=True,
|
|
201
|
+
cmap='RdBu_r'
|
|
202
|
+
)
|
|
203
|
+
if fig is not None:
|
|
204
|
+
path = f"{output_dir}/correlation_heatmap.png"
|
|
205
|
+
save_figure(fig, path)
|
|
206
|
+
plots.append(path)
|
|
207
|
+
figures.append(fig)
|
|
208
|
+
print(f" ✓ Correlation heatmap")
|
|
209
|
+
|
|
210
|
+
# 2. Feature relationships with target (scatter plots)
|
|
211
|
+
if target_col and target_col in df.columns and len(numeric_cols) > 0:
|
|
212
|
+
top_features = numeric_cols[:4] # Top 4 features
|
|
213
|
+
|
|
214
|
+
# Create multiple scatter plots
|
|
215
|
+
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
|
|
216
|
+
axes = axes.flatten()
|
|
217
|
+
|
|
218
|
+
for i, col in enumerate(top_features):
|
|
219
|
+
ax = axes[i]
|
|
220
|
+
ax.scatter(df[col], df[target_col], alpha=0.5, s=30,
|
|
221
|
+
c='steelblue', edgecolors='black', linewidth=0.5)
|
|
222
|
+
ax.set_xlabel(col, fontsize=11)
|
|
223
|
+
ax.set_ylabel(target_col, fontsize=11)
|
|
224
|
+
ax.set_title(f"{col} vs {target_col}", fontsize=12, fontweight='bold')
|
|
225
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
226
|
+
|
|
227
|
+
fig.suptitle(f"Top Features vs {target_col}", fontsize=14, fontweight='bold', y=0.995)
|
|
228
|
+
plt.tight_layout()
|
|
229
|
+
|
|
230
|
+
path = f"{output_dir}/feature_relationships.png"
|
|
231
|
+
save_figure(fig, path)
|
|
232
|
+
plots.append(path)
|
|
233
|
+
figures.append(fig)
|
|
234
|
+
print(f" ✓ Feature relationships plot")
|
|
235
|
+
|
|
236
|
+
# 3. Pairplot for top features (sample data for performance)
|
|
237
|
+
if len(numeric_cols) >= 3:
|
|
238
|
+
sample_size = min(1000, len(df))
|
|
239
|
+
sample_df = df[numeric_cols[:3]].sample(sample_size)
|
|
240
|
+
|
|
241
|
+
# Create pairplot using seaborn
|
|
242
|
+
pair_grid = sns.pairplot(sample_df, corner=True, diag_kind='kde',
|
|
243
|
+
plot_kws={'alpha': 0.6, 's': 20})
|
|
244
|
+
fig = pair_grid.fig
|
|
245
|
+
fig.suptitle("Feature Pairplot (Top 3 Features)", fontsize=14,
|
|
246
|
+
fontweight='bold', y=1.01)
|
|
247
|
+
|
|
248
|
+
path = f"{output_dir}/pairplot.png"
|
|
249
|
+
save_figure(fig, path)
|
|
250
|
+
plots.append(path)
|
|
251
|
+
figures.append(fig)
|
|
252
|
+
print(f" ✓ Pairplot")
|
|
253
|
+
|
|
254
|
+
return {"plot_paths": plots, "figures": figures, "n_plots": len(plots)}
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def generate_distribution_plots(file_path: str, output_dir: str) -> Dict[str, Any]:
|
|
258
|
+
"""Generate distribution analysis plots using Matplotlib."""
|
|
259
|
+
df = load_dataframe(file_path).to_pandas()
|
|
260
|
+
plots = []
|
|
261
|
+
figures = []
|
|
262
|
+
|
|
263
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns[:6]
|
|
264
|
+
|
|
265
|
+
if len(numeric_cols) > 0:
|
|
266
|
+
# Histograms for numeric features in a grid
|
|
267
|
+
n_cols = min(3, len(numeric_cols))
|
|
268
|
+
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
|
269
|
+
|
|
270
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, n_rows * 4))
|
|
271
|
+
axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
|
|
272
|
+
|
|
273
|
+
for i, col in enumerate(numeric_cols):
|
|
274
|
+
ax = axes[i]
|
|
275
|
+
data = df[col].dropna()
|
|
276
|
+
|
|
277
|
+
# Create histogram with KDE
|
|
278
|
+
ax.hist(data, bins=30, color='steelblue', edgecolor='black',
|
|
279
|
+
alpha=0.7, density=True)
|
|
280
|
+
|
|
281
|
+
# Add KDE
|
|
282
|
+
try:
|
|
283
|
+
sns.kdeplot(data, ax=ax, color='darkred', linewidth=2)
|
|
284
|
+
except:
|
|
285
|
+
pass # Skip KDE if it fails
|
|
286
|
+
|
|
287
|
+
ax.set_title(col[:25], fontsize=11, fontweight='bold')
|
|
288
|
+
ax.set_xlabel('Value', fontsize=10)
|
|
289
|
+
ax.set_ylabel('Density', fontsize=10)
|
|
290
|
+
ax.grid(True, alpha=0.3, linestyle='--')
|
|
291
|
+
|
|
292
|
+
# Hide unused subplots
|
|
293
|
+
for i in range(len(numeric_cols), len(axes)):
|
|
294
|
+
axes[i].axis('off')
|
|
295
|
+
|
|
296
|
+
fig.suptitle("Feature Distributions", fontsize=14, fontweight='bold', y=0.995)
|
|
297
|
+
plt.tight_layout()
|
|
298
|
+
|
|
299
|
+
path = f"{output_dir}/distributions_histogram.png"
|
|
300
|
+
save_figure(fig, path)
|
|
301
|
+
plots.append(path)
|
|
302
|
+
figures.append(fig)
|
|
303
|
+
print(f" ✓ Distribution histograms")
|
|
304
|
+
|
|
305
|
+
return {"plot_paths": plots, "figures": figures, "n_plots": len(plots)}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def generate_model_performance_plots(y_true, y_pred, y_pred_proba=None,
|
|
309
|
+
task_type="regression",
|
|
310
|
+
model_name="Model",
|
|
311
|
+
output_dir="./outputs/plots") -> Dict[str, Any]:
|
|
312
|
+
"""
|
|
313
|
+
Generate model performance plots using Matplotlib.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
y_true: True labels
|
|
317
|
+
y_pred: Predicted labels
|
|
318
|
+
y_pred_proba: Predicted probabilities (for classification)
|
|
319
|
+
task_type: 'classification' or 'regression'
|
|
320
|
+
model_name: Name of the model
|
|
321
|
+
output_dir: Output directory
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Dictionary with plot paths and figure objects
|
|
325
|
+
"""
|
|
326
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
327
|
+
plots = []
|
|
328
|
+
figures = []
|
|
329
|
+
|
|
330
|
+
if task_type == "classification":
|
|
331
|
+
# 1. Confusion Matrix
|
|
332
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
333
|
+
class_names = [f"Class {i}" for i in range(len(cm))]
|
|
334
|
+
|
|
335
|
+
fig = create_confusion_matrix(
|
|
336
|
+
cm=cm,
|
|
337
|
+
class_names=class_names,
|
|
338
|
+
title=f"Confusion Matrix - {model_name}",
|
|
339
|
+
show_percentages=True,
|
|
340
|
+
figsize=(10, 8)
|
|
341
|
+
)
|
|
342
|
+
if fig is not None:
|
|
343
|
+
path = f"{output_dir}/confusion_matrix_{model_name}.png"
|
|
344
|
+
save_figure(fig, path)
|
|
345
|
+
plots.append(path)
|
|
346
|
+
figures.append(fig)
|
|
347
|
+
print(f" ✓ Confusion matrix")
|
|
348
|
+
|
|
349
|
+
# 2. ROC Curve (if probabilities provided)
|
|
350
|
+
if y_pred_proba is not None and len(np.unique(y_true)) == 2:
|
|
351
|
+
y_proba = y_pred_proba[:, 1] if y_pred_proba.ndim > 1 else y_pred_proba
|
|
352
|
+
fpr, tpr, _ = roc_curve(y_true, y_proba)
|
|
353
|
+
roc_auc = auc(fpr, tpr)
|
|
354
|
+
|
|
355
|
+
models_data = {model_name: (fpr, tpr, roc_auc)}
|
|
356
|
+
fig = create_roc_curve(
|
|
357
|
+
models_data=models_data,
|
|
358
|
+
title=f"ROC Curve - {model_name}",
|
|
359
|
+
figsize=(10, 8)
|
|
360
|
+
)
|
|
361
|
+
if fig is not None:
|
|
362
|
+
path = f"{output_dir}/roc_curve_{model_name}.png"
|
|
363
|
+
save_figure(fig, path)
|
|
364
|
+
plots.append(path)
|
|
365
|
+
figures.append(fig)
|
|
366
|
+
print(f" ✓ ROC curve")
|
|
367
|
+
|
|
368
|
+
else: # Regression
|
|
369
|
+
# 1. Residual plot (Predicted vs Actual + Residuals)
|
|
370
|
+
fig = create_residual_plot(
|
|
371
|
+
y_true=y_true,
|
|
372
|
+
y_pred=y_pred,
|
|
373
|
+
title=f"Residual Analysis - {model_name}",
|
|
374
|
+
figsize=(10, 6)
|
|
375
|
+
)
|
|
376
|
+
if fig is not None:
|
|
377
|
+
path = f"{output_dir}/residuals_{model_name}.png"
|
|
378
|
+
save_figure(fig, path)
|
|
379
|
+
plots.append(path)
|
|
380
|
+
figures.append(fig)
|
|
381
|
+
print(f" ✓ Residual plot")
|
|
382
|
+
|
|
383
|
+
# 2. Residuals distribution
|
|
384
|
+
residuals = y_true - y_pred
|
|
385
|
+
fig = create_histogram(
|
|
386
|
+
data=residuals,
|
|
387
|
+
title=f"Residuals Distribution - {model_name}",
|
|
388
|
+
xlabel="Residuals",
|
|
389
|
+
ylabel="Frequency",
|
|
390
|
+
bins=30,
|
|
391
|
+
kde=True,
|
|
392
|
+
figsize=(10, 6)
|
|
393
|
+
)
|
|
394
|
+
if fig is not None:
|
|
395
|
+
path = f"{output_dir}/residuals_dist_{model_name}.png"
|
|
396
|
+
save_figure(fig, path)
|
|
397
|
+
plots.append(path)
|
|
398
|
+
figures.append(fig)
|
|
399
|
+
print(f" ✓ Residuals distribution")
|
|
400
|
+
|
|
401
|
+
return {"plot_paths": plots, "figures": figures, "n_plots": len(plots)}
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def generate_feature_importance_plot(feature_importances: Dict[str, float],
|
|
405
|
+
output_path: str = "./outputs/plots/feature_importance.png",
|
|
406
|
+
top_n: int = 20) -> str:
|
|
407
|
+
"""
|
|
408
|
+
Generate feature importance plot using Matplotlib.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
feature_importances: Dictionary of feature: importance
|
|
412
|
+
output_path: Where to save the plot
|
|
413
|
+
top_n: Number of top features to show
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
Path to saved plot
|
|
417
|
+
"""
|
|
418
|
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
419
|
+
|
|
420
|
+
# Convert dict to lists
|
|
421
|
+
features = list(feature_importances.keys())
|
|
422
|
+
importances = np.array(list(feature_importances.values()))
|
|
423
|
+
|
|
424
|
+
# Create plot
|
|
425
|
+
fig = create_feature_importance(
|
|
426
|
+
feature_names=features,
|
|
427
|
+
importances=importances,
|
|
428
|
+
title=f"Top {top_n} Feature Importances",
|
|
429
|
+
top_n=top_n,
|
|
430
|
+
figsize=(10, max(8, top_n * 0.4))
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if fig is not None:
|
|
434
|
+
save_figure(fig, output_path)
|
|
435
|
+
print(f" ✓ Feature importance plot")
|
|
436
|
+
close_figure(fig)
|
|
437
|
+
return output_path
|
|
438
|
+
|
|
439
|
+
return None
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def generate_learning_curve(train_sizes, train_scores, val_scores,
|
|
443
|
+
model_name="Model",
|
|
444
|
+
output_path="./outputs/plots/learning_curve.png") -> str:
|
|
445
|
+
"""Generate learning curve plot using Matplotlib."""
|
|
446
|
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
447
|
+
|
|
448
|
+
# Calculate mean and std
|
|
449
|
+
if isinstance(train_scores, list):
|
|
450
|
+
train_scores = np.array(train_scores)
|
|
451
|
+
val_scores = np.array(val_scores)
|
|
452
|
+
|
|
453
|
+
if train_scores.ndim > 1:
|
|
454
|
+
train_scores_mean = np.mean(train_scores, axis=1)
|
|
455
|
+
train_scores_std = np.std(train_scores, axis=1)
|
|
456
|
+
val_scores_mean = np.mean(val_scores, axis=1)
|
|
457
|
+
val_scores_std = np.std(val_scores, axis=1)
|
|
458
|
+
else:
|
|
459
|
+
train_scores_mean = train_scores
|
|
460
|
+
train_scores_std = np.zeros_like(train_scores)
|
|
461
|
+
val_scores_mean = val_scores
|
|
462
|
+
val_scores_std = np.zeros_like(val_scores)
|
|
463
|
+
|
|
464
|
+
# Create plot
|
|
465
|
+
from .matplotlib_visualizations import create_learning_curve as mlp_learning_curve
|
|
466
|
+
|
|
467
|
+
fig = mlp_learning_curve(
|
|
468
|
+
train_sizes=train_sizes,
|
|
469
|
+
train_scores_mean=train_scores_mean,
|
|
470
|
+
train_scores_std=train_scores_std,
|
|
471
|
+
val_scores_mean=val_scores_mean,
|
|
472
|
+
val_scores_std=val_scores_std,
|
|
473
|
+
title=f"Learning Curve - {model_name}",
|
|
474
|
+
figsize=(10, 6)
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
if fig is not None:
|
|
478
|
+
save_figure(fig, output_path)
|
|
479
|
+
close_figure(fig)
|
|
480
|
+
return output_path
|
|
481
|
+
|
|
482
|
+
return None
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def create_plot_gallery_html(plot_paths: List[str], output_path: str = "./outputs/plots/gallery.html") -> str:
|
|
486
|
+
"""
|
|
487
|
+
Create an HTML gallery page showing all plots (now as PNG images).
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
plot_paths: List of paths to plot files (now PNG instead of HTML)
|
|
491
|
+
output_path: Where to save the gallery
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
Path to gallery HTML
|
|
495
|
+
"""
|
|
496
|
+
html_content = """
|
|
497
|
+
<!DOCTYPE html>
|
|
498
|
+
<html>
|
|
499
|
+
<head>
|
|
500
|
+
<title>Data Analysis Plot Gallery</title>
|
|
501
|
+
<style>
|
|
502
|
+
body {
|
|
503
|
+
font-family: Arial, sans-serif;
|
|
504
|
+
margin: 20px;
|
|
505
|
+
background-color: #f5f5f5;
|
|
506
|
+
}
|
|
507
|
+
h1 {
|
|
508
|
+
color: #333;
|
|
509
|
+
text-align: center;
|
|
510
|
+
}
|
|
511
|
+
.plot-container {
|
|
512
|
+
background: white;
|
|
513
|
+
margin: 20px 0;
|
|
514
|
+
padding: 20px;
|
|
515
|
+
border-radius: 8px;
|
|
516
|
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
|
517
|
+
}
|
|
518
|
+
.plot-image {
|
|
519
|
+
width: 100%;
|
|
520
|
+
max-width: 1200px;
|
|
521
|
+
height: auto;
|
|
522
|
+
display: block;
|
|
523
|
+
margin: 0 auto;
|
|
524
|
+
}
|
|
525
|
+
.plot-title {
|
|
526
|
+
font-size: 18px;
|
|
527
|
+
font-weight: bold;
|
|
528
|
+
margin-bottom: 10px;
|
|
529
|
+
color: #555;
|
|
530
|
+
}
|
|
531
|
+
</style>
|
|
532
|
+
</head>
|
|
533
|
+
<body>
|
|
534
|
+
<h1>📊 Data Analysis Visualization Gallery</h1>
|
|
535
|
+
<p style="text-align: center; color: #666;">Total Plots: {}</p>
|
|
536
|
+
""".format(len(plot_paths))
|
|
537
|
+
|
|
538
|
+
for i, plot_path in enumerate(plot_paths, 1):
|
|
539
|
+
plot_name = Path(plot_path).stem.replace('_', ' ').title()
|
|
540
|
+
rel_path = os.path.relpath(plot_path, os.path.dirname(output_path))
|
|
541
|
+
|
|
542
|
+
html_content += f"""
|
|
543
|
+
<div class="plot-container">
|
|
544
|
+
<div class="plot-title">{i}. {plot_name}</div>
|
|
545
|
+
<img src="{rel_path}" alt="{plot_name}" class="plot-image">
|
|
546
|
+
</div>
|
|
547
|
+
"""
|
|
548
|
+
|
|
549
|
+
html_content += """
|
|
550
|
+
</body>
|
|
551
|
+
</html>
|
|
552
|
+
"""
|
|
553
|
+
|
|
554
|
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
555
|
+
with open(output_path, 'w') as f:
|
|
556
|
+
f.write(html_content)
|
|
557
|
+
|
|
558
|
+
print(f"✅ Created plot gallery: {output_path}")
|
|
559
|
+
return output_path
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Utils module initialization."""
|
|
2
|
+
|
|
3
|
+
from .polars_helpers import (
|
|
4
|
+
load_dataframe,
|
|
5
|
+
save_dataframe,
|
|
6
|
+
get_numeric_columns,
|
|
7
|
+
get_categorical_columns,
|
|
8
|
+
get_datetime_columns,
|
|
9
|
+
detect_id_columns,
|
|
10
|
+
get_column_info,
|
|
11
|
+
calculate_memory_usage,
|
|
12
|
+
split_features_target,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .validation import (
|
|
16
|
+
ValidationError,
|
|
17
|
+
validate_file_exists,
|
|
18
|
+
validate_file_format,
|
|
19
|
+
validate_dataframe,
|
|
20
|
+
validate_column_exists,
|
|
21
|
+
validate_columns_exist,
|
|
22
|
+
validate_target_column,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"load_dataframe",
|
|
27
|
+
"save_dataframe",
|
|
28
|
+
"get_numeric_columns",
|
|
29
|
+
"get_categorical_columns",
|
|
30
|
+
"get_datetime_columns",
|
|
31
|
+
"detect_id_columns",
|
|
32
|
+
"get_column_info",
|
|
33
|
+
"calculate_memory_usage",
|
|
34
|
+
"split_features_target",
|
|
35
|
+
"ValidationError",
|
|
36
|
+
"validate_file_exists",
|
|
37
|
+
"validate_file_format",
|
|
38
|
+
"validate_dataframe",
|
|
39
|
+
"validate_column_exists",
|
|
40
|
+
"validate_columns_exist",
|
|
41
|
+
"validate_target_column",
|
|
42
|
+
]
|