ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,593 @@
1
+ """
2
+ Data Profiling Tools
3
+ Tools for analyzing and understanding dataset characteristics.
4
+ """
5
+
6
+ import polars as pl
7
+ import numpy as np
8
+ from typing import Dict, Any, List, Optional
9
+ from pathlib import Path
10
+ import sys
11
+ import os
12
+
13
+ # Add parent directory to path for imports
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from ds_agent.utils.polars_helpers import (
17
+ load_dataframe,
18
+ get_numeric_columns,
19
+ get_categorical_columns,
20
+ get_datetime_columns,
21
+ get_column_info,
22
+ calculate_memory_usage,
23
+ detect_id_columns,
24
+ )
25
+ from ds_agent.utils.validation import (
26
+ validate_file_exists,
27
+ validate_file_format,
28
+ validate_dataframe,
29
+ )
30
+
31
+
32
+ def profile_dataset(file_path: str) -> Dict[str, Any]:
33
+ """
34
+ Get comprehensive statistics about a dataset.
35
+
36
+ Args:
37
+ file_path: Path to CSV or Parquet file
38
+
39
+ Returns:
40
+ Dictionary with dataset profile including:
41
+ - shape (rows, columns)
42
+ - column types
43
+ - memory usage
44
+ - null counts
45
+ - unique values
46
+ - missing value percentage per column (NEW)
47
+ - unique value counts per column (NEW)
48
+ - basic statistics for each column
49
+ """
50
+ # Validation
51
+ validate_file_exists(file_path)
52
+ validate_file_format(file_path)
53
+
54
+ # Load data
55
+ df = load_dataframe(file_path)
56
+ validate_dataframe(df)
57
+
58
+ # Basic info
59
+ profile = {
60
+ "file_path": file_path,
61
+ "shape": {
62
+ "rows": len(df),
63
+ "columns": len(df.columns)
64
+ },
65
+ "memory_usage": calculate_memory_usage(df),
66
+ "column_types": {
67
+ "numeric": get_numeric_columns(df),
68
+ "categorical": get_categorical_columns(df),
69
+ "datetime": get_datetime_columns(df),
70
+ "id_columns": detect_id_columns(df),
71
+ },
72
+ "columns": {},
73
+ "missing_values_per_column": {}, # NEW: Per-column missing %
74
+ "unique_counts_per_column": {} # NEW: Per-column unique counts
75
+ }
76
+
77
+ # Per-column statistics with enhanced missing % and unique counts
78
+ for col in df.columns:
79
+ # Get existing column info
80
+ profile["columns"][col] = get_column_info(df, col)
81
+
82
+ # NEW: Calculate missing value percentage for this column
83
+ null_count = df[col].null_count()
84
+ missing_pct = round((null_count / len(df)) * 100, 2) if len(df) > 0 else 0
85
+ profile["missing_values_per_column"][col] = {
86
+ "count": int(null_count),
87
+ "percentage": missing_pct
88
+ }
89
+
90
+ # NEW: Calculate unique value counts (with dict handling)
91
+ try:
92
+ # Try to get unique count directly
93
+ unique_count = df[col].n_unique()
94
+ profile["unique_counts_per_column"][col] = int(unique_count)
95
+ except Exception as e:
96
+ # If column contains unhashable types (dicts, lists), handle gracefully
97
+ try:
98
+ # Convert to string and then count unique
99
+ unique_count = df[col].cast(pl.Utf8).n_unique()
100
+ profile["unique_counts_per_column"][col] = int(unique_count)
101
+ except:
102
+ profile["unique_counts_per_column"][col] = "N/A (unhashable type)"
103
+
104
+ # Overall statistics
105
+ total_nulls = sum(df[col].null_count() for col in df.columns)
106
+ total_cells = len(df) * len(df.columns)
107
+
108
+ profile["overall_stats"] = {
109
+ "total_cells": total_cells,
110
+ "total_nulls": total_nulls,
111
+ "null_percentage": round(total_nulls / total_cells * 100, 2) if total_cells > 0 else 0,
112
+ "duplicate_rows": df.is_duplicated().sum(),
113
+ "duplicate_percentage": round(df.is_duplicated().sum() / len(df) * 100, 2) if len(df) > 0 else 0,
114
+ }
115
+
116
+ return profile
117
+
118
+
119
+ def get_smart_summary(file_path: str, n_samples: int = 30) -> Dict[str, Any]:
120
+ """
121
+ Enhanced data summary with missing %, unique counts, and safe dict handling.
122
+
123
+ This function provides a smarter, more LLM-friendly summary compared to profile_dataset().
124
+ It includes per-column missing percentages, unique value counts, and handles
125
+ dictionary columns gracefully (converts to strings to avoid hashing errors).
126
+
127
+ Args:
128
+ file_path: Path to CSV or Parquet file
129
+ n_samples: Number of sample rows to include (default: 30)
130
+
131
+ Returns:
132
+ Dictionary with comprehensive smart summary including:
133
+ - Basic shape info
134
+ - Column data types
135
+ - Missing value percentage by column (sorted by % descending)
136
+ - Unique value counts by column
137
+ - First N sample rows
138
+ - Descriptive statistics for numeric columns
139
+ - Safe handling of dictionary/unhashable columns
140
+
141
+ Example:
142
+ >>> summary = get_smart_summary("data.csv")
143
+ >>> print(summary["missing_summary"])
144
+ >>> # Output: [("col_A", 45.2), ("col_B", 12.3), ...]
145
+ """
146
+ # Validation
147
+ validate_file_exists(file_path)
148
+ validate_file_format(file_path)
149
+
150
+ # Load data
151
+ df = load_dataframe(file_path)
152
+ validate_dataframe(df)
153
+
154
+ # Convert dictionary-type columns to strings (prevents unhashable dict errors)
155
+ for col in df.columns:
156
+ try:
157
+ # Try to detect if column might contain dicts/lists
158
+ sample = df[col].drop_nulls().head(5)
159
+ if len(sample) > 0:
160
+ first_val = sample[0]
161
+ # Check if it's a complex type
162
+ if isinstance(first_val, (dict, list)):
163
+ df = df.with_columns(pl.col(col).cast(pl.Utf8).alias(col))
164
+ except:
165
+ # If any error, just continue
166
+ pass
167
+
168
+ # Calculate missing value statistics (sorted by % descending)
169
+ missing_stats = []
170
+ for col in df.columns:
171
+ null_count = df[col].null_count()
172
+ null_pct = round((null_count / len(df)) * 100, 2) if len(df) > 0 else 0
173
+ missing_stats.append({
174
+ "column": col,
175
+ "count": int(null_count),
176
+ "percentage": null_pct
177
+ })
178
+
179
+ # Sort by percentage descending
180
+ missing_stats.sort(key=lambda x: x["percentage"], reverse=True)
181
+
182
+ # Calculate unique value counts
183
+ unique_counts = {}
184
+ for col in df.columns:
185
+ try:
186
+ unique_count = df[col].n_unique()
187
+ unique_counts[col] = int(unique_count)
188
+ except:
189
+ # Fallback for unhashable types
190
+ try:
191
+ unique_count = df[col].cast(pl.Utf8).n_unique()
192
+ unique_counts[col] = int(unique_count)
193
+ except:
194
+ unique_counts[col] = "N/A"
195
+
196
+ # Get column data types
197
+ column_types = {col: str(df[col].dtype) for col in df.columns}
198
+
199
+ # Get sample rows (first n_samples)
200
+ sample_data = df.head(n_samples).to_dicts()
201
+
202
+ # Get descriptive statistics for numeric columns
203
+ numeric_cols = get_numeric_columns(df)
204
+ numeric_stats = {}
205
+
206
+ if numeric_cols:
207
+ df_numeric = df.select(numeric_cols)
208
+ # Convert to pandas for describe() functionality
209
+ df_pd = df_numeric.to_pandas()
210
+ stats_df = df_pd.describe()
211
+ numeric_stats = stats_df.to_dict()
212
+
213
+ # Build comprehensive summary
214
+ summary = {
215
+ "file_path": file_path,
216
+ "shape": {
217
+ "rows": len(df),
218
+ "columns": len(df.columns)
219
+ },
220
+ "column_types": column_types,
221
+ "missing_summary": missing_stats, # Sorted by % descending
222
+ "unique_counts": unique_counts,
223
+ "sample_data": sample_data,
224
+ "numeric_statistics": numeric_stats,
225
+ "memory_usage_mb": calculate_memory_usage(df),
226
+ "summary_notes": []
227
+ }
228
+
229
+ # Add helpful notes for LLM
230
+ high_missing_cols = [item for item in missing_stats if item["percentage"] > 40]
231
+ if high_missing_cols:
232
+ summary["summary_notes"].append(
233
+ f"{len(high_missing_cols)} column(s) have >40% missing values (consider dropping)"
234
+ )
235
+
236
+ high_cardinality_cols = [col for col, count in unique_counts.items()
237
+ if isinstance(count, int) and count > len(df) * 0.5]
238
+ if high_cardinality_cols:
239
+ summary["summary_notes"].append(
240
+ f"{len(high_cardinality_cols)} column(s) have very high cardinality (>50% unique values)"
241
+ )
242
+
243
+ return summary
244
+
245
+
246
+ def detect_data_quality_issues(file_path: str) -> Dict[str, Any]:
247
+ """
248
+ Detect data quality issues in the dataset.
249
+
250
+ Args:
251
+ file_path: Path to CSV or Parquet file
252
+
253
+ Returns:
254
+ Dictionary with detected issues organized by severity:
255
+ - critical: Issues that will break model training
256
+ - warning: Issues that may affect model performance
257
+ - info: Observations that may be relevant
258
+ """
259
+ # Validation
260
+ validate_file_exists(file_path)
261
+ validate_file_format(file_path)
262
+
263
+ # Load data
264
+ df = load_dataframe(file_path)
265
+ validate_dataframe(df)
266
+
267
+ issues = {
268
+ "critical": [],
269
+ "warning": [],
270
+ "info": []
271
+ }
272
+
273
+ # Check for completely null columns
274
+ for col in df.columns:
275
+ null_count = df[col].null_count()
276
+ null_pct = (null_count / len(df)) * 100
277
+
278
+ if null_count == len(df):
279
+ issues["critical"].append({
280
+ "type": "all_null_column",
281
+ "column": col,
282
+ "message": f"Column '{col}' has all null values"
283
+ })
284
+ elif null_pct > 50:
285
+ issues["warning"].append({
286
+ "type": "high_null_percentage",
287
+ "column": col,
288
+ "null_percentage": round(null_pct, 2),
289
+ "message": f"Column '{col}' has {round(null_pct, 2)}% null values"
290
+ })
291
+ elif null_pct > 10:
292
+ issues["info"].append({
293
+ "type": "moderate_null_percentage",
294
+ "column": col,
295
+ "null_percentage": round(null_pct, 2),
296
+ "message": f"Column '{col}' has {round(null_pct, 2)}% null values"
297
+ })
298
+
299
+ # Check for duplicate rows
300
+ dup_count = df.is_duplicated().sum()
301
+ if dup_count > 0:
302
+ dup_pct = (dup_count / len(df)) * 100
303
+ severity = "warning" if dup_pct > 10 else "info"
304
+ issues[severity].append({
305
+ "type": "duplicate_rows",
306
+ "count": int(dup_count),
307
+ "percentage": round(dup_pct, 2),
308
+ "message": f"Dataset has {dup_count} duplicate rows ({round(dup_pct, 2)}%)"
309
+ })
310
+
311
+ # Check for outliers in numeric columns using IQR method
312
+ numeric_cols = get_numeric_columns(df)
313
+ for col in numeric_cols:
314
+ col_data = df[col].drop_nulls()
315
+ if len(col_data) == 0:
316
+ continue
317
+
318
+ q1 = col_data.quantile(0.25)
319
+ q3 = col_data.quantile(0.75)
320
+ iqr = q3 - q1
321
+
322
+ lower_bound = q1 - 1.5 * iqr
323
+ upper_bound = q3 + 1.5 * iqr
324
+
325
+ outliers = ((col_data < lower_bound) | (col_data > upper_bound)).sum()
326
+
327
+ if outliers > 0:
328
+ outlier_pct = (outliers / len(col_data)) * 100
329
+ if outlier_pct > 10:
330
+ issues["warning"].append({
331
+ "type": "outliers",
332
+ "column": col,
333
+ "count": int(outliers),
334
+ "percentage": round(outlier_pct, 2),
335
+ "bounds": {"lower": float(lower_bound), "upper": float(upper_bound)},
336
+ "message": f"Column '{col}' has {outliers} outliers ({round(outlier_pct, 2)}%)"
337
+ })
338
+ elif outlier_pct > 1:
339
+ issues["info"].append({
340
+ "type": "outliers",
341
+ "column": col,
342
+ "count": int(outliers),
343
+ "percentage": round(outlier_pct, 2),
344
+ "bounds": {"lower": float(lower_bound), "upper": float(upper_bound)},
345
+ "message": f"Column '{col}' has {outliers} outliers ({round(outlier_pct, 2)}%)"
346
+ })
347
+
348
+ # Check for high cardinality in categorical columns
349
+ categorical_cols = get_categorical_columns(df)
350
+ for col in categorical_cols:
351
+ n_unique = df[col].n_unique()
352
+ cardinality_pct = (n_unique / len(df)) * 100
353
+
354
+ if n_unique > 100 and cardinality_pct > 50:
355
+ issues["warning"].append({
356
+ "type": "high_cardinality",
357
+ "column": col,
358
+ "unique_values": int(n_unique),
359
+ "percentage": round(cardinality_pct, 2),
360
+ "message": f"Column '{col}' has very high cardinality ({n_unique} unique values, {round(cardinality_pct, 2)}%)"
361
+ })
362
+
363
+ # Check for constant columns (single unique value)
364
+ for col in df.columns:
365
+ n_unique = df[col].n_unique()
366
+ if n_unique == 1:
367
+ issues["warning"].append({
368
+ "type": "constant_column",
369
+ "column": col,
370
+ "message": f"Column '{col}' has only one unique value (constant)"
371
+ })
372
+
373
+ # Check for imbalanced datasets (for potential target columns)
374
+ for col in df.columns:
375
+ col_data = df[col]
376
+ n_unique = col_data.n_unique()
377
+
378
+ # Check if this could be a target column (2-20 unique values)
379
+ if 2 <= n_unique <= 20:
380
+ value_counts = col_data.value_counts()
381
+ if len(value_counts) >= 2:
382
+ max_count = value_counts[value_counts.columns[1]][0]
383
+ max_pct = (max_count / len(df)) * 100
384
+
385
+ if max_pct > 90:
386
+ issues["warning"].append({
387
+ "type": "class_imbalance",
388
+ "column": col,
389
+ "dominant_class_percentage": round(max_pct, 2),
390
+ "message": f"Column '{col}' may be imbalanced (dominant class: {round(max_pct, 2)}%)"
391
+ })
392
+
393
+ # Summary
394
+ issues["summary"] = {
395
+ "total_issues": len(issues["critical"]) + len(issues["warning"]) + len(issues["info"]),
396
+ "critical_count": len(issues["critical"]),
397
+ "warning_count": len(issues["warning"]),
398
+ "info_count": len(issues["info"])
399
+ }
400
+
401
+ return issues
402
+
403
+
404
+ def analyze_correlations(file_path: str, target: Optional[str] = None) -> Dict[str, Any]:
405
+ """
406
+ Analyze correlations between features.
407
+
408
+ Args:
409
+ file_path: Path to CSV or Parquet file
410
+ target: Optional target column to analyze correlations with
411
+
412
+ Returns:
413
+ Dictionary with correlation analysis including:
414
+ - correlation matrix (for numeric columns)
415
+ - top correlations with target (if specified)
416
+ - highly correlated feature pairs
417
+ """
418
+ # Validation
419
+ validate_file_exists(file_path)
420
+ validate_file_format(file_path)
421
+
422
+ # Load data
423
+ df = load_dataframe(file_path)
424
+ validate_dataframe(df)
425
+
426
+ numeric_cols = get_numeric_columns(df)
427
+
428
+ if len(numeric_cols) < 2:
429
+ return {
430
+ "error": "Dataset must have at least 2 numeric columns for correlation analysis",
431
+ "numeric_columns_found": len(numeric_cols)
432
+ }
433
+
434
+ # Select only numeric columns for correlation
435
+ df_numeric = df.select(numeric_cols)
436
+
437
+ # Calculate correlation matrix using pandas (Polars doesn't have native corr yet)
438
+ df_pd = df_numeric.to_pandas()
439
+ corr_matrix = df_pd.corr()
440
+
441
+ result = {
442
+ "numeric_columns": numeric_cols,
443
+ "correlation_matrix": corr_matrix.to_dict()
444
+ }
445
+
446
+ # Find highly correlated pairs (excluding diagonal)
447
+ high_corr_pairs = []
448
+ for i in range(len(corr_matrix.columns)):
449
+ for j in range(i + 1, len(corr_matrix.columns)):
450
+ col1 = corr_matrix.columns[i]
451
+ col2 = corr_matrix.columns[j]
452
+ corr_value = corr_matrix.iloc[i, j]
453
+
454
+ if abs(corr_value) > 0.7: # High correlation threshold
455
+ high_corr_pairs.append({
456
+ "feature_1": col1,
457
+ "feature_2": col2,
458
+ "correlation": round(float(corr_value), 4)
459
+ })
460
+
461
+ # Sort by absolute correlation
462
+ high_corr_pairs.sort(key=lambda x: abs(x["correlation"]), reverse=True)
463
+ result["high_correlations"] = high_corr_pairs
464
+
465
+ # If target specified, show top correlations with target
466
+ if target:
467
+ if target not in df.columns:
468
+ result["target_correlations_error"] = f"Target column '{target}' not found"
469
+ elif target not in numeric_cols:
470
+ result["target_correlations_error"] = f"Target column '{target}' is not numeric"
471
+ else:
472
+ target_corrs = []
473
+ for col in numeric_cols:
474
+ if col != target:
475
+ corr_value = corr_matrix.loc[target, col]
476
+ target_corrs.append({
477
+ "feature": col,
478
+ "correlation": round(float(corr_value), 4)
479
+ })
480
+
481
+ # Sort by absolute correlation
482
+ target_corrs.sort(key=lambda x: abs(x["correlation"]), reverse=True)
483
+ result["target_correlations"] = {
484
+ "target": target,
485
+ "top_features": target_corrs[:20] # Top 20
486
+ }
487
+
488
+ return result
489
+
490
+
491
+ def detect_label_errors(
492
+ file_path: str,
493
+ target_col: str,
494
+ features: Optional[List[str]] = None,
495
+ n_folds: int = 5,
496
+ output_path: Optional[str] = None
497
+ ) -> Dict[str, Any]:
498
+ """
499
+ Detect potential label errors in a classification dataset using cleanlab.
500
+
501
+ Uses confident learning to find mislabeled examples by:
502
+ 1. Training cross-validated classifiers
503
+ 2. Computing out-of-sample predicted probabilities
504
+ 3. Identifying labels that disagree with model predictions
505
+
506
+ Args:
507
+ file_path: Path to dataset
508
+ target_col: Target/label column name
509
+ features: Feature columns to use (None = all numeric)
510
+ n_folds: Number of cross-validation folds
511
+ output_path: Optional path to save flagged rows
512
+
513
+ Returns:
514
+ Dictionary with label error analysis results
515
+ """
516
+ try:
517
+ from cleanlab.classification import CleanLearning
518
+ except ImportError:
519
+ return {
520
+ 'status': 'error',
521
+ 'message': 'cleanlab not installed. Install with: pip install cleanlab>=2.6'
522
+ }
523
+
524
+ from sklearn.linear_model import LogisticRegression
525
+ from sklearn.preprocessing import LabelEncoder
526
+
527
+ validate_file_exists(file_path)
528
+ validate_file_format(file_path)
529
+
530
+ df = load_dataframe(file_path)
531
+ validate_dataframe(df)
532
+ validate_column_exists(df, target_col)
533
+
534
+ print(f"🔍 Detecting label errors in '{target_col}' using cleanlab...")
535
+
536
+ # Get features
537
+ if features is None:
538
+ features = get_numeric_columns(df)
539
+ features = [f for f in features if f != target_col]
540
+
541
+ if not features:
542
+ return {'status': 'error', 'message': 'No numeric features found for label error detection'}
543
+
544
+ # Convert to pandas/numpy
545
+ df_pd = df.to_pandas()
546
+ X = df_pd[features].fillna(0).values
547
+ y_raw = df_pd[target_col].values
548
+
549
+ # Encode labels
550
+ le = LabelEncoder()
551
+ y = le.fit_transform(y_raw)
552
+
553
+ # Use CleanLearning to find label issues
554
+ cl = CleanLearning(
555
+ clf=LogisticRegression(max_iter=500, solver='lbfgs', multi_class='auto'),
556
+ cv_n_folds=n_folds
557
+ )
558
+
559
+ label_issues = cl.find_label_issues(X, y)
560
+
561
+ # Extract results
562
+ n_issues = label_issues['is_label_issue'].sum()
563
+ issue_indices = label_issues[label_issues['is_label_issue']].index.tolist()
564
+
565
+ # Get details for flagged rows
566
+ flagged_rows = []
567
+ for idx in issue_indices[:50]: # Limit to top 50
568
+ flagged_rows.append({
569
+ 'row_index': int(idx),
570
+ 'current_label': str(y_raw[idx]),
571
+ 'suggested_label': str(le.inverse_transform([label_issues.loc[idx, 'predicted_label']])[0]) if 'predicted_label' in label_issues.columns else 'unknown',
572
+ 'confidence': float(1 - label_issues.loc[idx, 'label_quality']) if 'label_quality' in label_issues.columns else None
573
+ })
574
+
575
+ print(f" 🚨 Found {n_issues} potential label errors ({n_issues/len(y)*100:.1f}%)")
576
+
577
+ # Save flagged rows
578
+ if output_path and issue_indices:
579
+ flagged_df = df_pd.iloc[issue_indices]
580
+ flagged_df.to_csv(output_path, index=False)
581
+ print(f" 💾 Flagged rows saved to: {output_path}")
582
+
583
+ return {
584
+ 'status': 'success',
585
+ 'total_samples': len(y),
586
+ 'label_errors_found': int(n_issues),
587
+ 'error_percentage': round(n_issues / len(y) * 100, 2),
588
+ 'flagged_rows': flagged_rows,
589
+ 'n_classes': len(le.classes_),
590
+ 'classes': le.classes_.tolist(),
591
+ 'output_path': output_path,
592
+ 'recommendation': f'Review {n_issues} flagged samples for potential mislabeling' if n_issues > 0 else 'No label errors detected'
593
+ }