ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,497 @@
1
+ """
2
+ Plotly Interactive Visualization Tools
3
+ Create interactive, web-based visualizations that can be explored in browsers.
4
+ """
5
+
6
+ import polars as pl
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ import numpy as np
11
+ from typing import Dict, Any, List, Optional
12
+ from pathlib import Path
13
+ import sys
14
+ import os
15
+
16
+ # Add parent directory to path for imports
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from ds_agent.utils.polars_helpers import (
20
+ load_dataframe,
21
+ get_numeric_columns,
22
+ get_categorical_columns,
23
+ )
24
+ from ds_agent.utils.validation import (
25
+ validate_file_exists,
26
+ validate_file_format,
27
+ validate_dataframe,
28
+ validate_column_exists,
29
+ )
30
+
31
+
32
+ def generate_interactive_scatter(
33
+ file_path: str,
34
+ x_col: str,
35
+ y_col: str,
36
+ color_col: Optional[str] = None,
37
+ size_col: Optional[str] = None,
38
+ output_path: str = "./outputs/plots/interactive/scatter.html"
39
+ ) -> Dict[str, Any]:
40
+ """
41
+ Create interactive scatter plot with Plotly.
42
+
43
+ Args:
44
+ file_path: Path to dataset
45
+ x_col: Column for X-axis
46
+ y_col: Column for Y-axis
47
+ color_col: Optional column for color coding
48
+ size_col: Optional column for bubble size
49
+ output_path: Path to save HTML file
50
+
51
+ Returns:
52
+ Dictionary with plot info and path
53
+ """
54
+ # Validation
55
+ validate_file_exists(file_path)
56
+ validate_file_format(file_path)
57
+
58
+ # Load data
59
+ df = load_dataframe(file_path)
60
+ validate_dataframe(df)
61
+ validate_column_exists(df, x_col)
62
+ validate_column_exists(df, y_col)
63
+
64
+ if color_col:
65
+ validate_column_exists(df, color_col)
66
+ if size_col:
67
+ validate_column_exists(df, size_col)
68
+
69
+ # Convert to pandas for plotly
70
+ df_pd = df.to_pandas()
71
+
72
+ # Create figure
73
+ fig = px.scatter(
74
+ df_pd,
75
+ x=x_col,
76
+ y=y_col,
77
+ color=color_col,
78
+ size=size_col,
79
+ hover_data=df_pd.columns.tolist(),
80
+ title=f"Interactive Scatter: {y_col} vs {x_col}",
81
+ template="plotly_white"
82
+ )
83
+
84
+ # Update layout for better interactivity
85
+ fig.update_layout(
86
+ hovermode='closest',
87
+ height=600,
88
+ font=dict(size=12)
89
+ )
90
+
91
+ # Save
92
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
93
+ fig.write_html(output_path)
94
+
95
+ return {
96
+ "status": "success",
97
+ "plot_type": "interactive_scatter",
98
+ "output_path": output_path,
99
+ "x_col": x_col,
100
+ "y_col": y_col,
101
+ "color_col": color_col,
102
+ "size_col": size_col,
103
+ "num_points": len(df)
104
+ }
105
+
106
+
107
+ def generate_interactive_histogram(
108
+ file_path: str,
109
+ column: str,
110
+ bins: int = 30,
111
+ color_col: Optional[str] = None,
112
+ output_path: str = "./outputs/plots/interactive/histogram.html"
113
+ ) -> Dict[str, Any]:
114
+ """
115
+ Create interactive histogram with Plotly.
116
+
117
+ Args:
118
+ file_path: Path to dataset
119
+ column: Column to plot
120
+ bins: Number of bins
121
+ color_col: Optional column for grouped histograms
122
+ output_path: Path to save HTML file
123
+
124
+ Returns:
125
+ Dictionary with plot info
126
+ """
127
+ # Validation
128
+ validate_file_exists(file_path)
129
+ validate_file_format(file_path)
130
+
131
+ df = load_dataframe(file_path)
132
+ validate_dataframe(df)
133
+ validate_column_exists(df, column)
134
+
135
+ if color_col:
136
+ validate_column_exists(df, color_col)
137
+
138
+ df_pd = df.to_pandas()
139
+
140
+ # Create histogram
141
+ fig = px.histogram(
142
+ df_pd,
143
+ x=column,
144
+ nbins=bins,
145
+ color=color_col,
146
+ title=f"Distribution of {column}",
147
+ template="plotly_white",
148
+ marginal="box" # Add box plot on top
149
+ )
150
+
151
+ fig.update_layout(
152
+ bargap=0.1,
153
+ height=600,
154
+ showlegend=True if color_col else False
155
+ )
156
+
157
+ # Save
158
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
159
+ fig.write_html(output_path)
160
+
161
+ return {
162
+ "status": "success",
163
+ "plot_type": "interactive_histogram",
164
+ "output_path": output_path,
165
+ "column": column,
166
+ "bins": bins,
167
+ "color_col": color_col
168
+ }
169
+
170
+
171
+ def generate_interactive_correlation_heatmap(
172
+ file_path: str,
173
+ output_path: str = None
174
+ ) -> Dict[str, Any]:
175
+ """
176
+ Create interactive correlation heatmap with Plotly.
177
+
178
+ Args:
179
+ file_path: Path to dataset
180
+ output_path: Path to save HTML file (auto-determined if None)
181
+
182
+ Returns:
183
+ Dictionary with plot info
184
+ """
185
+ # Auto-determine output path based on environment
186
+ if output_path is None:
187
+ output_base = os.getenv("DS_AGENT_OUTPUT_DIR", "./outputs")
188
+ output_path = f"{output_base}/plots/interactive/correlation_heatmap.html"
189
+
190
+ # Validation
191
+ validate_file_exists(file_path)
192
+ validate_file_format(file_path)
193
+
194
+ df = load_dataframe(file_path)
195
+ validate_dataframe(df)
196
+
197
+ # Get numeric columns
198
+ numeric_cols = get_numeric_columns(df)
199
+
200
+ if len(numeric_cols) < 2:
201
+ return {
202
+ "status": "error",
203
+ "message": "Need at least 2 numeric columns for correlation"
204
+ }
205
+
206
+ # Calculate correlation matrix
207
+ df_numeric = df.select(numeric_cols)
208
+ corr_matrix = df_numeric.to_pandas().corr()
209
+
210
+ # Create heatmap
211
+ fig = go.Figure(data=go.Heatmap(
212
+ z=corr_matrix.values,
213
+ x=corr_matrix.columns,
214
+ y=corr_matrix.columns,
215
+ colorscale='RdBu',
216
+ zmid=0,
217
+ text=np.round(corr_matrix.values, 2),
218
+ texttemplate='%{text}',
219
+ textfont={"size": 10},
220
+ colorbar=dict(title="Correlation")
221
+ ))
222
+
223
+ fig.update_layout(
224
+ title="Interactive Correlation Heatmap",
225
+ template="plotly_white",
226
+ height=max(600, len(numeric_cols) * 30),
227
+ width=max(600, len(numeric_cols) * 30),
228
+ xaxis={'side': 'bottom'},
229
+ yaxis={'side': 'left'}
230
+ )
231
+
232
+ # Save
233
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
234
+ fig.write_html(output_path)
235
+
236
+ return {
237
+ "status": "success",
238
+ "plot_type": "interactive_correlation_heatmap",
239
+ "output_path": output_path,
240
+ "num_features": len(numeric_cols)
241
+ }
242
+
243
+
244
+ def generate_interactive_box_plots(
245
+ file_path: str,
246
+ columns: Optional[List[str]] = None,
247
+ group_by: Optional[str] = None,
248
+ output_path: str = "./outputs/plots/interactive/box_plots.html"
249
+ ) -> Dict[str, Any]:
250
+ """
251
+ Create interactive box plots for outlier detection.
252
+
253
+ Args:
254
+ file_path: Path to dataset
255
+ columns: Columns to plot (all numeric if None)
256
+ group_by: Optional categorical column for grouping
257
+ output_path: Path to save HTML file
258
+
259
+ Returns:
260
+ Dictionary with plot info
261
+ """
262
+ # Validation
263
+ validate_file_exists(file_path)
264
+ validate_file_format(file_path)
265
+
266
+ df = load_dataframe(file_path)
267
+ validate_dataframe(df)
268
+
269
+ # Determine columns to plot
270
+ if columns is None:
271
+ columns = get_numeric_columns(df)
272
+ else:
273
+ for col in columns:
274
+ validate_column_exists(df, col)
275
+
276
+ if len(columns) == 0:
277
+ return {
278
+ "status": "error",
279
+ "message": "No numeric columns to plot"
280
+ }
281
+
282
+ if group_by:
283
+ validate_column_exists(df, group_by)
284
+
285
+ df_pd = df.to_pandas()
286
+
287
+ # Create subplots
288
+ rows = (len(columns) + 2) // 3 # 3 plots per row
289
+ cols = min(3, len(columns))
290
+
291
+ fig = make_subplots(
292
+ rows=rows,
293
+ cols=cols,
294
+ subplot_titles=columns,
295
+ vertical_spacing=0.1
296
+ )
297
+
298
+ for idx, col in enumerate(columns):
299
+ row = idx // 3 + 1
300
+ col_idx = idx % 3 + 1
301
+
302
+ if group_by:
303
+ for group in df_pd[group_by].unique():
304
+ group_data = df_pd[df_pd[group_by] == group][col]
305
+ fig.add_trace(
306
+ go.Box(y=group_data, name=str(group), showlegend=(idx == 0)),
307
+ row=row,
308
+ col=col_idx
309
+ )
310
+ else:
311
+ fig.add_trace(
312
+ go.Box(y=df_pd[col], name=col, showlegend=False),
313
+ row=row,
314
+ col=col_idx
315
+ )
316
+
317
+ fig.update_layout(
318
+ title="Interactive Box Plots - Outlier Detection",
319
+ template="plotly_white",
320
+ height=400 * rows,
321
+ showlegend=bool(group_by)
322
+ )
323
+
324
+ # Save
325
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
326
+ fig.write_html(output_path)
327
+
328
+ return {
329
+ "status": "success",
330
+ "plot_type": "interactive_box_plots",
331
+ "output_path": output_path,
332
+ "columns_plotted": columns,
333
+ "group_by": group_by
334
+ }
335
+
336
+
337
+ def generate_interactive_time_series(
338
+ file_path: str,
339
+ time_col: str,
340
+ value_cols: List[str],
341
+ output_path: str = "./outputs/plots/interactive/time_series.html"
342
+ ) -> Dict[str, Any]:
343
+ """
344
+ Create interactive time series plot with Plotly.
345
+
346
+ Args:
347
+ file_path: Path to dataset
348
+ time_col: Column with datetime values
349
+ value_cols: Columns to plot over time
350
+ output_path: Path to save HTML file
351
+
352
+ Returns:
353
+ Dictionary with plot info
354
+ """
355
+ # Validation
356
+ validate_file_exists(file_path)
357
+ validate_file_format(file_path)
358
+
359
+ df = load_dataframe(file_path)
360
+ validate_dataframe(df)
361
+ validate_column_exists(df, time_col)
362
+
363
+ for col in value_cols:
364
+ validate_column_exists(df, col)
365
+
366
+ # Parse datetime if needed
367
+ if df[time_col].dtype == pl.Utf8:
368
+ df = df.with_columns(
369
+ pl.col(time_col).str.strptime(pl.Datetime, strict=False).alias(time_col)
370
+ )
371
+
372
+ df_pd = df.to_pandas()
373
+
374
+ # Create figure
375
+ fig = go.Figure()
376
+
377
+ for col in value_cols:
378
+ fig.add_trace(go.Scatter(
379
+ x=df_pd[time_col],
380
+ y=df_pd[col],
381
+ mode='lines+markers',
382
+ name=col,
383
+ hovertemplate=f'<b>{col}</b><br>Time: %{{x}}<br>Value: %{{y:.2f}}<extra></extra>'
384
+ ))
385
+
386
+ fig.update_layout(
387
+ title="Interactive Time Series",
388
+ xaxis_title=time_col,
389
+ yaxis_title="Value",
390
+ template="plotly_white",
391
+ height=600,
392
+ hovermode='x unified',
393
+ legend=dict(
394
+ orientation="h",
395
+ yanchor="bottom",
396
+ y=1.02,
397
+ xanchor="right",
398
+ x=1
399
+ )
400
+ )
401
+
402
+ # Add range slider
403
+ fig.update_xaxes(rangeslider_visible=True)
404
+
405
+ # Save
406
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
407
+ fig.write_html(output_path)
408
+
409
+ return {
410
+ "status": "success",
411
+ "plot_type": "interactive_time_series",
412
+ "output_path": output_path,
413
+ "time_col": time_col,
414
+ "value_cols": value_cols
415
+ }
416
+
417
+
418
+ def generate_plotly_dashboard(
419
+ file_path: str,
420
+ target_col: Optional[str] = None,
421
+ output_dir: str = "./outputs/plots/interactive"
422
+ ) -> Dict[str, Any]:
423
+ """
424
+ Generate a complete dashboard with multiple interactive plots.
425
+
426
+ Args:
427
+ file_path: Path to dataset
428
+ target_col: Optional target column for supervised analysis
429
+ output_dir: Directory to save all plots
430
+
431
+ Returns:
432
+ Dictionary with paths to all generated plots
433
+ """
434
+ # Validation
435
+ validate_file_exists(file_path)
436
+ validate_file_format(file_path)
437
+
438
+ df = load_dataframe(file_path)
439
+ validate_dataframe(df)
440
+
441
+ if target_col:
442
+ validate_column_exists(df, target_col)
443
+
444
+ numeric_cols = get_numeric_columns(df)
445
+ categorical_cols = get_categorical_columns(df)
446
+
447
+ plots_generated = []
448
+
449
+ # 1. Correlation heatmap
450
+ if len(numeric_cols) >= 2:
451
+ result = generate_interactive_correlation_heatmap(
452
+ file_path,
453
+ output_path=f"{output_dir}/correlation_heatmap.html"
454
+ )
455
+ if result["status"] == "success":
456
+ plots_generated.append(result)
457
+
458
+ # 2. Box plots for outliers
459
+ if len(numeric_cols) > 0:
460
+ result = generate_interactive_box_plots(
461
+ file_path,
462
+ columns=numeric_cols[:10], # Limit to 10 for performance
463
+ output_path=f"{output_dir}/box_plots.html"
464
+ )
465
+ if result["status"] == "success":
466
+ plots_generated.append(result)
467
+
468
+ # 3. Target variable analysis if provided
469
+ if target_col and target_col in numeric_cols:
470
+ # Scatter plots against target
471
+ for col in numeric_cols[:5]: # Top 5 features
472
+ if col != target_col:
473
+ result = generate_interactive_scatter(
474
+ file_path,
475
+ x_col=col,
476
+ y_col=target_col,
477
+ output_path=f"{output_dir}/scatter_{col}_vs_{target_col}.html"
478
+ )
479
+ if result["status"] == "success":
480
+ plots_generated.append(result)
481
+
482
+ # 4. Distribution plots for numeric features
483
+ for col in numeric_cols[:5]: # Top 5 features
484
+ result = generate_interactive_histogram(
485
+ file_path,
486
+ column=col,
487
+ output_path=f"{output_dir}/histogram_{col}.html"
488
+ )
489
+ if result["status"] == "success":
490
+ plots_generated.append(result)
491
+
492
+ return {
493
+ "status": "success",
494
+ "plots_generated": len(plots_generated),
495
+ "plots": plots_generated,
496
+ "output_dir": output_dir
497
+ }