flowyml 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. flowyml/assets/dataset.py +570 -17
  2. flowyml/assets/model.py +1052 -15
  3. flowyml/core/executor.py +70 -11
  4. flowyml/core/orchestrator.py +37 -2
  5. flowyml/core/pipeline.py +32 -4
  6. flowyml/core/scheduler.py +88 -5
  7. flowyml/integrations/keras.py +247 -82
  8. flowyml/storage/sql.py +24 -6
  9. flowyml/ui/backend/routers/runs.py +112 -0
  10. flowyml/ui/backend/routers/schedules.py +35 -15
  11. flowyml/ui/frontend/dist/assets/index-B40RsQDq.css +1 -0
  12. flowyml/ui/frontend/dist/assets/index-CjI0zKCn.js +685 -0
  13. flowyml/ui/frontend/dist/index.html +2 -2
  14. flowyml/ui/frontend/package-lock.json +11 -0
  15. flowyml/ui/frontend/package.json +1 -0
  16. flowyml/ui/frontend/src/app/assets/page.jsx +890 -321
  17. flowyml/ui/frontend/src/app/dashboard/page.jsx +1 -1
  18. flowyml/ui/frontend/src/app/experiments/[experimentId]/page.jsx +1 -1
  19. flowyml/ui/frontend/src/app/leaderboard/page.jsx +1 -1
  20. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectMetricsPanel.jsx +1 -1
  21. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRunsList.jsx +3 -3
  22. flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +590 -102
  23. flowyml/ui/frontend/src/components/ArtifactViewer.jsx +62 -2
  24. flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +401 -28
  25. flowyml/ui/frontend/src/components/AssetTreeHierarchy.jsx +119 -11
  26. flowyml/ui/frontend/src/components/DatasetViewer.jsx +753 -0
  27. flowyml/ui/frontend/src/components/TrainingHistoryChart.jsx +514 -0
  28. flowyml/ui/frontend/src/components/TrainingMetricsPanel.jsx +175 -0
  29. {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/METADATA +1 -1
  30. {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/RECORD +33 -30
  31. flowyml/ui/frontend/dist/assets/index-By4trVyv.css +0 -1
  32. flowyml/ui/frontend/dist/assets/index-CX5RV2C9.js +0 -630
  33. {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/WHEEL +0 -0
  34. {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/entry_points.txt +0 -0
  35. {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/licenses/LICENSE +0 -0
flowyml/assets/dataset.py CHANGED
@@ -1,19 +1,332 @@
1
- """Dataset Asset - Represents ML datasets with schema validation."""
1
+ """Dataset Asset - Represents ML datasets with automatic statistics extraction."""
2
2
 
3
3
  from typing import Any
4
+ import logging
5
+
4
6
  from flowyml.assets.base import Asset
5
7
 
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class DatasetStats:
12
+ """Utility class for computing dataset statistics from various data formats."""
13
+
14
+ @staticmethod
15
+ def detect_data_type(data: Any) -> str:
16
+ """Detect the type of data structure.
17
+
18
+ Returns one of: 'pandas', 'numpy', 'tensorflow', 'torch', 'dict', 'list', 'unknown'
19
+ """
20
+ if data is None:
21
+ return "unknown"
22
+
23
+ # Check class name to avoid importing heavy libraries
24
+ type_name = type(data).__name__
25
+ module_name = type(data).__module__
26
+
27
+ # Pandas DataFrame
28
+ if type_name == "DataFrame" and "pandas" in module_name:
29
+ return "pandas"
30
+
31
+ # Numpy array
32
+ if type_name == "ndarray" and "numpy" in module_name:
33
+ return "numpy"
34
+
35
+ # TensorFlow Dataset
36
+ if "tensorflow" in module_name or "tf" in module_name:
37
+ if "Dataset" in type_name or "dataset" in module_name:
38
+ return "tensorflow"
39
+
40
+ # PyTorch Dataset/DataLoader
41
+ if "torch" in module_name:
42
+ if "Dataset" in type_name or "DataLoader" in type_name:
43
+ return "torch"
44
+
45
+ # Dictionary (common format for features/target)
46
+ if isinstance(data, dict):
47
+ return "dict"
48
+
49
+ # List/Tuple
50
+ if isinstance(data, (list, tuple)):
51
+ return "list"
52
+
53
+ return "unknown"
54
+
55
+ @staticmethod
56
+ def compute_numeric_stats(values: list) -> dict[str, Any]:
57
+ """Compute statistics for a numeric column."""
58
+ try:
59
+ # Filter to numeric values only
60
+ numeric_vals = [v for v in values if isinstance(v, (int, float)) and v == v] # v == v filters NaN
61
+ if not numeric_vals:
62
+ return {}
63
+
64
+ n = len(numeric_vals)
65
+ sorted_vals = sorted(numeric_vals)
66
+
67
+ mean = sum(numeric_vals) / n
68
+ variance = sum((x - mean) ** 2 for x in numeric_vals) / n
69
+ std = variance**0.5
70
+
71
+ # Median
72
+ mid = n // 2
73
+ median = sorted_vals[mid] if n % 2 else (sorted_vals[mid - 1] + sorted_vals[mid]) / 2
74
+
75
+ return {
76
+ "mean": round(mean, 6),
77
+ "std": round(std, 6),
78
+ "min": round(min(numeric_vals), 6),
79
+ "max": round(max(numeric_vals), 6),
80
+ "median": round(median, 6),
81
+ "count": n,
82
+ "unique": len(set(numeric_vals)),
83
+ "dtype": "numeric",
84
+ }
85
+ except Exception as e:
86
+ logger.debug(f"Could not compute numeric stats: {e}")
87
+ return {}
88
+
89
+ @staticmethod
90
+ def compute_categorical_stats(values: list) -> dict[str, Any]:
91
+ """Compute statistics for a categorical column."""
92
+ try:
93
+ n = len(values)
94
+ unique_vals = {str(v) for v in values if v is not None}
95
+
96
+ return {
97
+ "count": n,
98
+ "unique": len(unique_vals),
99
+ "dtype": "categorical",
100
+ "top_values": list(unique_vals)[:5], # Sample of unique values
101
+ }
102
+ except Exception as e:
103
+ logger.debug(f"Could not compute categorical stats: {e}")
104
+ return {}
105
+
106
+ @staticmethod
107
+ def extract_from_pandas(df: Any) -> dict[str, Any]:
108
+ """Extract statistics from a pandas DataFrame."""
109
+ try:
110
+ columns = list(df.columns)
111
+ n_samples = len(df)
112
+
113
+ # Compute per-column stats
114
+ column_stats = {}
115
+ for col in columns:
116
+ col_data = df[col].tolist()
117
+ # Check if numeric
118
+ if df[col].dtype.kind in ("i", "f", "u"): # int, float, unsigned
119
+ column_stats[col] = DatasetStats.compute_numeric_stats(col_data)
120
+ else:
121
+ column_stats[col] = DatasetStats.compute_categorical_stats(col_data)
122
+
123
+ # Detect target column (common naming conventions)
124
+ target_candidates = ["target", "label", "y", "class", "output"]
125
+ target_col = next((c for c in columns if c.lower() in target_candidates), None)
126
+ feature_cols = [c for c in columns if c != target_col] if target_col else columns
127
+
128
+ return {
129
+ "samples": n_samples,
130
+ "num_features": len(feature_cols),
131
+ "feature_columns": feature_cols,
132
+ "label_column": target_col,
133
+ "columns": columns,
134
+ "column_stats": column_stats,
135
+ "framework": "pandas",
136
+ "_auto_extracted": True,
137
+ }
138
+ except Exception as e:
139
+ logger.debug(f"Could not extract pandas stats: {e}")
140
+ return {}
141
+
142
+ @staticmethod
143
+ def extract_from_numpy(arr: Any) -> dict[str, Any]:
144
+ """Extract statistics from a numpy array."""
145
+ try:
146
+ shape = arr.shape
147
+ dtype = str(arr.dtype)
148
+
149
+ result = {
150
+ "shape": list(shape),
151
+ "dtype": dtype,
152
+ "samples": shape[0] if len(shape) > 0 else 1,
153
+ "num_features": shape[1] if len(shape) > 1 else 1,
154
+ "framework": "numpy",
155
+ "_auto_extracted": True,
156
+ }
157
+
158
+ # Compute stats if 1D or 2D numeric
159
+ if arr.dtype.kind in ("i", "f", "u") and len(shape) <= 2:
160
+ flat_data = arr.flatten().tolist()
161
+ stats = DatasetStats.compute_numeric_stats(flat_data)
162
+ result["stats"] = stats
163
+
164
+ return result
165
+ except Exception as e:
166
+ logger.debug(f"Could not extract numpy stats: {e}")
167
+ return {}
168
+
169
+ @staticmethod
170
+ def extract_from_dict(data: dict) -> dict[str, Any]:
171
+ """Extract statistics from a dict of arrays (common format)."""
172
+ try:
173
+ # Check if it's a features/target format
174
+ if "features" in data and isinstance(data["features"], dict):
175
+ features = data["features"]
176
+ target = data.get("target", [])
177
+
178
+ columns = list(features.keys())
179
+ if target:
180
+ columns.append("target")
181
+
182
+ # Get sample count from first feature
183
+ first_key = next(iter(features.keys()))
184
+ n_samples = len(features[first_key]) if features else 0
185
+
186
+ # Compute per-column stats
187
+ column_stats = {}
188
+ for col, values in features.items():
189
+ if values and isinstance(values[0], (int, float)):
190
+ column_stats[col] = DatasetStats.compute_numeric_stats(values)
191
+ else:
192
+ column_stats[col] = DatasetStats.compute_categorical_stats(values)
193
+
194
+ if target:
195
+ if target and isinstance(target[0], (int, float)):
196
+ column_stats["target"] = DatasetStats.compute_numeric_stats(target)
197
+ else:
198
+ column_stats["target"] = DatasetStats.compute_categorical_stats(target)
199
+
200
+ return {
201
+ "samples": n_samples,
202
+ "num_features": len(features),
203
+ "feature_columns": list(features.keys()),
204
+ "label_column": "target" if target else None,
205
+ "columns": columns,
206
+ "column_stats": column_stats,
207
+ "framework": "dict",
208
+ "_auto_extracted": True,
209
+ }
210
+
211
+ # Generic dict of arrays
212
+ columns = list(data.keys())
213
+ n_samples = 0
214
+ column_stats = {}
215
+
216
+ for col, values in data.items():
217
+ if isinstance(values, (list, tuple)):
218
+ n_samples = max(n_samples, len(values))
219
+ if values and isinstance(values[0], (int, float)):
220
+ column_stats[col] = DatasetStats.compute_numeric_stats(list(values))
221
+ else:
222
+ column_stats[col] = DatasetStats.compute_categorical_stats(list(values))
223
+
224
+ # Detect target column
225
+ target_candidates = ["target", "label", "y", "class", "output"]
226
+ target_col = next((c for c in columns if c.lower() in target_candidates), None)
227
+ feature_cols = [c for c in columns if c != target_col] if target_col else columns
228
+
229
+ return {
230
+ "samples": n_samples,
231
+ "num_features": len(feature_cols),
232
+ "feature_columns": feature_cols,
233
+ "label_column": target_col,
234
+ "columns": columns,
235
+ "column_stats": column_stats,
236
+ "framework": "dict",
237
+ "_auto_extracted": True,
238
+ }
239
+ except Exception as e:
240
+ logger.debug(f"Could not extract dict stats: {e}")
241
+ return {}
242
+
243
+ @staticmethod
244
+ def extract_from_tensorflow(dataset: Any) -> dict[str, Any]:
245
+ """Extract statistics from a TensorFlow dataset."""
246
+ try:
247
+ result = {
248
+ "framework": "tensorflow",
249
+ "_auto_extracted": True,
250
+ }
251
+
252
+ # Get cardinality if available
253
+ if hasattr(dataset, "cardinality"):
254
+ card = dataset.cardinality()
255
+ if hasattr(card, "numpy"):
256
+ card = card.numpy()
257
+ result["cardinality"] = int(card) if card >= 0 else "unknown"
258
+ result["samples"] = int(card) if card >= 0 else None
259
+
260
+ # Get element spec if available
261
+ if hasattr(dataset, "element_spec"):
262
+ spec = dataset.element_spec
263
+
264
+ def spec_to_dict(s: Any) -> dict | str:
265
+ if hasattr(s, "shape") and hasattr(s, "dtype"):
266
+ return {"shape": str(s.shape), "dtype": str(s.dtype)}
267
+ if isinstance(s, dict):
268
+ return {k: spec_to_dict(v) for k, v in s.items()}
269
+ if isinstance(s, (tuple, list)):
270
+ return [spec_to_dict(x) for x in s]
271
+ return str(s)
272
+
273
+ result["element_spec"] = spec_to_dict(spec)
274
+
275
+ return result
276
+ except Exception as e:
277
+ logger.debug(f"Could not extract tensorflow stats: {e}")
278
+ return {"framework": "tensorflow", "_auto_extracted": True}
279
+
280
+ @staticmethod
281
+ def extract_stats(data: Any) -> dict[str, Any]:
282
+ """Auto-detect data type and extract statistics."""
283
+ data_type = DatasetStats.detect_data_type(data)
284
+
285
+ if data_type == "pandas":
286
+ return DatasetStats.extract_from_pandas(data)
287
+ elif data_type == "numpy":
288
+ return DatasetStats.extract_from_numpy(data)
289
+ elif data_type == "dict":
290
+ return DatasetStats.extract_from_dict(data)
291
+ elif data_type == "tensorflow":
292
+ return DatasetStats.extract_from_tensorflow(data)
293
+ elif data_type == "list":
294
+ # Try to convert list to dict format
295
+ if data and isinstance(data[0], dict):
296
+ # List of dicts -> convert to dict of lists
297
+ keys = data[0].keys()
298
+ dict_data = {k: [row.get(k) for row in data] for k in keys}
299
+ return DatasetStats.extract_from_dict(dict_data)
300
+ return {"samples": len(data), "framework": "list", "_auto_extracted": True}
301
+
302
+ return {"_auto_extracted": False}
303
+
6
304
 
7
305
  class Dataset(Asset):
8
- """Dataset asset with schema and lineage tracking.
306
+ """Dataset asset with automatic schema detection and statistics extraction.
307
+
308
+ The Dataset class automatically extracts statistics and metadata from various
309
+ data formats, reducing boilerplate code and improving UX.
310
+
311
+ Supported formats:
312
+ - pandas DataFrame: Auto-extracts columns, dtypes, statistics
313
+ - numpy array: Auto-extracts shape, dtype, statistics
314
+ - dict: Auto-extracts features/target structure, column stats
315
+ - TensorFlow Dataset: Auto-extracts element_spec, cardinality
316
+ - List of dicts: Converts to dict format and extracts stats
9
317
 
10
318
  Example:
11
- >>> raw_data = Dataset(
12
- ... name="imagenet_train",
13
- ... version="v2.0",
14
- ... data=train_dataset,
15
- ... properties={"size": "150GB", "samples": 1_281_167},
16
- ... )
319
+ >>> # Minimal usage - stats are extracted automatically!
320
+ >>> import pandas as pd
321
+ >>> df = pd.read_csv("data.csv")
322
+ >>> dataset = Dataset.create(data=df, name="my_dataset")
323
+ >>> print(dataset.num_samples) # Auto-extracted
324
+ >>> print(dataset.feature_columns) # Auto-detected
325
+
326
+ >>> # With dict format
327
+ >>> data = {"features": {"x": [1, 2, 3], "y": [4, 5, 6]}, "target": [0, 1, 0]}
328
+ >>> dataset = Dataset.create(data=data, name="my_dataset")
329
+ >>> # All stats computed automatically!
17
330
  """
18
331
 
19
332
  def __init__(
@@ -26,14 +339,40 @@ class Dataset(Asset):
26
339
  parent: Asset | None = None,
27
340
  tags: dict[str, str] | None = None,
28
341
  properties: dict[str, Any] | None = None,
342
+ auto_extract_stats: bool = True,
29
343
  ):
344
+ """Initialize Dataset with automatic statistics extraction.
345
+
346
+ Args:
347
+ name: Dataset name
348
+ version: Version string
349
+ data: The actual data (DataFrame, array, dict, etc.)
350
+ schema: Optional schema definition
351
+ location: Storage location/path
352
+ parent: Parent asset for lineage
353
+ tags: Metadata tags
354
+ properties: Additional properties (merged with auto-extracted)
355
+ auto_extract_stats: Whether to automatically extract statistics
356
+ """
357
+ # Initialize properties dict
358
+ final_properties = properties.copy() if properties else {}
359
+
360
+ # Auto-extract statistics if enabled and data is provided
361
+ if auto_extract_stats and data is not None:
362
+ extracted = DatasetStats.extract_stats(data)
363
+ # Merge extracted stats with user-provided properties
364
+ # User properties take precedence
365
+ for key, value in extracted.items():
366
+ if key not in final_properties:
367
+ final_properties[key] = value
368
+
30
369
  super().__init__(
31
370
  name=name,
32
371
  version=version,
33
372
  data=data,
34
373
  parent=parent,
35
374
  tags=tags,
36
- properties=properties,
375
+ properties=final_properties,
37
376
  )
38
377
 
39
378
  self.schema = schema
@@ -45,6 +384,129 @@ class Dataset(Asset):
45
384
  if location:
46
385
  self.metadata.properties["location"] = location
47
386
 
387
+ @classmethod
388
+ def create(
389
+ cls,
390
+ data: Any,
391
+ name: str,
392
+ version: str | None = None,
393
+ schema: Any | None = None,
394
+ location: str | None = None,
395
+ parent: Asset | None = None,
396
+ tags: dict[str, str] | None = None,
397
+ properties: dict[str, Any] | None = None,
398
+ auto_extract_stats: bool = True,
399
+ **kwargs: Any,
400
+ ) -> "Dataset":
401
+ """Create a Dataset with automatic statistics extraction.
402
+
403
+ This is the preferred way to create Dataset objects. Statistics are
404
+ automatically extracted from the data, reducing boilerplate code.
405
+
406
+ Args:
407
+ data: The actual data (DataFrame, array, dict, etc.)
408
+ name: Dataset name
409
+ version: Version string (optional)
410
+ schema: Optional schema definition
411
+ location: Storage location/path
412
+ parent: Parent asset for lineage
413
+ tags: Metadata tags
414
+ properties: Additional properties (merged with auto-extracted)
415
+ auto_extract_stats: Whether to automatically extract statistics
416
+ **kwargs: Additional properties to store
417
+
418
+ Returns:
419
+ Dataset instance with auto-extracted statistics
420
+
421
+ Example:
422
+ >>> df = pd.read_csv("data.csv")
423
+ >>> dataset = Dataset.create(data=df, name="my_data", source="data.csv")
424
+ >>> # Stats are automatically extracted!
425
+ """
426
+ # Merge kwargs into properties
427
+ final_props = properties.copy() if properties else {}
428
+ for key, value in kwargs.items():
429
+ if key not in final_props:
430
+ final_props[key] = value
431
+
432
+ return cls(
433
+ name=name,
434
+ version=version,
435
+ data=data,
436
+ schema=schema,
437
+ location=location,
438
+ parent=parent,
439
+ tags=tags,
440
+ properties=final_props,
441
+ auto_extract_stats=auto_extract_stats,
442
+ )
443
+
444
+ @classmethod
445
+ def from_csv(
446
+ cls,
447
+ path: str,
448
+ name: str | None = None,
449
+ **kwargs: Any,
450
+ ) -> "Dataset":
451
+ """Load a Dataset from a CSV file with automatic statistics.
452
+
453
+ Args:
454
+ path: Path to CSV file
455
+ name: Dataset name (defaults to filename)
456
+ **kwargs: Additional properties
457
+
458
+ Returns:
459
+ Dataset with auto-extracted statistics
460
+ """
461
+ try:
462
+ import pandas as pd
463
+
464
+ df = pd.read_csv(path)
465
+ dataset_name = name or path.split("/")[-1].replace(".csv", "")
466
+
467
+ return cls.create(
468
+ data=df,
469
+ name=dataset_name,
470
+ location=path,
471
+ properties={"source": path, "format": "csv"},
472
+ **kwargs,
473
+ )
474
+ except ImportError:
475
+ raise ImportError("pandas is required for from_csv(). Install with: pip install pandas")
476
+
477
+ @classmethod
478
+ def from_parquet(
479
+ cls,
480
+ path: str,
481
+ name: str | None = None,
482
+ **kwargs: Any,
483
+ ) -> "Dataset":
484
+ """Load a Dataset from a Parquet file with automatic statistics.
485
+
486
+ Args:
487
+ path: Path to Parquet file
488
+ name: Dataset name (defaults to filename)
489
+ **kwargs: Additional properties
490
+
491
+ Returns:
492
+ Dataset with auto-extracted statistics
493
+ """
494
+ try:
495
+ import pandas as pd
496
+
497
+ df = pd.read_parquet(path)
498
+ dataset_name = name or path.split("/")[-1].replace(".parquet", "")
499
+
500
+ return cls.create(
501
+ data=df,
502
+ name=dataset_name,
503
+ location=path,
504
+ properties={"source": path, "format": "parquet"},
505
+ **kwargs,
506
+ )
507
+ except ImportError:
508
+ raise ImportError("pandas and pyarrow are required for from_parquet()")
509
+
48
510
  @property
49
511
  def size(self) -> int | None:
50
512
  """Get dataset size if available."""
@@ -52,8 +514,57 @@ class Dataset(Asset):
52
514
 
53
515
  @property
54
516
  def num_samples(self) -> int | None:
55
- """Get number of samples if available."""
56
- return self.metadata.properties.get("samples") or self.metadata.properties.get("num_samples")
517
+ """Get number of samples (auto-extracted or user-provided)."""
518
+ return (
519
+ self.metadata.properties.get("samples")
520
+ or self.metadata.properties.get("num_samples")
521
+ or self.metadata.properties.get("cardinality")
522
+ )
523
+
524
+ @property
525
+ def num_features(self) -> int | None:
526
+ """Get number of features (auto-extracted or user-provided)."""
527
+ return self.metadata.properties.get("num_features")
528
+
529
+ @property
530
+ def feature_columns(self) -> list[str] | None:
531
+ """Get list of feature column names (auto-extracted or user-provided)."""
532
+ return self.metadata.properties.get("feature_columns")
533
+
534
+ @property
535
+ def label_column(self) -> str | None:
536
+ """Get the label/target column name (auto-detected or user-provided)."""
537
+ return self.metadata.properties.get("label_column")
538
+
539
+ @property
540
+ def columns(self) -> list[str] | None:
541
+ """Get all column names (auto-extracted or user-provided)."""
542
+ return self.metadata.properties.get("columns")
543
+
544
+ @property
545
+ def column_stats(self) -> dict[str, dict] | None:
546
+ """Get per-column statistics (auto-extracted)."""
547
+ return self.metadata.properties.get("column_stats")
548
+
549
+ @property
550
+ def framework(self) -> str | None:
551
+ """Get the data framework/format (auto-detected)."""
552
+ return self.metadata.properties.get("framework")
553
+
554
+ def get_column_stat(self, column: str, stat: str) -> Any:
555
+ """Get a specific statistic for a column.
556
+
557
+ Args:
558
+ column: Column name
559
+ stat: Statistic name (mean, std, min, max, median, count, unique)
560
+
561
+ Returns:
562
+ The statistic value or None
563
+ """
564
+ stats = self.column_stats
565
+ if stats and column in stats:
566
+ return stats[column].get(stat)
567
+ return None
57
568
 
58
569
  def validate_schema(self) -> bool:
59
570
  """Validate data against schema (placeholder)."""
@@ -62,22 +573,53 @@ class Dataset(Asset):
62
573
  # Schema validation would go here
63
574
  return True
64
575
 
65
- def split(self, train_ratio: float = 0.8, name_prefix: str | None = None) -> tuple["Dataset", "Dataset"]:
66
- """Split dataset into train/test.
576
+ def split(
577
+ self,
578
+ train_ratio: float = 0.8,
579
+ name_prefix: str | None = None,
580
+ random_state: int | None = 42,
581
+ ) -> tuple["Dataset", "Dataset"]:
582
+ """Split dataset into train/test with auto-extracted statistics.
67
583
 
68
584
  Args:
69
585
  train_ratio: Ratio for training split
70
586
  name_prefix: Prefix for split dataset names
587
+ random_state: Random seed for reproducibility
71
588
 
72
589
  Returns:
73
590
  Tuple of (train_dataset, test_dataset)
74
591
  """
75
592
  prefix = name_prefix or self.name
76
593
 
77
- # Placeholder - actual splitting logic would depend on data type
78
- _ = train_ratio # Unused in placeholder
79
- train_data = self.data # Would actually split the data
80
- test_data = self.data
594
+ # Try to split based on data type
595
+ data_type = DatasetStats.detect_data_type(self.data)
596
+
597
+ if data_type == "pandas":
598
+ try:
599
+ df = self.data.sample(frac=1, random_state=random_state).reset_index(drop=True)
600
+ train_size = int(len(df) * train_ratio)
601
+ train_data = df[:train_size]
602
+ test_data = df[train_size:]
603
+ except Exception:
604
+ train_data = self.data
605
+ test_data = self.data
606
+ elif data_type == "dict" and "features" in self.data:
607
+ # Split dict format
608
+ features = self.data["features"]
609
+ target = self.data.get("target", [])
610
+ first_key = next(iter(features.keys()))
611
+ n_samples = len(features[first_key])
612
+ train_size = int(n_samples * train_ratio)
613
+
614
+ train_features = {k: v[:train_size] for k, v in features.items()}
615
+ test_features = {k: v[train_size:] for k, v in features.items()}
616
+
617
+ train_data = {"features": train_features, "target": target[:train_size] if target else []}
618
+ test_data = {"features": test_features, "target": target[train_size:] if target else []}
619
+ else:
620
+ # Fallback - no actual splitting
621
+ train_data = self.data
622
+ test_data = self.data
81
623
 
82
624
  train_dataset = Dataset(
83
625
  name=f"{prefix}_train",
@@ -98,3 +640,14 @@ class Dataset(Asset):
98
640
  )
99
641
 
100
642
  return train_dataset, test_dataset
643
+
644
+ def __repr__(self) -> str:
645
+ """String representation with key stats."""
646
+ parts = [f"Dataset(name='{self.name}'"]
647
+ if self.num_samples:
648
+ parts.append(f"samples={self.num_samples}")
649
+ if self.num_features:
650
+ parts.append(f"features={self.num_features}")
651
+ if self.framework:
652
+ parts.append(f"framework='{self.framework}'")
653
+ return ", ".join(parts) + ")"