ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
package/src/cli.py ADDED
@@ -0,0 +1,2886 @@
1
+ """
2
+ Command Line Interface for Data Science Copilot
3
+ """
4
+
5
+ import typer
6
+ from rich.console import Console
7
+ from rich.table import Table
8
+ from rich.panel import Panel
9
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
10
+ from rich.live import Live
11
+ from rich import print as rprint
12
+ from pathlib import Path
13
+ import json
14
+ import sys
15
+ import os
16
+ import io
17
+ import contextlib
18
+ import importlib.util
19
+ from typing import Optional
20
+
21
+ # Add src to path
22
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
23
+
24
+ app = typer.Typer(
25
+ name="ds-agent",
26
+ help="AI-powered data science CLI for profiling, modeling, forecasting, and chat workflows.",
27
+ add_completion=False,
28
+ no_args_is_help=False,
29
+ rich_markup_mode="rich",
30
+ epilog=(
31
+ "Quick start: run [bold]ds-agent quickstart[/bold]\n"
32
+ "Examples: [bold]ds-agent analyze data.csv --target price[/bold] | "
33
+ "[bold]ds-agent chat data.csv[/bold]"
34
+ )
35
+ )
36
+
37
+ sessions_app = typer.Typer(help="Manage saved chat/analysis sessions")
38
+ app.add_typer(sessions_app, name="sessions")
39
+
40
+ console = Console()
41
+
42
+
43
+ def _print_cli_home():
44
+ """Render a first-run friendly home screen for users running the CLI without arguments."""
45
+ banner = r"""[bold cyan]
46
+ ____ ____ _ ____ _____ _ _ _____
47
+ | _ \/ ___| / \ / ___| ____| \ | |_ _|
48
+ | | | \___ \ / _ \| | _| _| | \| | | |
49
+ | |_| |___) | / ___ \ |_| | |___| |\ | | |
50
+ |____/|____/ /_/ \_\____|_____|_| \_| |_|
51
+ [/bold cyan]"""
52
+ console.print(banner)
53
+ console.print("[dim]The open data science agent CLI[/dim]\n")
54
+
55
+ table = Table(show_header=False, box=None, pad_edge=False)
56
+ table.add_column("Command", style="bold white", width=52, no_wrap=True)
57
+ table.add_column("What it does", style="white")
58
+
59
+ table.add_row("ds-agent quickstart", "Show examples for all major workflows")
60
+ table.add_row("ds-agent analyze <file>", "Run end-to-end automated analysis")
61
+ table.add_row("ds-agent chat [file]", "Start interactive session with memory")
62
+ table.add_row("ds-agent report <file>", "Generate full HTML profiling report")
63
+ table.add_row("ds-agent train <file> --target <col>", "Train baseline ML models")
64
+ table.add_row("ds-agent tune <file> --target <col>", "Hyperparameter optimization")
65
+ table.add_row("ds-agent plot <file>", "Generate charts and visualizations")
66
+ table.add_row("ds-agent forecast <file> --time <col> --target <col>", "Time-series forecasting")
67
+ table.add_row("ds-agent sessions list", "List and manage saved sessions")
68
+ table.add_row("ds-agent --help", "Full command reference")
69
+
70
+ console.print(table)
71
+ console.print("\n[dim]Tip: run [bold]ds-agent quickstart[/bold] for copy-paste commands.[/dim]")
72
+
73
+
74
+ @app.callback(invoke_without_command=True)
75
+ def root_callback(
76
+ ctx: typer.Context,
77
+ version: bool = typer.Option(False, "--version", "-v", help="Show CLI version and exit.")
78
+ ):
79
+ """Root CLI callback for startup UX and global flags."""
80
+ if version:
81
+ try:
82
+ from ds_agent import __version__
83
+ except Exception:
84
+ __version__ = "0.1.0"
85
+ console.print(f"ds-agent v{__version__}")
86
+ raise typer.Exit(0)
87
+
88
+ if ctx.invoked_subcommand is None:
89
+ _print_cli_home()
90
+ raise typer.Exit(0)
91
+
92
+
93
+ @app.command()
94
+ def quickstart():
95
+ """Show beginner-friendly command guide with copy-paste examples."""
96
+ guide = Table(title="⚡ Quickstart Guide")
97
+ guide.add_column("Workflow", style="cyan", width=24, no_wrap=True)
98
+ guide.add_column("Command", style="white", no_wrap=True)
99
+
100
+ guide.add_row("Analyze", "ds-agent analyze data.csv --target price --task \"predict price\"")
101
+ guide.add_row("Chat", "ds-agent chat data.csv")
102
+ guide.add_row("Profile Report", "ds-agent report data.csv --engine ydata --output ./outputs/reports")
103
+ guide.add_row("Model Training", "ds-agent train data.csv --target label --task-type classification")
104
+ guide.add_row("Hyperparameter Tuning", "ds-agent tune data.csv --target label --model xgboost --trials 50")
105
+ guide.add_row("Visualizations", "ds-agent plot data.csv --type heatmap")
106
+ guide.add_row("Compare Datasets", "ds-agent compare train.csv test.csv --label1 train --label2 test")
107
+ guide.add_row("Forecast", "ds-agent forecast sales.csv --time date --target revenue --horizon 30")
108
+ guide.add_row("NLP", "ds-agent nlp reviews.csv --text review --task sentiment")
109
+ guide.add_row("Session Management", "ds-agent sessions list | ds-agent sessions resume <session_id>")
110
+
111
+ console.print()
112
+ console.print(guide)
113
+ console.print("\n[dim]Use [bold]ds-agent <command> --help[/bold] for detailed options and examples.[/dim]")
114
+
115
+
116
+ def _load_module_from_src(module_name: str, relative_path: str):
117
+ """Load a module directly from src to avoid triggering package-wide __init__ side effects."""
118
+ module_path = Path(__file__).resolve().parent / relative_path
119
+ spec = importlib.util.spec_from_file_location(module_name, str(module_path))
120
+ if spec is None or spec.loader is None:
121
+ raise ImportError(f"Unable to load module at {module_path}")
122
+ module = importlib.util.module_from_spec(spec)
123
+ spec.loader.exec_module(module)
124
+ return module
125
+
126
+
127
+ @app.command()
128
+ def analyze(
129
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
130
+ task: str = typer.Option(
131
+ "Complete data science workflow: profile, clean, engineer features, and train models",
132
+ "--task", "-t",
133
+ help="Description of the analysis task"
134
+ ),
135
+ target: str = typer.Option(None, "--target", "-y", help="Target column name for prediction"),
136
+ output: str = typer.Option("./outputs", "--output", "-o", help="Output directory"),
137
+ no_cache: bool = typer.Option(False, "--no-cache", help="Disable caching"),
138
+ reasoning: str = typer.Option("medium", "--reasoning", "-r", help="Reasoning effort (low/medium/high)")
139
+ ):
140
+ """
141
+ Analyze a dataset and perform complete data science workflow.
142
+
143
+ Example:
144
+ python cli.py analyze data.csv --target Survived --task "Predict survival"
145
+ """
146
+ console.print(Panel.fit(
147
+ "🤖 Data Science Copilot - AI-Powered Analysis",
148
+ style="bold blue"
149
+ ))
150
+
151
+ # Validate file exists
152
+ if not Path(file_path).exists():
153
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
154
+ raise typer.Exit(1)
155
+
156
+ # Initialize copilot
157
+ try:
158
+ from ds_agent.orchestrator import DataScienceCopilot
159
+
160
+ with Progress(
161
+ SpinnerColumn(),
162
+ TextColumn("[progress.description]{task.description}"),
163
+ console=console
164
+ ) as progress:
165
+ task_init = progress.add_task("Initializing Data Science Copilot...", total=None)
166
+ copilot = DataScienceCopilot(reasoning_effort=reasoning)
167
+ progress.update(task_init, completed=True)
168
+
169
+ except Exception as e:
170
+ console.print(f"[red]✗ Error initializing copilot: {e}[/red]")
171
+ console.print("[yellow]Make sure GROQ_API_KEY is set in .env file[/yellow]")
172
+ raise typer.Exit(1)
173
+
174
+ # Run analysis
175
+ console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
176
+ console.print(f"🎯 [bold]Task:[/bold] {task}")
177
+ if target:
178
+ console.print(f"🎲 [bold]Target:[/bold] {target}")
179
+ console.print()
180
+
181
+ try:
182
+ with Progress(
183
+ SpinnerColumn(),
184
+ TextColumn("[progress.description]{task.description}"),
185
+ console=console
186
+ ) as progress:
187
+ task_analyze = progress.add_task("Running analysis workflow...", total=None)
188
+
189
+ result = copilot.analyze(
190
+ file_path=file_path,
191
+ task_description=task,
192
+ target_col=target,
193
+ use_cache=not no_cache
194
+ )
195
+
196
+ progress.update(task_analyze, completed=True)
197
+
198
+ except Exception as e:
199
+ console.print(f"\n[red]✗ Analysis failed: {e}[/red]")
200
+ raise typer.Exit(1)
201
+
202
+ # Display results
203
+ if result["status"] == "success":
204
+ console.print("\n[green]✓ Analysis Complete![/green]\n")
205
+
206
+ # Summary
207
+ console.print(Panel(
208
+ result["summary"],
209
+ title="📋 Analysis Summary",
210
+ border_style="green"
211
+ ))
212
+
213
+ # Workflow history
214
+ console.print("\n[bold]🔧 Tools Executed:[/bold]")
215
+ for step in result["workflow_history"]:
216
+ tool_name = step["tool"]
217
+ success = step["result"].get("success", False)
218
+ icon = "✓" if success else "✗"
219
+ color = "green" if success else "red"
220
+ console.print(f" [{color}]{icon}[/{color}] {tool_name}")
221
+
222
+ # Stats
223
+ stats_table = Table(title="📊 Execution Statistics", show_header=False)
224
+ stats_table.add_column("Metric", style="cyan")
225
+ stats_table.add_column("Value", style="white")
226
+
227
+ stats_table.add_row("Iterations", str(result["iterations"]))
228
+ stats_table.add_row("API Calls", str(result["api_calls"]))
229
+ stats_table.add_row("Execution Time", f"{result['execution_time']}s")
230
+
231
+ console.print()
232
+ console.print(stats_table)
233
+
234
+ # Save full report
235
+ report_path = Path(output) / "reports" / f"analysis_{Path(file_path).stem}.json"
236
+ report_path.parent.mkdir(parents=True, exist_ok=True)
237
+ with open(report_path, "w") as f:
238
+ json.dump(result, f, indent=2)
239
+
240
+ console.print(f"\n💾 Full report saved to: [cyan]{report_path}[/cyan]")
241
+
242
+ elif result["status"] == "error":
243
+ console.print(f"\n[red]✗ Error: {result['error']}[/red]")
244
+ raise typer.Exit(1)
245
+
246
+ else:
247
+ console.print(f"\n[yellow]⚠ Analysis incomplete: {result.get('message')}[/yellow]")
248
+
249
+
250
+ @app.command()
251
+ def profile(
252
+ file_path: str = typer.Argument(..., help="Path to dataset file")
253
+ ):
254
+ """
255
+ Quick profile of a dataset (basic statistics and quality checks).
256
+
257
+ Example:
258
+ python cli.py profile data.csv
259
+ """
260
+ from ds_agent.tools.data_profiling import profile_dataset, detect_data_quality_issues
261
+
262
+ console.print(f"\n📊 [bold]Profiling:[/bold] {file_path}\n")
263
+
264
+ # Profile
265
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
266
+ task1 = progress.add_task("Analyzing dataset...", total=None)
267
+ profile = profile_dataset(file_path)
268
+ progress.update(task1, completed=True)
269
+
270
+ # Display basic info
271
+ info_table = Table(title="Dataset Information", show_header=False)
272
+ info_table.add_column("Property", style="cyan")
273
+ info_table.add_column("Value", style="white")
274
+
275
+ info_table.add_row("Rows", str(profile["shape"]["rows"]))
276
+ info_table.add_row("Columns", str(profile["shape"]["columns"]))
277
+ info_table.add_row("Memory", f"{profile['memory_usage']['total_mb']} MB")
278
+ info_table.add_row("Null %", f"{profile['overall_stats']['null_percentage']}%")
279
+ info_table.add_row("Duplicates", str(profile['overall_stats']['duplicate_rows']))
280
+
281
+ console.print()
282
+ console.print(info_table)
283
+
284
+ # Column types
285
+ console.print("\n[bold]Column Types:[/bold]")
286
+ console.print(f" Numeric: {len(profile['column_types']['numeric'])}")
287
+ console.print(f" Categorical: {len(profile['column_types']['categorical'])}")
288
+ console.print(f" Datetime: {len(profile['column_types']['datetime'])}")
289
+
290
+ # Detect issues
291
+ console.print("\n[bold]Quality Check:[/bold]")
292
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
293
+ task2 = progress.add_task("Detecting quality issues...", total=None)
294
+ issues = detect_data_quality_issues(file_path)
295
+ progress.update(task2, completed=True)
296
+
297
+ console.print(f" 🔴 Critical: {issues['summary']['critical_count']}")
298
+ console.print(f" 🟡 Warnings: {issues['summary']['warning_count']}")
299
+ console.print(f" 🔵 Info: {issues['summary']['info_count']}")
300
+
301
+
302
+ @app.command()
303
+ def eda(
304
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
305
+ target: str = typer.Option(None, "--target", "-y", help="Target column name for supervised EDA"),
306
+ output: str = typer.Option("./outputs/eda/", "--output", "-o", help="Output directory for EDA artifacts"),
307
+ full: bool = typer.Option(False, "--full", help="Run deeper EDA analysis")
308
+ ):
309
+ """
310
+ Run exploratory data analysis with quality checks and optional root-cause insights.
311
+
312
+ Example:
313
+ ds-agent eda sales.csv
314
+ ds-agent eda sales.csv --target revenue --full
315
+ """
316
+ from ds_agent.tools.advanced_analysis import perform_eda_analysis
317
+ from ds_agent.tools.data_profiling import detect_data_quality_issues
318
+ from ds_agent.tools.advanced_insights import analyze_root_cause
319
+
320
+ console.print(Panel.fit(
321
+ "📈 Exploratory Data Analysis",
322
+ style="bold blue"
323
+ ))
324
+
325
+ # Validate file exists
326
+ if not Path(file_path).exists():
327
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
328
+ raise typer.Exit(1)
329
+
330
+ output_dir = Path(output)
331
+ output_dir.mkdir(parents=True, exist_ok=True)
332
+ output_html = f"{str(output_dir).rstrip('/')}/eda.html"
333
+
334
+ console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
335
+ if target:
336
+ console.print(f"🎯 [bold]Target:[/bold] {target}")
337
+ console.print(f"🧪 [bold]Depth:[/bold] {'Full' if full else 'Standard'}")
338
+ console.print()
339
+
340
+ try:
341
+ # Step 1: Perform EDA analysis
342
+ with Progress(
343
+ SpinnerColumn(),
344
+ TextColumn("[progress.description]{task.description}"),
345
+ console=console
346
+ ) as progress:
347
+ task_eda = progress.add_task("Running comprehensive EDA...", total=None)
348
+ eda_result = perform_eda_analysis(
349
+ file_path,
350
+ target_col=target,
351
+ output_html=output_html
352
+ )
353
+ progress.update(task_eda, completed=True)
354
+
355
+ # Step 2: Detect quality issues
356
+ with Progress(
357
+ SpinnerColumn(),
358
+ TextColumn("[progress.description]{task.description}"),
359
+ console=console
360
+ ) as progress:
361
+ task_quality = progress.add_task("Detecting data quality issues...", total=None)
362
+ quality_result = detect_data_quality_issues(file_path)
363
+ progress.update(task_quality, completed=True)
364
+
365
+ # Step 3: Root cause analysis (only when target is provided)
366
+ root_cause_result = None
367
+ if target:
368
+ with Progress(
369
+ SpinnerColumn(),
370
+ TextColumn("[progress.description]{task.description}"),
371
+ console=console
372
+ ) as progress:
373
+ task_root = progress.add_task("Analyzing root-cause factors...", total=None)
374
+ if full:
375
+ root_cause_result = analyze_root_cause(
376
+ file_path,
377
+ target_col=target,
378
+ threshold_drop=0.10
379
+ )
380
+ else:
381
+ root_cause_result = analyze_root_cause(file_path, target_col=target)
382
+ progress.update(task_root, completed=True)
383
+
384
+ except Exception as e:
385
+ console.print(f"\n[red]✗ EDA failed: {e}[/red]")
386
+ raise typer.Exit(1)
387
+
388
+ # Prepare summary fields
389
+ dataset_shape = eda_result.get("dataset_shape", {})
390
+ column_types = eda_result.get("column_types", {})
391
+
392
+ ranked_issues = (
393
+ quality_result.get("critical", [])
394
+ + quality_result.get("warning", [])
395
+ + quality_result.get("info", [])
396
+ )
397
+ top_issues = ranked_issues[:3]
398
+
399
+ top_issue_lines = []
400
+ for issue in top_issues:
401
+ msg = issue.get("message", "Issue detected")
402
+ top_issue_lines.append(f"- {msg}")
403
+ if not top_issue_lines:
404
+ top_issue_lines.append("- No major quality issues detected")
405
+
406
+ top_corr_lines = []
407
+ if target and isinstance(root_cause_result, dict):
408
+ corr_map = root_cause_result.get("correlations", {})
409
+ if isinstance(corr_map, dict) and corr_map:
410
+ top_corrs = sorted(corr_map.items(), key=lambda x: abs(x[1]), reverse=True)[:3]
411
+ for feature, corr in top_corrs:
412
+ top_corr_lines.append(f"- {feature}: {corr:.3f}")
413
+ if target and not top_corr_lines:
414
+ top_corr_lines.append("- No strong target correlations detected")
415
+
416
+ # Step 4: Rich summary table
417
+ summary_table = Table(title="📊 EDA Summary", show_header=False)
418
+ summary_table.add_column("Metric", style="cyan", width=28)
419
+ summary_table.add_column("Value", style="white")
420
+
421
+ summary_table.add_row("Rows", str(dataset_shape.get("rows", "-")))
422
+ summary_table.add_row("Columns", str(dataset_shape.get("columns", "-")))
423
+ summary_table.add_row("Numeric Columns", str(column_types.get("numeric", "-")))
424
+ summary_table.add_row("Categorical Columns", str(column_types.get("categorical", "-")))
425
+ summary_table.add_row("Top 3 Quality Issues", "\n".join(top_issue_lines))
426
+ if target:
427
+ summary_table.add_row("Top 3 Correlations with Target", "\n".join(top_corr_lines))
428
+
429
+ console.print()
430
+ console.print(summary_table)
431
+
432
+ # Step 5: Report path
433
+ console.print(f"\n💾 HTML report saved to: [cyan]{output_html}[/cyan]")
434
+
435
+ quality_summary = quality_result.get("summary", {})
436
+ final_lines = [
437
+ f"Rows: {dataset_shape.get('rows', '-')}",
438
+ f"Columns: {dataset_shape.get('columns', '-')}",
439
+ f"Issues found: {quality_summary.get('total_issues', 0)}",
440
+ f"Mode: {'Full' if full else 'Standard'}"
441
+ ]
442
+ if target:
443
+ final_lines.append(f"Target analyzed: {target}")
444
+
445
+ console.print(Panel(
446
+ "\n".join(final_lines),
447
+ title="✅ EDA Complete",
448
+ border_style="green"
449
+ ))
450
+
451
+
452
+ @app.command()
453
+ def report(
454
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
455
+ engine: str = typer.Option("ydata", "--engine", help="Report engine: ydata or sweetviz"),
456
+ target: str = typer.Option(None, "--target", "-y", help="Optional target column for supervised analysis"),
457
+ compare: str = typer.Option(None, "--compare", help="Optional second dataset path for Sweetviz comparison"),
458
+ output: str = typer.Option("./outputs/reports/", "--output", "-o", help="Output directory for HTML reports"),
459
+ minimal: bool = typer.Option(False, "--minimal", help="Use minimal mode for large files")
460
+ ):
461
+ """
462
+ Generate a full HTML profiling report using ydata-profiling or Sweetviz.
463
+
464
+ Example:
465
+ ds-agent report sales.csv
466
+ ds-agent report sales.csv --engine sweetviz --target revenue
467
+ ds-agent report sales.csv --compare test.csv --engine sweetviz
468
+ """
469
+ eda_reports_module = _load_module_from_src("ds_agent_cli_eda_reports", "tools/eda_reports.py")
470
+ generate_ydata_profiling_report = eda_reports_module.generate_ydata_profiling_report
471
+ generate_sweetviz_report = eda_reports_module.generate_sweetviz_report
472
+ import webbrowser
473
+
474
+ console.print(Panel.fit(
475
+ "📄 HTML Profiling Report",
476
+ style="bold blue"
477
+ ))
478
+
479
+ if not Path(file_path).exists():
480
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
481
+ raise typer.Exit(1)
482
+
483
+ engine_normalized = engine.strip().lower()
484
+ if engine_normalized not in {"ydata", "sweetviz"}:
485
+ console.print("[red]✗ Error: --engine must be either 'ydata' or 'sweetviz'[/red]")
486
+ raise typer.Exit(1)
487
+
488
+ if compare and not Path(compare).exists():
489
+ console.print(f"[red]✗ Error: Compare file not found: {compare}[/red]")
490
+ raise typer.Exit(1)
491
+
492
+ output_dir = Path(output)
493
+ output_dir.mkdir(parents=True, exist_ok=True)
494
+
495
+ report_filename = "ydata_profile.html" if engine_normalized == "ydata" else "sweetviz_report.html"
496
+ output_path = str(output_dir / report_filename)
497
+ report_title = f"{Path(file_path).stem} - Profiling Report"
498
+
499
+ console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
500
+ console.print(f"⚙️ [bold]Engine:[/bold] {engine_normalized}")
501
+ if target:
502
+ console.print(f"🎯 [bold]Target:[/bold] {target}")
503
+ if compare:
504
+ console.print(f"🔁 [bold]Compare:[/bold] {compare}")
505
+ console.print("⏳ [yellow]Generating report may take 1-2 minutes...[/yellow]\n")
506
+
507
+ try:
508
+ with Progress(
509
+ SpinnerColumn(),
510
+ TextColumn("[progress.description]{task.description}"),
511
+ console=console
512
+ ) as progress:
513
+ task_report = progress.add_task("Generating HTML report...", total=None)
514
+
515
+ if engine_normalized == "ydata":
516
+ result = generate_ydata_profiling_report(
517
+ file_path,
518
+ output_path,
519
+ minimal,
520
+ report_title
521
+ )
522
+ else:
523
+ result = generate_sweetviz_report(
524
+ file_path,
525
+ target,
526
+ compare,
527
+ output_path
528
+ )
529
+
530
+ progress.update(task_report, completed=True)
531
+
532
+ except Exception as e:
533
+ console.print(f"[red]✗ Report generation failed: {e}[/red]")
534
+ raise typer.Exit(1)
535
+
536
+ if not result.get("success", False):
537
+ console.print(f"[red]✗ Report generation failed: {result.get('error', 'Unknown error')}[/red]")
538
+ raise typer.Exit(1)
539
+
540
+ report_path = result.get("report_path", output_path)
541
+
542
+ console.print(Panel(
543
+ f"✅ Report generated successfully!\n\nPath: {report_path}",
544
+ title="Report Ready",
545
+ border_style="green"
546
+ ))
547
+
548
+ try:
549
+ webbrowser.open(str(Path(report_path).resolve().as_uri()))
550
+ except Exception:
551
+ pass
552
+
553
+
554
+ @app.command()
555
+ def compare(
556
+ file1: str = typer.Argument(..., help="First dataset path (CSV or Parquet)"),
557
+ file2: str = typer.Argument(..., help="Second dataset path (CSV or Parquet)"),
558
+ label1: str = typer.Option("Dataset A", "--label1", help="Display label for first dataset"),
559
+ label2: str = typer.Option("Dataset B", "--label2", help="Display label for second dataset"),
560
+ output: str = typer.Option("./outputs/compare/", "--output", "-o", help="Output directory for comparison reports")
561
+ ):
562
+ """
563
+ Compare two datasets side-by-side for schema and distribution differences.
564
+
565
+ Example:
566
+ ds-agent compare train.csv test.csv
567
+ ds-agent compare raw.csv cleaned.csv --label1 "Before" --label2 "After"
568
+ """
569
+ import math
570
+ from datetime import datetime
571
+ import polars as pl
572
+ from ds_agent.utils.polars_helpers import load_dataframe
573
+ eda_reports_module = _load_module_from_src("ds_agent_cli_eda_reports_compare", "tools/eda_reports.py")
574
+ generate_sweetviz_report = eda_reports_module.generate_sweetviz_report
575
+
576
+ if not Path(file1).exists():
577
+ console.print(f"[red]✗ Error: File not found: {file1}[/red]")
578
+ raise typer.Exit(1)
579
+ if not Path(file2).exists():
580
+ console.print(f"[red]✗ Error: File not found: {file2}[/red]")
581
+ raise typer.Exit(1)
582
+
583
+ output_dir = Path(output)
584
+ output_dir.mkdir(parents=True, exist_ok=True)
585
+
586
+ console.print(Panel.fit(
587
+ f"🔍 Dataset Compare\n{label1} vs {label2}",
588
+ style="bold blue"
589
+ ))
590
+
591
+ def _fmt_num(value, digits: int = 4) -> str:
592
+ if value is None:
593
+ return "-"
594
+ try:
595
+ value = float(value)
596
+ except Exception:
597
+ return "-"
598
+ if math.isnan(value) or math.isinf(value):
599
+ return "-"
600
+ return f"{value:.{digits}f}"
601
+
602
+ def _col_stats(df: pl.DataFrame, col_name: str) -> dict:
603
+ series = df[col_name]
604
+ row_count = len(df)
605
+ null_count = series.null_count()
606
+ null_pct = (null_count / row_count * 100.0) if row_count > 0 else 0.0
607
+
608
+ is_numeric = series.dtype in pl.NUMERIC_DTYPES
609
+ mean_val = series.mean() if is_numeric else None
610
+ median_val = series.median() if is_numeric else None
611
+ std_val = series.std() if is_numeric else None
612
+
613
+ return {
614
+ "dtype": str(series.dtype),
615
+ "is_numeric": is_numeric,
616
+ "mean": float(mean_val) if mean_val is not None else None,
617
+ "median": float(median_val) if median_val is not None else None,
618
+ "std": float(std_val) if std_val is not None else None,
619
+ "null_pct": round(null_pct, 2),
620
+ "unique": int(series.n_unique())
621
+ }
622
+
623
+ with Progress(
624
+ SpinnerColumn(),
625
+ TextColumn("[progress.description]{task.description}"),
626
+ console=console
627
+ ) as progress:
628
+ task = progress.add_task("Loading datasets with Polars...", total=None)
629
+ try:
630
+ df_a = load_dataframe(file1)
631
+ df_b = load_dataframe(file2)
632
+ except Exception as e:
633
+ console.print(f"[red]✗ Failed to load datasets: {e}[/red]")
634
+ raise typer.Exit(1)
635
+ progress.update(task, completed=True)
636
+
637
+ cols_a = set(df_a.columns)
638
+ cols_b = set(df_b.columns)
639
+
640
+ shared_cols = sorted(cols_a.intersection(cols_b))
641
+ only_a = sorted(cols_a - cols_b)
642
+ only_b = sorted(cols_b - cols_a)
643
+
644
+ comparison_rows = []
645
+ numeric_shift_rows = []
646
+
647
+ with Progress(
648
+ SpinnerColumn(),
649
+ TextColumn("[progress.description]{task.description}"),
650
+ console=console
651
+ ) as progress:
652
+ task = progress.add_task("Computing side-by-side comparison...", total=None)
653
+ for col in shared_cols:
654
+ stats_a = _col_stats(df_a, col)
655
+ stats_b = _col_stats(df_b, col)
656
+
657
+ mean_diff = None
658
+ std_diff = None
659
+ if stats_a["is_numeric"] and stats_b["is_numeric"]:
660
+ if stats_a["mean"] is not None and stats_b["mean"] is not None:
661
+ mean_diff = stats_a["mean"] - stats_b["mean"]
662
+ if stats_a["std"] is not None and stats_b["std"] is not None:
663
+ std_diff = stats_a["std"] - stats_b["std"]
664
+
665
+ comparison_rows.append({
666
+ "column": col,
667
+ "a": stats_a,
668
+ "b": stats_b,
669
+ "mean_diff": mean_diff,
670
+ "std_diff": std_diff
671
+ })
672
+
673
+ if mean_diff is not None or std_diff is not None:
674
+ numeric_shift_rows.append({
675
+ "column": col,
676
+ "mean_diff": mean_diff,
677
+ "std_diff": std_diff
678
+ })
679
+ progress.update(task, completed=True)
680
+
681
+ stats_panel = Panel(
682
+ "\n".join([
683
+ f"Shared columns: {len(shared_cols)}",
684
+ f"Columns only in {label1}: {len(only_a)}",
685
+ f"Columns only in {label2}: {len(only_b)}",
686
+ f"Rows in {label1}: {len(df_a)}",
687
+ f"Rows in {label2}: {len(df_b)}"
688
+ ]),
689
+ title="Overall Stats",
690
+ border_style="cyan"
691
+ )
692
+ console.print()
693
+ console.print(stats_panel)
694
+
695
+ comparison_table = Table(title=f"📋 Shared Column Comparison ({label1} vs {label2})")
696
+ comparison_table.add_column("Column", style="bold")
697
+ comparison_table.add_column(f"Type ({label1})", style="cyan")
698
+ comparison_table.add_column(f"Type ({label2})", style="cyan")
699
+ comparison_table.add_column(f"Mean ({label1})", justify="right")
700
+ comparison_table.add_column(f"Mean ({label2})", justify="right")
701
+ comparison_table.add_column(f"Median ({label1})", justify="right")
702
+ comparison_table.add_column(f"Median ({label2})", justify="right")
703
+ comparison_table.add_column(f"Null % ({label1})", justify="right")
704
+ comparison_table.add_column(f"Null % ({label2})", justify="right")
705
+ comparison_table.add_column(f"Unique ({label1})", justify="right")
706
+ comparison_table.add_column(f"Unique ({label2})", justify="right")
707
+ comparison_table.add_column("Mean Diff", justify="right")
708
+ comparison_table.add_column("Std Diff", justify="right")
709
+
710
+ for row in comparison_rows:
711
+ a_stats = row["a"]
712
+ b_stats = row["b"]
713
+ comparison_table.add_row(
714
+ row["column"],
715
+ a_stats["dtype"],
716
+ b_stats["dtype"],
717
+ _fmt_num(a_stats["mean"]),
718
+ _fmt_num(b_stats["mean"]),
719
+ _fmt_num(a_stats["median"]),
720
+ _fmt_num(b_stats["median"]),
721
+ _fmt_num(a_stats["null_pct"], 2),
722
+ _fmt_num(b_stats["null_pct"], 2),
723
+ str(a_stats["unique"]),
724
+ str(b_stats["unique"]),
725
+ _fmt_num(row["mean_diff"]),
726
+ _fmt_num(row["std_diff"])
727
+ )
728
+
729
+ console.print()
730
+ console.print(comparison_table)
731
+
732
+ if only_a or only_b:
733
+ missing_table = Table(title="🚨 Columns Not Shared", border_style="red")
734
+ missing_table.add_column("Column", style="red")
735
+ missing_table.add_column("Present In", style="red")
736
+ missing_table.add_column("Missing In", style="red")
737
+
738
+ for col in only_a:
739
+ missing_table.add_row(col, label1, label2, style="red")
740
+ for col in only_b:
741
+ missing_table.add_row(col, label2, label1, style="red")
742
+
743
+ console.print()
744
+ console.print(missing_table)
745
+
746
+ if numeric_shift_rows:
747
+ shift_table = Table(title="📈 Numeric Distribution Shift")
748
+ shift_table.add_column("Column", style="bold")
749
+ shift_table.add_column("Mean Diff", justify="right", style="yellow")
750
+ shift_table.add_column("Std Diff", justify="right", style="yellow")
751
+
752
+ for item in numeric_shift_rows:
753
+ shift_table.add_row(
754
+ item["column"],
755
+ _fmt_num(item["mean_diff"]),
756
+ _fmt_num(item["std_diff"])
757
+ )
758
+
759
+ console.print()
760
+ console.print(shift_table)
761
+
762
+ report_data = {
763
+ "generated_at": datetime.utcnow().isoformat() + "Z",
764
+ "dataset_a": {
765
+ "label": label1,
766
+ "path": file1,
767
+ "rows": len(df_a),
768
+ "columns": len(df_a.columns)
769
+ },
770
+ "dataset_b": {
771
+ "label": label2,
772
+ "path": file2,
773
+ "rows": len(df_b),
774
+ "columns": len(df_b.columns)
775
+ },
776
+ "summary": {
777
+ "shared_columns_count": len(shared_cols),
778
+ "columns_only_in_a_count": len(only_a),
779
+ "columns_only_in_b_count": len(only_b),
780
+ "rows_a": len(df_a),
781
+ "rows_b": len(df_b)
782
+ },
783
+ "columns_only_in_a": only_a,
784
+ "columns_only_in_b": only_b,
785
+ "shared_columns": comparison_rows,
786
+ "numeric_distribution_shift": numeric_shift_rows
787
+ }
788
+
789
+ report_path = output_dir / f"compare_{Path(file1).stem}_vs_{Path(file2).stem}.json"
790
+ with open(report_path, "w") as f:
791
+ json.dump(report_data, f, indent=2)
792
+
793
+ console.print(f"\n💾 Comparison JSON report saved to: [cyan]{report_path}[/cyan]")
794
+
795
+ if output:
796
+ sweetviz_path = output_dir / "sweetviz_compare_report.html"
797
+ with Progress(
798
+ SpinnerColumn(),
799
+ TextColumn("[progress.description]{task.description}"),
800
+ console=console
801
+ ) as progress:
802
+ task = progress.add_task("Generating Sweetviz compare report...", total=None)
803
+ sweetviz_result = generate_sweetviz_report(
804
+ file_path=file1,
805
+ target_col=None,
806
+ compare_file_path=file2,
807
+ output_path=str(sweetviz_path),
808
+ title=f"{label1} vs {label2} - Sweetviz Comparison"
809
+ )
810
+ progress.update(task, completed=True)
811
+
812
+ if sweetviz_result.get("success", False):
813
+ console.print(f"🌐 Sweetviz compare report saved to: [cyan]{sweetviz_result.get('report_path', sweetviz_path)}[/cyan]")
814
+ else:
815
+ console.print(f"[yellow]⚠ Sweetviz compare report generation skipped: {sweetviz_result.get('error', 'Unknown error')}[/yellow]")
816
+
817
+
818
+ @app.command()
819
+ def bi(
820
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
821
+ customer_col: str = typer.Option(None, "--customer", "-c", help="Customer ID column name"),
822
+ date_col: str = typer.Option(None, "--date", "-d", help="Date column name"),
823
+ value_col: str = typer.Option(None, "--value", "-v", help="Revenue/value column name"),
824
+ target_col: str = typer.Option(None, "--target", "-y", help="Optional target column for churn analysis"),
825
+ analysis: str = typer.Option("all", "--analysis", help="Analysis type: all, cohort, rfm, kpi, churn"),
826
+ output: str = typer.Option("./outputs/bi/", "--output", "-o", help="Output directory for BI results")
827
+ ):
828
+ """
829
+ Run business intelligence analytics on a dataset.
830
+
831
+ Example:
832
+ ds-agent bi transactions.csv --customer customer_id --date order_date --value amount
833
+ ds-agent bi customers.csv --customer id --date signup_date --analysis cohort
834
+ """
835
+ from datetime import datetime
836
+ import math
837
+ import pandas as pd
838
+ from ds_agent.utils.polars_helpers import load_dataframe
839
+ import ds_agent.tools.business_intelligence as bi_tools
840
+
841
+ if not Path(file_path).exists():
842
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
843
+ raise typer.Exit(1)
844
+
845
+ analysis = analysis.strip().lower()
846
+ valid_analyses = {"all", "cohort", "rfm", "kpi", "churn"}
847
+ if analysis not in valid_analyses:
848
+ console.print("[red]✗ Error: --analysis must be one of all, cohort, rfm, kpi, churn[/red]")
849
+ raise typer.Exit(1)
850
+
851
+ output_dir = Path(output)
852
+ output_dir.mkdir(parents=True, exist_ok=True)
853
+
854
+ with Progress(
855
+ SpinnerColumn(),
856
+ TextColumn("[progress.description]{task.description}"),
857
+ console=console
858
+ ) as progress:
859
+ task = progress.add_task("Loading dataset...", total=None)
860
+ try:
861
+ df = load_dataframe(file_path)
862
+ except Exception as e:
863
+ console.print(f"[red]✗ Failed to load dataset: {e}[/red]")
864
+ raise typer.Exit(1)
865
+ progress.update(task, completed=True)
866
+
867
+ def _format_metric(value) -> str:
868
+ if value is None:
869
+ return "-"
870
+ if isinstance(value, bool):
871
+ return "Yes" if value else "No"
872
+ if isinstance(value, int):
873
+ return f"{value:,}"
874
+ if isinstance(value, float):
875
+ if math.isnan(value) or math.isinf(value):
876
+ return "-"
877
+ if abs(value) >= 1000:
878
+ return f"{value:,.2f}"
879
+ return f"{value:.4f}"
880
+ return str(value)
881
+
882
+ def _trend_arrow(trend_value) -> str:
883
+ if trend_value is None:
884
+ return "→"
885
+ if isinstance(trend_value, str):
886
+ text = trend_value.strip().lower()
887
+ if any(token in text for token in ["up", "increase", "growing", "higher", "+"]):
888
+ return "↑"
889
+ if any(token in text for token in ["down", "decrease", "decline", "lower", "-"]):
890
+ return "↓"
891
+ return "→"
892
+ try:
893
+ numeric = float(trend_value)
894
+ except Exception:
895
+ return "→"
896
+ if numeric > 0:
897
+ return "↑"
898
+ if numeric < 0:
899
+ return "↓"
900
+ return "→"
901
+
902
+ validation_errors = []
903
+ required_columns = set()
904
+
905
+ def _require_option(value: Optional[str], option_name: str, analysis_name: str):
906
+ if not value:
907
+ validation_errors.append(f"{option_name} is required for '{analysis_name}' analysis")
908
+ else:
909
+ required_columns.add(value)
910
+
911
+ if analysis in {"kpi", "all"}:
912
+ _require_option(date_col, "--date/-d", "kpi")
913
+ _require_option(value_col, "--value/-v", "kpi")
914
+
915
+ if analysis in {"cohort", "all"}:
916
+ _require_option(customer_col, "--customer/-c", "cohort")
917
+ _require_option(date_col, "--date/-d", "cohort")
918
+
919
+ if analysis in {"rfm", "all"}:
920
+ _require_option(customer_col, "--customer/-c", "rfm")
921
+ _require_option(date_col, "--date/-d", "rfm")
922
+ _require_option(value_col, "--value/-v", "rfm")
923
+
924
+ if analysis in {"churn", "all"}:
925
+ _require_option(target_col, "--target/-y", "churn")
926
+
927
+ if validation_errors:
928
+ for err in validation_errors:
929
+ console.print(f"[red]✗ {err}[/red]")
930
+ raise typer.Exit(1)
931
+
932
+ missing_cols = [col for col in sorted(required_columns) if col not in df.columns]
933
+ if missing_cols:
934
+ console.print(f"[red]✗ Required columns not found in dataset: {', '.join(missing_cols)}[/red]")
935
+ raise typer.Exit(1)
936
+
937
+ console.print(Panel.fit(
938
+ f"💼 Business Intelligence Analytics\nAnalysis: {analysis}",
939
+ style="bold blue"
940
+ ))
941
+
942
+ results = {
943
+ "meta": {
944
+ "file_path": file_path,
945
+ "analysis": analysis,
946
+ "rows": len(df),
947
+ "columns": len(df.columns),
948
+ "output_dir": str(output_dir),
949
+ "generated_at": datetime.utcnow().isoformat() + "Z"
950
+ },
951
+ "kpi": None,
952
+ "cohort": None,
953
+ "rfm": None,
954
+ "churn": None
955
+ }
956
+ actionable_insights = []
957
+
958
+ if analysis in {"kpi", "all"}:
959
+ with Progress(
960
+ SpinnerColumn(),
961
+ TextColumn("[progress.description]{task.description}"),
962
+ console=console
963
+ ) as progress:
964
+ task = progress.add_task("Calculating KPI analytics...", total=None)
965
+
966
+ calculate_kpis = getattr(bi_tools, "calculate_kpis", None)
967
+ if callable(calculate_kpis):
968
+ try:
969
+ kpi_result = calculate_kpis(
970
+ data=df,
971
+ customer_id_column=customer_col,
972
+ date_column=date_col,
973
+ value_column=value_col
974
+ )
975
+ except TypeError:
976
+ try:
977
+ kpi_result = calculate_kpis(df, customer_col, date_col, value_col)
978
+ except Exception as e:
979
+ console.print(f"[yellow]⚠ calculate_kpis failed, using fallback KPI computation: {e}[/yellow]")
980
+ kpi_result = None
981
+ else:
982
+ kpi_result = None
983
+
984
+ if kpi_result is None:
985
+ base_kpis = {
986
+ "Total Rows": len(df),
987
+ "Total Revenue": float(df[value_col].sum()) if value_col else None,
988
+ "Average Value": float(df[value_col].mean()) if value_col else None,
989
+ "Unique Customers": int(df[customer_col].n_unique()) if customer_col else None,
990
+ }
991
+
992
+ trend_pct = None
993
+ if date_col and value_col:
994
+ try:
995
+ trend_df = df.select([date_col, value_col]).to_pandas()
996
+ trend_df[date_col] = pd.to_datetime(trend_df[date_col], errors="coerce")
997
+ trend_df[value_col] = pd.to_numeric(trend_df[value_col], errors="coerce")
998
+ trend_df = trend_df.dropna(subset=[date_col, value_col]).sort_values(date_col)
999
+ if len(trend_df) >= 6:
1000
+ mid = len(trend_df) // 2
1001
+ first_avg = trend_df.iloc[:mid][value_col].mean()
1002
+ second_avg = trend_df.iloc[mid:][value_col].mean()
1003
+ if first_avg and not pd.isna(first_avg):
1004
+ trend_pct = float(((second_avg - first_avg) / abs(first_avg)) * 100.0)
1005
+ except Exception:
1006
+ trend_pct = None
1007
+
1008
+ kpi_result = {
1009
+ "kpis": [{"name": k, "value": v, "trend": trend_pct if k in {"Total Revenue", "Average Value"} else None} for k, v in base_kpis.items() if v is not None],
1010
+ "fallback": True
1011
+ }
1012
+
1013
+ progress.update(task, completed=True)
1014
+
1015
+ kpi_table = Table(title="📊 KPI Summary")
1016
+ kpi_table.add_column("KPI", style="cyan")
1017
+ kpi_table.add_column("Value", style="white", justify="right")
1018
+ kpi_table.add_column("Trend", style="white", justify="center")
1019
+
1020
+ kpi_items = []
1021
+ if isinstance(kpi_result, dict) and isinstance(kpi_result.get("kpis"), list):
1022
+ for item in kpi_result.get("kpis", []):
1023
+ if isinstance(item, dict):
1024
+ kpi_items.append({
1025
+ "name": str(item.get("name", "KPI")),
1026
+ "value": item.get("value"),
1027
+ "trend": item.get("trend")
1028
+ })
1029
+
1030
+ if not kpi_items and isinstance(kpi_result, dict):
1031
+ for key, val in kpi_result.items():
1032
+ if isinstance(val, (str, int, float, bool)):
1033
+ kpi_items.append({"name": str(key), "value": val, "trend": None})
1034
+
1035
+ for item in kpi_items:
1036
+ arrow = _trend_arrow(item.get("trend"))
1037
+ arrow_style = "green" if arrow == "↑" else ("red" if arrow == "↓" else "yellow")
1038
+ kpi_table.add_row(
1039
+ item["name"],
1040
+ _format_metric(item.get("value")),
1041
+ f"[{arrow_style}]{arrow}[/{arrow_style}]"
1042
+ )
1043
+
1044
+ console.print()
1045
+ console.print(kpi_table)
1046
+ results["kpi"] = kpi_result
1047
+
1048
+ revenue_item = next((item for item in kpi_items if "revenue" in item["name"].lower()), None)
1049
+ if revenue_item:
1050
+ arrow = _trend_arrow(revenue_item.get("trend"))
1051
+ if arrow == "↓":
1052
+ actionable_insights.append("Revenue momentum is declining; review recent pricing, acquisition channels, or discount strategy.")
1053
+ elif arrow == "↑":
1054
+ actionable_insights.append("Revenue momentum is improving; scale the channels or offers driving the uplift.")
1055
+
1056
+ if analysis in {"cohort", "all"}:
1057
+ with Progress(
1058
+ SpinnerColumn(),
1059
+ TextColumn("[progress.description]{task.description}"),
1060
+ console=console
1061
+ ) as progress:
1062
+ task = progress.add_task("Running cohort retention analysis...", total=None)
1063
+ cohort_result = bi_tools.perform_cohort_analysis(
1064
+ data=df,
1065
+ customer_id_column=customer_col,
1066
+ date_column=date_col,
1067
+ value_column=value_col,
1068
+ cohort_period="monthly",
1069
+ metric="retention"
1070
+ )
1071
+ progress.update(task, completed=True)
1072
+
1073
+ cohort_matrix = cohort_result.get("cohort_matrix", {})
1074
+ period_keys = sorted(cohort_matrix.keys(), key=lambda x: int(x) if str(x).isdigit() else str(x))
1075
+ cohort_labels = sorted({str(label) for inner in cohort_matrix.values() if isinstance(inner, dict) for label in inner.keys()})
1076
+
1077
+ retention_table = Table(title="👥 Cohort Retention Matrix")
1078
+ retention_table.add_column("Cohort", style="bold")
1079
+ for key in period_keys:
1080
+ retention_table.add_column(f"P{key}", justify="right")
1081
+
1082
+ for cohort_label in cohort_labels:
1083
+ row_values = [cohort_label]
1084
+ for key in period_keys:
1085
+ val = None
1086
+ inner = cohort_matrix.get(key, {})
1087
+ if isinstance(inner, dict):
1088
+ val = inner.get(cohort_label)
1089
+
1090
+ if val is None:
1091
+ cell = "-"
1092
+ else:
1093
+ try:
1094
+ rate = float(val)
1095
+ except Exception:
1096
+ cell = str(val)
1097
+ else:
1098
+ pct = rate * 100.0
1099
+ if pct >= 70:
1100
+ cell = f"[green]{pct:.1f}%[/green]"
1101
+ elif pct >= 40:
1102
+ cell = f"[yellow]{pct:.1f}%[/yellow]"
1103
+ else:
1104
+ cell = f"[red]{pct:.1f}%[/red]"
1105
+
1106
+ row_values.append(cell)
1107
+
1108
+ retention_table.add_row(*row_values)
1109
+
1110
+ console.print()
1111
+ console.print(retention_table)
1112
+ results["cohort"] = cohort_result
1113
+
1114
+ for insight in cohort_result.get("insights", [])[:2]:
1115
+ actionable_insights.append(str(insight))
1116
+
1117
+ if analysis in {"rfm", "all"}:
1118
+ with Progress(
1119
+ SpinnerColumn(),
1120
+ TextColumn("[progress.description]{task.description}"),
1121
+ console=console
1122
+ ) as progress:
1123
+ task = progress.add_task("Running RFM segmentation...", total=None)
1124
+
1125
+ rfm_fn = getattr(bi_tools, "rfm_segmentation", None)
1126
+ if callable(rfm_fn):
1127
+ try:
1128
+ rfm_result = rfm_fn(
1129
+ data=df,
1130
+ customer_id_column=customer_col,
1131
+ date_column=date_col,
1132
+ value_column=value_col
1133
+ )
1134
+ except TypeError:
1135
+ rfm_result = rfm_fn(df, customer_col, date_col, value_col)
1136
+ else:
1137
+ rfm_result = bi_tools.perform_rfm_analysis(
1138
+ data=df,
1139
+ customer_id_column=customer_col,
1140
+ date_column=date_col,
1141
+ value_column=value_col
1142
+ )
1143
+
1144
+ progress.update(task, completed=True)
1145
+
1146
+ segment_summary = rfm_result.get("segment_summary", {}) if isinstance(rfm_result, dict) else {}
1147
+ rfm_table = Table(title="🎯 Customer Segments (RFM)")
1148
+ rfm_table.add_column("Segment", style="cyan")
1149
+ rfm_table.add_column("Customers", justify="right")
1150
+ rfm_table.add_column("Share", justify="right")
1151
+ rfm_table.add_column("Avg Monetary", justify="right")
1152
+
1153
+ if isinstance(segment_summary, dict) and segment_summary:
1154
+ sorted_segments = sorted(segment_summary.items(), key=lambda x: x[1].get("count", 0), reverse=True)
1155
+ for segment, stats in sorted_segments:
1156
+ rfm_table.add_row(
1157
+ str(segment),
1158
+ _format_metric(stats.get("count")),
1159
+ f"{float(stats.get('percentage', 0.0)):.1f}%",
1160
+ _format_metric(stats.get("avg_monetary"))
1161
+ )
1162
+ top_segment = sorted_segments[0][0]
1163
+ actionable_insights.append(f"Your largest segment is '{top_segment}'; tailor campaigns specifically for this group to improve conversion.")
1164
+ else:
1165
+ rfm_table.add_row("-", "-", "-", "-")
1166
+
1167
+ console.print()
1168
+ console.print(rfm_table)
1169
+ results["rfm"] = rfm_result
1170
+
1171
+ if analysis in {"churn", "all"} and target_col:
1172
+ with Progress(
1173
+ SpinnerColumn(),
1174
+ TextColumn("[progress.description]{task.description}"),
1175
+ console=console
1176
+ ) as progress:
1177
+ task = progress.add_task("Running churn analysis...", total=None)
1178
+
1179
+ churn_fn = getattr(bi_tools, "churn_prediction", None)
1180
+ if callable(churn_fn):
1181
+ try:
1182
+ churn_result = churn_fn(data=df, target_column=target_col)
1183
+ except TypeError:
1184
+ try:
1185
+ churn_result = churn_fn(df, target_col)
1186
+ except Exception as e:
1187
+ churn_result = {"error": str(e)}
1188
+ else:
1189
+ target_series = df[target_col].drop_nulls()
1190
+ unique_vals = set([str(v).strip().lower() for v in target_series.unique().to_list()])
1191
+ positive_tokens = {"1", "true", "yes", "y", "churn", "churned"}
1192
+ churn_rate = None
1193
+
1194
+ try:
1195
+ if unique_vals.issubset({"0", "1", "true", "false", "yes", "no", "y", "n", "churn", "churned", "active"}):
1196
+ mapped = target_series.cast(str).str.to_lowercase()
1197
+ positives = mapped.is_in(list(positive_tokens)).sum()
1198
+ churn_rate = float(positives / len(mapped)) if len(mapped) else None
1199
+ else:
1200
+ numeric_target = pd.to_numeric(target_series.to_list(), errors="coerce")
1201
+ numeric_target = pd.Series(numeric_target).dropna()
1202
+ if len(numeric_target) and numeric_target.nunique() <= 2:
1203
+ churn_rate = float(numeric_target.mean())
1204
+ except Exception:
1205
+ churn_rate = None
1206
+
1207
+ churn_result = {
1208
+ "status": "fallback",
1209
+ "message": "churn_prediction function not found; returning churn baseline summary.",
1210
+ "target_column": target_col,
1211
+ "rows_evaluated": int(len(target_series)),
1212
+ "churn_rate": churn_rate,
1213
+ "unique_target_values": sorted(list(unique_vals))
1214
+ }
1215
+
1216
+ progress.update(task, completed=True)
1217
+
1218
+ churn_table = Table(title="📉 Churn Analysis")
1219
+ churn_table.add_column("Metric", style="cyan")
1220
+ churn_table.add_column("Value", justify="right")
1221
+
1222
+ if isinstance(churn_result, dict):
1223
+ churn_table.add_row("Target Column", str(churn_result.get("target_column", target_col)))
1224
+ if churn_result.get("churn_rate") is not None:
1225
+ churn_rate_pct = float(churn_result.get("churn_rate")) * 100.0
1226
+ churn_table.add_row("Estimated Churn Rate", f"{churn_rate_pct:.2f}%")
1227
+ if churn_rate_pct >= 30:
1228
+ actionable_insights.append("Estimated churn is high; prioritize retention campaigns for at-risk customers this cycle.")
1229
+ elif churn_rate_pct >= 15:
1230
+ actionable_insights.append("Estimated churn is moderate; test proactive customer success outreach on vulnerable segments.")
1231
+ else:
1232
+ actionable_insights.append("Estimated churn is low; maintain retention programs and monitor leading indicators weekly.")
1233
+ for key in ["status", "message", "rows_evaluated"]:
1234
+ if key in churn_result:
1235
+ churn_table.add_row(key.replace("_", " ").title(), _format_metric(churn_result.get(key)))
1236
+
1237
+ console.print()
1238
+ console.print(churn_table)
1239
+ results["churn"] = churn_result
1240
+
1241
+ timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1242
+ for key in ["kpi", "cohort", "rfm", "churn"]:
1243
+ if results.get(key) is not None:
1244
+ partial_path = output_dir / f"{key}_results_{timestamp}.json"
1245
+ with open(partial_path, "w") as f:
1246
+ json.dump(results[key], f, indent=2, default=str)
1247
+
1248
+ summary_path = output_dir / f"bi_results_{timestamp}.json"
1249
+ with open(summary_path, "w") as f:
1250
+ json.dump(results, f, indent=2, default=str)
1251
+
1252
+ console.print(f"\n💾 BI results saved to: [cyan]{output_dir}[/cyan]")
1253
+ console.print(f"📄 Combined summary: [cyan]{summary_path}[/cyan]")
1254
+
1255
+ deduped_insights = []
1256
+ for insight in actionable_insights:
1257
+ clean = str(insight).strip()
1258
+ if clean and clean not in deduped_insights:
1259
+ deduped_insights.append(clean)
1260
+
1261
+ while len(deduped_insights) < 3:
1262
+ fallback_pool = [
1263
+ "Monitor weekly KPI trends and quickly investigate any sudden drops.",
1264
+ "Segment customers by lifecycle stage and personalize messaging for each segment.",
1265
+ "Use cohort retention patterns to prioritize onboarding improvements in the first customer period."
1266
+ ]
1267
+ for candidate in fallback_pool:
1268
+ if candidate not in deduped_insights:
1269
+ deduped_insights.append(candidate)
1270
+ if len(deduped_insights) >= 3:
1271
+ break
1272
+
1273
+ top3 = deduped_insights[:3]
1274
+ insights_text = "\n".join([f"{idx}. {insight}" for idx, insight in enumerate(top3, start=1)])
1275
+ console.print()
1276
+ console.print(Panel(
1277
+ insights_text,
1278
+ title="💡 Top 3 Business Insights",
1279
+ border_style="green"
1280
+ ))
1281
+
1282
+
1283
+ @app.command()
1284
+ def chat(
1285
+ file_path: Optional[str] = typer.Argument(None, help="Optional dataset path to auto-load at startup"),
1286
+ session_id: str = typer.Option(None, "--session", "-s", help="Resume a specific session ID"),
1287
+ model: str = typer.Option("llama-3.3-70b-versatile", "--model", help="LLM model to use")
1288
+ ):
1289
+ """
1290
+ Launch an interactive chat session for follow-up questions on datasets.
1291
+
1292
+ Example:
1293
+ ds-agent chat sales.csv
1294
+ ds-agent chat --session abc123
1295
+ """
1296
+ from datetime import datetime
1297
+ from rich.prompt import Prompt
1298
+ from rich.markdown import Markdown
1299
+ from ds_agent.session_store import SessionStore
1300
+ from ds_agent.session_memory import SessionMemory
1301
+ from ds_agent.tools.data_profiling import profile_dataset
1302
+
1303
+ console.print(Panel.fit(
1304
+ "DS-Agent Chat — type 'help' for commands, 'exit' to quit",
1305
+ title="💬 Chat",
1306
+ border_style="blue"
1307
+ ))
1308
+
1309
+ store = SessionStore()
1310
+ save_session_fn = getattr(store, "save_session", None) or getattr(store, "save", None)
1311
+ load_session_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
1312
+
1313
+ if session_id:
1314
+ session_memory = load_session_fn(session_id) if callable(load_session_fn) else None
1315
+ if session_memory is None:
1316
+ console.print(f"[yellow]⚠ Session '{session_id}' not found. Creating a new one.[/yellow]")
1317
+ session_memory = SessionMemory(session_id=session_id)
1318
+ else:
1319
+ session_memory = SessionMemory()
1320
+
1321
+ current_dataset = file_path or session_memory.last_dataset
1322
+ last_analysis_result = None
1323
+
1324
+ def _print_dataset_summary(profile: dict, path: str):
1325
+ shape = profile.get("shape", {})
1326
+ col_types = profile.get("column_types", {})
1327
+ overall = profile.get("overall_stats", {})
1328
+
1329
+ summary_table = Table(title=f"📊 Dataset Summary: {Path(path).name}", show_header=False)
1330
+ summary_table.add_column("Metric", style="cyan", width=26)
1331
+ summary_table.add_column("Value", style="white")
1332
+
1333
+ summary_table.add_row("File", str(path))
1334
+ summary_table.add_row("Rows", str(shape.get("rows", "-")))
1335
+ summary_table.add_row("Columns", str(shape.get("columns", "-")))
1336
+ summary_table.add_row("Numeric Columns", str(len(col_types.get("numeric", []))))
1337
+ summary_table.add_row("Categorical Columns", str(len(col_types.get("categorical", []))))
1338
+ summary_table.add_row("Datetime Columns", str(len(col_types.get("datetime", []))))
1339
+ summary_table.add_row("Null %", f"{overall.get('null_percentage', 0)}%")
1340
+ summary_table.add_row("Duplicate Rows", str(overall.get("duplicate_rows", 0)))
1341
+
1342
+ console.print()
1343
+ console.print(summary_table)
1344
+
1345
+ def _load_and_profile_dataset(path: str) -> bool:
1346
+ nonlocal current_dataset
1347
+ if not Path(path).exists():
1348
+ console.print(f"[red]✗ File not found: {path}[/red]")
1349
+ return False
1350
+
1351
+ with Progress(
1352
+ SpinnerColumn(),
1353
+ TextColumn("[progress.description]{task.description}"),
1354
+ console=console
1355
+ ) as progress:
1356
+ task = progress.add_task(f"Profiling dataset: {path}", total=None)
1357
+ try:
1358
+ profile = profile_dataset(path)
1359
+ except Exception as e:
1360
+ console.print(f"[red]✗ Failed to profile dataset: {e}[/red]")
1361
+ return False
1362
+ progress.update(task, completed=True)
1363
+
1364
+ current_dataset = path
1365
+ session_memory.update(last_dataset=path)
1366
+ session_memory.add_workflow_step(
1367
+ "profile_dataset",
1368
+ {
1369
+ "success": True,
1370
+ "arguments": {"file_path": path},
1371
+ "result": {
1372
+ "output_path": path,
1373
+ "shape": profile.get("shape", {})
1374
+ }
1375
+ }
1376
+ )
1377
+ _print_dataset_summary(profile, path)
1378
+ return True
1379
+
1380
+ if current_dataset:
1381
+ _load_and_profile_dataset(current_dataset)
1382
+
1383
+ try:
1384
+ os.environ["GROQ_MODEL"] = model
1385
+ from ds_agent.orchestrator import DataScienceCopilot
1386
+
1387
+ copilot = DataScienceCopilot(
1388
+ provider="groq",
1389
+ session_id=session_memory.session_id,
1390
+ use_session_memory=False
1391
+ )
1392
+ except Exception as e:
1393
+ console.print(f"[red]✗ Failed to initialize orchestrator: {e}[/red]")
1394
+ if callable(save_session_fn):
1395
+ save_session_fn(session_memory)
1396
+ raise typer.Exit(1)
1397
+
1398
+ help_panel = Panel(
1399
+ "\n".join([
1400
+ "help Show commands",
1401
+ "history Show workflow history",
1402
+ "summary Re-print current dataset summary",
1403
+ "load <file> Load/switch dataset",
1404
+ "clear Clear current session history",
1405
+ "export Export current session/results",
1406
+ "exit or quit Save session and exit",
1407
+ ]),
1408
+ title="Available Commands",
1409
+ border_style="cyan"
1410
+ )
1411
+
1412
+ while True:
1413
+ user_message = Prompt.ask(f"[bold cyan][{session_memory.session_id}][/bold cyan] > ").strip()
1414
+ if not user_message:
1415
+ continue
1416
+
1417
+ lowered = user_message.lower().strip()
1418
+
1419
+ if lowered in {"exit", "quit"}:
1420
+ if callable(save_session_fn):
1421
+ save_session_fn(session_memory)
1422
+ console.print("[green]✓ Session saved. Goodbye![/green]")
1423
+ break
1424
+
1425
+ if lowered == "help":
1426
+ console.print(help_panel)
1427
+ continue
1428
+
1429
+ if lowered == "history":
1430
+ history = session_memory.workflow_history
1431
+ if not history:
1432
+ console.print("[yellow]No workflow history yet.[/yellow]")
1433
+ continue
1434
+
1435
+ history_table = Table(title="🧭 Workflow History")
1436
+ history_table.add_column("Time", style="cyan")
1437
+ history_table.add_column("Tool", style="white")
1438
+ history_table.add_column("Status", style="white")
1439
+
1440
+ for step in history[-20:]:
1441
+ ts = str(step.get("timestamp", ""))
1442
+ ts_short = ts.split("T")[-1][:8] if "T" in ts else ts[:8]
1443
+ tool = str(step.get("tool", "-"))
1444
+ result = step.get("result", {})
1445
+ success = False
1446
+ if isinstance(result, dict):
1447
+ success = bool(result.get("success", False))
1448
+ status = "[green]✓[/green]" if success else "[red]✗[/red]"
1449
+ history_table.add_row(ts_short, tool, status)
1450
+
1451
+ console.print()
1452
+ console.print(history_table)
1453
+ continue
1454
+
1455
+ if lowered == "summary":
1456
+ if not current_dataset:
1457
+ console.print("[yellow]No dataset is loaded. Use 'load <file>'.[/yellow]")
1458
+ continue
1459
+ _load_and_profile_dataset(current_dataset)
1460
+ continue
1461
+
1462
+ if lowered == "clear":
1463
+ # Keep currently loaded dataset but clear conversational/workflow memory.
1464
+ last_dataset = current_dataset
1465
+ session_memory.clear()
1466
+ if last_dataset:
1467
+ session_memory.update(last_dataset=last_dataset)
1468
+ console.print("[green]✓ Session history cleared.[/green]")
1469
+ continue
1470
+
1471
+ if lowered == "export":
1472
+ export_dir = Path("./outputs/chat/")
1473
+ export_dir.mkdir(parents=True, exist_ok=True)
1474
+ ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
1475
+ export_path = export_dir / f"chat_session_{session_memory.session_id}_{ts}.json"
1476
+
1477
+ export_payload = {
1478
+ "session": session_memory.to_dict(),
1479
+ "last_result": last_analysis_result,
1480
+ "current_dataset": current_dataset,
1481
+ "exported_at": datetime.utcnow().isoformat() + "Z"
1482
+ }
1483
+
1484
+ with open(export_path, "w") as f:
1485
+ json.dump(export_payload, f, indent=2, default=str)
1486
+
1487
+ console.print(f"[green]✓ Exported session to:[/green] [cyan]{export_path}[/cyan]")
1488
+ continue
1489
+
1490
+ if lowered.startswith("load "):
1491
+ load_path = user_message[5:].strip().strip('"').strip("'")
1492
+ if not load_path:
1493
+ console.print("[yellow]Usage: load <file_path>[/yellow]")
1494
+ continue
1495
+ _load_and_profile_dataset(load_path)
1496
+ continue
1497
+
1498
+ resolver = getattr(session_memory, "resolve_context", None) or getattr(session_memory, "resolve_ambiguity", None)
1499
+ resolved_context = resolver(user_message) if callable(resolver) else {}
1500
+
1501
+ effective_dataset = resolved_context.get("file_path") or current_dataset or session_memory.last_dataset
1502
+ effective_target = resolved_context.get("target_col") or session_memory.last_target_col
1503
+
1504
+ if not effective_dataset:
1505
+ console.print("[yellow]No dataset loaded. Use 'load <file>' or pass file_path at startup.[/yellow]")
1506
+ continue
1507
+
1508
+ with Progress(
1509
+ SpinnerColumn(),
1510
+ TextColumn("[progress.description]{task.description}"),
1511
+ console=console
1512
+ ) as progress:
1513
+ task = progress.add_task("Thinking...", total=None)
1514
+ try:
1515
+ analysis_result = copilot.analyze(
1516
+ file_path=effective_dataset,
1517
+ task_description=user_message,
1518
+ target_col=effective_target,
1519
+ use_cache=True,
1520
+ stream=False,
1521
+ max_iterations=20
1522
+ )
1523
+ except Exception as e:
1524
+ progress.update(task, completed=True)
1525
+ console.print(f"[red]✗ Chat analysis failed: {e}[/red]")
1526
+ continue
1527
+ progress.update(task, completed=True)
1528
+
1529
+ last_analysis_result = analysis_result
1530
+ session_memory.update(last_dataset=effective_dataset)
1531
+ if effective_target:
1532
+ session_memory.update(last_target_col=effective_target)
1533
+
1534
+ workflow_steps = analysis_result.get("workflow_history", []) if isinstance(analysis_result, dict) else []
1535
+ for step in workflow_steps:
1536
+ if not isinstance(step, dict):
1537
+ continue
1538
+ tool_name = step.get("tool", "unknown_tool")
1539
+ tool_result = step.get("result", {}) if isinstance(step.get("result", {}), dict) else {"raw_result": step.get("result")}
1540
+ session_memory.add_workflow_step(
1541
+ tool_name,
1542
+ {
1543
+ "success": tool_result.get("success", True),
1544
+ "arguments": {
1545
+ "file_path": effective_dataset,
1546
+ "target_col": effective_target
1547
+ },
1548
+ "result": tool_result
1549
+ }
1550
+ )
1551
+
1552
+ response_text = ""
1553
+ if isinstance(analysis_result, dict):
1554
+ response_text = analysis_result.get("summary") or analysis_result.get("message") or "Analysis complete."
1555
+ else:
1556
+ response_text = str(analysis_result)
1557
+
1558
+ session_memory.add_conversation(user_message, response_text)
1559
+ console.print()
1560
+ console.print(Markdown(response_text))
1561
+
1562
+
1563
+ @app.command()
1564
+ def plot(
1565
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
1566
+ query: Optional[str] = typer.Argument(None, help="Optional plot request like 'correlation heatmap' or 'revenue by region'"),
1567
+ chart_type: str = typer.Option("auto", "--type", "-t", help="Chart type: auto, bar, line, histogram, scatter, heatmap, box, distribution"),
1568
+ x: str = typer.Option(None, "--x", help="X-axis column name"),
1569
+ y: str = typer.Option(None, "--y", help="Y-axis column name"),
1570
+ target: str = typer.Option(None, "--target", help="Target/color-by column name"),
1571
+ output: str = typer.Option("./outputs/plots/", "--output", "-o", help="Output directory for generated plots")
1572
+ ):
1573
+ """
1574
+ Generate visualization charts from a dataset.
1575
+
1576
+ Example:
1577
+ ds-agent plot sales.csv "top products bar chart" --y revenue
1578
+ ds-agent plot sales.csv --type heatmap
1579
+ ds-agent plot sales.csv --type histogram --x revenue
1580
+ """
1581
+ from ds_agent.tools.visualization_engine import generate_all_plots
1582
+ from ds_agent.tools.plotly_visualizations import (
1583
+ generate_interactive_correlation_heatmap,
1584
+ generate_interactive_histogram,
1585
+ generate_interactive_scatter,
1586
+ generate_interactive_box_plots,
1587
+ generate_interactive_time_series,
1588
+ )
1589
+ from ds_agent.utils.polars_helpers import load_dataframe, get_numeric_columns
1590
+
1591
+ console.print(Panel.fit(
1592
+ "📉 Plot Generator",
1593
+ style="bold blue"
1594
+ ))
1595
+
1596
+ if not Path(file_path).exists():
1597
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
1598
+ raise typer.Exit(1)
1599
+
1600
+ valid_types = {"auto", "bar", "line", "histogram", "scatter", "heatmap", "box", "distribution"}
1601
+ chart_type = chart_type.strip().lower()
1602
+ if chart_type not in valid_types:
1603
+ console.print(f"[red]✗ Error: Invalid chart type '{chart_type}'.[/red]")
1604
+ raise typer.Exit(1)
1605
+
1606
+ output_dir = Path(output)
1607
+ output_dir.mkdir(parents=True, exist_ok=True)
1608
+
1609
+ # Load once for safe column inference when user does not pass x/y.
1610
+ try:
1611
+ df = load_dataframe(file_path)
1612
+ numeric_cols = get_numeric_columns(df)
1613
+ all_cols = list(df.columns)
1614
+ except Exception as e:
1615
+ console.print(f"[red]✗ Could not load dataset for plotting: {e}[/red]")
1616
+ raise typer.Exit(1)
1617
+
1618
+ if x is None and all_cols:
1619
+ x = all_cols[0]
1620
+ if y is None and numeric_cols:
1621
+ y = numeric_cols[0] if x not in numeric_cols else (numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0])
1622
+
1623
+ query_text = (query or "").lower()
1624
+
1625
+ selected_mode = "all"
1626
+ if "correlation" in query_text or "heatmap" in query_text:
1627
+ selected_mode = "heatmap"
1628
+ elif "distribution" in query_text or "histogram" in query_text:
1629
+ selected_mode = "histogram"
1630
+ elif "scatter" in query_text:
1631
+ selected_mode = "scatter"
1632
+ elif "box" in query_text:
1633
+ selected_mode = "box"
1634
+ elif "time" in query_text or "trend" in query_text:
1635
+ selected_mode = "time_series"
1636
+ elif chart_type != "auto":
1637
+ if chart_type == "heatmap":
1638
+ selected_mode = "heatmap"
1639
+ elif chart_type in {"histogram", "distribution"}:
1640
+ selected_mode = "histogram"
1641
+ elif chart_type == "scatter":
1642
+ selected_mode = "scatter"
1643
+ elif chart_type == "box":
1644
+ selected_mode = "box"
1645
+ elif chart_type == "line":
1646
+ selected_mode = "time_series"
1647
+ else:
1648
+ selected_mode = "all"
1649
+
1650
+ console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
1651
+ if query:
1652
+ console.print(f"📝 [bold]Query:[/bold] {query}")
1653
+ console.print(f"🎯 [bold]Mode:[/bold] {selected_mode}")
1654
+
1655
+ plot_paths = []
1656
+
1657
+ try:
1658
+ with Progress(
1659
+ SpinnerColumn(),
1660
+ TextColumn("[progress.description]{task.description}"),
1661
+ console=console
1662
+ ) as progress:
1663
+ task_plot = progress.add_task("Generating visualizations...", total=None)
1664
+
1665
+ if selected_mode == "heatmap":
1666
+ result = generate_interactive_correlation_heatmap(
1667
+ file_path,
1668
+ output_path=str(output_dir / "correlation_heatmap.html")
1669
+ )
1670
+ if result.get("status") == "success":
1671
+ plot_paths.append(result.get("output_path"))
1672
+ else:
1673
+ raise ValueError(result.get("message", "Failed to generate heatmap"))
1674
+
1675
+ elif selected_mode == "histogram":
1676
+ hist_col = x if x in all_cols else (y if y in all_cols else (numeric_cols[0] if numeric_cols else None))
1677
+ if not hist_col:
1678
+ raise ValueError("No suitable numeric column found for histogram")
1679
+ result = generate_interactive_histogram(
1680
+ file_path,
1681
+ column=hist_col,
1682
+ color_col=target,
1683
+ output_path=str(output_dir / "histogram.html")
1684
+ )
1685
+ if result.get("status") == "success":
1686
+ plot_paths.append(result.get("output_path"))
1687
+ else:
1688
+ raise ValueError(result.get("message", "Failed to generate histogram"))
1689
+
1690
+ elif selected_mode == "scatter":
1691
+ if not x or not y:
1692
+ raise ValueError("Scatter plot requires x and y columns (inferred values unavailable)")
1693
+ result = generate_interactive_scatter(
1694
+ file_path,
1695
+ x_col=x,
1696
+ y_col=y,
1697
+ color_col=target,
1698
+ output_path=str(output_dir / "scatter.html")
1699
+ )
1700
+ if result.get("status") == "success":
1701
+ plot_paths.append(result.get("output_path"))
1702
+ else:
1703
+ raise ValueError(result.get("message", "Failed to generate scatter plot"))
1704
+
1705
+ elif selected_mode == "box":
1706
+ result = generate_interactive_box_plots(
1707
+ file_path,
1708
+ columns=[col for col in [x, y] if col],
1709
+ group_by=target,
1710
+ output_path=str(output_dir / "box_plots.html")
1711
+ )
1712
+ if result.get("status") == "success":
1713
+ plot_paths.append(result.get("output_path"))
1714
+ else:
1715
+ raise ValueError(result.get("message", "Failed to generate box plot"))
1716
+
1717
+ elif selected_mode == "time_series":
1718
+ if not x:
1719
+ raise ValueError("Time series requires --x as time column")
1720
+ value_cols = [col for col in [y, target] if col and col != x]
1721
+ if not value_cols:
1722
+ value_cols = [col for col in numeric_cols if col != x][:1]
1723
+ if not value_cols:
1724
+ raise ValueError("Time series requires at least one numeric value column")
1725
+ result = generate_interactive_time_series(
1726
+ file_path,
1727
+ time_col=x,
1728
+ value_cols=value_cols,
1729
+ output_path=str(output_dir / "time_series.html")
1730
+ )
1731
+ if result.get("status") == "success":
1732
+ plot_paths.append(result.get("output_path"))
1733
+ else:
1734
+ raise ValueError(result.get("message", "Failed to generate time series"))
1735
+
1736
+ else:
1737
+ result = generate_all_plots(
1738
+ file_path,
1739
+ target_col=target,
1740
+ output_dir=str(output_dir)
1741
+ )
1742
+ plot_paths.extend(result.get("plots_generated", []))
1743
+
1744
+ progress.update(task_plot, completed=True)
1745
+
1746
+ except Exception as e:
1747
+ console.print(f"\n[red]✗ Plot generation failed: {e}[/red]")
1748
+ raise typer.Exit(1)
1749
+
1750
+ plot_paths = [p for p in plot_paths if p]
1751
+ if not plot_paths:
1752
+ console.print("[yellow]⚠ No plots were generated.[/yellow]")
1753
+ raise typer.Exit(1)
1754
+
1755
+ console.print("\n[green]✓ Plots generated:[/green]")
1756
+ for path in plot_paths:
1757
+ console.print(f" - [cyan]{path}[/cyan]")
1758
+
1759
+ summary = Table(title="📊 Plot Summary")
1760
+ summary.add_column("#", style="cyan", width=4)
1761
+ summary.add_column("Plot File", style="white")
1762
+ summary.add_column("Path", style="green")
1763
+
1764
+ if len(plot_paths) > 1:
1765
+ with Live(summary, console=console, refresh_per_second=8):
1766
+ for idx, path in enumerate(plot_paths, start=1):
1767
+ summary.add_row(str(idx), Path(path).name, path)
1768
+ else:
1769
+ summary.add_row("1", Path(plot_paths[0]).name, plot_paths[0])
1770
+
1771
+ console.print()
1772
+ console.print(summary)
1773
+
1774
+
1775
+ @app.command()
1776
+ def forecast(
1777
+ file_path: str = typer.Argument(..., help="Path to time series dataset file (CSV or Parquet)"),
1778
+ time_col: str = typer.Option(None, "--time", "-t", help="Time column name (required)"),
1779
+ target_col: str = typer.Option(None, "--target", "-y", help="Target column to forecast (required)"),
1780
+ horizon: int = typer.Option(30, "--horizon", "-h", help="Forecast horizon (periods ahead)"),
1781
+ method: str = typer.Option("prophet", "--method", "-m", help="Forecasting method: prophet, arima, exponential_smoothing"),
1782
+ output: str = typer.Option("./outputs/forecast/", "--output", "-o", help="Output directory for forecast files")
1783
+ ):
1784
+ """
1785
+ Forecast a time series target column.
1786
+
1787
+ Example:
1788
+ ds-agent forecast sales.csv --time date --target revenue --horizon 90
1789
+ ds-agent forecast stock.csv --time Date --target Close --method arima
1790
+ """
1791
+ from ds_agent.tools.time_series import forecast_time_series
1792
+ import csv
1793
+
1794
+ # detect_seasonality is requested by name; fallback keeps compatibility with current tools module.
1795
+ try:
1796
+ from ds_agent.tools.time_series import detect_seasonality
1797
+ except ImportError:
1798
+ from ds_agent.tools.time_series import detect_seasonality_trends as detect_seasonality
1799
+
1800
+ console.print(Panel.fit(
1801
+ "⏳ Time Series Forecast",
1802
+ style="bold blue"
1803
+ ))
1804
+
1805
+ # Step 1: Validation
1806
+ if not Path(file_path).exists():
1807
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
1808
+ raise typer.Exit(1)
1809
+
1810
+ if not time_col:
1811
+ console.print("[red]✗ Error: --time (-t) is required[/red]")
1812
+ raise typer.Exit(1)
1813
+
1814
+ if not target_col:
1815
+ console.print("[red]✗ Error: --target (-y) is required[/red]")
1816
+ raise typer.Exit(1)
1817
+
1818
+ if horizon <= 0:
1819
+ console.print("[red]✗ Error: --horizon must be a positive integer[/red]")
1820
+ raise typer.Exit(1)
1821
+
1822
+ method_normalized = method.strip().lower()
1823
+ valid_methods = {"prophet", "arima", "exponential_smoothing"}
1824
+ if method_normalized not in valid_methods:
1825
+ console.print("[red]✗ Error: --method must be one of prophet, arima, exponential_smoothing[/red]")
1826
+ raise typer.Exit(1)
1827
+
1828
+ output_dir = Path(output)
1829
+ output_dir.mkdir(parents=True, exist_ok=True)
1830
+ output_csv = output_dir / f"forecast_{Path(file_path).stem}_{target_col}_{method_normalized}.csv"
1831
+
1832
+ console.print(f"\n📊 [bold]Dataset:[/bold] {file_path}")
1833
+ console.print(f"🕒 [bold]Time Column:[/bold] {time_col}")
1834
+ console.print(f"🎯 [bold]Target Column:[/bold] {target_col}")
1835
+ console.print(f"📈 [bold]Method:[/bold] {method_normalized}")
1836
+ console.print(f"🔭 [bold]Horizon:[/bold] {horizon}")
1837
+
1838
+ # Step 2: Seasonality detection
1839
+ try:
1840
+ with Progress(
1841
+ SpinnerColumn(),
1842
+ TextColumn("[progress.description]{task.description}"),
1843
+ console=console
1844
+ ) as progress:
1845
+ task_seasonality = progress.add_task("Detecting seasonality...", total=None)
1846
+ seasonality_result = detect_seasonality(file_path, time_col, target_col)
1847
+ progress.update(task_seasonality, completed=True)
1848
+ except Exception as e:
1849
+ seasonality_result = {"status": "error", "message": str(e)}
1850
+
1851
+ if seasonality_result.get("status") == "success":
1852
+ detected_period = seasonality_result.get("detected_period", "unknown")
1853
+ interpretation = seasonality_result.get("interpretation", {})
1854
+ seasonality_label = interpretation.get("seasonality", "unknown")
1855
+ console.print(
1856
+ f"[green]✓ Seasonality detected:[/green] {seasonality_label} (period={detected_period})"
1857
+ )
1858
+ else:
1859
+ console.print(
1860
+ f"[yellow]⚠ Seasonality detection unavailable:[/yellow] {seasonality_result.get('message', 'Unknown issue')}"
1861
+ )
1862
+
1863
+ # Step 3: Forecast generation
1864
+ try:
1865
+ with Progress(
1866
+ SpinnerColumn(),
1867
+ TextColumn("[progress.description]{task.description}"),
1868
+ console=console
1869
+ ) as progress:
1870
+ task_forecast = progress.add_task("Generating forecast...", total=None)
1871
+ forecast_result = forecast_time_series(
1872
+ file_path=file_path,
1873
+ time_col=time_col,
1874
+ target_col=target_col,
1875
+ forecast_horizon=horizon,
1876
+ method=method_normalized,
1877
+ output_path=str(output_csv)
1878
+ )
1879
+ progress.update(task_forecast, completed=True)
1880
+ except Exception as e:
1881
+ console.print(f"[red]✗ Forecast failed: {e}[/red]")
1882
+ raise typer.Exit(1)
1883
+
1884
+ if forecast_result.get("status") != "success":
1885
+ console.print(f"[red]✗ Forecast failed: {forecast_result.get('message', 'Unknown error')}[/red]")
1886
+ raise typer.Exit(1)
1887
+
1888
+ forecast_rows = forecast_result.get("forecast", [])
1889
+ if not isinstance(forecast_rows, list) or not forecast_rows:
1890
+ console.print("[yellow]⚠ Forecast completed but returned no rows.[/yellow]")
1891
+ raise typer.Exit(1)
1892
+
1893
+ # Normalize row fields for display and CSV export.
1894
+ normalized_rows = []
1895
+ for row in forecast_rows:
1896
+ date_val = row.get("date") or row.get("ds") or row.get("time")
1897
+ actual_val = row.get("actual") or row.get("y")
1898
+ predicted_val = row.get("predicted") or row.get("yhat") or row.get("value")
1899
+ lower_val = row.get("lower_bound") or row.get("yhat_lower") or row.get("lower_ci")
1900
+ upper_val = row.get("upper_bound") or row.get("yhat_upper") or row.get("upper_ci")
1901
+
1902
+ normalized_rows.append({
1903
+ "date": date_val,
1904
+ "actual": actual_val,
1905
+ "predicted": predicted_val,
1906
+ "lower_bound": lower_val,
1907
+ "upper_bound": upper_val,
1908
+ })
1909
+
1910
+ # Step 4: Forecast results table
1911
+ table = Table(title="📊 Forecast Results")
1912
+ table.add_column("Date", style="cyan")
1913
+ table.add_column("Actual", justify="right")
1914
+ table.add_column("Predicted", justify="right", style="green")
1915
+ table.add_column("Lower Bound", justify="right")
1916
+ table.add_column("Upper Bound", justify="right")
1917
+
1918
+ for row in normalized_rows:
1919
+ actual_display = "-" if row["actual"] is None else f"{float(row['actual']):.4f}"
1920
+ predicted_display = "-" if row["predicted"] is None else f"{float(row['predicted']):.4f}"
1921
+ lower_display = "-" if row["lower_bound"] is None else f"{float(row['lower_bound']):.4f}"
1922
+ upper_display = "-" if row["upper_bound"] is None else f"{float(row['upper_bound']):.4f}"
1923
+
1924
+ table.add_row(
1925
+ str(row["date"]),
1926
+ actual_display,
1927
+ predicted_display,
1928
+ lower_display,
1929
+ upper_display,
1930
+ )
1931
+
1932
+ console.print()
1933
+ console.print(table)
1934
+
1935
+ # Step 5: Save normalized forecast CSV in output directory.
1936
+ with open(output_csv, "w", newline="") as csv_file:
1937
+ writer = csv.DictWriter(
1938
+ csv_file,
1939
+ fieldnames=["date", "actual", "predicted", "lower_bound", "upper_bound"]
1940
+ )
1941
+ writer.writeheader()
1942
+ writer.writerows(normalized_rows)
1943
+
1944
+ console.print(f"\n💾 Forecast CSV saved to: [cyan]{output_csv}[/cyan]")
1945
+
1946
+ # Step 6: Print model accuracy metrics if available.
1947
+ metrics = {}
1948
+ for metric_key in ["mae", "rmse", "mape"]:
1949
+ if metric_key in forecast_result:
1950
+ metrics[metric_key] = forecast_result.get(metric_key)
1951
+
1952
+ nested_metrics = forecast_result.get("metrics")
1953
+ if isinstance(nested_metrics, dict):
1954
+ for metric_key in ["mae", "rmse", "mape"]:
1955
+ if metric_key in nested_metrics:
1956
+ metrics[metric_key] = nested_metrics.get(metric_key)
1957
+
1958
+ if metrics:
1959
+ metrics_table = Table(title="📏 Forecast Accuracy")
1960
+ metrics_table.add_column("Metric", style="cyan")
1961
+ metrics_table.add_column("Value", style="white", justify="right")
1962
+
1963
+ if metrics.get("mae") is not None:
1964
+ metrics_table.add_row("MAE", f"{float(metrics['mae']):.4f}")
1965
+ if metrics.get("rmse") is not None:
1966
+ metrics_table.add_row("RMSE", f"{float(metrics['rmse']):.4f}")
1967
+ if metrics.get("mape") is not None:
1968
+ metrics_table.add_row("MAPE", f"{float(metrics['mape']):.2f}%")
1969
+
1970
+ console.print()
1971
+ console.print(metrics_table)
1972
+ else:
1973
+ console.print("[yellow]ℹ Accuracy metrics (MAE/RMSE/MAPE) were not returned by this model run.[/yellow]")
1974
+
1975
+
1976
+ @app.command()
1977
+ def nlp(
1978
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
1979
+ text_col: str = typer.Option(None, "--text", "-t", help="Text column name (required)"),
1980
+ task: str = typer.Option("all", "--task", help="NLP task: all, topics, sentiment, entities, classify"),
1981
+ n_topics: int = typer.Option(5, "--n-topics", help="Number of topics for topic modeling"),
1982
+ output: str = typer.Option("./outputs/nlp/", "--output", "-o", help="Output directory for NLP artifacts")
1983
+ ):
1984
+ """
1985
+ Run NLP analysis tasks such as sentiment, topics, and entity extraction.
1986
+
1987
+ Example:
1988
+ ds-agent nlp reviews.csv --text review_text --task sentiment
1989
+ ds-agent nlp news.csv --text content --task topics --n-topics 8
1990
+ ds-agent nlp articles.csv --text body
1991
+ """
1992
+ from ds_agent.utils.polars_helpers import load_dataframe
1993
+ from ds_agent.tools.nlp_text_analytics import (
1994
+ analyze_sentiment_advanced as sentiment_analysis,
1995
+ perform_topic_modeling,
1996
+ perform_named_entity_recognition as extract_entities,
1997
+ )
1998
+
1999
+ console.print(Panel.fit(
2000
+ "📝 NLP Text Analytics",
2001
+ style="bold blue"
2002
+ ))
2003
+
2004
+ # Step 1: Validate file and text column.
2005
+ if not Path(file_path).exists():
2006
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
2007
+ raise typer.Exit(1)
2008
+
2009
+ if not text_col:
2010
+ console.print("[red]✗ Error: --text (-t) is required[/red]")
2011
+ raise typer.Exit(1)
2012
+
2013
+ valid_tasks = {"all", "topics", "sentiment", "entities", "classify"}
2014
+ task = task.strip().lower()
2015
+ if task not in valid_tasks:
2016
+ console.print("[red]✗ Error: --task must be one of all, topics, sentiment, entities, classify[/red]")
2017
+ raise typer.Exit(1)
2018
+
2019
+ if n_topics <= 0:
2020
+ console.print("[red]✗ Error: --n-topics must be a positive integer[/red]")
2021
+ raise typer.Exit(1)
2022
+
2023
+ try:
2024
+ df = load_dataframe(file_path)
2025
+ except Exception as e:
2026
+ console.print(f"[red]✗ Error loading dataset: {e}[/red]")
2027
+ raise typer.Exit(1)
2028
+
2029
+ if text_col not in df.columns:
2030
+ console.print(f"[red]✗ Error: Text column '{text_col}' not found in dataset[/red]")
2031
+ raise typer.Exit(1)
2032
+
2033
+ output_dir = Path(output)
2034
+ output_dir.mkdir(parents=True, exist_ok=True)
2035
+
2036
+ # Step 2: Preview first 5 rows of text column.
2037
+ preview_table = Table(title="🔎 Text Preview (first 5 rows)")
2038
+ preview_table.add_column("Row", style="cyan", width=6)
2039
+ preview_table.add_column(text_col, style="white")
2040
+
2041
+ preview_values = df[text_col].head(5).to_list()
2042
+ for idx, value in enumerate(preview_values, start=1):
2043
+ text_value = "" if value is None else str(value)
2044
+ text_value = text_value.replace("\n", " ").strip()
2045
+ if len(text_value) > 140:
2046
+ text_value = text_value[:137] + "..."
2047
+ preview_table.add_row(str(idx), text_value)
2048
+
2049
+ console.print()
2050
+ console.print(preview_table)
2051
+
2052
+ sentiment_result = None
2053
+ topics_result = None
2054
+ entities_result = None
2055
+ saved_files = []
2056
+
2057
+ # Step 3A: Sentiment analysis
2058
+ if task in {"all", "sentiment"}:
2059
+ with Progress(
2060
+ SpinnerColumn(),
2061
+ TextColumn("[progress.description]{task.description}"),
2062
+ console=console
2063
+ ) as progress:
2064
+ sentiment_task = progress.add_task("Running sentiment analysis...", total=None)
2065
+ sentiment_result = sentiment_analysis(df, text_col)
2066
+ progress.update(sentiment_task, completed=True)
2067
+
2068
+ sentiment_dist = sentiment_result.get("statistics", {}).get("sentiment_distribution")
2069
+ if not sentiment_dist:
2070
+ sentiments = sentiment_result.get("sentiments", [])
2071
+ labels = [str(item.get("label", "NEUTRAL")).upper() for item in sentiments]
2072
+ total = len(labels)
2073
+ sentiment_dist = {
2074
+ "POSITIVE": sum(1 for label in labels if "POS" in label),
2075
+ "NEGATIVE": sum(1 for label in labels if "NEG" in label),
2076
+ "NEUTRAL": sum(1 for label in labels if "NEU" in label),
2077
+ } if total > 0 else {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0}
2078
+
2079
+ total_count = sum(sentiment_dist.values()) if isinstance(sentiment_dist, dict) else 0
2080
+
2081
+ sentiment_table = Table(title="😊 Sentiment Distribution")
2082
+ sentiment_table.add_column("Label", style="cyan")
2083
+ sentiment_table.add_column("Count", style="white", justify="right")
2084
+ sentiment_table.add_column("Percentage", style="green", justify="right")
2085
+
2086
+ for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]:
2087
+ count = int(sentiment_dist.get(label, 0)) if isinstance(sentiment_dist, dict) else 0
2088
+ pct = (count / total_count * 100) if total_count > 0 else 0.0
2089
+ sentiment_table.add_row(label, str(count), f"{pct:.2f}%")
2090
+
2091
+ console.print()
2092
+ console.print(sentiment_table)
2093
+
2094
+ sentiment_path = output_dir / "sentiment_results.json"
2095
+ with open(sentiment_path, "w") as f:
2096
+ json.dump(sentiment_result, f, indent=2, default=str)
2097
+ saved_files.append(str(sentiment_path))
2098
+
2099
+ # Step 3B: Topic modeling
2100
+ if task in {"all", "topics"}:
2101
+ with Progress(
2102
+ SpinnerColumn(),
2103
+ TextColumn("[progress.description]{task.description}"),
2104
+ console=console
2105
+ ) as progress:
2106
+ topics_task = progress.add_task("Running topic modeling...", total=None)
2107
+ topics_result = perform_topic_modeling(
2108
+ data=df,
2109
+ text_column=text_col,
2110
+ n_topics=n_topics
2111
+ )
2112
+ progress.update(topics_task, completed=True)
2113
+
2114
+ topics_table = Table(title="🧠 Topic Modeling Results")
2115
+ topics_table.add_column("Topic", style="cyan", width=8)
2116
+ topics_table.add_column("Top 5 Keywords", style="white")
2117
+
2118
+ for topic in topics_result.get("topics", []):
2119
+ topic_id = topic.get("topic_id", "-")
2120
+ words = topic.get("words", [])[:5]
2121
+ topics_table.add_row(str(topic_id), ", ".join(words) if words else "-")
2122
+
2123
+ console.print()
2124
+ console.print(topics_table)
2125
+
2126
+ topics_path = output_dir / "topics_results.json"
2127
+ with open(topics_path, "w") as f:
2128
+ json.dump(topics_result, f, indent=2, default=str)
2129
+ saved_files.append(str(topics_path))
2130
+
2131
+ # Step 3C: Entity extraction
2132
+ if task in {"all", "entities"}:
2133
+ with Progress(
2134
+ SpinnerColumn(),
2135
+ TextColumn("[progress.description]{task.description}"),
2136
+ console=console
2137
+ ) as progress:
2138
+ entities_task = progress.add_task("Extracting entities...", total=None)
2139
+ entities_result = extract_entities(df, text_col)
2140
+ progress.update(entities_task, completed=True)
2141
+
2142
+ entities_table = Table(title="🏷️ Top Entities by Type")
2143
+ entities_table.add_column("Entity Type", style="cyan")
2144
+ entities_table.add_column("Top 10 Entities", style="white")
2145
+
2146
+ by_type = entities_result.get("by_type", {}) if isinstance(entities_result, dict) else {}
2147
+ for entity_type, details in by_type.items():
2148
+ top_items = details.get("top_entities") if isinstance(details, dict) else None
2149
+ if isinstance(top_items, list) and top_items:
2150
+ top_10 = [
2151
+ f"{item.get('text', '')} ({item.get('count', 0)})"
2152
+ for item in top_items[:10]
2153
+ ]
2154
+ else:
2155
+ examples = details.get("examples", []) if isinstance(details, dict) else []
2156
+ top_10 = [str(ex) for ex in examples[:10]]
2157
+
2158
+ entities_table.add_row(entity_type, ", ".join(top_10) if top_10 else "-")
2159
+
2160
+ console.print()
2161
+ console.print(entities_table)
2162
+
2163
+ entities_path = output_dir / "entities_results.json"
2164
+ with open(entities_path, "w") as f:
2165
+ json.dump(entities_result, f, indent=2, default=str)
2166
+ saved_files.append(str(entities_path))
2167
+
2168
+ # Step 4: Save classify placeholder when classify-only requested.
2169
+ if task == "classify":
2170
+ classify_result = {
2171
+ "status": "not_implemented",
2172
+ "message": "Classify task is reserved but no classification routine is currently wired in nlp_text_analytics."
2173
+ }
2174
+ classify_path = output_dir / "classify_results.json"
2175
+ with open(classify_path, "w") as f:
2176
+ json.dump(classify_result, f, indent=2)
2177
+ saved_files.append(str(classify_path))
2178
+ console.print("[yellow]ℹ Classify task placeholder saved (feature not implemented in tools module).[/yellow]")
2179
+
2180
+ # Step 5: Final summary panel
2181
+ summary_lines = [
2182
+ f"Task: {task}",
2183
+ f"Dataset: {file_path}",
2184
+ f"Text column: {text_col}",
2185
+ f"Rows processed: {len(df)}",
2186
+ "Saved files:"
2187
+ ]
2188
+ summary_lines.extend([f"- {path}" for path in saved_files])
2189
+
2190
+ console.print(Panel(
2191
+ "\n".join(summary_lines),
2192
+ title="✅ NLP Analysis Complete",
2193
+ border_style="green"
2194
+ ))
2195
+
2196
+
2197
+ @app.command()
2198
+ def pipeline(
2199
+ file_path: str = typer.Argument(..., help="Path to dataset file (CSV or Parquet)"),
2200
+ target_col: str = typer.Option(None, "--target", "-y", help="Target column name (required)"),
2201
+ level: str = typer.Option("basic", "--level", "-l", help="Feature engineering level: basic, intermediate, advanced"),
2202
+ task_type: str = typer.Option("auto", "--task-type", help="Task type: auto, classification, regression"),
2203
+ output: str = typer.Option("./outputs/data/pipeline_output.csv", "--output", "-o", help="Output file path")
2204
+ ):
2205
+ """
2206
+ Run the full automated ML pipeline.
2207
+
2208
+ Example:
2209
+ ds-agent pipeline titanic.csv --target Survived
2210
+ ds-agent pipeline sales.csv --target revenue --level advanced --task-type regression
2211
+ """
2212
+ from ds_agent.tools.auto_pipeline import auto_ml_pipeline
2213
+
2214
+ if not Path(file_path).exists():
2215
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
2216
+ raise typer.Exit(1)
2217
+
2218
+ if not target_col:
2219
+ console.print("[red]✗ Error: --target (-y) is required[/red]")
2220
+ raise typer.Exit(1)
2221
+
2222
+ level = level.strip().lower()
2223
+ valid_levels = {"basic", "intermediate", "advanced"}
2224
+ if level not in valid_levels:
2225
+ console.print("[red]✗ Error: --level must be one of basic, intermediate, advanced[/red]")
2226
+ raise typer.Exit(1)
2227
+
2228
+ task_type = task_type.strip().lower()
2229
+ valid_task_types = {"auto", "classification", "regression"}
2230
+ if task_type not in valid_task_types:
2231
+ console.print("[red]✗ Error: --task-type must be one of auto, classification, regression[/red]")
2232
+ raise typer.Exit(1)
2233
+
2234
+ output_path = str(Path(output))
2235
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
2236
+
2237
+ overview_lines = [
2238
+ "1. Type detection",
2239
+ "2. Cleaning",
2240
+ "3. Outliers",
2241
+ "4. Encoding",
2242
+ "5. Feature engineering",
2243
+ "6. Feature selection",
2244
+ "",
2245
+ f"Dataset: {file_path}",
2246
+ f"Target: {target_col}",
2247
+ f"Feature engineering level: {level}",
2248
+ f"Task type: {task_type}",
2249
+ ]
2250
+
2251
+ # Step 1: Pipeline Overview
2252
+ console.print(Panel(
2253
+ "\n".join(overview_lines),
2254
+ title="🔄 Pipeline Overview",
2255
+ border_style="blue"
2256
+ ))
2257
+
2258
+ stage_to_progress = {
2259
+ "Stage 1": 1,
2260
+ "Stage 2": 2,
2261
+ "Stage 3": 3,
2262
+ "Stage 5": 4,
2263
+ "Stage 6": 5,
2264
+ "Stage 7": 6,
2265
+ }
2266
+
2267
+ class _PipelineStream(io.TextIOBase):
2268
+ def __init__(self, progress, task_id):
2269
+ self.progress = progress
2270
+ self.task_id = task_id
2271
+ self.buffer = ""
2272
+ self.max_completed = 0
2273
+
2274
+ def write(self, text):
2275
+ self.buffer += text
2276
+ while "\n" in self.buffer:
2277
+ line, self.buffer = self.buffer.split("\n", 1)
2278
+ line = line.strip()
2279
+ for key, step in stage_to_progress.items():
2280
+ if key in line and step > self.max_completed:
2281
+ self.max_completed = step
2282
+ self.progress.update(self.task_id, completed=step)
2283
+ break
2284
+ return len(text)
2285
+
2286
+ def flush(self):
2287
+ return None
2288
+
2289
+ try:
2290
+ with Progress(
2291
+ SpinnerColumn(),
2292
+ TextColumn("[progress.description]{task.description}"),
2293
+ console=console
2294
+ ) as progress:
2295
+ pipeline_task = progress.add_task("Running automated pipeline...", total=6)
2296
+ stream = _PipelineStream(progress, pipeline_task)
2297
+
2298
+ with contextlib.redirect_stdout(stream):
2299
+ result = auto_ml_pipeline(
2300
+ file_path,
2301
+ target_col,
2302
+ task_type,
2303
+ output_path,
2304
+ level,
2305
+ )
2306
+
2307
+ progress.update(pipeline_task, completed=6)
2308
+
2309
+ except Exception as e:
2310
+ console.print(f"[red]✗ Pipeline failed: {e}[/red]")
2311
+ raise typer.Exit(1)
2312
+
2313
+ original_shape = result.get("original_shape", {})
2314
+ final_shape = result.get("final_shape", {})
2315
+ original_cols = int(original_shape.get("columns", 0))
2316
+ final_cols = int(final_shape.get("columns", 0))
2317
+ features_added = max(final_cols - original_cols, 0)
2318
+ features_removed = max(original_cols - final_cols, 0)
2319
+ transformations_applied = result.get("transformations_applied", [])
2320
+
2321
+ # Step 4: Comparison table
2322
+ comparison_table = Table(title="📊 Pipeline Comparison", show_header=False)
2323
+ comparison_table.add_column("Metric", style="cyan", width=28)
2324
+ comparison_table.add_column("Value", style="white")
2325
+
2326
+ comparison_table.add_row(
2327
+ "Original Shape",
2328
+ f"{original_shape.get('rows', '-')} rows × {original_shape.get('columns', '-')} cols"
2329
+ )
2330
+ comparison_table.add_row(
2331
+ "Final Shape",
2332
+ f"{final_shape.get('rows', '-')} rows × {final_shape.get('columns', '-')} cols"
2333
+ )
2334
+ comparison_table.add_row("Features Added", str(features_added))
2335
+ comparison_table.add_row("Features Removed", str(features_removed))
2336
+ comparison_table.add_row(
2337
+ "Transformations Applied",
2338
+ str(len(transformations_applied))
2339
+ )
2340
+
2341
+ transformation_names = [t.get("stage", "Unknown") for t in transformations_applied]
2342
+ comparison_table.add_row(
2343
+ "Transformation Stages",
2344
+ ", ".join(transformation_names) if transformation_names else "-"
2345
+ )
2346
+
2347
+ console.print()
2348
+ console.print(comparison_table)
2349
+
2350
+ # Step 5: Top selected features
2351
+ selected_features = result.get("selected_features", [])
2352
+ feature_importance = result.get("feature_importance", {})
2353
+
2354
+ if selected_features:
2355
+ feature_table = Table(title="🏆 Top Selected Features")
2356
+ feature_table.add_column("Rank", style="cyan", width=6)
2357
+ feature_table.add_column("Feature", style="white")
2358
+ feature_table.add_column("Importance", style="green", justify="right")
2359
+
2360
+ if isinstance(feature_importance, dict) and feature_importance:
2361
+ ranked = []
2362
+ for feature in selected_features:
2363
+ score = feature_importance.get(feature)
2364
+ if score is None:
2365
+ score = feature_importance.get(str(feature))
2366
+ ranked.append((feature, score))
2367
+
2368
+ ranked.sort(key=lambda x: float(x[1]) if x[1] is not None else float("-inf"), reverse=True)
2369
+ top_features = ranked[:10]
2370
+ else:
2371
+ top_features = [(feature, None) for feature in selected_features[:10]]
2372
+
2373
+ for idx, (feature, score) in enumerate(top_features, start=1):
2374
+ score_text = "-" if score is None else f"{float(score):.6f}"
2375
+ feature_table.add_row(str(idx), str(feature), score_text)
2376
+
2377
+ console.print()
2378
+ console.print(feature_table)
2379
+
2380
+ # Step 6: Output file path
2381
+ output_file = result.get("output_path", output_path)
2382
+ console.print(f"\n💾 Pipeline output saved to: [cyan]{output_file}[/cyan]")
2383
+
2384
+
2385
+ @app.command()
2386
+ def tune(
2387
+ file_path: str = typer.Argument(..., help="Path to pre-processed dataset (CSV or Parquet)"),
2388
+ target_col: str = typer.Option(None, "--target", "-y", help="Target column name (required)"),
2389
+ model_type: str = typer.Option("xgboost", "--model", "-m", help="Model to tune: random_forest, xgboost, lightgbm"),
2390
+ trials: int = typer.Option(30, "--trials", "-n", help="Number of Optuna trials"),
2391
+ task_type: str = typer.Option("auto", "--task-type", help="Task type: auto, classification, regression"),
2392
+ output: str = typer.Option("./outputs/models/", "--output", "-o", help="Output directory or model file path")
2393
+ ):
2394
+ """
2395
+ Run Bayesian hyperparameter optimization with Optuna.
2396
+
2397
+ Example:
2398
+ ds-agent tune cleaned.csv --target Survived --model xgboost --trials 50
2399
+ """
2400
+ from ds_agent.tools.advanced_training import hyperparameter_tuning
2401
+ from ds_agent.utils.polars_helpers import load_dataframe
2402
+ import ds_agent.tools.advanced_training as advanced_training_module
2403
+ import numpy as np
2404
+ from sklearn.model_selection import cross_val_score, StratifiedKFold, KFold
2405
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
2406
+ from xgboost import XGBClassifier, XGBRegressor
2407
+
2408
+ if not Path(file_path).exists():
2409
+ console.print(f"[red]✗ Error: File not found: {file_path}[/red]")
2410
+ raise typer.Exit(1)
2411
+
2412
+ if not target_col:
2413
+ console.print("[red]✗ Error: --target (-y) is required[/red]")
2414
+ raise typer.Exit(1)
2415
+
2416
+ model_type = model_type.strip().lower()
2417
+ valid_models = {"random_forest", "xgboost", "lightgbm"}
2418
+ if model_type not in valid_models:
2419
+ console.print("[red]✗ Error: --model must be one of random_forest, xgboost, lightgbm[/red]")
2420
+ raise typer.Exit(1)
2421
+
2422
+ task_type = task_type.strip().lower()
2423
+ valid_task_types = {"auto", "classification", "regression"}
2424
+ if task_type not in valid_task_types:
2425
+ console.print("[red]✗ Error: --task-type must be one of auto, classification, regression[/red]")
2426
+ raise typer.Exit(1)
2427
+
2428
+ if trials <= 0:
2429
+ console.print("[red]✗ Error: --trials (-n) must be a positive integer[/red]")
2430
+ raise typer.Exit(1)
2431
+
2432
+ try:
2433
+ df_for_estimate = load_dataframe(file_path)
2434
+ dataset_rows = len(df_for_estimate)
2435
+ except Exception:
2436
+ dataset_rows = 0
2437
+
2438
+ base_estimated_seconds = int(trials * 10)
2439
+ size_factor = max(1.0, (dataset_rows / 50000.0)) if dataset_rows > 0 else 1.0
2440
+ estimated_seconds = int(base_estimated_seconds * size_factor)
2441
+
2442
+ # Resolve output path: directory or file path.
2443
+ output_path_obj = Path(output)
2444
+ if output_path_obj.suffix.lower() in {".pkl", ".joblib"}:
2445
+ model_output_path = output_path_obj
2446
+ model_output_path.parent.mkdir(parents=True, exist_ok=True)
2447
+ else:
2448
+ output_path_obj.mkdir(parents=True, exist_ok=True)
2449
+ model_output_path = output_path_obj / f"tuned_{model_type}_{Path(file_path).stem}.pkl"
2450
+
2451
+ # Step 1: Upfront warning panel
2452
+ console.print(Panel(
2453
+ f"⚠️ Hyperparameter tuning is compute-intensive.\n"
2454
+ f"Estimated time: {base_estimated_seconds} seconds. Press Ctrl+C to stop early.\n"
2455
+ f"Dataset-size adjusted estimate: {estimated_seconds} seconds.",
2456
+ title="Tuning Warning",
2457
+ border_style="yellow"
2458
+ ))
2459
+
2460
+ baseline_score = None
2461
+
2462
+ def _compute_baseline_score() -> Optional[float]:
2463
+ try:
2464
+ df = load_dataframe(file_path).to_pandas()
2465
+ if target_col not in df.columns:
2466
+ return None
2467
+
2468
+ X = df.drop(columns=[target_col])
2469
+ y = df[target_col]
2470
+ X = X.select_dtypes(include=[np.number]).fillna(0)
2471
+ if X.shape[1] == 0:
2472
+ return None
2473
+
2474
+ inferred_task = task_type
2475
+ if inferred_task == "auto":
2476
+ inferred_task = "classification" if y.nunique() < 20 else "regression"
2477
+
2478
+ if model_type == "random_forest":
2479
+ model = RandomForestClassifier(random_state=42) if inferred_task == "classification" else RandomForestRegressor(random_state=42)
2480
+ elif model_type == "xgboost":
2481
+ model = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric="logloss") if inferred_task == "classification" else XGBRegressor(random_state=42)
2482
+ elif model_type == "lightgbm":
2483
+ from lightgbm import LGBMClassifier, LGBMRegressor
2484
+ model = LGBMClassifier(random_state=42, verbosity=-1) if inferred_task == "classification" else LGBMRegressor(random_state=42, verbosity=-1)
2485
+ else:
2486
+ return None
2487
+
2488
+ if inferred_task == "classification":
2489
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
2490
+ scores = cross_val_score(model, X, y, cv=cv, scoring="accuracy", n_jobs=-1)
2491
+ else:
2492
+ cv = KFold(n_splits=5, shuffle=True, random_state=42)
2493
+ scores = cross_val_score(model, X, y, cv=cv, scoring="neg_root_mean_squared_error", n_jobs=-1)
2494
+
2495
+ return float(scores.mean())
2496
+ except Exception:
2497
+ return None
2498
+
2499
+ baseline_score = _compute_baseline_score()
2500
+
2501
+ trial_counter = {"count": 0}
2502
+ optuna_module = advanced_training_module.optuna
2503
+ original_optimize = optuna_module.study.study.Study.optimize
2504
+
2505
+ def _patched_optimize(self, func, *args, **kwargs):
2506
+ def _trial_callback(study, trial):
2507
+ trial_counter["count"] += 1
2508
+ progress.update(
2509
+ progress_task,
2510
+ completed=min(trial_counter["count"], trials),
2511
+ description=f"Running Optuna trials ({trial_counter['count']} of {trials})..."
2512
+ )
2513
+
2514
+ callbacks = list(kwargs.get("callbacks") or [])
2515
+ callbacks.append(_trial_callback)
2516
+ kwargs["callbacks"] = callbacks
2517
+ kwargs["show_progress_bar"] = False
2518
+ return original_optimize(self, func, *args, **kwargs)
2519
+
2520
+ try:
2521
+ with Progress(
2522
+ SpinnerColumn(),
2523
+ TextColumn("[progress.description]{task.description}"),
2524
+ BarColumn(),
2525
+ TextColumn("[progress.percentage]{task.completed:.0f}/{task.total:.0f}"),
2526
+ console=console
2527
+ ) as progress:
2528
+ progress_task = progress.add_task("Running Optuna trials (0 of 0)...", total=trials)
2529
+
2530
+ optuna_module.study.study.Study.optimize = _patched_optimize
2531
+ tuning_result = hyperparameter_tuning(
2532
+ file_path,
2533
+ target_col,
2534
+ model_type,
2535
+ task_type,
2536
+ trials,
2537
+ output_path=str(model_output_path)
2538
+ )
2539
+
2540
+ progress.update(
2541
+ progress_task,
2542
+ completed=trials,
2543
+ description=f"Running Optuna trials ({trial_counter['count']} of {trials})..."
2544
+ )
2545
+ except KeyboardInterrupt:
2546
+ console.print("[yellow]⚠ Tuning stopped by user (Ctrl+C).[/yellow]")
2547
+ raise typer.Exit(1)
2548
+ except Exception as e:
2549
+ console.print(f"[red]✗ Hyperparameter tuning failed: {e}[/red]")
2550
+ raise typer.Exit(1)
2551
+ finally:
2552
+ optuna_module.study.study.Study.optimize = original_optimize
2553
+
2554
+ if tuning_result.get("status") != "success":
2555
+ console.print(f"[red]✗ Hyperparameter tuning failed: {tuning_result.get('message', 'Unknown error')}[/red]")
2556
+ raise typer.Exit(1)
2557
+
2558
+ best_params = tuning_result.get("best_params", {})
2559
+ best_score = tuning_result.get("best_cv_score")
2560
+
2561
+ improvement_pct = None
2562
+ if baseline_score is not None and best_score is not None and baseline_score != 0:
2563
+ improvement_pct = ((float(best_score) - float(baseline_score)) / abs(float(baseline_score))) * 100.0
2564
+
2565
+ # Step 4: Results table
2566
+ summary_table = Table(title="🎯 Tuning Results", show_header=False)
2567
+ summary_table.add_column("Metric", style="cyan", width=28)
2568
+ summary_table.add_column("Value", style="white")
2569
+
2570
+ summary_table.add_row("Best Parameters", json.dumps(best_params, indent=2) if best_params else "-")
2571
+ summary_table.add_row("Best Score", "-" if best_score is None else f"{float(best_score):.6f}")
2572
+ summary_table.add_row("Baseline Score", "-" if baseline_score is None else f"{float(baseline_score):.6f}")
2573
+ summary_table.add_row(
2574
+ "Improvement Percentage",
2575
+ "-" if improvement_pct is None else f"{improvement_pct:+.2f}%"
2576
+ )
2577
+
2578
+ console.print()
2579
+ console.print(summary_table)
2580
+
2581
+ # Step 5: Saved model path
2582
+ saved_model_path = tuning_result.get("model_path") or str(model_output_path)
2583
+ console.print(f"\n💾 Tuned model saved to: [cyan]{saved_model_path}[/cyan]")
2584
+
2585
+
2586
+ @app.command()
2587
+ def clean(
2588
+ file_path: str = typer.Argument(..., help="Path to dataset file"),
2589
+ output: str = typer.Option(None, "--output", "-o", help="Output file path"),
2590
+ strategy: str = typer.Option("auto", "--strategy", "-s", help="Cleaning strategy (auto/median/mean/mode/drop)")
2591
+ ):
2592
+ """
2593
+ Clean dataset (handle missing values and outliers).
2594
+
2595
+ Example:
2596
+ python cli.py clean data.csv --output cleaned_data.csv
2597
+ """
2598
+ from ds_agent.tools.data_cleaning import clean_missing_values
2599
+ from ds_agent.tools.data_profiling import profile_dataset
2600
+
2601
+ if output is None:
2602
+ output = f"./outputs/data/cleaned_{Path(file_path).name}"
2603
+
2604
+ console.print(f"\n🧹 [bold]Cleaning:[/bold] {file_path}\n")
2605
+
2606
+ # Get columns with missing values
2607
+ profile = profile_dataset(file_path)
2608
+ cols_with_nulls = {
2609
+ col: "auto"
2610
+ for col, info in profile["columns"].items()
2611
+ if info["null_count"] > 0
2612
+ }
2613
+
2614
+ if not cols_with_nulls:
2615
+ console.print("[green]✓ No missing values found - dataset is clean![/green]")
2616
+ return
2617
+
2618
+ console.print(f"Found {len(cols_with_nulls)} columns with missing values")
2619
+
2620
+ # Clean
2621
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
2622
+ task = progress.add_task("Cleaning dataset...", total=None)
2623
+ result = clean_missing_values(file_path, cols_with_nulls, output)
2624
+ progress.update(task, completed=True)
2625
+
2626
+ console.print(f"\n[green]✓ Cleaned dataset saved to: {output}[/green]")
2627
+ console.print(f" Rows: {result['original_rows']} → {result['final_rows']}")
2628
+
2629
+
2630
+ @app.command()
2631
+ def train(
2632
+ file_path: str = typer.Argument(..., help="Path to prepared dataset"),
2633
+ target: str = typer.Argument(..., help="Target column name"),
2634
+ task_type: str = typer.Option("auto", "--task-type", help="Task type (classification/regression/auto)")
2635
+ ):
2636
+ """
2637
+ Train baseline models on prepared dataset.
2638
+
2639
+ Example:
2640
+ python cli.py train cleaned_data.csv Survived --task-type classification
2641
+ """
2642
+ from ds_agent.tools.model_training import train_baseline_models
2643
+
2644
+ console.print(f"\n🤖 [bold]Training Models[/bold]\n")
2645
+ console.print(f"📊 Dataset: {file_path}")
2646
+ console.print(f"🎯 Target: {target}\n")
2647
+
2648
+ # Train
2649
+ with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress:
2650
+ task = progress.add_task("Training baseline models...", total=None)
2651
+ result = train_baseline_models(file_path, target, task_type)
2652
+ progress.update(task, completed=True)
2653
+
2654
+ if "error" in result:
2655
+ console.print(f"[red]✗ Error: {result['message']}[/red]")
2656
+ raise typer.Exit(1)
2657
+
2658
+ # Display results
2659
+ console.print(f"\n[green]✓ Training Complete![/green]\n")
2660
+ console.print(f"Task Type: {result['task_type']}")
2661
+ console.print(f"Features: {result['n_features']}")
2662
+ console.print(f"Samples: {result['n_samples']}\n")
2663
+
2664
+ # Model comparison table
2665
+ table = Table(title="Model Performance")
2666
+ table.add_column("Model", style="cyan")
2667
+
2668
+ # Add metric columns based on task type
2669
+ if result["task_type"] == "classification":
2670
+ table.add_column("Accuracy", justify="right")
2671
+ table.add_column("F1 Score", justify="right")
2672
+ else:
2673
+ table.add_column("R² Score", justify="right")
2674
+ table.add_column("RMSE", justify="right")
2675
+
2676
+ for model_name, model_result in result["models"].items():
2677
+ if "test_metrics" in model_result:
2678
+ metrics = model_result["test_metrics"]
2679
+ if result["task_type"] == "classification":
2680
+ table.add_row(
2681
+ model_name,
2682
+ f"{metrics['accuracy']:.4f}",
2683
+ f"{metrics['f1']:.4f}"
2684
+ )
2685
+ else:
2686
+ table.add_row(
2687
+ model_name,
2688
+ f"{metrics['r2']:.4f}",
2689
+ f"{metrics['rmse']:.4f}"
2690
+ )
2691
+
2692
+ console.print(table)
2693
+
2694
+ # Best model
2695
+ console.print(f"\n🏆 [bold]Best Model:[/bold] {result['best_model']['name']}")
2696
+ console.print(f" Score: {result['best_model']['score']:.4f}")
2697
+ console.print(f" Path: {result['best_model']['model_path']}")
2698
+
2699
+
2700
+ @app.command()
2701
+ def cache_stats():
2702
+ """Show cache statistics."""
2703
+ from ds_agent.orchestrator import DataScienceCopilot
2704
+
2705
+ copilot = DataScienceCopilot()
2706
+ stats = copilot.get_cache_stats()
2707
+
2708
+ table = Table(title="Cache Statistics")
2709
+ table.add_column("Metric", style="cyan")
2710
+ table.add_column("Value", style="white")
2711
+
2712
+ table.add_row("Total Entries", str(stats["total_entries"]))
2713
+ table.add_row("Valid Entries", str(stats["valid_entries"]))
2714
+ table.add_row("Expired Entries", str(stats["expired_entries"]))
2715
+ table.add_row("Size", f"{stats['size_mb']} MB")
2716
+
2717
+ console.print()
2718
+ console.print(table)
2719
+
2720
+
2721
+ @app.command()
2722
+ def clear_cache():
2723
+ """Clear all cached results."""
2724
+ from ds_agent.orchestrator import DataScienceCopilot
2725
+
2726
+ copilot = DataScienceCopilot()
2727
+ copilot.clear_cache()
2728
+ console.print("[green]✓ Cache cleared successfully[/green]")
2729
+
2730
+
2731
+ @sessions_app.command("list")
2732
+ def sessions_list():
2733
+ """List saved sessions."""
2734
+ from ds_agent.session_store import SessionStore
2735
+
2736
+ store = SessionStore()
2737
+ list_fn = getattr(store, "list_sessions", None)
2738
+ load_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
2739
+
2740
+ if not callable(list_fn):
2741
+ console.print("[red]✗ Session store does not provide list_sessions().[/red]")
2742
+ raise typer.Exit(1)
2743
+
2744
+ sessions = list_fn(limit=10000)
2745
+ if not sessions:
2746
+ console.print("[yellow]No saved sessions found.[/yellow]")
2747
+ return
2748
+
2749
+ table = Table(title="Saved Sessions")
2750
+ table.add_column("session_id", style="cyan")
2751
+ table.add_column("created_at", style="white")
2752
+ table.add_column("last_active", style="white")
2753
+ table.add_column("last_dataset", style="magenta")
2754
+ table.add_column("steps_completed", justify="right", style="green")
2755
+
2756
+ for item in sessions:
2757
+ sid = item.get("session_id", "-")
2758
+ created_at = item.get("created_at", "-")
2759
+ last_active = item.get("last_active", "-")
2760
+
2761
+ last_dataset = "-"
2762
+ steps_completed = 0
2763
+
2764
+ if callable(load_fn):
2765
+ try:
2766
+ session_obj = load_fn(sid)
2767
+ if session_obj is not None:
2768
+ last_dataset = getattr(session_obj, "last_dataset", None) or "-"
2769
+ workflow_history = getattr(session_obj, "workflow_history", []) or []
2770
+ steps_completed = len(workflow_history)
2771
+ except Exception:
2772
+ pass
2773
+
2774
+ table.add_row(
2775
+ str(sid),
2776
+ str(created_at),
2777
+ str(last_active),
2778
+ str(last_dataset),
2779
+ str(steps_completed)
2780
+ )
2781
+
2782
+ console.print()
2783
+ console.print(table)
2784
+
2785
+
2786
+ @sessions_app.command("resume")
2787
+ def sessions_resume(
2788
+ session_id: str = typer.Argument(..., help="Session ID to resume")
2789
+ ):
2790
+ """Load a saved session and print its context."""
2791
+ from ds_agent.session_store import SessionStore
2792
+
2793
+ store = SessionStore()
2794
+ load_fn = getattr(store, "load_session", None) or getattr(store, "load", None)
2795
+
2796
+ if not callable(load_fn):
2797
+ console.print("[red]✗ Session store does not provide load_session().[/red]")
2798
+ raise typer.Exit(1)
2799
+
2800
+ session = load_fn(session_id)
2801
+ if session is None:
2802
+ console.print(f"[red]✗ Session not found: {session_id}[/red]")
2803
+ raise typer.Exit(1)
2804
+
2805
+ context_text = ""
2806
+ get_context_summary = getattr(session, "get_context_summary", None)
2807
+ if callable(get_context_summary):
2808
+ context_text = get_context_summary()
2809
+
2810
+ if not context_text:
2811
+ context_lines = [
2812
+ f"Session ID: {getattr(session, 'session_id', session_id)}",
2813
+ f"Created: {getattr(session, 'created_at', '-')}",
2814
+ f"Last Active: {getattr(session, 'last_active', '-')}",
2815
+ f"Last Dataset: {getattr(session, 'last_dataset', '-')}",
2816
+ f"Last Target: {getattr(session, 'last_target_col', '-')}",
2817
+ f"Last Model: {getattr(session, 'last_model', '-')}",
2818
+ f"Steps Completed: {len(getattr(session, 'workflow_history', []) or [])}",
2819
+ ]
2820
+ context_text = "\n".join(context_lines)
2821
+
2822
+ console.print(Panel(
2823
+ context_text,
2824
+ title=f"Session Context: {session_id}",
2825
+ border_style="green"
2826
+ ))
2827
+
2828
+
2829
+ @sessions_app.command("delete")
2830
+ def sessions_delete(
2831
+ session_id: str = typer.Argument(..., help="Session ID to delete")
2832
+ ):
2833
+ """Delete one saved session from SQLite."""
2834
+ from ds_agent.session_store import SessionStore
2835
+
2836
+ store = SessionStore()
2837
+ delete_fn = getattr(store, "delete_session", None) or getattr(store, "delete", None)
2838
+
2839
+ if not callable(delete_fn):
2840
+ console.print("[red]✗ Session store does not provide delete_session().[/red]")
2841
+ raise typer.Exit(1)
2842
+
2843
+ deleted = bool(delete_fn(session_id))
2844
+ if deleted:
2845
+ console.print(f"[green]✓ Deleted session: {session_id}[/green]")
2846
+ else:
2847
+ console.print(f"[yellow]Session not found: {session_id}[/yellow]")
2848
+
2849
+
2850
+ @sessions_app.command("clear")
2851
+ def sessions_clear():
2852
+ """Delete all saved sessions."""
2853
+ from ds_agent.session_store import SessionStore
2854
+
2855
+ store = SessionStore()
2856
+ list_fn = getattr(store, "list_sessions", None)
2857
+ delete_fn = getattr(store, "delete_session", None) or getattr(store, "delete", None)
2858
+
2859
+ if not callable(list_fn):
2860
+ console.print("[red]✗ Session store does not provide list_sessions().[/red]")
2861
+ raise typer.Exit(1)
2862
+ if not callable(delete_fn):
2863
+ console.print("[red]✗ Session store does not provide delete_session().[/red]")
2864
+ raise typer.Exit(1)
2865
+
2866
+ sessions = list_fn(limit=10000)
2867
+ if not sessions:
2868
+ console.print("[yellow]No sessions to clear.[/yellow]")
2869
+ return
2870
+
2871
+ deleted_count = 0
2872
+ for item in sessions:
2873
+ sid = item.get("session_id")
2874
+ if sid and delete_fn(sid):
2875
+ deleted_count += 1
2876
+
2877
+ console.print(f"[green]✓ Cleared {deleted_count} session(s).[/green]")
2878
+
2879
+
2880
+ def main():
2881
+ """Console entry point for ds-agent."""
2882
+ app()
2883
+
2884
+
2885
+ if __name__ == "__main__":
2886
+ main()