ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,402 @@
1
+ """
2
+ Parallel Tool Execution with Dependency Detection
3
+
4
+ Enables concurrent execution of independent tools while respecting
5
+ dependencies and avoiding overwhelming system resources.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import Dict, List, Any, Set, Optional, Tuple, Callable
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ import time
13
+
14
+
15
+ class ToolWeight(Enum):
16
+ """Tool execution weight (resource intensity)."""
17
+ LIGHT = 1 # Fast operations (< 1s): profiling, validation
18
+ MEDIUM = 2 # Moderate operations (1-10s): cleaning, encoding
19
+ HEAVY = 3 # Expensive operations (> 10s): ML training, large viz
20
+
21
+
22
+ # Tool weight classification
23
+ TOOL_WEIGHTS = {
24
+ # Light tools (can run many in parallel)
25
+ "profile_dataset": ToolWeight.LIGHT,
26
+ "detect_data_quality_issues": ToolWeight.LIGHT,
27
+ "analyze_correlations": ToolWeight.LIGHT,
28
+ "get_smart_summary": ToolWeight.LIGHT,
29
+ "smart_type_inference": ToolWeight.LIGHT,
30
+
31
+ # Medium tools (limit 2-3 concurrent)
32
+ "clean_missing_values": ToolWeight.MEDIUM,
33
+ "handle_outliers": ToolWeight.MEDIUM,
34
+ "encode_categorical": ToolWeight.MEDIUM,
35
+ "create_time_features": ToolWeight.MEDIUM,
36
+ "create_interaction_features": ToolWeight.MEDIUM,
37
+ "create_ratio_features": ToolWeight.MEDIUM,
38
+ "create_statistical_features": ToolWeight.MEDIUM,
39
+ "generate_interactive_scatter": ToolWeight.MEDIUM,
40
+ "generate_interactive_histogram": ToolWeight.MEDIUM,
41
+ "generate_interactive_box_plots": ToolWeight.MEDIUM,
42
+ "generate_interactive_correlation_heatmap": ToolWeight.MEDIUM,
43
+
44
+ # Heavy tools (limit 1 concurrent) - NEVER RUN MULTIPLE HEAVY TOOLS IN PARALLEL
45
+ "train_baseline_models": ToolWeight.HEAVY,
46
+ "hyperparameter_tuning": ToolWeight.HEAVY,
47
+ "perform_cross_validation": ToolWeight.HEAVY,
48
+ "train_ensemble_models": ToolWeight.HEAVY,
49
+ "auto_ml_pipeline": ToolWeight.HEAVY,
50
+ "generate_ydata_profiling_report": ToolWeight.HEAVY,
51
+ "generate_combined_eda_report": ToolWeight.HEAVY,
52
+ "generate_plotly_dashboard": ToolWeight.HEAVY,
53
+ "execute_python_code": ToolWeight.HEAVY, # Unknown code complexity
54
+ "auto_feature_engineering": ToolWeight.HEAVY, # ML-based feature generation
55
+ }
56
+
57
+
58
+ @dataclass
59
+ class ToolExecution:
60
+ """Represents a tool execution task."""
61
+ tool_name: str
62
+ arguments: Dict[str, Any]
63
+ weight: ToolWeight
64
+ dependencies: Set[str] # Other tool names that must complete first
65
+ execution_id: str
66
+
67
+ def __hash__(self):
68
+ return hash(self.execution_id)
69
+
70
+
71
+ class ToolDependencyGraph:
72
+ """
73
+ Analyzes tool dependencies based on input/output files.
74
+
75
+ Detects dependencies like:
76
+ - clean_missing_values → encode_categorical (same file transformation)
77
+ - profile_dataset → train_baseline_models (uses profiling results)
78
+ - Multiple visualizations (can run in parallel)
79
+ """
80
+
81
+ def __init__(self):
82
+ self.graph: Dict[str, Set[str]] = {}
83
+
84
+ def detect_dependencies(self, executions: List[ToolExecution]) -> Dict[str, Set[str]]:
85
+ """
86
+ Detect dependencies between tool executions.
87
+
88
+ Rules:
89
+ 1. If tool B reads output of tool A → B depends on A
90
+ 2. If tools read/write same file → sequential execution
91
+ 3. If tools are independent (different files/ops) → parallel
92
+
93
+ Args:
94
+ executions: List of tool executions
95
+
96
+ Returns:
97
+ Dict mapping execution_id → set of execution_ids it depends on
98
+ """
99
+ dependencies: Dict[str, Set[str]] = {ex.execution_id: set() for ex in executions}
100
+
101
+ # Build file I/O map
102
+ file_producers: Dict[str, str] = {} # file_path → execution_id
103
+ file_consumers: Dict[str, List[str]] = {} # file_path → [execution_ids]
104
+
105
+ for ex in executions:
106
+ # Check input files
107
+ input_file = ex.arguments.get("file_path")
108
+ if input_file:
109
+ if input_file not in file_consumers:
110
+ file_consumers[input_file] = []
111
+ file_consumers[input_file].append(ex.execution_id)
112
+
113
+ # Check output files
114
+ output_file = ex.arguments.get("output_path") or ex.arguments.get("output_file")
115
+ if output_file:
116
+ file_producers[output_file] = ex.execution_id
117
+
118
+ # Detect dependencies: consumers depend on producers
119
+ for output_file, producer_id in file_producers.items():
120
+ if output_file in file_consumers:
121
+ for consumer_id in file_consumers[output_file]:
122
+ if consumer_id != producer_id:
123
+ dependencies[consumer_id].add(producer_id)
124
+
125
+ # Special rule: training tools depend on profiling/cleaning if they exist
126
+ training_tools = ["train_baseline_models", "hyperparameter_tuning", "train_ensemble_models"]
127
+ prep_tools = ["profile_dataset", "clean_missing_values", "encode_categorical"]
128
+
129
+ training_execs = [ex for ex in executions if ex.tool_name in training_tools]
130
+ prep_execs = [ex for ex in executions if ex.tool_name in prep_tools]
131
+
132
+ for train_ex in training_execs:
133
+ for prep_ex in prep_execs:
134
+ # Same file? Training depends on prep
135
+ if train_ex.arguments.get("file_path") == prep_ex.arguments.get("file_path"):
136
+ dependencies[train_ex.execution_id].add(prep_ex.execution_id)
137
+
138
+ return dependencies
139
+
140
+ def get_execution_batches(self, executions: List[ToolExecution]) -> List[List[ToolExecution]]:
141
+ """
142
+ Group executions into batches that can run in parallel.
143
+
144
+ Returns:
145
+ List of batches, where each batch contains independent tools
146
+ """
147
+ dependencies = self.detect_dependencies(executions)
148
+
149
+ # Topological sort to get execution order
150
+ batches: List[List[ToolExecution]] = []
151
+ completed: Set[str] = set()
152
+ remaining = {ex.execution_id: ex for ex in executions}
153
+
154
+ while remaining:
155
+ # Find all tools with satisfied dependencies
156
+ ready = []
157
+ for exec_id, ex in remaining.items():
158
+ deps = dependencies[exec_id]
159
+ if deps.issubset(completed):
160
+ ready.append(ex)
161
+
162
+ if not ready:
163
+ # Circular dependency or error - add remaining as single batch
164
+ print("⚠️ Warning: Possible circular dependency detected")
165
+ batches.append(list(remaining.values()))
166
+ break
167
+
168
+ # Add ready tools as a batch
169
+ batches.append(ready)
170
+
171
+ # Mark as completed
172
+ for ex in ready:
173
+ completed.add(ex.execution_id)
174
+ del remaining[ex.execution_id]
175
+
176
+ return batches
177
+
178
+
179
+ class ParallelToolExecutor:
180
+ """
181
+ Executes tools in parallel while respecting dependencies and resource limits.
182
+
183
+ Features:
184
+ - Automatic dependency detection
185
+ - Weight-based resource management (limit heavy tools)
186
+ - Progress reporting for parallel executions
187
+ - Error isolation (one tool failure doesn't crash others)
188
+ """
189
+
190
+ def __init__(self, max_heavy_concurrent: int = 1, max_medium_concurrent: int = 2,
191
+ max_light_concurrent: int = 5):
192
+ """
193
+ Initialize parallel executor.
194
+
195
+ Args:
196
+ max_heavy_concurrent: Max heavy tools running simultaneously
197
+ max_medium_concurrent: Max medium tools running simultaneously
198
+ max_light_concurrent: Max light tools running simultaneously
199
+ """
200
+ self.max_heavy = max_heavy_concurrent
201
+ self.max_medium = max_medium_concurrent
202
+ self.max_light = max_light_concurrent
203
+
204
+ # Semaphores for resource control
205
+ self.heavy_semaphore = asyncio.Semaphore(max_heavy_concurrent)
206
+ self.medium_semaphore = asyncio.Semaphore(max_medium_concurrent)
207
+ self.light_semaphore = asyncio.Semaphore(max_light_concurrent)
208
+
209
+ self.dependency_graph = ToolDependencyGraph()
210
+
211
+ print(f"⚡ Parallel Executor initialized:")
212
+ print(f" Heavy tools: {max_heavy_concurrent} concurrent")
213
+ print(f" Medium tools: {max_medium_concurrent} concurrent")
214
+ print(f" Light tools: {max_light_concurrent} concurrent")
215
+
216
+ def _get_semaphore(self, weight: ToolWeight) -> asyncio.Semaphore:
217
+ """Get appropriate semaphore for tool weight."""
218
+ if weight == ToolWeight.HEAVY:
219
+ return self.heavy_semaphore
220
+ elif weight == ToolWeight.MEDIUM:
221
+ return self.medium_semaphore
222
+ else:
223
+ return self.light_semaphore
224
+
225
+ async def _execute_single(self, execution: ToolExecution,
226
+ execute_func: Callable,
227
+ progress_callback: Optional[Callable] = None) -> Dict[str, Any]:
228
+ """
229
+ Execute a single tool with resource management.
230
+
231
+ Args:
232
+ execution: Tool execution details
233
+ execute_func: Function to execute tool (sync)
234
+ progress_callback: Optional callback for progress updates
235
+
236
+ Returns:
237
+ Execution result
238
+ """
239
+ semaphore = self._get_semaphore(execution.weight)
240
+
241
+ async with semaphore:
242
+ if progress_callback:
243
+ await progress_callback(f"⚡ Executing {execution.tool_name}", "start")
244
+
245
+ start_time = time.time()
246
+
247
+ try:
248
+ # Run sync function in executor to avoid blocking
249
+ loop = asyncio.get_event_loop()
250
+ result = await loop.run_in_executor(
251
+ None,
252
+ execute_func,
253
+ execution.tool_name,
254
+ execution.arguments
255
+ )
256
+
257
+ duration = time.time() - start_time
258
+
259
+ if progress_callback:
260
+ await progress_callback(
261
+ f"✅ {execution.tool_name} completed ({duration:.1f}s)",
262
+ "complete"
263
+ )
264
+
265
+ return {
266
+ "execution_id": execution.execution_id,
267
+ "tool_name": execution.tool_name,
268
+ "success": True,
269
+ "result": result,
270
+ "duration": duration
271
+ }
272
+
273
+ except Exception as e:
274
+ duration = time.time() - start_time
275
+
276
+ if progress_callback:
277
+ await progress_callback(
278
+ f"❌ {execution.tool_name} failed: {str(e)[:100]}",
279
+ "error"
280
+ )
281
+
282
+ return {
283
+ "execution_id": execution.execution_id,
284
+ "tool_name": execution.tool_name,
285
+ "success": False,
286
+ "error": str(e),
287
+ "duration": duration
288
+ }
289
+
290
+ async def execute_batch(self, batch: List[ToolExecution],
291
+ execute_func: Callable,
292
+ progress_callback: Optional[Callable] = None) -> List[Dict[str, Any]]:
293
+ """
294
+ Execute a batch of independent tools in parallel.
295
+
296
+ Args:
297
+ batch: List of tool executions (no dependencies between them)
298
+ execute_func: Sync function to execute tools
299
+ progress_callback: Optional progress callback
300
+
301
+ Returns:
302
+ List of execution results
303
+ """
304
+ print(f"⚡ Parallel batch: {len(batch)} tools")
305
+ for ex in batch:
306
+ print(f" - {ex.tool_name} ({ex.weight.name})")
307
+
308
+ # Execute all in parallel
309
+ tasks = [
310
+ self._execute_single(ex, execute_func, progress_callback)
311
+ for ex in batch
312
+ ]
313
+
314
+ results = await asyncio.gather(*tasks, return_exceptions=True)
315
+
316
+ # Handle exceptions
317
+ processed_results = []
318
+ for i, result in enumerate(results):
319
+ if isinstance(result, Exception):
320
+ processed_results.append({
321
+ "execution_id": batch[i].execution_id,
322
+ "tool_name": batch[i].tool_name,
323
+ "success": False,
324
+ "error": str(result)
325
+ })
326
+ else:
327
+ processed_results.append(result)
328
+
329
+ return processed_results
330
+
331
+ async def execute_all(self, executions: List[ToolExecution],
332
+ execute_func: Callable,
333
+ progress_callback: Optional[Callable] = None) -> List[Dict[str, Any]]:
334
+ """
335
+ Execute all tools with automatic dependency resolution and parallelization.
336
+
337
+ Args:
338
+ executions: List of all tool executions
339
+ execute_func: Sync function to execute tools
340
+ progress_callback: Optional progress callback
341
+
342
+ Returns:
343
+ List of all execution results in order
344
+ """
345
+ if not executions:
346
+ return []
347
+
348
+ # Get execution batches (respecting dependencies)
349
+ batches = self.dependency_graph.get_execution_batches(executions)
350
+
351
+ print(f"⚡ Execution plan: {len(batches)} batches for {len(executions)} tools")
352
+
353
+ all_results = []
354
+
355
+ for i, batch in enumerate(batches):
356
+ print(f"\n📦 Batch {i+1}/{len(batches)}")
357
+ batch_results = await self.execute_batch(batch, execute_func, progress_callback)
358
+ all_results.extend(batch_results)
359
+
360
+ return all_results
361
+
362
+ def classify_tools(self, tool_calls: List[Dict[str, Any]]) -> List[ToolExecution]:
363
+ """
364
+ Convert tool calls to ToolExecution objects with weights.
365
+
366
+ Args:
367
+ tool_calls: List of tool calls from LLM
368
+
369
+ Returns:
370
+ List of ToolExecution objects
371
+ """
372
+ executions = []
373
+
374
+ for i, call in enumerate(tool_calls):
375
+ tool_name = call.get("name") or call.get("tool_name")
376
+ arguments = call.get("arguments", {})
377
+
378
+ # Get weight
379
+ weight = TOOL_WEIGHTS.get(tool_name, ToolWeight.MEDIUM)
380
+
381
+ execution = ToolExecution(
382
+ tool_name=tool_name,
383
+ arguments=arguments,
384
+ weight=weight,
385
+ dependencies=set(), # Will be computed by dependency graph
386
+ execution_id=f"{tool_name}_{i}"
387
+ )
388
+
389
+ executions.append(execution)
390
+
391
+ return executions
392
+
393
+
394
+ # Global parallel executor
395
+ _parallel_executor = None
396
+
397
+ def get_parallel_executor() -> ParallelToolExecutor:
398
+ """Get or create global parallel executor."""
399
+ global _parallel_executor
400
+ if _parallel_executor is None:
401
+ _parallel_executor = ParallelToolExecutor()
402
+ return _parallel_executor
@@ -0,0 +1,248 @@
1
+ """
2
+ Polars utility functions for data manipulation.
3
+ """
4
+
5
+ import polars as pl
6
+ from typing import List, Dict, Any, Optional
7
+
8
+
9
+ def load_dataframe(file_path: str) -> pl.DataFrame:
10
+ """
11
+ Load a dataframe from CSV or Parquet file.
12
+
13
+ Args:
14
+ file_path: Path to file
15
+
16
+ Returns:
17
+ Polars DataFrame
18
+ """
19
+ if file_path.endswith('.parquet'):
20
+ return pl.read_parquet(file_path)
21
+ elif file_path.endswith('.csv'):
22
+ # Use longer schema inference to handle mixed types better
23
+ # and ignore errors to handle problematic rows gracefully
24
+ return pl.read_csv(
25
+ file_path,
26
+ try_parse_dates=True,
27
+ infer_schema_length=10000, # Scan more rows for better type inference
28
+ ignore_errors=True # Skip problematic rows instead of failing
29
+ )
30
+ else:
31
+ raise ValueError(f"Unsupported file format: {file_path}")
32
+
33
+
34
+ def save_dataframe(df: pl.DataFrame, file_path: str) -> None:
35
+ """
36
+ Save dataframe to CSV or Parquet file.
37
+
38
+ Args:
39
+ df: Polars DataFrame
40
+ file_path: Output path
41
+ """
42
+ if file_path.endswith('.parquet'):
43
+ df.write_parquet(file_path)
44
+ elif file_path.endswith('.csv'):
45
+ df.write_csv(file_path)
46
+ else:
47
+ raise ValueError(f"Unsupported file format: {file_path}")
48
+
49
+
50
+ def get_numeric_columns(df: pl.DataFrame) -> List[str]:
51
+ """
52
+ Get list of numeric column names.
53
+
54
+ Args:
55
+ df: Polars DataFrame
56
+
57
+ Returns:
58
+ List of numeric column names
59
+ """
60
+ return [col for col in df.columns if df[col].dtype in pl.NUMERIC_DTYPES]
61
+
62
+
63
+ def get_categorical_columns(df: pl.DataFrame) -> List[str]:
64
+ """
65
+ Get list of categorical/string column names.
66
+
67
+ Args:
68
+ df: Polars DataFrame
69
+
70
+ Returns:
71
+ List of categorical column names
72
+ """
73
+ return [col for col in df.columns if df[col].dtype in [pl.Utf8, pl.Categorical]]
74
+
75
+
76
+ def get_datetime_columns(df: pl.DataFrame) -> List[str]:
77
+ """
78
+ Get list of datetime column names.
79
+
80
+ Args:
81
+ df: Polars DataFrame
82
+
83
+ Returns:
84
+ List of datetime column names
85
+ """
86
+ return [col for col in df.columns if df[col].dtype in [pl.Date, pl.Datetime]]
87
+
88
+
89
+ def detect_id_columns(df: pl.DataFrame) -> List[str]:
90
+ """
91
+ Detect columns that are likely IDs (unique values, low information content).
92
+
93
+ Args:
94
+ df: Polars DataFrame
95
+
96
+ Returns:
97
+ List of likely ID column names
98
+ """
99
+ id_columns = []
100
+
101
+ for col in df.columns:
102
+ # Check if column name suggests it's an ID
103
+ col_lower = col.lower()
104
+ if any(id_term in col_lower for id_term in ['id', '_id', 'key', 'index']):
105
+ id_columns.append(col)
106
+ continue
107
+
108
+ # Check if column has mostly unique values (>95% unique)
109
+ n_unique = df[col].n_unique()
110
+ n_total = len(df)
111
+ if n_total > 0 and (n_unique / n_total) > 0.95:
112
+ id_columns.append(col)
113
+
114
+ return id_columns
115
+
116
+
117
+ def safe_cast_numeric(df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
118
+ """
119
+ Safely cast columns to numeric, handling errors gracefully.
120
+
121
+ Args:
122
+ df: Polars DataFrame
123
+ columns: List of columns to cast
124
+
125
+ Returns:
126
+ DataFrame with columns cast to numeric where possible
127
+ """
128
+ result = df.clone()
129
+
130
+ for col in columns:
131
+ try:
132
+ result = result.with_columns(
133
+ pl.col(col).cast(pl.Float64).alias(col)
134
+ )
135
+ except Exception:
136
+ # If casting fails, keep original column
137
+ pass
138
+
139
+ return result
140
+
141
+
142
+ def get_column_info(df: pl.DataFrame, col: str) -> Dict[str, Any]:
143
+ """
144
+ Get comprehensive information about a column.
145
+
146
+ Args:
147
+ df: Polars DataFrame
148
+ col: Column name
149
+
150
+ Returns:
151
+ Dictionary with column statistics
152
+ """
153
+ col_data = df[col]
154
+
155
+ info = {
156
+ "name": col,
157
+ "dtype": str(col_data.dtype),
158
+ "null_count": col_data.null_count(),
159
+ "null_percentage": round(col_data.null_count() / len(df) * 100, 2),
160
+ "unique_count": col_data.n_unique(),
161
+ "unique_percentage": round(col_data.n_unique() / len(df) * 100, 2),
162
+ }
163
+
164
+ # Add numeric-specific stats
165
+ if col_data.dtype in pl.NUMERIC_DTYPES:
166
+ info.update({
167
+ "mean": float(col_data.mean()) if col_data.mean() is not None else None,
168
+ "std": float(col_data.std()) if col_data.std() is not None else None,
169
+ "min": float(col_data.min()) if col_data.min() is not None else None,
170
+ "max": float(col_data.max()) if col_data.max() is not None else None,
171
+ "median": float(col_data.median()) if col_data.median() is not None else None,
172
+ })
173
+
174
+ # Add categorical-specific stats
175
+ if col_data.dtype in [pl.Utf8, pl.Categorical]:
176
+ value_counts = col_data.value_counts().limit(5)
177
+ info["top_values"] = [
178
+ {"value": str(row[0]), "count": int(row[1])}
179
+ for row in value_counts.iter_rows()
180
+ ]
181
+
182
+ return info
183
+
184
+
185
+ def calculate_memory_usage(df: pl.DataFrame) -> Dict[str, Any]:
186
+ """
187
+ Calculate memory usage of dataframe.
188
+
189
+ Args:
190
+ df: Polars DataFrame
191
+
192
+ Returns:
193
+ Dictionary with memory usage statistics
194
+ """
195
+ total_bytes = df.estimated_size()
196
+
197
+ return {
198
+ "total_mb": round(total_bytes / (1024 * 1024), 2),
199
+ "total_bytes": total_bytes,
200
+ "rows": len(df),
201
+ "columns": len(df.columns),
202
+ "bytes_per_row": round(total_bytes / len(df), 2) if len(df) > 0 else 0,
203
+ }
204
+
205
+
206
+ def split_features_target(df: pl.DataFrame, target_col: str) -> tuple:
207
+ """
208
+ Split dataframe into features and target.
209
+
210
+ Args:
211
+ df: Polars DataFrame
212
+ target_col: Name of target column
213
+
214
+ Returns:
215
+ Tuple of (X, y) where X is features and y is target
216
+ """
217
+ if target_col not in df.columns:
218
+ raise ValueError(f"Target column '{target_col}' not found in dataframe")
219
+
220
+ X = df.drop(target_col)
221
+ y = df[target_col]
222
+
223
+ return X, y
224
+
225
+
226
+ def remove_low_variance_features(df: pl.DataFrame, threshold: float = 0.01) -> pl.DataFrame:
227
+ """
228
+ Remove features with low variance.
229
+
230
+ Args:
231
+ df: Polars DataFrame
232
+ threshold: Variance threshold (default 0.01)
233
+
234
+ Returns:
235
+ DataFrame with low variance columns removed
236
+ """
237
+ numeric_cols = get_numeric_columns(df)
238
+
239
+ cols_to_keep = []
240
+ for col in numeric_cols:
241
+ variance = df[col].var()
242
+ if variance is not None and variance > threshold:
243
+ cols_to_keep.append(col)
244
+
245
+ # Keep non-numeric columns
246
+ non_numeric_cols = [col for col in df.columns if col not in numeric_cols]
247
+
248
+ return df.select(cols_to_keep + non_numeric_cols)