ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
package/src/cli.py
ADDED
|
@@ -0,0 +1,2886 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command Line Interface for Data Science Copilot
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
from rich.panel import Panel
|
|
9
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
|
10
|
+
from rich.live import Live
|
|
11
|
+
from rich import print as rprint
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import json
|
|
14
|
+
import sys
|
|
15
|
+
import os
|
|
16
|
+
import io
|
|
17
|
+
import contextlib
|
|
18
|
+
import importlib.util
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
# Add src to path
|
|
22
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
|
|
23
|
+
|
|
24
|
+
app = typer.Typer(
|
|
25
|
+
name="ds-agent",
|
|
26
|
+
help="AI-powered data science CLI for profiling, modeling, forecasting, and chat workflows.",
|
|
27
|
+
add_completion=False,
|
|
28
|
+
no_args_is_help=False,
|
|
29
|
+
rich_markup_mode="rich",
|
|
30
|
+
epilog=(
|
|
31
|
+
"Quick start: run [bold]ds-agent quickstart[/bold]\n"
|
|
32
|
+
"Examples: [bold]ds-agent analyze data.csv --target price[/bold] | "
|
|
33
|
+
"[bold]ds-agent chat data.csv[/bold]"
|
|
34
|
+
)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
sessions_app = typer.Typer(help="Manage saved chat/analysis sessions")
|
|
38
|
+
app.add_typer(sessions_app, name="sessions")
|
|
39
|
+
|
|
40
|
+
console = Console()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _print_cli_home():
|
|
44
|
+
"""Render a first-run friendly home screen for users running the CLI without arguments."""
|
|
45
|
+
banner = r"""[bold cyan]
|
|
46
|
+
____ ____ _ ____ _____ _ _ _____
|
|
47
|
+
| _ \/ ___| / \ / ___| ____| \ | |_ _|
|
|
48
|
+
| | | \___ \ / _ \| | _| _| | \| | | |
|
|
49
|
+
| |_| |___) | / ___ \ |_| | |___| |\ | | |
|
|
50
|
+
|____/|____/ /_/ \_\____|_____|_| \_| |_|
|
|
51
|
+
[/bold cyan]"""
|
|
52
|
+
console.print(banner)
|
|
53
|
+
console.print("[dim]The open data science agent CLI[/dim]\n")
|
|
54
|
+
|
|
55
|
+
table = Table(show_header=False, box=None, pad_edge=False)
|
|
56
|
+
table.add_column("Command", style="bold white", width=52, no_wrap=True)
|
|
57
|
+
table.add_column("What it does", style="white")
|
|
58
|
+
|
|
59
|
+
table.add_row("ds-agent quickstart", "Show examples for all major workflows")
|
|
60
|
+
table.add_row("ds-agent analyze <file>", "Run end-to-end automated analysis")
|
|
61
|
+
table.add_row("ds-agent chat [file]", "Start interactive session with memory")
|
|
62
|
+
table.add_row("ds-agent report <file>", "Generate full HTML profiling report")
|
|
63
|
+
table.add_row("ds-agent train <file> --target <col>", "Train baseline ML models")
|
|
64
|
+
table.add_row("ds-agent tune <file> --target <col>", "Hyperparameter optimization")
|
|
65
|
+
table.add_row("ds-agent plot <file>", "Generate charts and visualizations")
|
|
66
|
+
table.add_row("ds-agent forecast <file> --time <col> --target <col>", "Time-series forecasting")
|
|
67
|
+
table.add_row("ds-agent sessions list", "List and manage saved sessions")
|
|
68
|
+
table.add_row("ds-agent --help", "Full command reference")
|
|
69
|
+
|
|
70
|
+
console.print(table)
|
|
71
|
+
console.print("\n[dim]Tip: run [bold]ds-agent quickstart[/bold] for copy-paste commands.[/dim]")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@app.callback(invoke_without_command=True)
|
|
75
|
+
def root_callback(
|
|
76
|
+
ctx: typer.Context,
|
|
77
|
+
version: bool = typer.Option(False, "--version", "-v", help="Show CLI version and exit.")
|
|
78
|
+
):
|
|
79
|
+
"""Root CLI callback for startup UX and global flags."""
|
|
80
|
+
if version:
|
|
81
|
+
try:
|
|
82
|
+
from ds_agent import __version__
|
|
83
|
+
except Exception:
|
|
84
|
+
__version__ = "0.1.0"
|
|
85
|
+
console.print(f"ds-agent v{__version__}")
|
|
86
|
+
raise typer.Exit(0)
|
|
87
|
+
|
|
88
|
+
if ctx.invoked_subcommand is None:
|
|
89
|
+
_print_cli_home()
|
|
90
|
+
raise typer.Exit(0)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@app.command()
|
|
94
|
+
def quickstart():
|
|
95
|
+
"""Show beginner-friendly command guide with copy-paste examples."""
|
|
96
|
+
guide = Table(title="⚡ Quickstart Guide")
|
|
97
|
+
guide.add_column("Workflow", style="cyan", width=24, no_wrap=True)
|
|
98
|
+
guide.add_column("Command", style="white", no_wrap=True)
|
|
99
|
+
|
|
100
|
+
guide.add_row("Analyze", "ds-agent analyze data.csv --target price --task \"predict price\"")
|
|
101
|
+
guide.add_row("Chat", "ds-agent chat data.csv")
|
|
102
|
+
guide.add_row("Profile Report", "ds-agent report data.csv --engine ydata --output ./outputs/reports")
|
|
103
|
+
guide.add_row("Model Training", "ds-agent train data.csv --target label --task-type classification")
|
|
104
|
+
guide.add_row("Hyperparameter Tuning", "ds-agent tune data.csv --target label --model xgboost --trials 50")
|
|
105
|
+
guide.add_row("Visualizations", "ds-agent plot data.csv --type heatmap")
|
|
106
|
+
guide.add_row("Compare Datasets", "ds-agent compare train.csv test.csv --label1 train --label2 test")
|
|
107
|
+
guide.add_row("Forecast", "ds-agent forecast sales.csv --time date --target revenue --horizon 30")
|
|
108
|
+
guide.add_row("NLP", "ds-agent nlp reviews.csv --text review --task sentiment")
|
|
109
|
+
guide.add_row("Session Management", "ds-agent sessions list | ds-agent sessions resume <session_id>")
|
|
110
|
+
|
|
111
|
+
console.print()
|
|
112
|
+
console.print(guide)
|
|
113
|
+
console.print("\n[dim]Use [bold]ds-agent <command> --help[/bold] for detailed options and examples.[/dim]")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _load_module_from_src(module_name: str, relative_path: str):
|
|
117
|
+
"""Load a module directly from src to avoid triggering package-wide __init__ side effects."""
|
|
118
|
+
module_path = Path(__file__).resolve().parent / relative_path
|
|
119
|
+
spec = importlib.util.spec_from_file_location(module_name, str(module_path))
|
|
120
|
+
if spec is None or spec.loader is None:
|
|
121
|
+
raise ImportError(f"Unable to load module at {module_path}")
|
|
122
|
+
module = importlib.util.module_from_spec(spec)
|
|
123
|
+
spec.loader.exec_module(module)
|
|
124
|
+
return module
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@app.command()
|
|
128
|
+
def analyze(
|
|
129
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
130
|
+
task: str = typer.Option(
|
|
131
|
+
"Complete data science workflow: profile, clean, engineer features, and train models",
|
|
132
|
+
"--task", "-t",
|
|
133
|
+
help="Description of the analysis task"
|
|
134
|
+
),
|
|
135
|
+
target: str = typer.Option(None, "--target", "-y", help="Target column name for prediction"),
|
|
136
|
+
output: str = typer.Option("./outputs", "--output", "-o", help="Output directory"),
|
|
137
|
+
no_cache: bool = typer.Option(False, "--no-cache", help="Disable caching"),
|
|
138
|
+
reasoning: str = typer.Option("medium", "--reasoning", "-r", help="Reasoning effort (low/medium/high)")
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Analyze a dataset and perform complete data science workflow.
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
python cli.py analyze data.csv --target Survived --task "Predict survival"
|
|
145
|
+
"""
|
|
146
|
+
console.print(Panel.fit(
|
|
147
|
+
"🤖 Data Science Copilot - AI-Powered Analysis",
|
|
148
|
+
style="bold blue"
|
|
149
|
+
))
|
|
150
|
+
|
|
151
|
+
# Validate file exists
|
|
152
|
+
if not Path(file_path).exists():
|
|
153
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
154
|
+
raise typer.Exit(1)
|
|
155
|
+
|
|
156
|
+
# Initialize copilot
|
|
157
|
+
try:
|
|
158
|
+
from ds_agent.orchestrator import DataScienceCopilot
|
|
159
|
+
|
|
160
|
+
with Progress(
|
|
161
|
+
SpinnerColumn(),
|
|
162
|
+
TextColumn("[progress.description]{task.description}"),
|
|
163
|
+
console=console
|
|
164
|
+
) as progress:
|
|
165
|
+
task_init = progress.add_task("Initializing Data Science Copilot...", total=None)
|
|
166
|
+
copilot = DataScienceCopilot(reasoning_effort=reasoning)
|
|
167
|
+
progress.update(task_init, completed=True)
|
|
168
|
+
|
|
169
|
+
except Exception as e:
|
|
170
|
+
console.print(f"[red]✗ Error initializing copilot: {e}[/red]")
|
|
171
|
+
console.print("[yellow]Make sure GROQ_API_KEY is set in .env file[/yellow]")
|
|
172
|
+
raise typer.Exit(1)
|
|
173
|
+
|
|
174
|
+
# Run analysis
|
|
175
|
+
console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
|
|
176
|
+
console.print(f"🎯 [bold]Task:[/bold] {task}")
|
|
177
|
+
if target:
|
|
178
|
+
console.print(f"🎲 [bold]Target:[/bold] {target}")
|
|
179
|
+
console.print()
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
with Progress(
|
|
183
|
+
SpinnerColumn(),
|
|
184
|
+
TextColumn("[progress.description]{task.description}"),
|
|
185
|
+
console=console
|
|
186
|
+
) as progress:
|
|
187
|
+
task_analyze = progress.add_task("Running analysis workflow...", total=None)
|
|
188
|
+
|
|
189
|
+
result = copilot.analyze(
|
|
190
|
+
file_path=file_path,
|
|
191
|
+
task_description=task,
|
|
192
|
+
target_col=target,
|
|
193
|
+
use_cache=not no_cache
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
progress.update(task_analyze, completed=True)
|
|
197
|
+
|
|
198
|
+
except Exception as e:
|
|
199
|
+
console.print(f"\n[red]✗ Analysis failed: {e}[/red]")
|
|
200
|
+
raise typer.Exit(1)
|
|
201
|
+
|
|
202
|
+
# Display results
|
|
203
|
+
if result["status"] == "success":
|
|
204
|
+
console.print("\n[green]✓ Analysis Complete![/green]\n")
|
|
205
|
+
|
|
206
|
+
# Summary
|
|
207
|
+
console.print(Panel(
|
|
208
|
+
result["summary"],
|
|
209
|
+
title="📋 Analysis Summary",
|
|
210
|
+
border_style="green"
|
|
211
|
+
))
|
|
212
|
+
|
|
213
|
+
# Workflow history
|
|
214
|
+
console.print("\n[bold]🔧 Tools Executed:[/bold]")
|
|
215
|
+
for step in result["workflow_history"]:
|
|
216
|
+
tool_name = step["tool"]
|
|
217
|
+
success = step["result"].get("success", False)
|
|
218
|
+
icon = "✓" if success else "✗"
|
|
219
|
+
color = "green" if success else "red"
|
|
220
|
+
console.print(f" [{color}]{icon}[/{color}] {tool_name}")
|
|
221
|
+
|
|
222
|
+
# Stats
|
|
223
|
+
stats_table = Table(title="📊 Execution Statistics", show_header=False)
|
|
224
|
+
stats_table.add_column("Metric", style="cyan")
|
|
225
|
+
stats_table.add_column("Value", style="white")
|
|
226
|
+
|
|
227
|
+
stats_table.add_row("Iterations", str(result["iterations"]))
|
|
228
|
+
stats_table.add_row("API Calls", str(result["api_calls"]))
|
|
229
|
+
stats_table.add_row("Execution Time", f"{result['execution_time']}s")
|
|
230
|
+
|
|
231
|
+
console.print()
|
|
232
|
+
console.print(stats_table)
|
|
233
|
+
|
|
234
|
+
# Save full report
|
|
235
|
+
report_path = Path(output) / "reports" / f"analysis_{Path(file_path).stem}.json"
|
|
236
|
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|
237
|
+
with open(report_path, "w") as f:
|
|
238
|
+
json.dump(result, f, indent=2)
|
|
239
|
+
|
|
240
|
+
console.print(f"\n💾 Full report saved to: [cyan]{report_path}[/cyan]")
|
|
241
|
+
|
|
242
|
+
elif result["status"] == "error":
|
|
243
|
+
console.print(f"\n[red]✗ Error: {result['error']}[/red]")
|
|
244
|
+
raise typer.Exit(1)
|
|
245
|
+
|
|
246
|
+
else:
|
|
247
|
+
console.print(f"\n[yellow]⚠ Analysis incomplete: {result.get('message')}[/yellow]")
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@app.command()
|
|
251
|
+
def profile(
|
|
252
|
+
file_path: str = typer.Argument(..., help="Path to dataset file")
|
|
253
|
+
):
|
|
254
|
+
"""
|
|
255
|
+
Quick profile of a dataset (basic statistics and quality checks).
|
|
256
|
+
|
|
257
|
+
Example:
|
|
258
|
+
python cli.py profile data.csv
|
|
259
|
+
"""
|
|
260
|
+
from ds_agent.tools.data_profiling import profile_dataset, detect_data_quality_issues
|
|
261
|
+
|
|
262
|
+
console.print(f"\n📊 [bold]Profiling:[/bold] {file_path}\n")
|
|
263
|
+
|
|
264
|
+
# Profile
|
|
265
|
+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
|
|
266
|
+
task1 = progress.add_task("Analyzing dataset...", total=None)
|
|
267
|
+
profile = profile_dataset(file_path)
|
|
268
|
+
progress.update(task1, completed=True)
|
|
269
|
+
|
|
270
|
+
# Display basic info
|
|
271
|
+
info_table = Table(title="Dataset Information", show_header=False)
|
|
272
|
+
info_table.add_column("Property", style="cyan")
|
|
273
|
+
info_table.add_column("Value", style="white")
|
|
274
|
+
|
|
275
|
+
info_table.add_row("Rows", str(profile["shape"]["rows"]))
|
|
276
|
+
info_table.add_row("Columns", str(profile["shape"]["columns"]))
|
|
277
|
+
info_table.add_row("Memory", f"{profile['memory_usage']['total_mb']} MB")
|
|
278
|
+
info_table.add_row("Null %", f"{profile['overall_stats']['null_percentage']}%")
|
|
279
|
+
info_table.add_row("Duplicates", str(profile['overall_stats']['duplicate_rows']))
|
|
280
|
+
|
|
281
|
+
console.print()
|
|
282
|
+
console.print(info_table)
|
|
283
|
+
|
|
284
|
+
# Column types
|
|
285
|
+
console.print("\n[bold]Column Types:[/bold]")
|
|
286
|
+
console.print(f" Numeric: {len(profile['column_types']['numeric'])}")
|
|
287
|
+
console.print(f" Categorical: {len(profile['column_types']['categorical'])}")
|
|
288
|
+
console.print(f" Datetime: {len(profile['column_types']['datetime'])}")
|
|
289
|
+
|
|
290
|
+
# Detect issues
|
|
291
|
+
console.print("\n[bold]Quality Check:[/bold]")
|
|
292
|
+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
|
|
293
|
+
task2 = progress.add_task("Detecting quality issues...", total=None)
|
|
294
|
+
issues = detect_data_quality_issues(file_path)
|
|
295
|
+
progress.update(task2, completed=True)
|
|
296
|
+
|
|
297
|
+
console.print(f" 🔴 Critical: {issues['summary']['critical_count']}")
|
|
298
|
+
console.print(f" 🟡 Warnings: {issues['summary']['warning_count']}")
|
|
299
|
+
console.print(f" 🔵 Info: {issues['summary']['info_count']}")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@app.command()
|
|
303
|
+
def eda(
|
|
304
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
305
|
+
target: str = typer.Option(None, "--target", "-y", help="Target column name for supervised EDA"),
|
|
306
|
+
output: str = typer.Option("./outputs/eda/", "--output", "-o", help="Output directory for EDA artifacts"),
|
|
307
|
+
full: bool = typer.Option(False, "--full", help="Run deeper EDA analysis")
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Run exploratory data analysis with quality checks and optional root-cause insights.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
ds-agent eda sales.csv
|
|
314
|
+
ds-agent eda sales.csv --target revenue --full
|
|
315
|
+
"""
|
|
316
|
+
from ds_agent.tools.advanced_analysis import perform_eda_analysis
|
|
317
|
+
from ds_agent.tools.data_profiling import detect_data_quality_issues
|
|
318
|
+
from ds_agent.tools.advanced_insights import analyze_root_cause
|
|
319
|
+
|
|
320
|
+
console.print(Panel.fit(
|
|
321
|
+
"📈 Exploratory Data Analysis",
|
|
322
|
+
style="bold blue"
|
|
323
|
+
))
|
|
324
|
+
|
|
325
|
+
# Validate file exists
|
|
326
|
+
if not Path(file_path).exists():
|
|
327
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
328
|
+
raise typer.Exit(1)
|
|
329
|
+
|
|
330
|
+
output_dir = Path(output)
|
|
331
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
output_html = f"{str(output_dir).rstrip('/')}/eda.html"
|
|
333
|
+
|
|
334
|
+
console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
|
|
335
|
+
if target:
|
|
336
|
+
console.print(f"🎯 [bold]Target:[/bold] {target}")
|
|
337
|
+
console.print(f"🧪 [bold]Depth:[/bold] {'Full' if full else 'Standard'}")
|
|
338
|
+
console.print()
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
# Step 1: Perform EDA analysis
|
|
342
|
+
with Progress(
|
|
343
|
+
SpinnerColumn(),
|
|
344
|
+
TextColumn("[progress.description]{task.description}"),
|
|
345
|
+
console=console
|
|
346
|
+
) as progress:
|
|
347
|
+
task_eda = progress.add_task("Running comprehensive EDA...", total=None)
|
|
348
|
+
eda_result = perform_eda_analysis(
|
|
349
|
+
file_path,
|
|
350
|
+
target_col=target,
|
|
351
|
+
output_html=output_html
|
|
352
|
+
)
|
|
353
|
+
progress.update(task_eda, completed=True)
|
|
354
|
+
|
|
355
|
+
# Step 2: Detect quality issues
|
|
356
|
+
with Progress(
|
|
357
|
+
SpinnerColumn(),
|
|
358
|
+
TextColumn("[progress.description]{task.description}"),
|
|
359
|
+
console=console
|
|
360
|
+
) as progress:
|
|
361
|
+
task_quality = progress.add_task("Detecting data quality issues...", total=None)
|
|
362
|
+
quality_result = detect_data_quality_issues(file_path)
|
|
363
|
+
progress.update(task_quality, completed=True)
|
|
364
|
+
|
|
365
|
+
# Step 3: Root cause analysis (only when target is provided)
|
|
366
|
+
root_cause_result = None
|
|
367
|
+
if target:
|
|
368
|
+
with Progress(
|
|
369
|
+
SpinnerColumn(),
|
|
370
|
+
TextColumn("[progress.description]{task.description}"),
|
|
371
|
+
console=console
|
|
372
|
+
) as progress:
|
|
373
|
+
task_root = progress.add_task("Analyzing root-cause factors...", total=None)
|
|
374
|
+
if full:
|
|
375
|
+
root_cause_result = analyze_root_cause(
|
|
376
|
+
file_path,
|
|
377
|
+
target_col=target,
|
|
378
|
+
threshold_drop=0.10
|
|
379
|
+
)
|
|
380
|
+
else:
|
|
381
|
+
root_cause_result = analyze_root_cause(file_path, target_col=target)
|
|
382
|
+
progress.update(task_root, completed=True)
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
console.print(f"\n[red]✗ EDA failed: {e}[/red]")
|
|
386
|
+
raise typer.Exit(1)
|
|
387
|
+
|
|
388
|
+
# Prepare summary fields
|
|
389
|
+
dataset_shape = eda_result.get("dataset_shape", {})
|
|
390
|
+
column_types = eda_result.get("column_types", {})
|
|
391
|
+
|
|
392
|
+
ranked_issues = (
|
|
393
|
+
quality_result.get("critical", [])
|
|
394
|
+
+ quality_result.get("warning", [])
|
|
395
|
+
+ quality_result.get("info", [])
|
|
396
|
+
)
|
|
397
|
+
top_issues = ranked_issues[:3]
|
|
398
|
+
|
|
399
|
+
top_issue_lines = []
|
|
400
|
+
for issue in top_issues:
|
|
401
|
+
msg = issue.get("message", "Issue detected")
|
|
402
|
+
top_issue_lines.append(f"- {msg}")
|
|
403
|
+
if not top_issue_lines:
|
|
404
|
+
top_issue_lines.append("- No major quality issues detected")
|
|
405
|
+
|
|
406
|
+
top_corr_lines = []
|
|
407
|
+
if target and isinstance(root_cause_result, dict):
|
|
408
|
+
corr_map = root_cause_result.get("correlations", {})
|
|
409
|
+
if isinstance(corr_map, dict) and corr_map:
|
|
410
|
+
top_corrs = sorted(corr_map.items(), key=lambda x: abs(x[1]), reverse=True)[:3]
|
|
411
|
+
for feature, corr in top_corrs:
|
|
412
|
+
top_corr_lines.append(f"- {feature}: {corr:.3f}")
|
|
413
|
+
if target and not top_corr_lines:
|
|
414
|
+
top_corr_lines.append("- No strong target correlations detected")
|
|
415
|
+
|
|
416
|
+
# Step 4: Rich summary table
|
|
417
|
+
summary_table = Table(title="📊 EDA Summary", show_header=False)
|
|
418
|
+
summary_table.add_column("Metric", style="cyan", width=28)
|
|
419
|
+
summary_table.add_column("Value", style="white")
|
|
420
|
+
|
|
421
|
+
summary_table.add_row("Rows", str(dataset_shape.get("rows", "-")))
|
|
422
|
+
summary_table.add_row("Columns", str(dataset_shape.get("columns", "-")))
|
|
423
|
+
summary_table.add_row("Numeric Columns", str(column_types.get("numeric", "-")))
|
|
424
|
+
summary_table.add_row("Categorical Columns", str(column_types.get("categorical", "-")))
|
|
425
|
+
summary_table.add_row("Top 3 Quality Issues", "\n".join(top_issue_lines))
|
|
426
|
+
if target:
|
|
427
|
+
summary_table.add_row("Top 3 Correlations with Target", "\n".join(top_corr_lines))
|
|
428
|
+
|
|
429
|
+
console.print()
|
|
430
|
+
console.print(summary_table)
|
|
431
|
+
|
|
432
|
+
# Step 5: Report path
|
|
433
|
+
console.print(f"\n💾 HTML report saved to: [cyan]{output_html}[/cyan]")
|
|
434
|
+
|
|
435
|
+
quality_summary = quality_result.get("summary", {})
|
|
436
|
+
final_lines = [
|
|
437
|
+
f"Rows: {dataset_shape.get('rows', '-')}",
|
|
438
|
+
f"Columns: {dataset_shape.get('columns', '-')}",
|
|
439
|
+
f"Issues found: {quality_summary.get('total_issues', 0)}",
|
|
440
|
+
f"Mode: {'Full' if full else 'Standard'}"
|
|
441
|
+
]
|
|
442
|
+
if target:
|
|
443
|
+
final_lines.append(f"Target analyzed: {target}")
|
|
444
|
+
|
|
445
|
+
console.print(Panel(
|
|
446
|
+
"\n".join(final_lines),
|
|
447
|
+
title="✅ EDA Complete",
|
|
448
|
+
border_style="green"
|
|
449
|
+
))
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@app.command()
|
|
453
|
+
def report(
|
|
454
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
455
|
+
engine: str = typer.Option("ydata", "--engine", help="Report engine: ydata or sweetviz"),
|
|
456
|
+
target: str = typer.Option(None, "--target", "-y", help="Optional target column for supervised analysis"),
|
|
457
|
+
compare: str = typer.Option(None, "--compare", help="Optional second dataset path for Sweetviz comparison"),
|
|
458
|
+
output: str = typer.Option("./outputs/reports/", "--output", "-o", help="Output directory for HTML reports"),
|
|
459
|
+
minimal: bool = typer.Option(False, "--minimal", help="Use minimal mode for large files")
|
|
460
|
+
):
|
|
461
|
+
"""
|
|
462
|
+
Generate a full HTML profiling report using ydata-profiling or Sweetviz.
|
|
463
|
+
|
|
464
|
+
Example:
|
|
465
|
+
ds-agent report sales.csv
|
|
466
|
+
ds-agent report sales.csv --engine sweetviz --target revenue
|
|
467
|
+
ds-agent report sales.csv --compare test.csv --engine sweetviz
|
|
468
|
+
"""
|
|
469
|
+
eda_reports_module = _load_module_from_src("ds_agent_cli_eda_reports", "tools/eda_reports.py")
|
|
470
|
+
generate_ydata_profiling_report = eda_reports_module.generate_ydata_profiling_report
|
|
471
|
+
generate_sweetviz_report = eda_reports_module.generate_sweetviz_report
|
|
472
|
+
import webbrowser
|
|
473
|
+
|
|
474
|
+
console.print(Panel.fit(
|
|
475
|
+
"📄 HTML Profiling Report",
|
|
476
|
+
style="bold blue"
|
|
477
|
+
))
|
|
478
|
+
|
|
479
|
+
if not Path(file_path).exists():
|
|
480
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
481
|
+
raise typer.Exit(1)
|
|
482
|
+
|
|
483
|
+
engine_normalized = engine.strip().lower()
|
|
484
|
+
if engine_normalized not in {"ydata", "sweetviz"}:
|
|
485
|
+
console.print("[red]✗ Error: --engine must be either 'ydata' or 'sweetviz'[/red]")
|
|
486
|
+
raise typer.Exit(1)
|
|
487
|
+
|
|
488
|
+
if compare and not Path(compare).exists():
|
|
489
|
+
console.print(f"[red]✗ Error: Compare file not found: {compare}[/red]")
|
|
490
|
+
raise typer.Exit(1)
|
|
491
|
+
|
|
492
|
+
output_dir = Path(output)
|
|
493
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
494
|
+
|
|
495
|
+
report_filename = "ydata_profile.html" if engine_normalized == "ydata" else "sweetviz_report.html"
|
|
496
|
+
output_path = str(output_dir / report_filename)
|
|
497
|
+
report_title = f"{Path(file_path).stem} - Profiling Report"
|
|
498
|
+
|
|
499
|
+
console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
|
|
500
|
+
console.print(f"⚙️ [bold]Engine:[/bold] {engine_normalized}")
|
|
501
|
+
if target:
|
|
502
|
+
console.print(f"🎯 [bold]Target:[/bold] {target}")
|
|
503
|
+
if compare:
|
|
504
|
+
console.print(f"🔁 [bold]Compare:[/bold] {compare}")
|
|
505
|
+
console.print("⏳ [yellow]Generating report may take 1-2 minutes...[/yellow]\n")
|
|
506
|
+
|
|
507
|
+
try:
|
|
508
|
+
with Progress(
|
|
509
|
+
SpinnerColumn(),
|
|
510
|
+
TextColumn("[progress.description]{task.description}"),
|
|
511
|
+
console=console
|
|
512
|
+
) as progress:
|
|
513
|
+
task_report = progress.add_task("Generating HTML report...", total=None)
|
|
514
|
+
|
|
515
|
+
if engine_normalized == "ydata":
|
|
516
|
+
result = generate_ydata_profiling_report(
|
|
517
|
+
file_path,
|
|
518
|
+
output_path,
|
|
519
|
+
minimal,
|
|
520
|
+
report_title
|
|
521
|
+
)
|
|
522
|
+
else:
|
|
523
|
+
result = generate_sweetviz_report(
|
|
524
|
+
file_path,
|
|
525
|
+
target,
|
|
526
|
+
compare,
|
|
527
|
+
output_path
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
progress.update(task_report, completed=True)
|
|
531
|
+
|
|
532
|
+
except Exception as e:
|
|
533
|
+
console.print(f"[red]✗ Report generation failed: {e}[/red]")
|
|
534
|
+
raise typer.Exit(1)
|
|
535
|
+
|
|
536
|
+
if not result.get("success", False):
|
|
537
|
+
console.print(f"[red]✗ Report generation failed: {result.get('error', 'Unknown error')}[/red]")
|
|
538
|
+
raise typer.Exit(1)
|
|
539
|
+
|
|
540
|
+
report_path = result.get("report_path", output_path)
|
|
541
|
+
|
|
542
|
+
console.print(Panel(
|
|
543
|
+
f"✅ Report generated successfully!\n\nPath: {report_path}",
|
|
544
|
+
title="Report Ready",
|
|
545
|
+
border_style="green"
|
|
546
|
+
))
|
|
547
|
+
|
|
548
|
+
try:
|
|
549
|
+
webbrowser.open(str(Path(report_path).resolve().as_uri()))
|
|
550
|
+
except Exception:
|
|
551
|
+
pass
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@app.command()
|
|
555
|
+
def compare(
|
|
556
|
+
file1: str = typer.Argument(..., help="First dataset path (CSV or Parquet)"),
|
|
557
|
+
file2: str = typer.Argument(..., help="Second dataset path (CSV or Parquet)"),
|
|
558
|
+
label1: str = typer.Option("Dataset A", "--label1", help="Display label for first dataset"),
|
|
559
|
+
label2: str = typer.Option("Dataset B", "--label2", help="Display label for second dataset"),
|
|
560
|
+
output: str = typer.Option("./outputs/compare/", "--output", "-o", help="Output directory for comparison reports")
|
|
561
|
+
):
|
|
562
|
+
"""
|
|
563
|
+
Compare two datasets side-by-side for schema and distribution differences.
|
|
564
|
+
|
|
565
|
+
Example:
|
|
566
|
+
ds-agent compare train.csv test.csv
|
|
567
|
+
ds-agent compare raw.csv cleaned.csv --label1 "Before" --label2 "After"
|
|
568
|
+
"""
|
|
569
|
+
import math
|
|
570
|
+
from datetime import datetime
|
|
571
|
+
import polars as pl
|
|
572
|
+
from ds_agent.utils.polars_helpers import load_dataframe
|
|
573
|
+
eda_reports_module = _load_module_from_src("ds_agent_cli_eda_reports_compare", "tools/eda_reports.py")
|
|
574
|
+
generate_sweetviz_report = eda_reports_module.generate_sweetviz_report
|
|
575
|
+
|
|
576
|
+
if not Path(file1).exists():
|
|
577
|
+
console.print(f"[red]✗ Error: File not found: {file1}[/red]")
|
|
578
|
+
raise typer.Exit(1)
|
|
579
|
+
if not Path(file2).exists():
|
|
580
|
+
console.print(f"[red]✗ Error: File not found: {file2}[/red]")
|
|
581
|
+
raise typer.Exit(1)
|
|
582
|
+
|
|
583
|
+
output_dir = Path(output)
|
|
584
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
585
|
+
|
|
586
|
+
console.print(Panel.fit(
|
|
587
|
+
f"🔍 Dataset Compare\n{label1} vs {label2}",
|
|
588
|
+
style="bold blue"
|
|
589
|
+
))
|
|
590
|
+
|
|
591
|
+
def _fmt_num(value, digits: int = 4) -> str:
|
|
592
|
+
if value is None:
|
|
593
|
+
return "-"
|
|
594
|
+
try:
|
|
595
|
+
value = float(value)
|
|
596
|
+
except Exception:
|
|
597
|
+
return "-"
|
|
598
|
+
if math.isnan(value) or math.isinf(value):
|
|
599
|
+
return "-"
|
|
600
|
+
return f"{value:.{digits}f}"
|
|
601
|
+
|
|
602
|
+
def _col_stats(df: pl.DataFrame, col_name: str) -> dict:
|
|
603
|
+
series = df[col_name]
|
|
604
|
+
row_count = len(df)
|
|
605
|
+
null_count = series.null_count()
|
|
606
|
+
null_pct = (null_count / row_count * 100.0) if row_count > 0 else 0.0
|
|
607
|
+
|
|
608
|
+
is_numeric = series.dtype in pl.NUMERIC_DTYPES
|
|
609
|
+
mean_val = series.mean() if is_numeric else None
|
|
610
|
+
median_val = series.median() if is_numeric else None
|
|
611
|
+
std_val = series.std() if is_numeric else None
|
|
612
|
+
|
|
613
|
+
return {
|
|
614
|
+
"dtype": str(series.dtype),
|
|
615
|
+
"is_numeric": is_numeric,
|
|
616
|
+
"mean": float(mean_val) if mean_val is not None else None,
|
|
617
|
+
"median": float(median_val) if median_val is not None else None,
|
|
618
|
+
"std": float(std_val) if std_val is not None else None,
|
|
619
|
+
"null_pct": round(null_pct, 2),
|
|
620
|
+
"unique": int(series.n_unique())
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
with Progress(
|
|
624
|
+
SpinnerColumn(),
|
|
625
|
+
TextColumn("[progress.description]{task.description}"),
|
|
626
|
+
console=console
|
|
627
|
+
) as progress:
|
|
628
|
+
task = progress.add_task("Loading datasets with Polars...", total=None)
|
|
629
|
+
try:
|
|
630
|
+
df_a = load_dataframe(file1)
|
|
631
|
+
df_b = load_dataframe(file2)
|
|
632
|
+
except Exception as e:
|
|
633
|
+
console.print(f"[red]✗ Failed to load datasets: {e}[/red]")
|
|
634
|
+
raise typer.Exit(1)
|
|
635
|
+
progress.update(task, completed=True)
|
|
636
|
+
|
|
637
|
+
cols_a = set(df_a.columns)
|
|
638
|
+
cols_b = set(df_b.columns)
|
|
639
|
+
|
|
640
|
+
shared_cols = sorted(cols_a.intersection(cols_b))
|
|
641
|
+
only_a = sorted(cols_a - cols_b)
|
|
642
|
+
only_b = sorted(cols_b - cols_a)
|
|
643
|
+
|
|
644
|
+
comparison_rows = []
|
|
645
|
+
numeric_shift_rows = []
|
|
646
|
+
|
|
647
|
+
with Progress(
|
|
648
|
+
SpinnerColumn(),
|
|
649
|
+
TextColumn("[progress.description]{task.description}"),
|
|
650
|
+
console=console
|
|
651
|
+
) as progress:
|
|
652
|
+
task = progress.add_task("Computing side-by-side comparison...", total=None)
|
|
653
|
+
for col in shared_cols:
|
|
654
|
+
stats_a = _col_stats(df_a, col)
|
|
655
|
+
stats_b = _col_stats(df_b, col)
|
|
656
|
+
|
|
657
|
+
mean_diff = None
|
|
658
|
+
std_diff = None
|
|
659
|
+
if stats_a["is_numeric"] and stats_b["is_numeric"]:
|
|
660
|
+
if stats_a["mean"] is not None and stats_b["mean"] is not None:
|
|
661
|
+
mean_diff = stats_a["mean"] - stats_b["mean"]
|
|
662
|
+
if stats_a["std"] is not None and stats_b["std"] is not None:
|
|
663
|
+
std_diff = stats_a["std"] - stats_b["std"]
|
|
664
|
+
|
|
665
|
+
comparison_rows.append({
|
|
666
|
+
"column": col,
|
|
667
|
+
"a": stats_a,
|
|
668
|
+
"b": stats_b,
|
|
669
|
+
"mean_diff": mean_diff,
|
|
670
|
+
"std_diff": std_diff
|
|
671
|
+
})
|
|
672
|
+
|
|
673
|
+
if mean_diff is not None or std_diff is not None:
|
|
674
|
+
numeric_shift_rows.append({
|
|
675
|
+
"column": col,
|
|
676
|
+
"mean_diff": mean_diff,
|
|
677
|
+
"std_diff": std_diff
|
|
678
|
+
})
|
|
679
|
+
progress.update(task, completed=True)
|
|
680
|
+
|
|
681
|
+
stats_panel = Panel(
|
|
682
|
+
"\n".join([
|
|
683
|
+
f"Shared columns: {len(shared_cols)}",
|
|
684
|
+
f"Columns only in {label1}: {len(only_a)}",
|
|
685
|
+
f"Columns only in {label2}: {len(only_b)}",
|
|
686
|
+
f"Rows in {label1}: {len(df_a)}",
|
|
687
|
+
f"Rows in {label2}: {len(df_b)}"
|
|
688
|
+
]),
|
|
689
|
+
title="Overall Stats",
|
|
690
|
+
border_style="cyan"
|
|
691
|
+
)
|
|
692
|
+
console.print()
|
|
693
|
+
console.print(stats_panel)
|
|
694
|
+
|
|
695
|
+
comparison_table = Table(title=f"📋 Shared Column Comparison ({label1} vs {label2})")
|
|
696
|
+
comparison_table.add_column("Column", style="bold")
|
|
697
|
+
comparison_table.add_column(f"Type ({label1})", style="cyan")
|
|
698
|
+
comparison_table.add_column(f"Type ({label2})", style="cyan")
|
|
699
|
+
comparison_table.add_column(f"Mean ({label1})", justify="right")
|
|
700
|
+
comparison_table.add_column(f"Mean ({label2})", justify="right")
|
|
701
|
+
comparison_table.add_column(f"Median ({label1})", justify="right")
|
|
702
|
+
comparison_table.add_column(f"Median ({label2})", justify="right")
|
|
703
|
+
comparison_table.add_column(f"Null % ({label1})", justify="right")
|
|
704
|
+
comparison_table.add_column(f"Null % ({label2})", justify="right")
|
|
705
|
+
comparison_table.add_column(f"Unique ({label1})", justify="right")
|
|
706
|
+
comparison_table.add_column(f"Unique ({label2})", justify="right")
|
|
707
|
+
comparison_table.add_column("Mean Diff", justify="right")
|
|
708
|
+
comparison_table.add_column("Std Diff", justify="right")
|
|
709
|
+
|
|
710
|
+
for row in comparison_rows:
|
|
711
|
+
a_stats = row["a"]
|
|
712
|
+
b_stats = row["b"]
|
|
713
|
+
comparison_table.add_row(
|
|
714
|
+
row["column"],
|
|
715
|
+
a_stats["dtype"],
|
|
716
|
+
b_stats["dtype"],
|
|
717
|
+
_fmt_num(a_stats["mean"]),
|
|
718
|
+
_fmt_num(b_stats["mean"]),
|
|
719
|
+
_fmt_num(a_stats["median"]),
|
|
720
|
+
_fmt_num(b_stats["median"]),
|
|
721
|
+
_fmt_num(a_stats["null_pct"], 2),
|
|
722
|
+
_fmt_num(b_stats["null_pct"], 2),
|
|
723
|
+
str(a_stats["unique"]),
|
|
724
|
+
str(b_stats["unique"]),
|
|
725
|
+
_fmt_num(row["mean_diff"]),
|
|
726
|
+
_fmt_num(row["std_diff"])
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
console.print()
|
|
730
|
+
console.print(comparison_table)
|
|
731
|
+
|
|
732
|
+
if only_a or only_b:
|
|
733
|
+
missing_table = Table(title="🚨 Columns Not Shared", border_style="red")
|
|
734
|
+
missing_table.add_column("Column", style="red")
|
|
735
|
+
missing_table.add_column("Present In", style="red")
|
|
736
|
+
missing_table.add_column("Missing In", style="red")
|
|
737
|
+
|
|
738
|
+
for col in only_a:
|
|
739
|
+
missing_table.add_row(col, label1, label2, style="red")
|
|
740
|
+
for col in only_b:
|
|
741
|
+
missing_table.add_row(col, label2, label1, style="red")
|
|
742
|
+
|
|
743
|
+
console.print()
|
|
744
|
+
console.print(missing_table)
|
|
745
|
+
|
|
746
|
+
if numeric_shift_rows:
|
|
747
|
+
shift_table = Table(title="📈 Numeric Distribution Shift")
|
|
748
|
+
shift_table.add_column("Column", style="bold")
|
|
749
|
+
shift_table.add_column("Mean Diff", justify="right", style="yellow")
|
|
750
|
+
shift_table.add_column("Std Diff", justify="right", style="yellow")
|
|
751
|
+
|
|
752
|
+
for item in numeric_shift_rows:
|
|
753
|
+
shift_table.add_row(
|
|
754
|
+
item["column"],
|
|
755
|
+
_fmt_num(item["mean_diff"]),
|
|
756
|
+
_fmt_num(item["std_diff"])
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
console.print()
|
|
760
|
+
console.print(shift_table)
|
|
761
|
+
|
|
762
|
+
report_data = {
|
|
763
|
+
"generated_at": datetime.utcnow().isoformat() + "Z",
|
|
764
|
+
"dataset_a": {
|
|
765
|
+
"label": label1,
|
|
766
|
+
"path": file1,
|
|
767
|
+
"rows": len(df_a),
|
|
768
|
+
"columns": len(df_a.columns)
|
|
769
|
+
},
|
|
770
|
+
"dataset_b": {
|
|
771
|
+
"label": label2,
|
|
772
|
+
"path": file2,
|
|
773
|
+
"rows": len(df_b),
|
|
774
|
+
"columns": len(df_b.columns)
|
|
775
|
+
},
|
|
776
|
+
"summary": {
|
|
777
|
+
"shared_columns_count": len(shared_cols),
|
|
778
|
+
"columns_only_in_a_count": len(only_a),
|
|
779
|
+
"columns_only_in_b_count": len(only_b),
|
|
780
|
+
"rows_a": len(df_a),
|
|
781
|
+
"rows_b": len(df_b)
|
|
782
|
+
},
|
|
783
|
+
"columns_only_in_a": only_a,
|
|
784
|
+
"columns_only_in_b": only_b,
|
|
785
|
+
"shared_columns": comparison_rows,
|
|
786
|
+
"numeric_distribution_shift": numeric_shift_rows
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
report_path = output_dir / f"compare_{Path(file1).stem}_vs_{Path(file2).stem}.json"
|
|
790
|
+
with open(report_path, "w") as f:
|
|
791
|
+
json.dump(report_data, f, indent=2)
|
|
792
|
+
|
|
793
|
+
console.print(f"\n💾 Comparison JSON report saved to: [cyan]{report_path}[/cyan]")
|
|
794
|
+
|
|
795
|
+
if output:
|
|
796
|
+
sweetviz_path = output_dir / "sweetviz_compare_report.html"
|
|
797
|
+
with Progress(
|
|
798
|
+
SpinnerColumn(),
|
|
799
|
+
TextColumn("[progress.description]{task.description}"),
|
|
800
|
+
console=console
|
|
801
|
+
) as progress:
|
|
802
|
+
task = progress.add_task("Generating Sweetviz compare report...", total=None)
|
|
803
|
+
sweetviz_result = generate_sweetviz_report(
|
|
804
|
+
file_path=file1,
|
|
805
|
+
target_col=None,
|
|
806
|
+
compare_file_path=file2,
|
|
807
|
+
output_path=str(sweetviz_path),
|
|
808
|
+
title=f"{label1} vs {label2} - Sweetviz Comparison"
|
|
809
|
+
)
|
|
810
|
+
progress.update(task, completed=True)
|
|
811
|
+
|
|
812
|
+
if sweetviz_result.get("success", False):
|
|
813
|
+
console.print(f"🌐 Sweetviz compare report saved to: [cyan]{sweetviz_result.get('report_path', sweetviz_path)}[/cyan]")
|
|
814
|
+
else:
|
|
815
|
+
console.print(f"[yellow]⚠ Sweetviz compare report generation skipped: {sweetviz_result.get('error', 'Unknown error')}[/yellow]")
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
@app.command()
|
|
819
|
+
def bi(
|
|
820
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
821
|
+
customer_col: str = typer.Option(None, "--customer", "-c", help="Customer ID column name"),
|
|
822
|
+
date_col: str = typer.Option(None, "--date", "-d", help="Date column name"),
|
|
823
|
+
value_col: str = typer.Option(None, "--value", "-v", help="Revenue/value column name"),
|
|
824
|
+
target_col: str = typer.Option(None, "--target", "-y", help="Optional target column for churn analysis"),
|
|
825
|
+
analysis: str = typer.Option("all", "--analysis", help="Analysis type: all, cohort, rfm, kpi, churn"),
|
|
826
|
+
output: str = typer.Option("./outputs/bi/", "--output", "-o", help="Output directory for BI results")
|
|
827
|
+
):
|
|
828
|
+
"""
|
|
829
|
+
Run business intelligence analytics on a dataset.
|
|
830
|
+
|
|
831
|
+
Example:
|
|
832
|
+
ds-agent bi transactions.csv --customer customer_id --date order_date --value amount
|
|
833
|
+
ds-agent bi customers.csv --customer id --date signup_date --analysis cohort
|
|
834
|
+
"""
|
|
835
|
+
from datetime import datetime
|
|
836
|
+
import math
|
|
837
|
+
import pandas as pd
|
|
838
|
+
from ds_agent.utils.polars_helpers import load_dataframe
|
|
839
|
+
import ds_agent.tools.business_intelligence as bi_tools
|
|
840
|
+
|
|
841
|
+
if not Path(file_path).exists():
|
|
842
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
843
|
+
raise typer.Exit(1)
|
|
844
|
+
|
|
845
|
+
analysis = analysis.strip().lower()
|
|
846
|
+
valid_analyses = {"all", "cohort", "rfm", "kpi", "churn"}
|
|
847
|
+
if analysis not in valid_analyses:
|
|
848
|
+
console.print("[red]✗ Error: --analysis must be one of all, cohort, rfm, kpi, churn[/red]")
|
|
849
|
+
raise typer.Exit(1)
|
|
850
|
+
|
|
851
|
+
output_dir = Path(output)
|
|
852
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
853
|
+
|
|
854
|
+
with Progress(
|
|
855
|
+
SpinnerColumn(),
|
|
856
|
+
TextColumn("[progress.description]{task.description}"),
|
|
857
|
+
console=console
|
|
858
|
+
) as progress:
|
|
859
|
+
task = progress.add_task("Loading dataset...", total=None)
|
|
860
|
+
try:
|
|
861
|
+
df = load_dataframe(file_path)
|
|
862
|
+
except Exception as e:
|
|
863
|
+
console.print(f"[red]✗ Failed to load dataset: {e}[/red]")
|
|
864
|
+
raise typer.Exit(1)
|
|
865
|
+
progress.update(task, completed=True)
|
|
866
|
+
|
|
867
|
+
def _format_metric(value) -> str:
|
|
868
|
+
if value is None:
|
|
869
|
+
return "-"
|
|
870
|
+
if isinstance(value, bool):
|
|
871
|
+
return "Yes" if value else "No"
|
|
872
|
+
if isinstance(value, int):
|
|
873
|
+
return f"{value:,}"
|
|
874
|
+
if isinstance(value, float):
|
|
875
|
+
if math.isnan(value) or math.isinf(value):
|
|
876
|
+
return "-"
|
|
877
|
+
if abs(value) >= 1000:
|
|
878
|
+
return f"{value:,.2f}"
|
|
879
|
+
return f"{value:.4f}"
|
|
880
|
+
return str(value)
|
|
881
|
+
|
|
882
|
+
def _trend_arrow(trend_value) -> str:
|
|
883
|
+
if trend_value is None:
|
|
884
|
+
return "→"
|
|
885
|
+
if isinstance(trend_value, str):
|
|
886
|
+
text = trend_value.strip().lower()
|
|
887
|
+
if any(token in text for token in ["up", "increase", "growing", "higher", "+"]):
|
|
888
|
+
return "↑"
|
|
889
|
+
if any(token in text for token in ["down", "decrease", "decline", "lower", "-"]):
|
|
890
|
+
return "↓"
|
|
891
|
+
return "→"
|
|
892
|
+
try:
|
|
893
|
+
numeric = float(trend_value)
|
|
894
|
+
except Exception:
|
|
895
|
+
return "→"
|
|
896
|
+
if numeric > 0:
|
|
897
|
+
return "↑"
|
|
898
|
+
if numeric < 0:
|
|
899
|
+
return "↓"
|
|
900
|
+
return "→"
|
|
901
|
+
|
|
902
|
+
validation_errors = []
|
|
903
|
+
required_columns = set()
|
|
904
|
+
|
|
905
|
+
def _require_option(value: Optional[str], option_name: str, analysis_name: str):
|
|
906
|
+
if not value:
|
|
907
|
+
validation_errors.append(f"{option_name} is required for '{analysis_name}' analysis")
|
|
908
|
+
else:
|
|
909
|
+
required_columns.add(value)
|
|
910
|
+
|
|
911
|
+
if analysis in {"kpi", "all"}:
|
|
912
|
+
_require_option(date_col, "--date/-d", "kpi")
|
|
913
|
+
_require_option(value_col, "--value/-v", "kpi")
|
|
914
|
+
|
|
915
|
+
if analysis in {"cohort", "all"}:
|
|
916
|
+
_require_option(customer_col, "--customer/-c", "cohort")
|
|
917
|
+
_require_option(date_col, "--date/-d", "cohort")
|
|
918
|
+
|
|
919
|
+
if analysis in {"rfm", "all"}:
|
|
920
|
+
_require_option(customer_col, "--customer/-c", "rfm")
|
|
921
|
+
_require_option(date_col, "--date/-d", "rfm")
|
|
922
|
+
_require_option(value_col, "--value/-v", "rfm")
|
|
923
|
+
|
|
924
|
+
if analysis in {"churn", "all"}:
|
|
925
|
+
_require_option(target_col, "--target/-y", "churn")
|
|
926
|
+
|
|
927
|
+
if validation_errors:
|
|
928
|
+
for err in validation_errors:
|
|
929
|
+
console.print(f"[red]✗ {err}[/red]")
|
|
930
|
+
raise typer.Exit(1)
|
|
931
|
+
|
|
932
|
+
missing_cols = [col for col in sorted(required_columns) if col not in df.columns]
|
|
933
|
+
if missing_cols:
|
|
934
|
+
console.print(f"[red]✗ Required columns not found in dataset: {', '.join(missing_cols)}[/red]")
|
|
935
|
+
raise typer.Exit(1)
|
|
936
|
+
|
|
937
|
+
console.print(Panel.fit(
|
|
938
|
+
f"💼 Business Intelligence Analytics\nAnalysis: {analysis}",
|
|
939
|
+
style="bold blue"
|
|
940
|
+
))
|
|
941
|
+
|
|
942
|
+
results = {
|
|
943
|
+
"meta": {
|
|
944
|
+
"file_path": file_path,
|
|
945
|
+
"analysis": analysis,
|
|
946
|
+
"rows": len(df),
|
|
947
|
+
"columns": len(df.columns),
|
|
948
|
+
"output_dir": str(output_dir),
|
|
949
|
+
"generated_at": datetime.utcnow().isoformat() + "Z"
|
|
950
|
+
},
|
|
951
|
+
"kpi": None,
|
|
952
|
+
"cohort": None,
|
|
953
|
+
"rfm": None,
|
|
954
|
+
"churn": None
|
|
955
|
+
}
|
|
956
|
+
actionable_insights = []
|
|
957
|
+
|
|
958
|
+
if analysis in {"kpi", "all"}:
|
|
959
|
+
with Progress(
|
|
960
|
+
SpinnerColumn(),
|
|
961
|
+
TextColumn("[progress.description]{task.description}"),
|
|
962
|
+
console=console
|
|
963
|
+
) as progress:
|
|
964
|
+
task = progress.add_task("Calculating KPI analytics...", total=None)
|
|
965
|
+
|
|
966
|
+
calculate_kpis = getattr(bi_tools, "calculate_kpis", None)
|
|
967
|
+
if callable(calculate_kpis):
|
|
968
|
+
try:
|
|
969
|
+
kpi_result = calculate_kpis(
|
|
970
|
+
data=df,
|
|
971
|
+
customer_id_column=customer_col,
|
|
972
|
+
date_column=date_col,
|
|
973
|
+
value_column=value_col
|
|
974
|
+
)
|
|
975
|
+
except TypeError:
|
|
976
|
+
try:
|
|
977
|
+
kpi_result = calculate_kpis(df, customer_col, date_col, value_col)
|
|
978
|
+
except Exception as e:
|
|
979
|
+
console.print(f"[yellow]⚠ calculate_kpis failed, using fallback KPI computation: {e}[/yellow]")
|
|
980
|
+
kpi_result = None
|
|
981
|
+
else:
|
|
982
|
+
kpi_result = None
|
|
983
|
+
|
|
984
|
+
if kpi_result is None:
|
|
985
|
+
base_kpis = {
|
|
986
|
+
"Total Rows": len(df),
|
|
987
|
+
"Total Revenue": float(df[value_col].sum()) if value_col else None,
|
|
988
|
+
"Average Value": float(df[value_col].mean()) if value_col else None,
|
|
989
|
+
"Unique Customers": int(df[customer_col].n_unique()) if customer_col else None,
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
trend_pct = None
|
|
993
|
+
if date_col and value_col:
|
|
994
|
+
try:
|
|
995
|
+
trend_df = df.select([date_col, value_col]).to_pandas()
|
|
996
|
+
trend_df[date_col] = pd.to_datetime(trend_df[date_col], errors="coerce")
|
|
997
|
+
trend_df[value_col] = pd.to_numeric(trend_df[value_col], errors="coerce")
|
|
998
|
+
trend_df = trend_df.dropna(subset=[date_col, value_col]).sort_values(date_col)
|
|
999
|
+
if len(trend_df) >= 6:
|
|
1000
|
+
mid = len(trend_df) // 2
|
|
1001
|
+
first_avg = trend_df.iloc[:mid][value_col].mean()
|
|
1002
|
+
second_avg = trend_df.iloc[mid:][value_col].mean()
|
|
1003
|
+
if first_avg and not pd.isna(first_avg):
|
|
1004
|
+
trend_pct = float(((second_avg - first_avg) / abs(first_avg)) * 100.0)
|
|
1005
|
+
except Exception:
|
|
1006
|
+
trend_pct = None
|
|
1007
|
+
|
|
1008
|
+
kpi_result = {
|
|
1009
|
+
"kpis": [{"name": k, "value": v, "trend": trend_pct if k in {"Total Revenue", "Average Value"} else None} for k, v in base_kpis.items() if v is not None],
|
|
1010
|
+
"fallback": True
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
progress.update(task, completed=True)
|
|
1014
|
+
|
|
1015
|
+
kpi_table = Table(title="📊 KPI Summary")
|
|
1016
|
+
kpi_table.add_column("KPI", style="cyan")
|
|
1017
|
+
kpi_table.add_column("Value", style="white", justify="right")
|
|
1018
|
+
kpi_table.add_column("Trend", style="white", justify="center")
|
|
1019
|
+
|
|
1020
|
+
kpi_items = []
|
|
1021
|
+
if isinstance(kpi_result, dict) and isinstance(kpi_result.get("kpis"), list):
|
|
1022
|
+
for item in kpi_result.get("kpis", []):
|
|
1023
|
+
if isinstance(item, dict):
|
|
1024
|
+
kpi_items.append({
|
|
1025
|
+
"name": str(item.get("name", "KPI")),
|
|
1026
|
+
"value": item.get("value"),
|
|
1027
|
+
"trend": item.get("trend")
|
|
1028
|
+
})
|
|
1029
|
+
|
|
1030
|
+
if not kpi_items and isinstance(kpi_result, dict):
|
|
1031
|
+
for key, val in kpi_result.items():
|
|
1032
|
+
if isinstance(val, (str, int, float, bool)):
|
|
1033
|
+
kpi_items.append({"name": str(key), "value": val, "trend": None})
|
|
1034
|
+
|
|
1035
|
+
for item in kpi_items:
|
|
1036
|
+
arrow = _trend_arrow(item.get("trend"))
|
|
1037
|
+
arrow_style = "green" if arrow == "↑" else ("red" if arrow == "↓" else "yellow")
|
|
1038
|
+
kpi_table.add_row(
|
|
1039
|
+
item["name"],
|
|
1040
|
+
_format_metric(item.get("value")),
|
|
1041
|
+
f"[{arrow_style}]{arrow}[/{arrow_style}]"
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
console.print()
|
|
1045
|
+
console.print(kpi_table)
|
|
1046
|
+
results["kpi"] = kpi_result
|
|
1047
|
+
|
|
1048
|
+
revenue_item = next((item for item in kpi_items if "revenue" in item["name"].lower()), None)
|
|
1049
|
+
if revenue_item:
|
|
1050
|
+
arrow = _trend_arrow(revenue_item.get("trend"))
|
|
1051
|
+
if arrow == "↓":
|
|
1052
|
+
actionable_insights.append("Revenue momentum is declining; review recent pricing, acquisition channels, or discount strategy.")
|
|
1053
|
+
elif arrow == "↑":
|
|
1054
|
+
actionable_insights.append("Revenue momentum is improving; scale the channels or offers driving the uplift.")
|
|
1055
|
+
|
|
1056
|
+
if analysis in {"cohort", "all"}:
|
|
1057
|
+
with Progress(
|
|
1058
|
+
SpinnerColumn(),
|
|
1059
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1060
|
+
console=console
|
|
1061
|
+
) as progress:
|
|
1062
|
+
task = progress.add_task("Running cohort retention analysis...", total=None)
|
|
1063
|
+
cohort_result = bi_tools.perform_cohort_analysis(
|
|
1064
|
+
data=df,
|
|
1065
|
+
customer_id_column=customer_col,
|
|
1066
|
+
date_column=date_col,
|
|
1067
|
+
value_column=value_col,
|
|
1068
|
+
cohort_period="monthly",
|
|
1069
|
+
metric="retention"
|
|
1070
|
+
)
|
|
1071
|
+
progress.update(task, completed=True)
|
|
1072
|
+
|
|
1073
|
+
cohort_matrix = cohort_result.get("cohort_matrix", {})
|
|
1074
|
+
period_keys = sorted(cohort_matrix.keys(), key=lambda x: int(x) if str(x).isdigit() else str(x))
|
|
1075
|
+
cohort_labels = sorted({str(label) for inner in cohort_matrix.values() if isinstance(inner, dict) for label in inner.keys()})
|
|
1076
|
+
|
|
1077
|
+
retention_table = Table(title="👥 Cohort Retention Matrix")
|
|
1078
|
+
retention_table.add_column("Cohort", style="bold")
|
|
1079
|
+
for key in period_keys:
|
|
1080
|
+
retention_table.add_column(f"P{key}", justify="right")
|
|
1081
|
+
|
|
1082
|
+
for cohort_label in cohort_labels:
|
|
1083
|
+
row_values = [cohort_label]
|
|
1084
|
+
for key in period_keys:
|
|
1085
|
+
val = None
|
|
1086
|
+
inner = cohort_matrix.get(key, {})
|
|
1087
|
+
if isinstance(inner, dict):
|
|
1088
|
+
val = inner.get(cohort_label)
|
|
1089
|
+
|
|
1090
|
+
if val is None:
|
|
1091
|
+
cell = "-"
|
|
1092
|
+
else:
|
|
1093
|
+
try:
|
|
1094
|
+
rate = float(val)
|
|
1095
|
+
except Exception:
|
|
1096
|
+
cell = str(val)
|
|
1097
|
+
else:
|
|
1098
|
+
pct = rate * 100.0
|
|
1099
|
+
if pct >= 70:
|
|
1100
|
+
cell = f"[green]{pct:.1f}%[/green]"
|
|
1101
|
+
elif pct >= 40:
|
|
1102
|
+
cell = f"[yellow]{pct:.1f}%[/yellow]"
|
|
1103
|
+
else:
|
|
1104
|
+
cell = f"[red]{pct:.1f}%[/red]"
|
|
1105
|
+
|
|
1106
|
+
row_values.append(cell)
|
|
1107
|
+
|
|
1108
|
+
retention_table.add_row(*row_values)
|
|
1109
|
+
|
|
1110
|
+
console.print()
|
|
1111
|
+
console.print(retention_table)
|
|
1112
|
+
results["cohort"] = cohort_result
|
|
1113
|
+
|
|
1114
|
+
for insight in cohort_result.get("insights", [])[:2]:
|
|
1115
|
+
actionable_insights.append(str(insight))
|
|
1116
|
+
|
|
1117
|
+
if analysis in {"rfm", "all"}:
|
|
1118
|
+
with Progress(
|
|
1119
|
+
SpinnerColumn(),
|
|
1120
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1121
|
+
console=console
|
|
1122
|
+
) as progress:
|
|
1123
|
+
task = progress.add_task("Running RFM segmentation...", total=None)
|
|
1124
|
+
|
|
1125
|
+
rfm_fn = getattr(bi_tools, "rfm_segmentation", None)
|
|
1126
|
+
if callable(rfm_fn):
|
|
1127
|
+
try:
|
|
1128
|
+
rfm_result = rfm_fn(
|
|
1129
|
+
data=df,
|
|
1130
|
+
customer_id_column=customer_col,
|
|
1131
|
+
date_column=date_col,
|
|
1132
|
+
value_column=value_col
|
|
1133
|
+
)
|
|
1134
|
+
except TypeError:
|
|
1135
|
+
rfm_result = rfm_fn(df, customer_col, date_col, value_col)
|
|
1136
|
+
else:
|
|
1137
|
+
rfm_result = bi_tools.perform_rfm_analysis(
|
|
1138
|
+
data=df,
|
|
1139
|
+
customer_id_column=customer_col,
|
|
1140
|
+
date_column=date_col,
|
|
1141
|
+
value_column=value_col
|
|
1142
|
+
)
|
|
1143
|
+
|
|
1144
|
+
progress.update(task, completed=True)
|
|
1145
|
+
|
|
1146
|
+
segment_summary = rfm_result.get("segment_summary", {}) if isinstance(rfm_result, dict) else {}
|
|
1147
|
+
rfm_table = Table(title="🎯 Customer Segments (RFM)")
|
|
1148
|
+
rfm_table.add_column("Segment", style="cyan")
|
|
1149
|
+
rfm_table.add_column("Customers", justify="right")
|
|
1150
|
+
rfm_table.add_column("Share", justify="right")
|
|
1151
|
+
rfm_table.add_column("Avg Monetary", justify="right")
|
|
1152
|
+
|
|
1153
|
+
if isinstance(segment_summary, dict) and segment_summary:
|
|
1154
|
+
sorted_segments = sorted(segment_summary.items(), key=lambda x: x[1].get("count", 0), reverse=True)
|
|
1155
|
+
for segment, stats in sorted_segments:
|
|
1156
|
+
rfm_table.add_row(
|
|
1157
|
+
str(segment),
|
|
1158
|
+
_format_metric(stats.get("count")),
|
|
1159
|
+
f"{float(stats.get('percentage', 0.0)):.1f}%",
|
|
1160
|
+
_format_metric(stats.get("avg_monetary"))
|
|
1161
|
+
)
|
|
1162
|
+
top_segment = sorted_segments[0][0]
|
|
1163
|
+
actionable_insights.append(f"Your largest segment is '{top_segment}'; tailor campaigns specifically for this group to improve conversion.")
|
|
1164
|
+
else:
|
|
1165
|
+
rfm_table.add_row("-", "-", "-", "-")
|
|
1166
|
+
|
|
1167
|
+
console.print()
|
|
1168
|
+
console.print(rfm_table)
|
|
1169
|
+
results["rfm"] = rfm_result
|
|
1170
|
+
|
|
1171
|
+
if analysis in {"churn", "all"} and target_col:
|
|
1172
|
+
with Progress(
|
|
1173
|
+
SpinnerColumn(),
|
|
1174
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1175
|
+
console=console
|
|
1176
|
+
) as progress:
|
|
1177
|
+
task = progress.add_task("Running churn analysis...", total=None)
|
|
1178
|
+
|
|
1179
|
+
churn_fn = getattr(bi_tools, "churn_prediction", None)
|
|
1180
|
+
if callable(churn_fn):
|
|
1181
|
+
try:
|
|
1182
|
+
churn_result = churn_fn(data=df, target_column=target_col)
|
|
1183
|
+
except TypeError:
|
|
1184
|
+
try:
|
|
1185
|
+
churn_result = churn_fn(df, target_col)
|
|
1186
|
+
except Exception as e:
|
|
1187
|
+
churn_result = {"error": str(e)}
|
|
1188
|
+
else:
|
|
1189
|
+
target_series = df[target_col].drop_nulls()
|
|
1190
|
+
unique_vals = set([str(v).strip().lower() for v in target_series.unique().to_list()])
|
|
1191
|
+
positive_tokens = {"1", "true", "yes", "y", "churn", "churned"}
|
|
1192
|
+
churn_rate = None
|
|
1193
|
+
|
|
1194
|
+
try:
|
|
1195
|
+
if unique_vals.issubset({"0", "1", "true", "false", "yes", "no", "y", "n", "churn", "churned", "active"}):
|
|
1196
|
+
mapped = target_series.cast(str).str.to_lowercase()
|
|
1197
|
+
positives = mapped.is_in(list(positive_tokens)).sum()
|
|
1198
|
+
churn_rate = float(positives / len(mapped)) if len(mapped) else None
|
|
1199
|
+
else:
|
|
1200
|
+
numeric_target = pd.to_numeric(target_series.to_list(), errors="coerce")
|
|
1201
|
+
numeric_target = pd.Series(numeric_target).dropna()
|
|
1202
|
+
if len(numeric_target) and numeric_target.nunique() <= 2:
|
|
1203
|
+
churn_rate = float(numeric_target.mean())
|
|
1204
|
+
except Exception:
|
|
1205
|
+
churn_rate = None
|
|
1206
|
+
|
|
1207
|
+
churn_result = {
|
|
1208
|
+
"status": "fallback",
|
|
1209
|
+
"message": "churn_prediction function not found; returning churn baseline summary.",
|
|
1210
|
+
"target_column": target_col,
|
|
1211
|
+
"rows_evaluated": int(len(target_series)),
|
|
1212
|
+
"churn_rate": churn_rate,
|
|
1213
|
+
"unique_target_values": sorted(list(unique_vals))
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
progress.update(task, completed=True)
|
|
1217
|
+
|
|
1218
|
+
churn_table = Table(title="📉 Churn Analysis")
|
|
1219
|
+
churn_table.add_column("Metric", style="cyan")
|
|
1220
|
+
churn_table.add_column("Value", justify="right")
|
|
1221
|
+
|
|
1222
|
+
if isinstance(churn_result, dict):
|
|
1223
|
+
churn_table.add_row("Target Column", str(churn_result.get("target_column", target_col)))
|
|
1224
|
+
if churn_result.get("churn_rate") is not None:
|
|
1225
|
+
churn_rate_pct = float(churn_result.get("churn_rate")) * 100.0
|
|
1226
|
+
churn_table.add_row("Estimated Churn Rate", f"{churn_rate_pct:.2f}%")
|
|
1227
|
+
if churn_rate_pct >= 30:
|
|
1228
|
+
actionable_insights.append("Estimated churn is high; prioritize retention campaigns for at-risk customers this cycle.")
|
|
1229
|
+
elif churn_rate_pct >= 15:
|
|
1230
|
+
actionable_insights.append("Estimated churn is moderate; test proactive customer success outreach on vulnerable segments.")
|
|
1231
|
+
else:
|
|
1232
|
+
actionable_insights.append("Estimated churn is low; maintain retention programs and monitor leading indicators weekly.")
|
|
1233
|
+
for key in ["status", "message", "rows_evaluated"]:
|
|
1234
|
+
if key in churn_result:
|
|
1235
|
+
churn_table.add_row(key.replace("_", " ").title(), _format_metric(churn_result.get(key)))
|
|
1236
|
+
|
|
1237
|
+
console.print()
|
|
1238
|
+
console.print(churn_table)
|
|
1239
|
+
results["churn"] = churn_result
|
|
1240
|
+
|
|
1241
|
+
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
1242
|
+
for key in ["kpi", "cohort", "rfm", "churn"]:
|
|
1243
|
+
if results.get(key) is not None:
|
|
1244
|
+
partial_path = output_dir / f"{key}_results_{timestamp}.json"
|
|
1245
|
+
with open(partial_path, "w") as f:
|
|
1246
|
+
json.dump(results[key], f, indent=2, default=str)
|
|
1247
|
+
|
|
1248
|
+
summary_path = output_dir / f"bi_results_{timestamp}.json"
|
|
1249
|
+
with open(summary_path, "w") as f:
|
|
1250
|
+
json.dump(results, f, indent=2, default=str)
|
|
1251
|
+
|
|
1252
|
+
console.print(f"\n💾 BI results saved to: [cyan]{output_dir}[/cyan]")
|
|
1253
|
+
console.print(f"📄 Combined summary: [cyan]{summary_path}[/cyan]")
|
|
1254
|
+
|
|
1255
|
+
deduped_insights = []
|
|
1256
|
+
for insight in actionable_insights:
|
|
1257
|
+
clean = str(insight).strip()
|
|
1258
|
+
if clean and clean not in deduped_insights:
|
|
1259
|
+
deduped_insights.append(clean)
|
|
1260
|
+
|
|
1261
|
+
while len(deduped_insights) < 3:
|
|
1262
|
+
fallback_pool = [
|
|
1263
|
+
"Monitor weekly KPI trends and quickly investigate any sudden drops.",
|
|
1264
|
+
"Segment customers by lifecycle stage and personalize messaging for each segment.",
|
|
1265
|
+
"Use cohort retention patterns to prioritize onboarding improvements in the first customer period."
|
|
1266
|
+
]
|
|
1267
|
+
for candidate in fallback_pool:
|
|
1268
|
+
if candidate not in deduped_insights:
|
|
1269
|
+
deduped_insights.append(candidate)
|
|
1270
|
+
if len(deduped_insights) >= 3:
|
|
1271
|
+
break
|
|
1272
|
+
|
|
1273
|
+
top3 = deduped_insights[:3]
|
|
1274
|
+
insights_text = "\n".join([f"{idx}. {insight}" for idx, insight in enumerate(top3, start=1)])
|
|
1275
|
+
console.print()
|
|
1276
|
+
console.print(Panel(
|
|
1277
|
+
insights_text,
|
|
1278
|
+
title="💡 Top 3 Business Insights",
|
|
1279
|
+
border_style="green"
|
|
1280
|
+
))
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
@app.command()
|
|
1284
|
+
def chat(
|
|
1285
|
+
file_path: Optional[str] = typer.Argument(None, help="Optional dataset path to auto-load at startup"),
|
|
1286
|
+
session_id: str = typer.Option(None, "--session", "-s", help="Resume a specific session ID"),
|
|
1287
|
+
model: str = typer.Option("llama-3.3-70b-versatile", "--model", help="LLM model to use")
|
|
1288
|
+
):
|
|
1289
|
+
"""
|
|
1290
|
+
Launch an interactive chat session for follow-up questions on datasets.
|
|
1291
|
+
|
|
1292
|
+
Example:
|
|
1293
|
+
ds-agent chat sales.csv
|
|
1294
|
+
ds-agent chat --session abc123
|
|
1295
|
+
"""
|
|
1296
|
+
from datetime import datetime
|
|
1297
|
+
from rich.prompt import Prompt
|
|
1298
|
+
from rich.markdown import Markdown
|
|
1299
|
+
from ds_agent.session_store import SessionStore
|
|
1300
|
+
from ds_agent.session_memory import SessionMemory
|
|
1301
|
+
from ds_agent.tools.data_profiling import profile_dataset
|
|
1302
|
+
|
|
1303
|
+
console.print(Panel.fit(
|
|
1304
|
+
"DS-Agent Chat — type 'help' for commands, 'exit' to quit",
|
|
1305
|
+
title="💬 Chat",
|
|
1306
|
+
border_style="blue"
|
|
1307
|
+
))
|
|
1308
|
+
|
|
1309
|
+
store = SessionStore()
|
|
1310
|
+
save_session_fn = getattr(store, "save_session", None) or getattr(store, "save", None)
|
|
1311
|
+
load_session_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
|
|
1312
|
+
|
|
1313
|
+
if session_id:
|
|
1314
|
+
session_memory = load_session_fn(session_id) if callable(load_session_fn) else None
|
|
1315
|
+
if session_memory is None:
|
|
1316
|
+
console.print(f"[yellow]⚠ Session '{session_id}' not found. Creating a new one.[/yellow]")
|
|
1317
|
+
session_memory = SessionMemory(session_id=session_id)
|
|
1318
|
+
else:
|
|
1319
|
+
session_memory = SessionMemory()
|
|
1320
|
+
|
|
1321
|
+
current_dataset = file_path or session_memory.last_dataset
|
|
1322
|
+
last_analysis_result = None
|
|
1323
|
+
|
|
1324
|
+
def _print_dataset_summary(profile: dict, path: str):
|
|
1325
|
+
shape = profile.get("shape", {})
|
|
1326
|
+
col_types = profile.get("column_types", {})
|
|
1327
|
+
overall = profile.get("overall_stats", {})
|
|
1328
|
+
|
|
1329
|
+
summary_table = Table(title=f"📊 Dataset Summary: {Path(path).name}", show_header=False)
|
|
1330
|
+
summary_table.add_column("Metric", style="cyan", width=26)
|
|
1331
|
+
summary_table.add_column("Value", style="white")
|
|
1332
|
+
|
|
1333
|
+
summary_table.add_row("File", str(path))
|
|
1334
|
+
summary_table.add_row("Rows", str(shape.get("rows", "-")))
|
|
1335
|
+
summary_table.add_row("Columns", str(shape.get("columns", "-")))
|
|
1336
|
+
summary_table.add_row("Numeric Columns", str(len(col_types.get("numeric", []))))
|
|
1337
|
+
summary_table.add_row("Categorical Columns", str(len(col_types.get("categorical", []))))
|
|
1338
|
+
summary_table.add_row("Datetime Columns", str(len(col_types.get("datetime", []))))
|
|
1339
|
+
summary_table.add_row("Null %", f"{overall.get('null_percentage', 0)}%")
|
|
1340
|
+
summary_table.add_row("Duplicate Rows", str(overall.get("duplicate_rows", 0)))
|
|
1341
|
+
|
|
1342
|
+
console.print()
|
|
1343
|
+
console.print(summary_table)
|
|
1344
|
+
|
|
1345
|
+
def _load_and_profile_dataset(path: str) -> bool:
|
|
1346
|
+
nonlocal current_dataset
|
|
1347
|
+
if not Path(path).exists():
|
|
1348
|
+
console.print(f"[red]✗ File not found: {path}[/red]")
|
|
1349
|
+
return False
|
|
1350
|
+
|
|
1351
|
+
with Progress(
|
|
1352
|
+
SpinnerColumn(),
|
|
1353
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1354
|
+
console=console
|
|
1355
|
+
) as progress:
|
|
1356
|
+
task = progress.add_task(f"Profiling dataset: {path}", total=None)
|
|
1357
|
+
try:
|
|
1358
|
+
profile = profile_dataset(path)
|
|
1359
|
+
except Exception as e:
|
|
1360
|
+
console.print(f"[red]✗ Failed to profile dataset: {e}[/red]")
|
|
1361
|
+
return False
|
|
1362
|
+
progress.update(task, completed=True)
|
|
1363
|
+
|
|
1364
|
+
current_dataset = path
|
|
1365
|
+
session_memory.update(last_dataset=path)
|
|
1366
|
+
session_memory.add_workflow_step(
|
|
1367
|
+
"profile_dataset",
|
|
1368
|
+
{
|
|
1369
|
+
"success": True,
|
|
1370
|
+
"arguments": {"file_path": path},
|
|
1371
|
+
"result": {
|
|
1372
|
+
"output_path": path,
|
|
1373
|
+
"shape": profile.get("shape", {})
|
|
1374
|
+
}
|
|
1375
|
+
}
|
|
1376
|
+
)
|
|
1377
|
+
_print_dataset_summary(profile, path)
|
|
1378
|
+
return True
|
|
1379
|
+
|
|
1380
|
+
if current_dataset:
|
|
1381
|
+
_load_and_profile_dataset(current_dataset)
|
|
1382
|
+
|
|
1383
|
+
try:
|
|
1384
|
+
os.environ["GROQ_MODEL"] = model
|
|
1385
|
+
from ds_agent.orchestrator import DataScienceCopilot
|
|
1386
|
+
|
|
1387
|
+
copilot = DataScienceCopilot(
|
|
1388
|
+
provider="groq",
|
|
1389
|
+
session_id=session_memory.session_id,
|
|
1390
|
+
use_session_memory=False
|
|
1391
|
+
)
|
|
1392
|
+
except Exception as e:
|
|
1393
|
+
console.print(f"[red]✗ Failed to initialize orchestrator: {e}[/red]")
|
|
1394
|
+
if callable(save_session_fn):
|
|
1395
|
+
save_session_fn(session_memory)
|
|
1396
|
+
raise typer.Exit(1)
|
|
1397
|
+
|
|
1398
|
+
help_panel = Panel(
|
|
1399
|
+
"\n".join([
|
|
1400
|
+
"help Show commands",
|
|
1401
|
+
"history Show workflow history",
|
|
1402
|
+
"summary Re-print current dataset summary",
|
|
1403
|
+
"load <file> Load/switch dataset",
|
|
1404
|
+
"clear Clear current session history",
|
|
1405
|
+
"export Export current session/results",
|
|
1406
|
+
"exit or quit Save session and exit",
|
|
1407
|
+
]),
|
|
1408
|
+
title="Available Commands",
|
|
1409
|
+
border_style="cyan"
|
|
1410
|
+
)
|
|
1411
|
+
|
|
1412
|
+
while True:
|
|
1413
|
+
user_message = Prompt.ask(f"[bold cyan][{session_memory.session_id}][/bold cyan] > ").strip()
|
|
1414
|
+
if not user_message:
|
|
1415
|
+
continue
|
|
1416
|
+
|
|
1417
|
+
lowered = user_message.lower().strip()
|
|
1418
|
+
|
|
1419
|
+
if lowered in {"exit", "quit"}:
|
|
1420
|
+
if callable(save_session_fn):
|
|
1421
|
+
save_session_fn(session_memory)
|
|
1422
|
+
console.print("[green]✓ Session saved. Goodbye![/green]")
|
|
1423
|
+
break
|
|
1424
|
+
|
|
1425
|
+
if lowered == "help":
|
|
1426
|
+
console.print(help_panel)
|
|
1427
|
+
continue
|
|
1428
|
+
|
|
1429
|
+
if lowered == "history":
|
|
1430
|
+
history = session_memory.workflow_history
|
|
1431
|
+
if not history:
|
|
1432
|
+
console.print("[yellow]No workflow history yet.[/yellow]")
|
|
1433
|
+
continue
|
|
1434
|
+
|
|
1435
|
+
history_table = Table(title="🧭 Workflow History")
|
|
1436
|
+
history_table.add_column("Time", style="cyan")
|
|
1437
|
+
history_table.add_column("Tool", style="white")
|
|
1438
|
+
history_table.add_column("Status", style="white")
|
|
1439
|
+
|
|
1440
|
+
for step in history[-20:]:
|
|
1441
|
+
ts = str(step.get("timestamp", ""))
|
|
1442
|
+
ts_short = ts.split("T")[-1][:8] if "T" in ts else ts[:8]
|
|
1443
|
+
tool = str(step.get("tool", "-"))
|
|
1444
|
+
result = step.get("result", {})
|
|
1445
|
+
success = False
|
|
1446
|
+
if isinstance(result, dict):
|
|
1447
|
+
success = bool(result.get("success", False))
|
|
1448
|
+
status = "[green]✓[/green]" if success else "[red]✗[/red]"
|
|
1449
|
+
history_table.add_row(ts_short, tool, status)
|
|
1450
|
+
|
|
1451
|
+
console.print()
|
|
1452
|
+
console.print(history_table)
|
|
1453
|
+
continue
|
|
1454
|
+
|
|
1455
|
+
if lowered == "summary":
|
|
1456
|
+
if not current_dataset:
|
|
1457
|
+
console.print("[yellow]No dataset is loaded. Use 'load <file>'.[/yellow]")
|
|
1458
|
+
continue
|
|
1459
|
+
_load_and_profile_dataset(current_dataset)
|
|
1460
|
+
continue
|
|
1461
|
+
|
|
1462
|
+
if lowered == "clear":
|
|
1463
|
+
# Keep currently loaded dataset but clear conversational/workflow memory.
|
|
1464
|
+
last_dataset = current_dataset
|
|
1465
|
+
session_memory.clear()
|
|
1466
|
+
if last_dataset:
|
|
1467
|
+
session_memory.update(last_dataset=last_dataset)
|
|
1468
|
+
console.print("[green]✓ Session history cleared.[/green]")
|
|
1469
|
+
continue
|
|
1470
|
+
|
|
1471
|
+
if lowered == "export":
|
|
1472
|
+
export_dir = Path("./outputs/chat/")
|
|
1473
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
1474
|
+
ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
1475
|
+
export_path = export_dir / f"chat_session_{session_memory.session_id}_{ts}.json"
|
|
1476
|
+
|
|
1477
|
+
export_payload = {
|
|
1478
|
+
"session": session_memory.to_dict(),
|
|
1479
|
+
"last_result": last_analysis_result,
|
|
1480
|
+
"current_dataset": current_dataset,
|
|
1481
|
+
"exported_at": datetime.utcnow().isoformat() + "Z"
|
|
1482
|
+
}
|
|
1483
|
+
|
|
1484
|
+
with open(export_path, "w") as f:
|
|
1485
|
+
json.dump(export_payload, f, indent=2, default=str)
|
|
1486
|
+
|
|
1487
|
+
console.print(f"[green]✓ Exported session to:[/green] [cyan]{export_path}[/cyan]")
|
|
1488
|
+
continue
|
|
1489
|
+
|
|
1490
|
+
if lowered.startswith("load "):
|
|
1491
|
+
load_path = user_message[5:].strip().strip('"').strip("'")
|
|
1492
|
+
if not load_path:
|
|
1493
|
+
console.print("[yellow]Usage: load <file_path>[/yellow]")
|
|
1494
|
+
continue
|
|
1495
|
+
_load_and_profile_dataset(load_path)
|
|
1496
|
+
continue
|
|
1497
|
+
|
|
1498
|
+
resolver = getattr(session_memory, "resolve_context", None) or getattr(session_memory, "resolve_ambiguity", None)
|
|
1499
|
+
resolved_context = resolver(user_message) if callable(resolver) else {}
|
|
1500
|
+
|
|
1501
|
+
effective_dataset = resolved_context.get("file_path") or current_dataset or session_memory.last_dataset
|
|
1502
|
+
effective_target = resolved_context.get("target_col") or session_memory.last_target_col
|
|
1503
|
+
|
|
1504
|
+
if not effective_dataset:
|
|
1505
|
+
console.print("[yellow]No dataset loaded. Use 'load <file>' or pass file_path at startup.[/yellow]")
|
|
1506
|
+
continue
|
|
1507
|
+
|
|
1508
|
+
with Progress(
|
|
1509
|
+
SpinnerColumn(),
|
|
1510
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1511
|
+
console=console
|
|
1512
|
+
) as progress:
|
|
1513
|
+
task = progress.add_task("Thinking...", total=None)
|
|
1514
|
+
try:
|
|
1515
|
+
analysis_result = copilot.analyze(
|
|
1516
|
+
file_path=effective_dataset,
|
|
1517
|
+
task_description=user_message,
|
|
1518
|
+
target_col=effective_target,
|
|
1519
|
+
use_cache=True,
|
|
1520
|
+
stream=False,
|
|
1521
|
+
max_iterations=20
|
|
1522
|
+
)
|
|
1523
|
+
except Exception as e:
|
|
1524
|
+
progress.update(task, completed=True)
|
|
1525
|
+
console.print(f"[red]✗ Chat analysis failed: {e}[/red]")
|
|
1526
|
+
continue
|
|
1527
|
+
progress.update(task, completed=True)
|
|
1528
|
+
|
|
1529
|
+
last_analysis_result = analysis_result
|
|
1530
|
+
session_memory.update(last_dataset=effective_dataset)
|
|
1531
|
+
if effective_target:
|
|
1532
|
+
session_memory.update(last_target_col=effective_target)
|
|
1533
|
+
|
|
1534
|
+
workflow_steps = analysis_result.get("workflow_history", []) if isinstance(analysis_result, dict) else []
|
|
1535
|
+
for step in workflow_steps:
|
|
1536
|
+
if not isinstance(step, dict):
|
|
1537
|
+
continue
|
|
1538
|
+
tool_name = step.get("tool", "unknown_tool")
|
|
1539
|
+
tool_result = step.get("result", {}) if isinstance(step.get("result", {}), dict) else {"raw_result": step.get("result")}
|
|
1540
|
+
session_memory.add_workflow_step(
|
|
1541
|
+
tool_name,
|
|
1542
|
+
{
|
|
1543
|
+
"success": tool_result.get("success", True),
|
|
1544
|
+
"arguments": {
|
|
1545
|
+
"file_path": effective_dataset,
|
|
1546
|
+
"target_col": effective_target
|
|
1547
|
+
},
|
|
1548
|
+
"result": tool_result
|
|
1549
|
+
}
|
|
1550
|
+
)
|
|
1551
|
+
|
|
1552
|
+
response_text = ""
|
|
1553
|
+
if isinstance(analysis_result, dict):
|
|
1554
|
+
response_text = analysis_result.get("summary") or analysis_result.get("message") or "Analysis complete."
|
|
1555
|
+
else:
|
|
1556
|
+
response_text = str(analysis_result)
|
|
1557
|
+
|
|
1558
|
+
session_memory.add_conversation(user_message, response_text)
|
|
1559
|
+
console.print()
|
|
1560
|
+
console.print(Markdown(response_text))
|
|
1561
|
+
|
|
1562
|
+
|
|
1563
|
+
@app.command()
|
|
1564
|
+
def plot(
|
|
1565
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
1566
|
+
query: Optional[str] = typer.Argument(None, help="Optional plot request like 'correlation heatmap' or 'revenue by region'"),
|
|
1567
|
+
chart_type: str = typer.Option("auto", "--type", "-t", help="Chart type: auto, bar, line, histogram, scatter, heatmap, box, distribution"),
|
|
1568
|
+
x: str = typer.Option(None, "--x", help="X-axis column name"),
|
|
1569
|
+
y: str = typer.Option(None, "--y", help="Y-axis column name"),
|
|
1570
|
+
target: str = typer.Option(None, "--target", help="Target/color-by column name"),
|
|
1571
|
+
output: str = typer.Option("./outputs/plots/", "--output", "-o", help="Output directory for generated plots")
|
|
1572
|
+
):
|
|
1573
|
+
"""
|
|
1574
|
+
Generate visualization charts from a dataset.
|
|
1575
|
+
|
|
1576
|
+
Example:
|
|
1577
|
+
ds-agent plot sales.csv "top products bar chart" --y revenue
|
|
1578
|
+
ds-agent plot sales.csv --type heatmap
|
|
1579
|
+
ds-agent plot sales.csv --type histogram --x revenue
|
|
1580
|
+
"""
|
|
1581
|
+
from ds_agent.tools.visualization_engine import generate_all_plots
|
|
1582
|
+
from ds_agent.tools.plotly_visualizations import (
|
|
1583
|
+
generate_interactive_correlation_heatmap,
|
|
1584
|
+
generate_interactive_histogram,
|
|
1585
|
+
generate_interactive_scatter,
|
|
1586
|
+
generate_interactive_box_plots,
|
|
1587
|
+
generate_interactive_time_series,
|
|
1588
|
+
)
|
|
1589
|
+
from ds_agent.utils.polars_helpers import load_dataframe, get_numeric_columns
|
|
1590
|
+
|
|
1591
|
+
console.print(Panel.fit(
|
|
1592
|
+
"📉 Plot Generator",
|
|
1593
|
+
style="bold blue"
|
|
1594
|
+
))
|
|
1595
|
+
|
|
1596
|
+
if not Path(file_path).exists():
|
|
1597
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
1598
|
+
raise typer.Exit(1)
|
|
1599
|
+
|
|
1600
|
+
valid_types = {"auto", "bar", "line", "histogram", "scatter", "heatmap", "box", "distribution"}
|
|
1601
|
+
chart_type = chart_type.strip().lower()
|
|
1602
|
+
if chart_type not in valid_types:
|
|
1603
|
+
console.print(f"[red]✗ Error: Invalid chart type '{chart_type}'.[/red]")
|
|
1604
|
+
raise typer.Exit(1)
|
|
1605
|
+
|
|
1606
|
+
output_dir = Path(output)
|
|
1607
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1608
|
+
|
|
1609
|
+
# Load once for safe column inference when user does not pass x/y.
|
|
1610
|
+
try:
|
|
1611
|
+
df = load_dataframe(file_path)
|
|
1612
|
+
numeric_cols = get_numeric_columns(df)
|
|
1613
|
+
all_cols = list(df.columns)
|
|
1614
|
+
except Exception as e:
|
|
1615
|
+
console.print(f"[red]✗ Could not load dataset for plotting: {e}[/red]")
|
|
1616
|
+
raise typer.Exit(1)
|
|
1617
|
+
|
|
1618
|
+
if x is None and all_cols:
|
|
1619
|
+
x = all_cols[0]
|
|
1620
|
+
if y is None and numeric_cols:
|
|
1621
|
+
y = numeric_cols[0] if x not in numeric_cols else (numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0])
|
|
1622
|
+
|
|
1623
|
+
query_text = (query or "").lower()
|
|
1624
|
+
|
|
1625
|
+
selected_mode = "all"
|
|
1626
|
+
if "correlation" in query_text or "heatmap" in query_text:
|
|
1627
|
+
selected_mode = "heatmap"
|
|
1628
|
+
elif "distribution" in query_text or "histogram" in query_text:
|
|
1629
|
+
selected_mode = "histogram"
|
|
1630
|
+
elif "scatter" in query_text:
|
|
1631
|
+
selected_mode = "scatter"
|
|
1632
|
+
elif "box" in query_text:
|
|
1633
|
+
selected_mode = "box"
|
|
1634
|
+
elif "time" in query_text or "trend" in query_text:
|
|
1635
|
+
selected_mode = "time_series"
|
|
1636
|
+
elif chart_type != "auto":
|
|
1637
|
+
if chart_type == "heatmap":
|
|
1638
|
+
selected_mode = "heatmap"
|
|
1639
|
+
elif chart_type in {"histogram", "distribution"}:
|
|
1640
|
+
selected_mode = "histogram"
|
|
1641
|
+
elif chart_type == "scatter":
|
|
1642
|
+
selected_mode = "scatter"
|
|
1643
|
+
elif chart_type == "box":
|
|
1644
|
+
selected_mode = "box"
|
|
1645
|
+
elif chart_type == "line":
|
|
1646
|
+
selected_mode = "time_series"
|
|
1647
|
+
else:
|
|
1648
|
+
selected_mode = "all"
|
|
1649
|
+
|
|
1650
|
+
console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
|
|
1651
|
+
if query:
|
|
1652
|
+
console.print(f"📝 [bold]Query:[/bold] {query}")
|
|
1653
|
+
console.print(f"🎯 [bold]Mode:[/bold] {selected_mode}")
|
|
1654
|
+
|
|
1655
|
+
plot_paths = []
|
|
1656
|
+
|
|
1657
|
+
try:
|
|
1658
|
+
with Progress(
|
|
1659
|
+
SpinnerColumn(),
|
|
1660
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1661
|
+
console=console
|
|
1662
|
+
) as progress:
|
|
1663
|
+
task_plot = progress.add_task("Generating visualizations...", total=None)
|
|
1664
|
+
|
|
1665
|
+
if selected_mode == "heatmap":
|
|
1666
|
+
result = generate_interactive_correlation_heatmap(
|
|
1667
|
+
file_path,
|
|
1668
|
+
output_path=str(output_dir / "correlation_heatmap.html")
|
|
1669
|
+
)
|
|
1670
|
+
if result.get("status") == "success":
|
|
1671
|
+
plot_paths.append(result.get("output_path"))
|
|
1672
|
+
else:
|
|
1673
|
+
raise ValueError(result.get("message", "Failed to generate heatmap"))
|
|
1674
|
+
|
|
1675
|
+
elif selected_mode == "histogram":
|
|
1676
|
+
hist_col = x if x in all_cols else (y if y in all_cols else (numeric_cols[0] if numeric_cols else None))
|
|
1677
|
+
if not hist_col:
|
|
1678
|
+
raise ValueError("No suitable numeric column found for histogram")
|
|
1679
|
+
result = generate_interactive_histogram(
|
|
1680
|
+
file_path,
|
|
1681
|
+
column=hist_col,
|
|
1682
|
+
color_col=target,
|
|
1683
|
+
output_path=str(output_dir / "histogram.html")
|
|
1684
|
+
)
|
|
1685
|
+
if result.get("status") == "success":
|
|
1686
|
+
plot_paths.append(result.get("output_path"))
|
|
1687
|
+
else:
|
|
1688
|
+
raise ValueError(result.get("message", "Failed to generate histogram"))
|
|
1689
|
+
|
|
1690
|
+
elif selected_mode == "scatter":
|
|
1691
|
+
if not x or not y:
|
|
1692
|
+
raise ValueError("Scatter plot requires x and y columns (inferred values unavailable)")
|
|
1693
|
+
result = generate_interactive_scatter(
|
|
1694
|
+
file_path,
|
|
1695
|
+
x_col=x,
|
|
1696
|
+
y_col=y,
|
|
1697
|
+
color_col=target,
|
|
1698
|
+
output_path=str(output_dir / "scatter.html")
|
|
1699
|
+
)
|
|
1700
|
+
if result.get("status") == "success":
|
|
1701
|
+
plot_paths.append(result.get("output_path"))
|
|
1702
|
+
else:
|
|
1703
|
+
raise ValueError(result.get("message", "Failed to generate scatter plot"))
|
|
1704
|
+
|
|
1705
|
+
elif selected_mode == "box":
|
|
1706
|
+
result = generate_interactive_box_plots(
|
|
1707
|
+
file_path,
|
|
1708
|
+
columns=[col for col in [x, y] if col],
|
|
1709
|
+
group_by=target,
|
|
1710
|
+
output_path=str(output_dir / "box_plots.html")
|
|
1711
|
+
)
|
|
1712
|
+
if result.get("status") == "success":
|
|
1713
|
+
plot_paths.append(result.get("output_path"))
|
|
1714
|
+
else:
|
|
1715
|
+
raise ValueError(result.get("message", "Failed to generate box plot"))
|
|
1716
|
+
|
|
1717
|
+
elif selected_mode == "time_series":
|
|
1718
|
+
if not x:
|
|
1719
|
+
raise ValueError("Time series requires --x as time column")
|
|
1720
|
+
value_cols = [col for col in [y, target] if col and col != x]
|
|
1721
|
+
if not value_cols:
|
|
1722
|
+
value_cols = [col for col in numeric_cols if col != x][:1]
|
|
1723
|
+
if not value_cols:
|
|
1724
|
+
raise ValueError("Time series requires at least one numeric value column")
|
|
1725
|
+
result = generate_interactive_time_series(
|
|
1726
|
+
file_path,
|
|
1727
|
+
time_col=x,
|
|
1728
|
+
value_cols=value_cols,
|
|
1729
|
+
output_path=str(output_dir / "time_series.html")
|
|
1730
|
+
)
|
|
1731
|
+
if result.get("status") == "success":
|
|
1732
|
+
plot_paths.append(result.get("output_path"))
|
|
1733
|
+
else:
|
|
1734
|
+
raise ValueError(result.get("message", "Failed to generate time series"))
|
|
1735
|
+
|
|
1736
|
+
else:
|
|
1737
|
+
result = generate_all_plots(
|
|
1738
|
+
file_path,
|
|
1739
|
+
target_col=target,
|
|
1740
|
+
output_dir=str(output_dir)
|
|
1741
|
+
)
|
|
1742
|
+
plot_paths.extend(result.get("plots_generated", []))
|
|
1743
|
+
|
|
1744
|
+
progress.update(task_plot, completed=True)
|
|
1745
|
+
|
|
1746
|
+
except Exception as e:
|
|
1747
|
+
console.print(f"\n[red]✗ Plot generation failed: {e}[/red]")
|
|
1748
|
+
raise typer.Exit(1)
|
|
1749
|
+
|
|
1750
|
+
plot_paths = [p for p in plot_paths if p]
|
|
1751
|
+
if not plot_paths:
|
|
1752
|
+
console.print("[yellow]⚠ No plots were generated.[/yellow]")
|
|
1753
|
+
raise typer.Exit(1)
|
|
1754
|
+
|
|
1755
|
+
console.print("\n[green]✓ Plots generated:[/green]")
|
|
1756
|
+
for path in plot_paths:
|
|
1757
|
+
console.print(f" - [cyan]{path}[/cyan]")
|
|
1758
|
+
|
|
1759
|
+
summary = Table(title="📊 Plot Summary")
|
|
1760
|
+
summary.add_column("#", style="cyan", width=4)
|
|
1761
|
+
summary.add_column("Plot File", style="white")
|
|
1762
|
+
summary.add_column("Path", style="green")
|
|
1763
|
+
|
|
1764
|
+
if len(plot_paths) > 1:
|
|
1765
|
+
with Live(summary, console=console, refresh_per_second=8):
|
|
1766
|
+
for idx, path in enumerate(plot_paths, start=1):
|
|
1767
|
+
summary.add_row(str(idx), Path(path).name, path)
|
|
1768
|
+
else:
|
|
1769
|
+
summary.add_row("1", Path(plot_paths[0]).name, plot_paths[0])
|
|
1770
|
+
|
|
1771
|
+
console.print()
|
|
1772
|
+
console.print(summary)
|
|
1773
|
+
|
|
1774
|
+
|
|
1775
|
+
@app.command()
|
|
1776
|
+
def forecast(
|
|
1777
|
+
file_path: str = typer.Argument(..., help="Path to time series dataset file (CSV or Parquet)"),
|
|
1778
|
+
time_col: str = typer.Option(None, "--time", "-t", help="Time column name (required)"),
|
|
1779
|
+
target_col: str = typer.Option(None, "--target", "-y", help="Target column to forecast (required)"),
|
|
1780
|
+
horizon: int = typer.Option(30, "--horizon", "-h", help="Forecast horizon (periods ahead)"),
|
|
1781
|
+
method: str = typer.Option("prophet", "--method", "-m", help="Forecasting method: prophet, arima, exponential_smoothing"),
|
|
1782
|
+
output: str = typer.Option("./outputs/forecast/", "--output", "-o", help="Output directory for forecast files")
|
|
1783
|
+
):
|
|
1784
|
+
"""
|
|
1785
|
+
Forecast a time series target column.
|
|
1786
|
+
|
|
1787
|
+
Example:
|
|
1788
|
+
ds-agent forecast sales.csv --time date --target revenue --horizon 90
|
|
1789
|
+
ds-agent forecast stock.csv --time Date --target Close --method arima
|
|
1790
|
+
"""
|
|
1791
|
+
from ds_agent.tools.time_series import forecast_time_series
|
|
1792
|
+
import csv
|
|
1793
|
+
|
|
1794
|
+
# detect_seasonality is requested by name; fallback keeps compatibility with current tools module.
|
|
1795
|
+
try:
|
|
1796
|
+
from ds_agent.tools.time_series import detect_seasonality
|
|
1797
|
+
except ImportError:
|
|
1798
|
+
from ds_agent.tools.time_series import detect_seasonality_trends as detect_seasonality
|
|
1799
|
+
|
|
1800
|
+
console.print(Panel.fit(
|
|
1801
|
+
"⏳ Time Series Forecast",
|
|
1802
|
+
style="bold blue"
|
|
1803
|
+
))
|
|
1804
|
+
|
|
1805
|
+
# Step 1: Validation
|
|
1806
|
+
if not Path(file_path).exists():
|
|
1807
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
1808
|
+
raise typer.Exit(1)
|
|
1809
|
+
|
|
1810
|
+
if not time_col:
|
|
1811
|
+
console.print("[red]✗ Error: --time (-t) is required[/red]")
|
|
1812
|
+
raise typer.Exit(1)
|
|
1813
|
+
|
|
1814
|
+
if not target_col:
|
|
1815
|
+
console.print("[red]✗ Error: --target (-y) is required[/red]")
|
|
1816
|
+
raise typer.Exit(1)
|
|
1817
|
+
|
|
1818
|
+
if horizon <= 0:
|
|
1819
|
+
console.print("[red]✗ Error: --horizon must be a positive integer[/red]")
|
|
1820
|
+
raise typer.Exit(1)
|
|
1821
|
+
|
|
1822
|
+
method_normalized = method.strip().lower()
|
|
1823
|
+
valid_methods = {"prophet", "arima", "exponential_smoothing"}
|
|
1824
|
+
if method_normalized not in valid_methods:
|
|
1825
|
+
console.print("[red]✗ Error: --method must be one of prophet, arima, exponential_smoothing[/red]")
|
|
1826
|
+
raise typer.Exit(1)
|
|
1827
|
+
|
|
1828
|
+
output_dir = Path(output)
|
|
1829
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1830
|
+
output_csv = output_dir / f"forecast_{Path(file_path).stem}_{target_col}_{method_normalized}.csv"
|
|
1831
|
+
|
|
1832
|
+
console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
|
|
1833
|
+
console.print(f"🕒 [bold]Time Column:[/bold] {time_col}")
|
|
1834
|
+
console.print(f"🎯 [bold]Target Column:[/bold] {target_col}")
|
|
1835
|
+
console.print(f"📈 [bold]Method:[/bold] {method_normalized}")
|
|
1836
|
+
console.print(f"🔭 [bold]Horizon:[/bold] {horizon}")
|
|
1837
|
+
|
|
1838
|
+
# Step 2: Seasonality detection
|
|
1839
|
+
try:
|
|
1840
|
+
with Progress(
|
|
1841
|
+
SpinnerColumn(),
|
|
1842
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1843
|
+
console=console
|
|
1844
|
+
) as progress:
|
|
1845
|
+
task_seasonality = progress.add_task("Detecting seasonality...", total=None)
|
|
1846
|
+
seasonality_result = detect_seasonality(file_path, time_col, target_col)
|
|
1847
|
+
progress.update(task_seasonality, completed=True)
|
|
1848
|
+
except Exception as e:
|
|
1849
|
+
seasonality_result = {"status": "error", "message": str(e)}
|
|
1850
|
+
|
|
1851
|
+
if seasonality_result.get("status") == "success":
|
|
1852
|
+
detected_period = seasonality_result.get("detected_period", "unknown")
|
|
1853
|
+
interpretation = seasonality_result.get("interpretation", {})
|
|
1854
|
+
seasonality_label = interpretation.get("seasonality", "unknown")
|
|
1855
|
+
console.print(
|
|
1856
|
+
f"[green]✓ Seasonality detected:[/green] {seasonality_label} (period={detected_period})"
|
|
1857
|
+
)
|
|
1858
|
+
else:
|
|
1859
|
+
console.print(
|
|
1860
|
+
f"[yellow]⚠ Seasonality detection unavailable:[/yellow] {seasonality_result.get('message', 'Unknown issue')}"
|
|
1861
|
+
)
|
|
1862
|
+
|
|
1863
|
+
# Step 3: Forecast generation
|
|
1864
|
+
try:
|
|
1865
|
+
with Progress(
|
|
1866
|
+
SpinnerColumn(),
|
|
1867
|
+
TextColumn("[progress.description]{task.description}"),
|
|
1868
|
+
console=console
|
|
1869
|
+
) as progress:
|
|
1870
|
+
task_forecast = progress.add_task("Generating forecast...", total=None)
|
|
1871
|
+
forecast_result = forecast_time_series(
|
|
1872
|
+
file_path=file_path,
|
|
1873
|
+
time_col=time_col,
|
|
1874
|
+
target_col=target_col,
|
|
1875
|
+
forecast_horizon=horizon,
|
|
1876
|
+
method=method_normalized,
|
|
1877
|
+
output_path=str(output_csv)
|
|
1878
|
+
)
|
|
1879
|
+
progress.update(task_forecast, completed=True)
|
|
1880
|
+
except Exception as e:
|
|
1881
|
+
console.print(f"[red]✗ Forecast failed: {e}[/red]")
|
|
1882
|
+
raise typer.Exit(1)
|
|
1883
|
+
|
|
1884
|
+
if forecast_result.get("status") != "success":
|
|
1885
|
+
console.print(f"[red]✗ Forecast failed: {forecast_result.get('message', 'Unknown error')}[/red]")
|
|
1886
|
+
raise typer.Exit(1)
|
|
1887
|
+
|
|
1888
|
+
forecast_rows = forecast_result.get("forecast", [])
|
|
1889
|
+
if not isinstance(forecast_rows, list) or not forecast_rows:
|
|
1890
|
+
console.print("[yellow]⚠ Forecast completed but returned no rows.[/yellow]")
|
|
1891
|
+
raise typer.Exit(1)
|
|
1892
|
+
|
|
1893
|
+
# Normalize row fields for display and CSV export.
|
|
1894
|
+
normalized_rows = []
|
|
1895
|
+
for row in forecast_rows:
|
|
1896
|
+
date_val = row.get("date") or row.get("ds") or row.get("time")
|
|
1897
|
+
actual_val = row.get("actual") or row.get("y")
|
|
1898
|
+
predicted_val = row.get("predicted") or row.get("yhat") or row.get("value")
|
|
1899
|
+
lower_val = row.get("lower_bound") or row.get("yhat_lower") or row.get("lower_ci")
|
|
1900
|
+
upper_val = row.get("upper_bound") or row.get("yhat_upper") or row.get("upper_ci")
|
|
1901
|
+
|
|
1902
|
+
normalized_rows.append({
|
|
1903
|
+
"date": date_val,
|
|
1904
|
+
"actual": actual_val,
|
|
1905
|
+
"predicted": predicted_val,
|
|
1906
|
+
"lower_bound": lower_val,
|
|
1907
|
+
"upper_bound": upper_val,
|
|
1908
|
+
})
|
|
1909
|
+
|
|
1910
|
+
# Step 4: Forecast results table
|
|
1911
|
+
table = Table(title="📊 Forecast Results")
|
|
1912
|
+
table.add_column("Date", style="cyan")
|
|
1913
|
+
table.add_column("Actual", justify="right")
|
|
1914
|
+
table.add_column("Predicted", justify="right", style="green")
|
|
1915
|
+
table.add_column("Lower Bound", justify="right")
|
|
1916
|
+
table.add_column("Upper Bound", justify="right")
|
|
1917
|
+
|
|
1918
|
+
for row in normalized_rows:
|
|
1919
|
+
actual_display = "-" if row["actual"] is None else f"{float(row['actual']):.4f}"
|
|
1920
|
+
predicted_display = "-" if row["predicted"] is None else f"{float(row['predicted']):.4f}"
|
|
1921
|
+
lower_display = "-" if row["lower_bound"] is None else f"{float(row['lower_bound']):.4f}"
|
|
1922
|
+
upper_display = "-" if row["upper_bound"] is None else f"{float(row['upper_bound']):.4f}"
|
|
1923
|
+
|
|
1924
|
+
table.add_row(
|
|
1925
|
+
str(row["date"]),
|
|
1926
|
+
actual_display,
|
|
1927
|
+
predicted_display,
|
|
1928
|
+
lower_display,
|
|
1929
|
+
upper_display,
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
console.print()
|
|
1933
|
+
console.print(table)
|
|
1934
|
+
|
|
1935
|
+
# Step 5: Save normalized forecast CSV in output directory.
|
|
1936
|
+
with open(output_csv, "w", newline="") as csv_file:
|
|
1937
|
+
writer = csv.DictWriter(
|
|
1938
|
+
csv_file,
|
|
1939
|
+
fieldnames=["date", "actual", "predicted", "lower_bound", "upper_bound"]
|
|
1940
|
+
)
|
|
1941
|
+
writer.writeheader()
|
|
1942
|
+
writer.writerows(normalized_rows)
|
|
1943
|
+
|
|
1944
|
+
console.print(f"\n💾 Forecast CSV saved to: [cyan]{output_csv}[/cyan]")
|
|
1945
|
+
|
|
1946
|
+
# Step 6: Print model accuracy metrics if available.
|
|
1947
|
+
metrics = {}
|
|
1948
|
+
for metric_key in ["mae", "rmse", "mape"]:
|
|
1949
|
+
if metric_key in forecast_result:
|
|
1950
|
+
metrics[metric_key] = forecast_result.get(metric_key)
|
|
1951
|
+
|
|
1952
|
+
nested_metrics = forecast_result.get("metrics")
|
|
1953
|
+
if isinstance(nested_metrics, dict):
|
|
1954
|
+
for metric_key in ["mae", "rmse", "mape"]:
|
|
1955
|
+
if metric_key in nested_metrics:
|
|
1956
|
+
metrics[metric_key] = nested_metrics.get(metric_key)
|
|
1957
|
+
|
|
1958
|
+
if metrics:
|
|
1959
|
+
metrics_table = Table(title="📏 Forecast Accuracy")
|
|
1960
|
+
metrics_table.add_column("Metric", style="cyan")
|
|
1961
|
+
metrics_table.add_column("Value", style="white", justify="right")
|
|
1962
|
+
|
|
1963
|
+
if metrics.get("mae") is not None:
|
|
1964
|
+
metrics_table.add_row("MAE", f"{float(metrics['mae']):.4f}")
|
|
1965
|
+
if metrics.get("rmse") is not None:
|
|
1966
|
+
metrics_table.add_row("RMSE", f"{float(metrics['rmse']):.4f}")
|
|
1967
|
+
if metrics.get("mape") is not None:
|
|
1968
|
+
metrics_table.add_row("MAPE", f"{float(metrics['mape']):.2f}%")
|
|
1969
|
+
|
|
1970
|
+
console.print()
|
|
1971
|
+
console.print(metrics_table)
|
|
1972
|
+
else:
|
|
1973
|
+
console.print("[yellow]ℹ Accuracy metrics (MAE/RMSE/MAPE) were not returned by this model run.[/yellow]")
|
|
1974
|
+
|
|
1975
|
+
|
|
1976
|
+
@app.command()
|
|
1977
|
+
def nlp(
|
|
1978
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
1979
|
+
text_col: str = typer.Option(None, "--text", "-t", help="Text column name (required)"),
|
|
1980
|
+
task: str = typer.Option("all", "--task", help="NLP task: all, topics, sentiment, entities, classify"),
|
|
1981
|
+
n_topics: int = typer.Option(5, "--n-topics", help="Number of topics for topic modeling"),
|
|
1982
|
+
output: str = typer.Option("./outputs/nlp/", "--output", "-o", help="Output directory for NLP artifacts")
|
|
1983
|
+
):
|
|
1984
|
+
"""
|
|
1985
|
+
Run NLP analysis tasks such as sentiment, topics, and entity extraction.
|
|
1986
|
+
|
|
1987
|
+
Example:
|
|
1988
|
+
ds-agent nlp reviews.csv --text review_text --task sentiment
|
|
1989
|
+
ds-agent nlp news.csv --text content --task topics --n-topics 8
|
|
1990
|
+
ds-agent nlp articles.csv --text body
|
|
1991
|
+
"""
|
|
1992
|
+
from ds_agent.utils.polars_helpers import load_dataframe
|
|
1993
|
+
from ds_agent.tools.nlp_text_analytics import (
|
|
1994
|
+
analyze_sentiment_advanced as sentiment_analysis,
|
|
1995
|
+
perform_topic_modeling,
|
|
1996
|
+
perform_named_entity_recognition as extract_entities,
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1999
|
+
console.print(Panel.fit(
|
|
2000
|
+
"📝 NLP Text Analytics",
|
|
2001
|
+
style="bold blue"
|
|
2002
|
+
))
|
|
2003
|
+
|
|
2004
|
+
# Step 1: Validate file and text column.
|
|
2005
|
+
if not Path(file_path).exists():
|
|
2006
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
2007
|
+
raise typer.Exit(1)
|
|
2008
|
+
|
|
2009
|
+
if not text_col:
|
|
2010
|
+
console.print("[red]✗ Error: --text (-t) is required[/red]")
|
|
2011
|
+
raise typer.Exit(1)
|
|
2012
|
+
|
|
2013
|
+
valid_tasks = {"all", "topics", "sentiment", "entities", "classify"}
|
|
2014
|
+
task = task.strip().lower()
|
|
2015
|
+
if task not in valid_tasks:
|
|
2016
|
+
console.print("[red]✗ Error: --task must be one of all, topics, sentiment, entities, classify[/red]")
|
|
2017
|
+
raise typer.Exit(1)
|
|
2018
|
+
|
|
2019
|
+
if n_topics <= 0:
|
|
2020
|
+
console.print("[red]✗ Error: --n-topics must be a positive integer[/red]")
|
|
2021
|
+
raise typer.Exit(1)
|
|
2022
|
+
|
|
2023
|
+
try:
|
|
2024
|
+
df = load_dataframe(file_path)
|
|
2025
|
+
except Exception as e:
|
|
2026
|
+
console.print(f"[red]✗ Error loading dataset: {e}[/red]")
|
|
2027
|
+
raise typer.Exit(1)
|
|
2028
|
+
|
|
2029
|
+
if text_col not in df.columns:
|
|
2030
|
+
console.print(f"[red]✗ Error: Text column '{text_col}' not found in dataset[/red]")
|
|
2031
|
+
raise typer.Exit(1)
|
|
2032
|
+
|
|
2033
|
+
output_dir = Path(output)
|
|
2034
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2035
|
+
|
|
2036
|
+
# Step 2: Preview first 5 rows of text column.
|
|
2037
|
+
preview_table = Table(title="🔎 Text Preview (first 5 rows)")
|
|
2038
|
+
preview_table.add_column("Row", style="cyan", width=6)
|
|
2039
|
+
preview_table.add_column(text_col, style="white")
|
|
2040
|
+
|
|
2041
|
+
preview_values = df[text_col].head(5).to_list()
|
|
2042
|
+
for idx, value in enumerate(preview_values, start=1):
|
|
2043
|
+
text_value = "" if value is None else str(value)
|
|
2044
|
+
text_value = text_value.replace("\n", " ").strip()
|
|
2045
|
+
if len(text_value) > 140:
|
|
2046
|
+
text_value = text_value[:137] + "..."
|
|
2047
|
+
preview_table.add_row(str(idx), text_value)
|
|
2048
|
+
|
|
2049
|
+
console.print()
|
|
2050
|
+
console.print(preview_table)
|
|
2051
|
+
|
|
2052
|
+
sentiment_result = None
|
|
2053
|
+
topics_result = None
|
|
2054
|
+
entities_result = None
|
|
2055
|
+
saved_files = []
|
|
2056
|
+
|
|
2057
|
+
# Step 3A: Sentiment analysis
|
|
2058
|
+
if task in {"all", "sentiment"}:
|
|
2059
|
+
with Progress(
|
|
2060
|
+
SpinnerColumn(),
|
|
2061
|
+
TextColumn("[progress.description]{task.description}"),
|
|
2062
|
+
console=console
|
|
2063
|
+
) as progress:
|
|
2064
|
+
sentiment_task = progress.add_task("Running sentiment analysis...", total=None)
|
|
2065
|
+
sentiment_result = sentiment_analysis(df, text_col)
|
|
2066
|
+
progress.update(sentiment_task, completed=True)
|
|
2067
|
+
|
|
2068
|
+
sentiment_dist = sentiment_result.get("statistics", {}).get("sentiment_distribution")
|
|
2069
|
+
if not sentiment_dist:
|
|
2070
|
+
sentiments = sentiment_result.get("sentiments", [])
|
|
2071
|
+
labels = [str(item.get("label", "NEUTRAL")).upper() for item in sentiments]
|
|
2072
|
+
total = len(labels)
|
|
2073
|
+
sentiment_dist = {
|
|
2074
|
+
"POSITIVE": sum(1 for label in labels if "POS" in label),
|
|
2075
|
+
"NEGATIVE": sum(1 for label in labels if "NEG" in label),
|
|
2076
|
+
"NEUTRAL": sum(1 for label in labels if "NEU" in label),
|
|
2077
|
+
} if total > 0 else {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0}
|
|
2078
|
+
|
|
2079
|
+
total_count = sum(sentiment_dist.values()) if isinstance(sentiment_dist, dict) else 0
|
|
2080
|
+
|
|
2081
|
+
sentiment_table = Table(title="😊 Sentiment Distribution")
|
|
2082
|
+
sentiment_table.add_column("Label", style="cyan")
|
|
2083
|
+
sentiment_table.add_column("Count", style="white", justify="right")
|
|
2084
|
+
sentiment_table.add_column("Percentage", style="green", justify="right")
|
|
2085
|
+
|
|
2086
|
+
for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]:
|
|
2087
|
+
count = int(sentiment_dist.get(label, 0)) if isinstance(sentiment_dist, dict) else 0
|
|
2088
|
+
pct = (count / total_count * 100) if total_count > 0 else 0.0
|
|
2089
|
+
sentiment_table.add_row(label, str(count), f"{pct:.2f}%")
|
|
2090
|
+
|
|
2091
|
+
console.print()
|
|
2092
|
+
console.print(sentiment_table)
|
|
2093
|
+
|
|
2094
|
+
sentiment_path = output_dir / "sentiment_results.json"
|
|
2095
|
+
with open(sentiment_path, "w") as f:
|
|
2096
|
+
json.dump(sentiment_result, f, indent=2, default=str)
|
|
2097
|
+
saved_files.append(str(sentiment_path))
|
|
2098
|
+
|
|
2099
|
+
# Step 3B: Topic modeling
|
|
2100
|
+
if task in {"all", "topics"}:
|
|
2101
|
+
with Progress(
|
|
2102
|
+
SpinnerColumn(),
|
|
2103
|
+
TextColumn("[progress.description]{task.description}"),
|
|
2104
|
+
console=console
|
|
2105
|
+
) as progress:
|
|
2106
|
+
topics_task = progress.add_task("Running topic modeling...", total=None)
|
|
2107
|
+
topics_result = perform_topic_modeling(
|
|
2108
|
+
data=df,
|
|
2109
|
+
text_column=text_col,
|
|
2110
|
+
n_topics=n_topics
|
|
2111
|
+
)
|
|
2112
|
+
progress.update(topics_task, completed=True)
|
|
2113
|
+
|
|
2114
|
+
topics_table = Table(title="🧠 Topic Modeling Results")
|
|
2115
|
+
topics_table.add_column("Topic", style="cyan", width=8)
|
|
2116
|
+
topics_table.add_column("Top 5 Keywords", style="white")
|
|
2117
|
+
|
|
2118
|
+
for topic in topics_result.get("topics", []):
|
|
2119
|
+
topic_id = topic.get("topic_id", "-")
|
|
2120
|
+
words = topic.get("words", [])[:5]
|
|
2121
|
+
topics_table.add_row(str(topic_id), ", ".join(words) if words else "-")
|
|
2122
|
+
|
|
2123
|
+
console.print()
|
|
2124
|
+
console.print(topics_table)
|
|
2125
|
+
|
|
2126
|
+
topics_path = output_dir / "topics_results.json"
|
|
2127
|
+
with open(topics_path, "w") as f:
|
|
2128
|
+
json.dump(topics_result, f, indent=2, default=str)
|
|
2129
|
+
saved_files.append(str(topics_path))
|
|
2130
|
+
|
|
2131
|
+
# Step 3C: Entity extraction
|
|
2132
|
+
if task in {"all", "entities"}:
|
|
2133
|
+
with Progress(
|
|
2134
|
+
SpinnerColumn(),
|
|
2135
|
+
TextColumn("[progress.description]{task.description}"),
|
|
2136
|
+
console=console
|
|
2137
|
+
) as progress:
|
|
2138
|
+
entities_task = progress.add_task("Extracting entities...", total=None)
|
|
2139
|
+
entities_result = extract_entities(df, text_col)
|
|
2140
|
+
progress.update(entities_task, completed=True)
|
|
2141
|
+
|
|
2142
|
+
entities_table = Table(title="🏷️ Top Entities by Type")
|
|
2143
|
+
entities_table.add_column("Entity Type", style="cyan")
|
|
2144
|
+
entities_table.add_column("Top 10 Entities", style="white")
|
|
2145
|
+
|
|
2146
|
+
by_type = entities_result.get("by_type", {}) if isinstance(entities_result, dict) else {}
|
|
2147
|
+
for entity_type, details in by_type.items():
|
|
2148
|
+
top_items = details.get("top_entities") if isinstance(details, dict) else None
|
|
2149
|
+
if isinstance(top_items, list) and top_items:
|
|
2150
|
+
top_10 = [
|
|
2151
|
+
f"{item.get('text', '')} ({item.get('count', 0)})"
|
|
2152
|
+
for item in top_items[:10]
|
|
2153
|
+
]
|
|
2154
|
+
else:
|
|
2155
|
+
examples = details.get("examples", []) if isinstance(details, dict) else []
|
|
2156
|
+
top_10 = [str(ex) for ex in examples[:10]]
|
|
2157
|
+
|
|
2158
|
+
entities_table.add_row(entity_type, ", ".join(top_10) if top_10 else "-")
|
|
2159
|
+
|
|
2160
|
+
console.print()
|
|
2161
|
+
console.print(entities_table)
|
|
2162
|
+
|
|
2163
|
+
entities_path = output_dir / "entities_results.json"
|
|
2164
|
+
with open(entities_path, "w") as f:
|
|
2165
|
+
json.dump(entities_result, f, indent=2, default=str)
|
|
2166
|
+
saved_files.append(str(entities_path))
|
|
2167
|
+
|
|
2168
|
+
# Step 4: Save classify placeholder when classify-only requested.
|
|
2169
|
+
if task == "classify":
|
|
2170
|
+
classify_result = {
|
|
2171
|
+
"status": "not_implemented",
|
|
2172
|
+
"message": "Classify task is reserved but no classification routine is currently wired in nlp_text_analytics."
|
|
2173
|
+
}
|
|
2174
|
+
classify_path = output_dir / "classify_results.json"
|
|
2175
|
+
with open(classify_path, "w") as f:
|
|
2176
|
+
json.dump(classify_result, f, indent=2)
|
|
2177
|
+
saved_files.append(str(classify_path))
|
|
2178
|
+
console.print("[yellow]ℹ Classify task placeholder saved (feature not implemented in tools module).[/yellow]")
|
|
2179
|
+
|
|
2180
|
+
# Step 5: Final summary panel
|
|
2181
|
+
summary_lines = [
|
|
2182
|
+
f"Task: {task}",
|
|
2183
|
+
f"Dataset: {file_path}",
|
|
2184
|
+
f"Text column: {text_col}",
|
|
2185
|
+
f"Rows processed: {len(df)}",
|
|
2186
|
+
"Saved files:"
|
|
2187
|
+
]
|
|
2188
|
+
summary_lines.extend([f"- {path}" for path in saved_files])
|
|
2189
|
+
|
|
2190
|
+
console.print(Panel(
|
|
2191
|
+
"\n".join(summary_lines),
|
|
2192
|
+
title="✅ NLP Analysis Complete",
|
|
2193
|
+
border_style="green"
|
|
2194
|
+
))
|
|
2195
|
+
|
|
2196
|
+
|
|
2197
|
+
@app.command()
|
|
2198
|
+
def pipeline(
|
|
2199
|
+
file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
|
|
2200
|
+
target_col: str = typer.Option(None, "--target", "-y", help="Target column name (required)"),
|
|
2201
|
+
level: str = typer.Option("basic", "--level", "-l", help="Feature engineering level: basic, intermediate, advanced"),
|
|
2202
|
+
task_type: str = typer.Option("auto", "--task-type", help="Task type: auto, classification, regression"),
|
|
2203
|
+
output: str = typer.Option("./outputs/data/pipeline_output.csv", "--output", "-o", help="Output file path")
|
|
2204
|
+
):
|
|
2205
|
+
"""
|
|
2206
|
+
Run the full automated ML pipeline.
|
|
2207
|
+
|
|
2208
|
+
Example:
|
|
2209
|
+
ds-agent pipeline titanic.csv --target Survived
|
|
2210
|
+
ds-agent pipeline sales.csv --target revenue --level advanced --task-type regression
|
|
2211
|
+
"""
|
|
2212
|
+
from ds_agent.tools.auto_pipeline import auto_ml_pipeline
|
|
2213
|
+
|
|
2214
|
+
if not Path(file_path).exists():
|
|
2215
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
2216
|
+
raise typer.Exit(1)
|
|
2217
|
+
|
|
2218
|
+
if not target_col:
|
|
2219
|
+
console.print("[red]✗ Error: --target (-y) is required[/red]")
|
|
2220
|
+
raise typer.Exit(1)
|
|
2221
|
+
|
|
2222
|
+
level = level.strip().lower()
|
|
2223
|
+
valid_levels = {"basic", "intermediate", "advanced"}
|
|
2224
|
+
if level not in valid_levels:
|
|
2225
|
+
console.print("[red]✗ Error: --level must be one of basic, intermediate, advanced[/red]")
|
|
2226
|
+
raise typer.Exit(1)
|
|
2227
|
+
|
|
2228
|
+
task_type = task_type.strip().lower()
|
|
2229
|
+
valid_task_types = {"auto", "classification", "regression"}
|
|
2230
|
+
if task_type not in valid_task_types:
|
|
2231
|
+
console.print("[red]✗ Error: --task-type must be one of auto, classification, regression[/red]")
|
|
2232
|
+
raise typer.Exit(1)
|
|
2233
|
+
|
|
2234
|
+
output_path = str(Path(output))
|
|
2235
|
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
2236
|
+
|
|
2237
|
+
overview_lines = [
|
|
2238
|
+
"1. Type detection",
|
|
2239
|
+
"2. Cleaning",
|
|
2240
|
+
"3. Outliers",
|
|
2241
|
+
"4. Encoding",
|
|
2242
|
+
"5. Feature engineering",
|
|
2243
|
+
"6. Feature selection",
|
|
2244
|
+
"",
|
|
2245
|
+
f"Dataset: {file_path}",
|
|
2246
|
+
f"Target: {target_col}",
|
|
2247
|
+
f"Feature engineering level: {level}",
|
|
2248
|
+
f"Task type: {task_type}",
|
|
2249
|
+
]
|
|
2250
|
+
|
|
2251
|
+
# Step 1: Pipeline Overview
|
|
2252
|
+
console.print(Panel(
|
|
2253
|
+
"\n".join(overview_lines),
|
|
2254
|
+
title="🔄 Pipeline Overview",
|
|
2255
|
+
border_style="blue"
|
|
2256
|
+
))
|
|
2257
|
+
|
|
2258
|
+
stage_to_progress = {
|
|
2259
|
+
"Stage 1": 1,
|
|
2260
|
+
"Stage 2": 2,
|
|
2261
|
+
"Stage 3": 3,
|
|
2262
|
+
"Stage 5": 4,
|
|
2263
|
+
"Stage 6": 5,
|
|
2264
|
+
"Stage 7": 6,
|
|
2265
|
+
}
|
|
2266
|
+
|
|
2267
|
+
class _PipelineStream(io.TextIOBase):
|
|
2268
|
+
def __init__(self, progress, task_id):
|
|
2269
|
+
self.progress = progress
|
|
2270
|
+
self.task_id = task_id
|
|
2271
|
+
self.buffer = ""
|
|
2272
|
+
self.max_completed = 0
|
|
2273
|
+
|
|
2274
|
+
def write(self, text):
|
|
2275
|
+
self.buffer += text
|
|
2276
|
+
while "\n" in self.buffer:
|
|
2277
|
+
line, self.buffer = self.buffer.split("\n", 1)
|
|
2278
|
+
line = line.strip()
|
|
2279
|
+
for key, step in stage_to_progress.items():
|
|
2280
|
+
if key in line and step > self.max_completed:
|
|
2281
|
+
self.max_completed = step
|
|
2282
|
+
self.progress.update(self.task_id, completed=step)
|
|
2283
|
+
break
|
|
2284
|
+
return len(text)
|
|
2285
|
+
|
|
2286
|
+
def flush(self):
|
|
2287
|
+
return None
|
|
2288
|
+
|
|
2289
|
+
try:
|
|
2290
|
+
with Progress(
|
|
2291
|
+
SpinnerColumn(),
|
|
2292
|
+
TextColumn("[progress.description]{task.description}"),
|
|
2293
|
+
console=console
|
|
2294
|
+
) as progress:
|
|
2295
|
+
pipeline_task = progress.add_task("Running automated pipeline...", total=6)
|
|
2296
|
+
stream = _PipelineStream(progress, pipeline_task)
|
|
2297
|
+
|
|
2298
|
+
with contextlib.redirect_stdout(stream):
|
|
2299
|
+
result = auto_ml_pipeline(
|
|
2300
|
+
file_path,
|
|
2301
|
+
target_col,
|
|
2302
|
+
task_type,
|
|
2303
|
+
output_path,
|
|
2304
|
+
level,
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
progress.update(pipeline_task, completed=6)
|
|
2308
|
+
|
|
2309
|
+
except Exception as e:
|
|
2310
|
+
console.print(f"[red]✗ Pipeline failed: {e}[/red]")
|
|
2311
|
+
raise typer.Exit(1)
|
|
2312
|
+
|
|
2313
|
+
original_shape = result.get("original_shape", {})
|
|
2314
|
+
final_shape = result.get("final_shape", {})
|
|
2315
|
+
original_cols = int(original_shape.get("columns", 0))
|
|
2316
|
+
final_cols = int(final_shape.get("columns", 0))
|
|
2317
|
+
features_added = max(final_cols - original_cols, 0)
|
|
2318
|
+
features_removed = max(original_cols - final_cols, 0)
|
|
2319
|
+
transformations_applied = result.get("transformations_applied", [])
|
|
2320
|
+
|
|
2321
|
+
# Step 4: Comparison table
|
|
2322
|
+
comparison_table = Table(title="📊 Pipeline Comparison", show_header=False)
|
|
2323
|
+
comparison_table.add_column("Metric", style="cyan", width=28)
|
|
2324
|
+
comparison_table.add_column("Value", style="white")
|
|
2325
|
+
|
|
2326
|
+
comparison_table.add_row(
|
|
2327
|
+
"Original Shape",
|
|
2328
|
+
f"{original_shape.get('rows', '-')} rows × {original_shape.get('columns', '-')} cols"
|
|
2329
|
+
)
|
|
2330
|
+
comparison_table.add_row(
|
|
2331
|
+
"Final Shape",
|
|
2332
|
+
f"{final_shape.get('rows', '-')} rows × {final_shape.get('columns', '-')} cols"
|
|
2333
|
+
)
|
|
2334
|
+
comparison_table.add_row("Features Added", str(features_added))
|
|
2335
|
+
comparison_table.add_row("Features Removed", str(features_removed))
|
|
2336
|
+
comparison_table.add_row(
|
|
2337
|
+
"Transformations Applied",
|
|
2338
|
+
str(len(transformations_applied))
|
|
2339
|
+
)
|
|
2340
|
+
|
|
2341
|
+
transformation_names = [t.get("stage", "Unknown") for t in transformations_applied]
|
|
2342
|
+
comparison_table.add_row(
|
|
2343
|
+
"Transformation Stages",
|
|
2344
|
+
", ".join(transformation_names) if transformation_names else "-"
|
|
2345
|
+
)
|
|
2346
|
+
|
|
2347
|
+
console.print()
|
|
2348
|
+
console.print(comparison_table)
|
|
2349
|
+
|
|
2350
|
+
# Step 5: Top selected features
|
|
2351
|
+
selected_features = result.get("selected_features", [])
|
|
2352
|
+
feature_importance = result.get("feature_importance", {})
|
|
2353
|
+
|
|
2354
|
+
if selected_features:
|
|
2355
|
+
feature_table = Table(title="🏆 Top Selected Features")
|
|
2356
|
+
feature_table.add_column("Rank", style="cyan", width=6)
|
|
2357
|
+
feature_table.add_column("Feature", style="white")
|
|
2358
|
+
feature_table.add_column("Importance", style="green", justify="right")
|
|
2359
|
+
|
|
2360
|
+
if isinstance(feature_importance, dict) and feature_importance:
|
|
2361
|
+
ranked = []
|
|
2362
|
+
for feature in selected_features:
|
|
2363
|
+
score = feature_importance.get(feature)
|
|
2364
|
+
if score is None:
|
|
2365
|
+
score = feature_importance.get(str(feature))
|
|
2366
|
+
ranked.append((feature, score))
|
|
2367
|
+
|
|
2368
|
+
ranked.sort(key=lambda x: float(x[1]) if x[1] is not None else float("-inf"), reverse=True)
|
|
2369
|
+
top_features = ranked[:10]
|
|
2370
|
+
else:
|
|
2371
|
+
top_features = [(feature, None) for feature in selected_features[:10]]
|
|
2372
|
+
|
|
2373
|
+
for idx, (feature, score) in enumerate(top_features, start=1):
|
|
2374
|
+
score_text = "-" if score is None else f"{float(score):.6f}"
|
|
2375
|
+
feature_table.add_row(str(idx), str(feature), score_text)
|
|
2376
|
+
|
|
2377
|
+
console.print()
|
|
2378
|
+
console.print(feature_table)
|
|
2379
|
+
|
|
2380
|
+
# Step 6: Output file path
|
|
2381
|
+
output_file = result.get("output_path", output_path)
|
|
2382
|
+
console.print(f"\n💾 Pipeline output saved to: [cyan]{output_file}[/cyan]")
|
|
2383
|
+
|
|
2384
|
+
|
|
2385
|
+
@app.command()
|
|
2386
|
+
def tune(
|
|
2387
|
+
file_path: str = typer.Argument(..., help="Path to pre-processed dataset (CSV or Parquet)"),
|
|
2388
|
+
target_col: str = typer.Option(None, "--target", "-y", help="Target column name (required)"),
|
|
2389
|
+
model_type: str = typer.Option("xgboost", "--model", "-m", help="Model to tune: random_forest, xgboost, lightgbm"),
|
|
2390
|
+
trials: int = typer.Option(30, "--trials", "-n", help="Number of Optuna trials"),
|
|
2391
|
+
task_type: str = typer.Option("auto", "--task-type", help="Task type: auto, classification, regression"),
|
|
2392
|
+
output: str = typer.Option("./outputs/models/", "--output", "-o", help="Output directory or model file path")
|
|
2393
|
+
):
|
|
2394
|
+
"""
|
|
2395
|
+
Run Bayesian hyperparameter optimization with Optuna.
|
|
2396
|
+
|
|
2397
|
+
Example:
|
|
2398
|
+
ds-agent tune cleaned.csv --target Survived --model xgboost --trials 50
|
|
2399
|
+
"""
|
|
2400
|
+
from ds_agent.tools.advanced_training import hyperparameter_tuning
|
|
2401
|
+
from ds_agent.utils.polars_helpers import load_dataframe
|
|
2402
|
+
import ds_agent.tools.advanced_training as advanced_training_module
|
|
2403
|
+
import numpy as np
|
|
2404
|
+
from sklearn.model_selection import cross_val_score, StratifiedKFold, KFold
|
|
2405
|
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
|
2406
|
+
from xgboost import XGBClassifier, XGBRegressor
|
|
2407
|
+
|
|
2408
|
+
if not Path(file_path).exists():
|
|
2409
|
+
console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
|
|
2410
|
+
raise typer.Exit(1)
|
|
2411
|
+
|
|
2412
|
+
if not target_col:
|
|
2413
|
+
console.print("[red]✗ Error: --target (-y) is required[/red]")
|
|
2414
|
+
raise typer.Exit(1)
|
|
2415
|
+
|
|
2416
|
+
model_type = model_type.strip().lower()
|
|
2417
|
+
valid_models = {"random_forest", "xgboost", "lightgbm"}
|
|
2418
|
+
if model_type not in valid_models:
|
|
2419
|
+
console.print("[red]✗ Error: --model must be one of random_forest, xgboost, lightgbm[/red]")
|
|
2420
|
+
raise typer.Exit(1)
|
|
2421
|
+
|
|
2422
|
+
task_type = task_type.strip().lower()
|
|
2423
|
+
valid_task_types = {"auto", "classification", "regression"}
|
|
2424
|
+
if task_type not in valid_task_types:
|
|
2425
|
+
console.print("[red]✗ Error: --task-type must be one of auto, classification, regression[/red]")
|
|
2426
|
+
raise typer.Exit(1)
|
|
2427
|
+
|
|
2428
|
+
if trials <= 0:
|
|
2429
|
+
console.print("[red]✗ Error: --trials (-n) must be a positive integer[/red]")
|
|
2430
|
+
raise typer.Exit(1)
|
|
2431
|
+
|
|
2432
|
+
try:
|
|
2433
|
+
df_for_estimate = load_dataframe(file_path)
|
|
2434
|
+
dataset_rows = len(df_for_estimate)
|
|
2435
|
+
except Exception:
|
|
2436
|
+
dataset_rows = 0
|
|
2437
|
+
|
|
2438
|
+
base_estimated_seconds = int(trials * 10)
|
|
2439
|
+
size_factor = max(1.0, (dataset_rows / 50000.0)) if dataset_rows > 0 else 1.0
|
|
2440
|
+
estimated_seconds = int(base_estimated_seconds * size_factor)
|
|
2441
|
+
|
|
2442
|
+
# Resolve output path: directory or file path.
|
|
2443
|
+
output_path_obj = Path(output)
|
|
2444
|
+
if output_path_obj.suffix.lower() in {".pkl", ".joblib"}:
|
|
2445
|
+
model_output_path = output_path_obj
|
|
2446
|
+
model_output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2447
|
+
else:
|
|
2448
|
+
output_path_obj.mkdir(parents=True, exist_ok=True)
|
|
2449
|
+
model_output_path = output_path_obj / f"tuned_{model_type}_{Path(file_path).stem}.pkl"
|
|
2450
|
+
|
|
2451
|
+
# Step 1: Upfront warning panel
|
|
2452
|
+
console.print(Panel(
|
|
2453
|
+
f"⚠️ Hyperparameter tuning is compute-intensive.\n"
|
|
2454
|
+
f"Estimated time: {base_estimated_seconds} seconds. Press Ctrl+C to stop early.\n"
|
|
2455
|
+
f"Dataset-size adjusted estimate: {estimated_seconds} seconds.",
|
|
2456
|
+
title="Tuning Warning",
|
|
2457
|
+
border_style="yellow"
|
|
2458
|
+
))
|
|
2459
|
+
|
|
2460
|
+
baseline_score = None
|
|
2461
|
+
|
|
2462
|
+
def _compute_baseline_score() -> Optional[float]:
|
|
2463
|
+
try:
|
|
2464
|
+
df = load_dataframe(file_path).to_pandas()
|
|
2465
|
+
if target_col not in df.columns:
|
|
2466
|
+
return None
|
|
2467
|
+
|
|
2468
|
+
X = df.drop(columns=[target_col])
|
|
2469
|
+
y = df[target_col]
|
|
2470
|
+
X = X.select_dtypes(include=[np.number]).fillna(0)
|
|
2471
|
+
if X.shape[1] == 0:
|
|
2472
|
+
return None
|
|
2473
|
+
|
|
2474
|
+
inferred_task = task_type
|
|
2475
|
+
if inferred_task == "auto":
|
|
2476
|
+
inferred_task = "classification" if y.nunique() < 20 else "regression"
|
|
2477
|
+
|
|
2478
|
+
if model_type == "random_forest":
|
|
2479
|
+
model = RandomForestClassifier(random_state=42) if inferred_task == "classification" else RandomForestRegressor(random_state=42)
|
|
2480
|
+
elif model_type == "xgboost":
|
|
2481
|
+
model = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric="logloss") if inferred_task == "classification" else XGBRegressor(random_state=42)
|
|
2482
|
+
elif model_type == "lightgbm":
|
|
2483
|
+
from lightgbm import LGBMClassifier, LGBMRegressor
|
|
2484
|
+
model = LGBMClassifier(random_state=42, verbosity=-1) if inferred_task == "classification" else LGBMRegressor(random_state=42, verbosity=-1)
|
|
2485
|
+
else:
|
|
2486
|
+
return None
|
|
2487
|
+
|
|
2488
|
+
if inferred_task == "classification":
|
|
2489
|
+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
|
2490
|
+
scores = cross_val_score(model, X, y, cv=cv, scoring="accuracy", n_jobs=-1)
|
|
2491
|
+
else:
|
|
2492
|
+
cv = KFold(n_splits=5, shuffle=True, random_state=42)
|
|
2493
|
+
scores = cross_val_score(model, X, y, cv=cv, scoring="neg_root_mean_squared_error", n_jobs=-1)
|
|
2494
|
+
|
|
2495
|
+
return float(scores.mean())
|
|
2496
|
+
except Exception:
|
|
2497
|
+
return None
|
|
2498
|
+
|
|
2499
|
+
baseline_score = _compute_baseline_score()
|
|
2500
|
+
|
|
2501
|
+
trial_counter = {"count": 0}
|
|
2502
|
+
optuna_module = advanced_training_module.optuna
|
|
2503
|
+
original_optimize = optuna_module.study.study.Study.optimize
|
|
2504
|
+
|
|
2505
|
+
def _patched_optimize(self, func, *args, **kwargs):
|
|
2506
|
+
def _trial_callback(study, trial):
|
|
2507
|
+
trial_counter["count"] += 1
|
|
2508
|
+
progress.update(
|
|
2509
|
+
progress_task,
|
|
2510
|
+
completed=min(trial_counter["count"], trials),
|
|
2511
|
+
description=f"Running Optuna trials ({trial_counter['count']} of {trials})..."
|
|
2512
|
+
)
|
|
2513
|
+
|
|
2514
|
+
callbacks = list(kwargs.get("callbacks") or [])
|
|
2515
|
+
callbacks.append(_trial_callback)
|
|
2516
|
+
kwargs["callbacks"] = callbacks
|
|
2517
|
+
kwargs["show_progress_bar"] = False
|
|
2518
|
+
return original_optimize(self, func, *args, **kwargs)
|
|
2519
|
+
|
|
2520
|
+
try:
|
|
2521
|
+
with Progress(
|
|
2522
|
+
SpinnerColumn(),
|
|
2523
|
+
TextColumn("[progress.description]{task.description}"),
|
|
2524
|
+
BarColumn(),
|
|
2525
|
+
TextColumn("[progress.percentage]{task.completed:.0f}/{task.total:.0f}"),
|
|
2526
|
+
console=console
|
|
2527
|
+
) as progress:
|
|
2528
|
+
progress_task = progress.add_task("Running Optuna trials (0 of 0)...", total=trials)
|
|
2529
|
+
|
|
2530
|
+
optuna_module.study.study.Study.optimize = _patched_optimize
|
|
2531
|
+
tuning_result = hyperparameter_tuning(
|
|
2532
|
+
file_path,
|
|
2533
|
+
target_col,
|
|
2534
|
+
model_type,
|
|
2535
|
+
task_type,
|
|
2536
|
+
trials,
|
|
2537
|
+
output_path=str(model_output_path)
|
|
2538
|
+
)
|
|
2539
|
+
|
|
2540
|
+
progress.update(
|
|
2541
|
+
progress_task,
|
|
2542
|
+
completed=trials,
|
|
2543
|
+
description=f"Running Optuna trials ({trial_counter['count']} of {trials})..."
|
|
2544
|
+
)
|
|
2545
|
+
except KeyboardInterrupt:
|
|
2546
|
+
console.print("[yellow]⚠ Tuning stopped by user (Ctrl+C).[/yellow]")
|
|
2547
|
+
raise typer.Exit(1)
|
|
2548
|
+
except Exception as e:
|
|
2549
|
+
console.print(f"[red]✗ Hyperparameter tuning failed: {e}[/red]")
|
|
2550
|
+
raise typer.Exit(1)
|
|
2551
|
+
finally:
|
|
2552
|
+
optuna_module.study.study.Study.optimize = original_optimize
|
|
2553
|
+
|
|
2554
|
+
if tuning_result.get("status") != "success":
|
|
2555
|
+
console.print(f"[red]✗ Hyperparameter tuning failed: {tuning_result.get('message', 'Unknown error')}[/red]")
|
|
2556
|
+
raise typer.Exit(1)
|
|
2557
|
+
|
|
2558
|
+
best_params = tuning_result.get("best_params", {})
|
|
2559
|
+
best_score = tuning_result.get("best_cv_score")
|
|
2560
|
+
|
|
2561
|
+
improvement_pct = None
|
|
2562
|
+
if baseline_score is not None and best_score is not None and baseline_score != 0:
|
|
2563
|
+
improvement_pct = ((float(best_score) - float(baseline_score)) / abs(float(baseline_score))) * 100.0
|
|
2564
|
+
|
|
2565
|
+
# Step 4: Results table
|
|
2566
|
+
summary_table = Table(title="🎯 Tuning Results", show_header=False)
|
|
2567
|
+
summary_table.add_column("Metric", style="cyan", width=28)
|
|
2568
|
+
summary_table.add_column("Value", style="white")
|
|
2569
|
+
|
|
2570
|
+
summary_table.add_row("Best Parameters", json.dumps(best_params, indent=2) if best_params else "-")
|
|
2571
|
+
summary_table.add_row("Best Score", "-" if best_score is None else f"{float(best_score):.6f}")
|
|
2572
|
+
summary_table.add_row("Baseline Score", "-" if baseline_score is None else f"{float(baseline_score):.6f}")
|
|
2573
|
+
summary_table.add_row(
|
|
2574
|
+
"Improvement Percentage",
|
|
2575
|
+
"-" if improvement_pct is None else f"{improvement_pct:+.2f}%"
|
|
2576
|
+
)
|
|
2577
|
+
|
|
2578
|
+
console.print()
|
|
2579
|
+
console.print(summary_table)
|
|
2580
|
+
|
|
2581
|
+
# Step 5: Saved model path
|
|
2582
|
+
saved_model_path = tuning_result.get("model_path") or str(model_output_path)
|
|
2583
|
+
console.print(f"\n💾 Tuned model saved to: [cyan]{saved_model_path}[/cyan]")
|
|
2584
|
+
|
|
2585
|
+
|
|
2586
|
+
@app.command()
|
|
2587
|
+
def clean(
|
|
2588
|
+
file_path: str = typer.Argument(..., help="Path to dataset file"),
|
|
2589
|
+
output: str = typer.Option(None, "--output", "-o", help="Output file path"),
|
|
2590
|
+
strategy: str = typer.Option("auto", "--strategy", "-s", help="Cleaning strategy (auto/median/mean/mode/drop)")
|
|
2591
|
+
):
|
|
2592
|
+
"""
|
|
2593
|
+
Clean dataset (handle missing values and outliers).
|
|
2594
|
+
|
|
2595
|
+
Example:
|
|
2596
|
+
python cli.py clean data.csv --output cleaned_data.csv
|
|
2597
|
+
"""
|
|
2598
|
+
from ds_agent.tools.data_cleaning import clean_missing_values
|
|
2599
|
+
from ds_agent.tools.data_profiling import profile_dataset
|
|
2600
|
+
|
|
2601
|
+
if output is None:
|
|
2602
|
+
output = f"./outputs/data/cleaned_{Path(file_path).name}"
|
|
2603
|
+
|
|
2604
|
+
console.print(f"\n🧹 [bold]Cleaning:[/bold] {file_path}\n")
|
|
2605
|
+
|
|
2606
|
+
# Get columns with missing values
|
|
2607
|
+
profile = profile_dataset(file_path)
|
|
2608
|
+
cols_with_nulls = {
|
|
2609
|
+
col: "auto"
|
|
2610
|
+
for col, info in profile["columns"].items()
|
|
2611
|
+
if info["null_count"] > 0
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
if not cols_with_nulls:
|
|
2615
|
+
console.print("[green]✓ No missing values found - dataset is clean![/green]")
|
|
2616
|
+
return
|
|
2617
|
+
|
|
2618
|
+
console.print(f"Found {len(cols_with_nulls)} columns with missing values")
|
|
2619
|
+
|
|
2620
|
+
# Clean
|
|
2621
|
+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
|
|
2622
|
+
task = progress.add_task("Cleaning dataset...", total=None)
|
|
2623
|
+
result = clean_missing_values(file_path, cols_with_nulls, output)
|
|
2624
|
+
progress.update(task, completed=True)
|
|
2625
|
+
|
|
2626
|
+
console.print(f"\n[green]✓ Cleaned dataset saved to: {output}[/green]")
|
|
2627
|
+
console.print(f" Rows: {result['original_rows']} → {result['final_rows']}")
|
|
2628
|
+
|
|
2629
|
+
|
|
2630
|
+
@app.command()
|
|
2631
|
+
def train(
|
|
2632
|
+
file_path: str = typer.Argument(..., help="Path to prepared dataset"),
|
|
2633
|
+
target: str = typer.Argument(..., help="Target column name"),
|
|
2634
|
+
task_type: str = typer.Option("auto", "--task-type", help="Task type (classification/regression/auto)")
|
|
2635
|
+
):
|
|
2636
|
+
"""
|
|
2637
|
+
Train baseline models on prepared dataset.
|
|
2638
|
+
|
|
2639
|
+
Example:
|
|
2640
|
+
python cli.py train cleaned_data.csv Survived --task-type classification
|
|
2641
|
+
"""
|
|
2642
|
+
from ds_agent.tools.model_training import train_baseline_models
|
|
2643
|
+
|
|
2644
|
+
console.print(f"\n🤖 [bold]Training Models[/bold]\n")
|
|
2645
|
+
console.print(f"📊 Dataset: {file_path}")
|
|
2646
|
+
console.print(f"🎯 Target: {target}\n")
|
|
2647
|
+
|
|
2648
|
+
# Train
|
|
2649
|
+
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
|
|
2650
|
+
task = progress.add_task("Training baseline models...", total=None)
|
|
2651
|
+
result = train_baseline_models(file_path, target, task_type)
|
|
2652
|
+
progress.update(task, completed=True)
|
|
2653
|
+
|
|
2654
|
+
if "error" in result:
|
|
2655
|
+
console.print(f"[red]✗ Error: {result['message']}[/red]")
|
|
2656
|
+
raise typer.Exit(1)
|
|
2657
|
+
|
|
2658
|
+
# Display results
|
|
2659
|
+
console.print(f"\n[green]✓ Training Complete![/green]\n")
|
|
2660
|
+
console.print(f"Task Type: {result['task_type']}")
|
|
2661
|
+
console.print(f"Features: {result['n_features']}")
|
|
2662
|
+
console.print(f"Samples: {result['n_samples']}\n")
|
|
2663
|
+
|
|
2664
|
+
# Model comparison table
|
|
2665
|
+
table = Table(title="Model Performance")
|
|
2666
|
+
table.add_column("Model", style="cyan")
|
|
2667
|
+
|
|
2668
|
+
# Add metric columns based on task type
|
|
2669
|
+
if result["task_type"] == "classification":
|
|
2670
|
+
table.add_column("Accuracy", justify="right")
|
|
2671
|
+
table.add_column("F1 Score", justify="right")
|
|
2672
|
+
else:
|
|
2673
|
+
table.add_column("R² Score", justify="right")
|
|
2674
|
+
table.add_column("RMSE", justify="right")
|
|
2675
|
+
|
|
2676
|
+
for model_name, model_result in result["models"].items():
|
|
2677
|
+
if "test_metrics" in model_result:
|
|
2678
|
+
metrics = model_result["test_metrics"]
|
|
2679
|
+
if result["task_type"] == "classification":
|
|
2680
|
+
table.add_row(
|
|
2681
|
+
model_name,
|
|
2682
|
+
f"{metrics['accuracy']:.4f}",
|
|
2683
|
+
f"{metrics['f1']:.4f}"
|
|
2684
|
+
)
|
|
2685
|
+
else:
|
|
2686
|
+
table.add_row(
|
|
2687
|
+
model_name,
|
|
2688
|
+
f"{metrics['r2']:.4f}",
|
|
2689
|
+
f"{metrics['rmse']:.4f}"
|
|
2690
|
+
)
|
|
2691
|
+
|
|
2692
|
+
console.print(table)
|
|
2693
|
+
|
|
2694
|
+
# Best model
|
|
2695
|
+
console.print(f"\n🏆 [bold]Best Model:[/bold] {result['best_model']['name']}")
|
|
2696
|
+
console.print(f" Score: {result['best_model']['score']:.4f}")
|
|
2697
|
+
console.print(f" Path: {result['best_model']['model_path']}")
|
|
2698
|
+
|
|
2699
|
+
|
|
2700
|
+
@app.command()
|
|
2701
|
+
def cache_stats():
|
|
2702
|
+
"""Show cache statistics."""
|
|
2703
|
+
from ds_agent.orchestrator import DataScienceCopilot
|
|
2704
|
+
|
|
2705
|
+
copilot = DataScienceCopilot()
|
|
2706
|
+
stats = copilot.get_cache_stats()
|
|
2707
|
+
|
|
2708
|
+
table = Table(title="Cache Statistics")
|
|
2709
|
+
table.add_column("Metric", style="cyan")
|
|
2710
|
+
table.add_column("Value", style="white")
|
|
2711
|
+
|
|
2712
|
+
table.add_row("Total Entries", str(stats["total_entries"]))
|
|
2713
|
+
table.add_row("Valid Entries", str(stats["valid_entries"]))
|
|
2714
|
+
table.add_row("Expired Entries", str(stats["expired_entries"]))
|
|
2715
|
+
table.add_row("Size", f"{stats['size_mb']} MB")
|
|
2716
|
+
|
|
2717
|
+
console.print()
|
|
2718
|
+
console.print(table)
|
|
2719
|
+
|
|
2720
|
+
|
|
2721
|
+
@app.command()
|
|
2722
|
+
def clear_cache():
|
|
2723
|
+
"""Clear all cached results."""
|
|
2724
|
+
from ds_agent.orchestrator import DataScienceCopilot
|
|
2725
|
+
|
|
2726
|
+
copilot = DataScienceCopilot()
|
|
2727
|
+
copilot.clear_cache()
|
|
2728
|
+
console.print("[green]✓ Cache cleared successfully[/green]")
|
|
2729
|
+
|
|
2730
|
+
|
|
2731
|
+
@sessions_app.command("list")
|
|
2732
|
+
def sessions_list():
|
|
2733
|
+
"""List saved sessions."""
|
|
2734
|
+
from ds_agent.session_store import SessionStore
|
|
2735
|
+
|
|
2736
|
+
store = SessionStore()
|
|
2737
|
+
list_fn = getattr(store, "list_sessions", None)
|
|
2738
|
+
load_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
|
|
2739
|
+
|
|
2740
|
+
if not callable(list_fn):
|
|
2741
|
+
console.print("[red]✗ Session store does not provide list_sessions().[/red]")
|
|
2742
|
+
raise typer.Exit(1)
|
|
2743
|
+
|
|
2744
|
+
sessions = list_fn(limit=10000)
|
|
2745
|
+
if not sessions:
|
|
2746
|
+
console.print("[yellow]No saved sessions found.[/yellow]")
|
|
2747
|
+
return
|
|
2748
|
+
|
|
2749
|
+
table = Table(title="Saved Sessions")
|
|
2750
|
+
table.add_column("session_id", style="cyan")
|
|
2751
|
+
table.add_column("created_at", style="white")
|
|
2752
|
+
table.add_column("last_active", style="white")
|
|
2753
|
+
table.add_column("last_dataset", style="magenta")
|
|
2754
|
+
table.add_column("steps_completed", justify="right", style="green")
|
|
2755
|
+
|
|
2756
|
+
for item in sessions:
|
|
2757
|
+
sid = item.get("session_id", "-")
|
|
2758
|
+
created_at = item.get("created_at", "-")
|
|
2759
|
+
last_active = item.get("last_active", "-")
|
|
2760
|
+
|
|
2761
|
+
last_dataset = "-"
|
|
2762
|
+
steps_completed = 0
|
|
2763
|
+
|
|
2764
|
+
if callable(load_fn):
|
|
2765
|
+
try:
|
|
2766
|
+
session_obj = load_fn(sid)
|
|
2767
|
+
if session_obj is not None:
|
|
2768
|
+
last_dataset = getattr(session_obj, "last_dataset", None) or "-"
|
|
2769
|
+
workflow_history = getattr(session_obj, "workflow_history", []) or []
|
|
2770
|
+
steps_completed = len(workflow_history)
|
|
2771
|
+
except Exception:
|
|
2772
|
+
pass
|
|
2773
|
+
|
|
2774
|
+
table.add_row(
|
|
2775
|
+
str(sid),
|
|
2776
|
+
str(created_at),
|
|
2777
|
+
str(last_active),
|
|
2778
|
+
str(last_dataset),
|
|
2779
|
+
str(steps_completed)
|
|
2780
|
+
)
|
|
2781
|
+
|
|
2782
|
+
console.print()
|
|
2783
|
+
console.print(table)
|
|
2784
|
+
|
|
2785
|
+
|
|
2786
|
+
@sessions_app.command("resume")
|
|
2787
|
+
def sessions_resume(
|
|
2788
|
+
session_id: str = typer.Argument(..., help="Session ID to resume")
|
|
2789
|
+
):
|
|
2790
|
+
"""Load a saved session and print its context."""
|
|
2791
|
+
from ds_agent.session_store import SessionStore
|
|
2792
|
+
|
|
2793
|
+
store = SessionStore()
|
|
2794
|
+
load_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
|
|
2795
|
+
|
|
2796
|
+
if not callable(load_fn):
|
|
2797
|
+
console.print("[red]✗ Session store does not provide load_session().[/red]")
|
|
2798
|
+
raise typer.Exit(1)
|
|
2799
|
+
|
|
2800
|
+
session = load_fn(session_id)
|
|
2801
|
+
if session is None:
|
|
2802
|
+
console.print(f"[red]✗ Session not found: {session_id}[/red]")
|
|
2803
|
+
raise typer.Exit(1)
|
|
2804
|
+
|
|
2805
|
+
context_text = ""
|
|
2806
|
+
get_context_summary = getattr(session, "get_context_summary", None)
|
|
2807
|
+
if callable(get_context_summary):
|
|
2808
|
+
context_text = get_context_summary()
|
|
2809
|
+
|
|
2810
|
+
if not context_text:
|
|
2811
|
+
context_lines = [
|
|
2812
|
+
f"Session ID: {getattr(session, 'session_id', session_id)}",
|
|
2813
|
+
f"Created: {getattr(session, 'created_at', '-')}",
|
|
2814
|
+
f"Last Active: {getattr(session, 'last_active', '-')}",
|
|
2815
|
+
f"Last Dataset: {getattr(session, 'last_dataset', '-')}",
|
|
2816
|
+
f"Last Target: {getattr(session, 'last_target_col', '-')}",
|
|
2817
|
+
f"Last Model: {getattr(session, 'last_model', '-')}",
|
|
2818
|
+
f"Steps Completed: {len(getattr(session, 'workflow_history', []) or [])}",
|
|
2819
|
+
]
|
|
2820
|
+
context_text = "\n".join(context_lines)
|
|
2821
|
+
|
|
2822
|
+
console.print(Panel(
|
|
2823
|
+
context_text,
|
|
2824
|
+
title=f"Session Context: {session_id}",
|
|
2825
|
+
border_style="green"
|
|
2826
|
+
))
|
|
2827
|
+
|
|
2828
|
+
|
|
2829
|
+
@sessions_app.command("delete")
|
|
2830
|
+
def sessions_delete(
|
|
2831
|
+
session_id: str = typer.Argument(..., help="Session ID to delete")
|
|
2832
|
+
):
|
|
2833
|
+
"""Delete one saved session from SQLite."""
|
|
2834
|
+
from ds_agent.session_store import SessionStore
|
|
2835
|
+
|
|
2836
|
+
store = SessionStore()
|
|
2837
|
+
delete_fn = getattr(store, "delete_session", None) or getattr(store, "delete", None)
|
|
2838
|
+
|
|
2839
|
+
if not callable(delete_fn):
|
|
2840
|
+
console.print("[red]✗ Session store does not provide delete_session().[/red]")
|
|
2841
|
+
raise typer.Exit(1)
|
|
2842
|
+
|
|
2843
|
+
deleted = bool(delete_fn(session_id))
|
|
2844
|
+
if deleted:
|
|
2845
|
+
console.print(f"[green]✓ Deleted session: {session_id}[/green]")
|
|
2846
|
+
else:
|
|
2847
|
+
console.print(f"[yellow]Session not found: {session_id}[/yellow]")
|
|
2848
|
+
|
|
2849
|
+
|
|
2850
|
+
@sessions_app.command("clear")
|
|
2851
|
+
def sessions_clear():
|
|
2852
|
+
"""Delete all saved sessions."""
|
|
2853
|
+
from ds_agent.session_store import SessionStore
|
|
2854
|
+
|
|
2855
|
+
store = SessionStore()
|
|
2856
|
+
list_fn = getattr(store, "list_sessions", None)
|
|
2857
|
+
delete_fn = getattr(store, "delete_session", None) or getattr(store, "delete", None)
|
|
2858
|
+
|
|
2859
|
+
if not callable(list_fn):
|
|
2860
|
+
console.print("[red]✗ Session store does not provide list_sessions().[/red]")
|
|
2861
|
+
raise typer.Exit(1)
|
|
2862
|
+
if not callable(delete_fn):
|
|
2863
|
+
console.print("[red]✗ Session store does not provide delete_session().[/red]")
|
|
2864
|
+
raise typer.Exit(1)
|
|
2865
|
+
|
|
2866
|
+
sessions = list_fn(limit=10000)
|
|
2867
|
+
if not sessions:
|
|
2868
|
+
console.print("[yellow]No sessions to clear.[/yellow]")
|
|
2869
|
+
return
|
|
2870
|
+
|
|
2871
|
+
deleted_count = 0
|
|
2872
|
+
for item in sessions:
|
|
2873
|
+
sid = item.get("session_id")
|
|
2874
|
+
if sid and delete_fn(sid):
|
|
2875
|
+
deleted_count += 1
|
|
2876
|
+
|
|
2877
|
+
console.print(f"[green]✓ Cleared {deleted_count} session(s).[/green]")
|
|
2878
|
+
|
|
2879
|
+
|
|
2880
|
+
def main():
|
|
2881
|
+
"""Console entry point for ds-agent."""
|
|
2882
|
+
app()
|
|
2883
|
+
|
|
2884
|
+
|
|
2885
|
+
if __name__ == "__main__":
|
|
2886
|
+
main()
|