flowyml 1.7.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. flowyml/assets/dataset.py +570 -17
  2. flowyml/assets/model.py +1052 -15
  3. flowyml/core/executor.py +70 -11
  4. flowyml/core/orchestrator.py +37 -2
  5. flowyml/core/pipeline.py +32 -4
  6. flowyml/core/scheduler.py +88 -5
  7. flowyml/integrations/keras.py +247 -82
  8. flowyml/ui/backend/routers/runs.py +112 -0
  9. flowyml/ui/backend/routers/schedules.py +35 -15
  10. flowyml/ui/frontend/dist/assets/index-B40RsQDq.css +1 -0
  11. flowyml/ui/frontend/dist/assets/index-CjI0zKCn.js +685 -0
  12. flowyml/ui/frontend/dist/index.html +2 -2
  13. flowyml/ui/frontend/package-lock.json +11 -0
  14. flowyml/ui/frontend/package.json +1 -0
  15. flowyml/ui/frontend/src/app/assets/page.jsx +890 -321
  16. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectMetricsPanel.jsx +1 -1
  17. flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +589 -101
  18. flowyml/ui/frontend/src/components/ArtifactViewer.jsx +62 -2
  19. flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +401 -28
  20. flowyml/ui/frontend/src/components/AssetTreeHierarchy.jsx +119 -11
  21. flowyml/ui/frontend/src/components/DatasetViewer.jsx +753 -0
  22. flowyml/ui/frontend/src/components/TrainingHistoryChart.jsx +514 -0
  23. flowyml/ui/frontend/src/components/TrainingMetricsPanel.jsx +175 -0
  24. {flowyml-1.7.1.dist-info → flowyml-1.7.2.dist-info}/METADATA +1 -1
  25. {flowyml-1.7.1.dist-info → flowyml-1.7.2.dist-info}/RECORD +28 -25
  26. flowyml/ui/frontend/dist/assets/index-BqDQvp63.js +0 -630
  27. flowyml/ui/frontend/dist/assets/index-By4trVyv.css +0 -1
  28. {flowyml-1.7.1.dist-info → flowyml-1.7.2.dist-info}/WHEEL +0 -0
  29. {flowyml-1.7.1.dist-info → flowyml-1.7.2.dist-info}/entry_points.txt +0 -0
  30. {flowyml-1.7.1.dist-info → flowyml-1.7.2.dist-info}/licenses/LICENSE +0 -0
flowyml/assets/model.py CHANGED
@@ -1,20 +1,755 @@
1
- """Model Asset - Represents ML models with metadata and lineage."""
1
+ """Model Asset - Represents ML models with automatic metadata extraction."""
2
2
 
3
3
  from typing import Any
4
+ import contextlib
5
+ import logging
6
+
4
7
  from flowyml.assets.base import Asset
5
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ModelInspector:
13
+ """Utility class for extracting model metadata from various frameworks.
14
+
15
+ This class provides robust, fail-safe extraction of model metadata.
16
+ It will never raise exceptions - if extraction fails, it returns
17
+ partial results with whatever could be extracted.
18
+
19
+ Supported frameworks with rich extraction:
20
+ - Keras/TensorFlow: Full architecture, optimizer, loss, metrics
21
+ - PyTorch: Architecture, parameters, layer info
22
+ - Scikit-learn: Hyperparameters, feature importance
23
+ - XGBoost/LightGBM/CatBoost: Trees, hyperparameters
24
+
25
+ All other model types are supported with basic metadata (type name, etc.)
26
+ """
27
+
28
+ @staticmethod
29
+ def detect_framework(model: Any) -> str | None:
30
+ """Detect the ML framework of a model.
31
+
32
+ Returns one of: 'keras', 'tensorflow', 'pytorch', 'sklearn', 'xgboost',
33
+ 'lightgbm', 'catboost', 'huggingface', 'onnx', 'custom', or None
34
+
35
+ This method is safe to call on any object and will never raise.
36
+ """
37
+ if model is None:
38
+ return None
39
+
40
+ try:
41
+ type_name = type(model).__name__
42
+ module_name = type(model).__module__
43
+
44
+ # Keras/TensorFlow - check multiple indicators
45
+ if any(x in module_name.lower() for x in ["keras", "tf.keras"]):
46
+ return "keras"
47
+ if type_name in ("Sequential", "Functional", "Model") and "tensorflow" in module_name.lower():
48
+ return "keras"
49
+ if "tensorflow" in module_name.lower():
50
+ return "tensorflow"
51
+
52
+ # PyTorch - check for nn.Module inheritance (handles user-defined models)
53
+ # Check if any base class is from torch
54
+ try:
55
+ for base in type(model).__mro__:
56
+ base_module = base.__module__
57
+ if "torch" in base_module.lower():
58
+ return "pytorch"
59
+ except Exception:
60
+ pass
61
+
62
+ # Also check module name directly (for torch tensors, etc.)
63
+ if "torch" in module_name.lower():
64
+ return "pytorch"
65
+
66
+ # Check for PyTorch-specific attributes
67
+ if hasattr(model, "forward") and hasattr(model, "parameters") and hasattr(model, "state_dict"):
68
+ # Likely a PyTorch model
69
+ return "pytorch"
70
+
71
+ # Scikit-learn - check for common base classes
72
+ if "sklearn" in module_name.lower():
73
+ return "sklearn"
74
+
75
+ # XGBoost
76
+ if "xgboost" in module_name.lower() or type_name.startswith("XGB"):
77
+ return "xgboost"
78
+
79
+ # LightGBM
80
+ if "lightgbm" in module_name.lower() or type_name.startswith("LGB"):
81
+ return "lightgbm"
82
+
83
+ # CatBoost
84
+ if "catboost" in module_name.lower():
85
+ return "catboost"
86
+
87
+ # Hugging Face Transformers
88
+ if "transformers" in module_name.lower():
89
+ return "huggingface"
90
+
91
+ # ONNX
92
+ if "onnx" in module_name.lower():
93
+ return "onnx"
94
+
95
+ # JAX/Flax
96
+ if "flax" in module_name.lower() or "jax" in module_name.lower():
97
+ return "jax"
98
+
99
+ # Check for common ML model attributes
100
+ if hasattr(model, "predict") and hasattr(model, "fit"):
101
+ return "sklearn" # Sklearn-like API
102
+
103
+ return "custom"
104
+
105
+ except Exception as e:
106
+ logger.debug(f"Error detecting framework: {e}")
107
+ return "unknown"
108
+
109
+ @staticmethod
110
+ def extract_keras_info(model: Any) -> dict[str, Any]:
111
+ """Extract metadata from a Keras/TensorFlow model.
112
+
113
+ This method is robust and will extract as much as possible,
114
+ even if some attributes are not available (e.g., uncompiled model).
115
+ """
116
+ result = {"framework": "keras", "_auto_extracted": True}
117
+
118
+ # Parameter count - handle both built and unbuilt models
119
+ try:
120
+ if hasattr(model, "count_params"):
121
+ with contextlib.suppress(Exception):
122
+ result["parameters"] = model.count_params()
123
+ except Exception as e:
124
+ logger.debug(f"Error getting param count: {e}")
125
+
126
+ # Trainable parameters
127
+ try:
128
+ if hasattr(model, "trainable_weights") and model.trainable_weights:
129
+ trainable = sum(int(w.numpy().size) if hasattr(w, "numpy") else 0 for w in model.trainable_weights)
130
+ if trainable > 0:
131
+ result["trainable_parameters"] = trainable
132
+ except Exception as e:
133
+ logger.debug(f"Error getting trainable params: {e}")
134
+
135
+ # Architecture name
136
+ try:
137
+ if hasattr(model, "name") and model.name:
138
+ result["architecture"] = model.name
139
+ result["model_class"] = type(model).__name__
140
+ except Exception as e:
141
+ logger.debug(f"Error getting architecture: {e}")
142
+
143
+ # Layer info - handle models without layers attribute
144
+ try:
145
+ if hasattr(model, "layers"):
146
+ layers = model.layers
147
+ if layers:
148
+ result["num_layers"] = len(layers)
149
+ result["layer_types"] = list({type(layer).__name__ for layer in layers})
150
+ except Exception as e:
151
+ logger.debug(f"Error getting layer info: {e}")
152
+
153
+ # Input/Output shapes - handle unbuilt models
154
+ with contextlib.suppress(Exception):
155
+ if hasattr(model, "input_shape"):
156
+ input_shape = model.input_shape
157
+ if input_shape is not None:
158
+ result["input_shape"] = str(input_shape)
159
+
160
+ with contextlib.suppress(Exception):
161
+ if hasattr(model, "output_shape"):
162
+ output_shape = model.output_shape
163
+ if output_shape is not None:
164
+ result["output_shape"] = str(output_shape)
165
+
166
+ # Optimizer info (only for compiled models)
167
+ try:
168
+ if hasattr(model, "optimizer") and model.optimizer is not None:
169
+ opt = model.optimizer
170
+ result["optimizer"] = type(opt).__name__
171
+
172
+ # Get learning rate - handle different Keras versions
173
+ if hasattr(opt, "learning_rate"):
174
+ lr = opt.learning_rate
175
+ if hasattr(lr, "numpy"):
176
+ lr = float(lr.numpy())
177
+ elif callable(lr):
178
+ # Learning rate schedule
179
+ result["lr_schedule"] = type(lr).__name__
180
+ else:
181
+ lr = float(lr)
182
+ if isinstance(lr, (int, float)):
183
+ result["learning_rate"] = lr
184
+ elif hasattr(opt, "lr"):
185
+ # Older Keras versions
186
+ result["learning_rate"] = float(opt.lr)
187
+ except Exception as e:
188
+ logger.debug(f"Error getting optimizer info: {e}")
189
+
190
+ # Loss function (only for compiled models)
191
+ try:
192
+ if hasattr(model, "loss") and model.loss is not None:
193
+ loss = model.loss
194
+ if isinstance(loss, str):
195
+ result["loss_function"] = loss
196
+ elif hasattr(loss, "__name__"):
197
+ result["loss_function"] = loss.__name__
198
+ elif hasattr(loss, "name"):
199
+ result["loss_function"] = loss.name
200
+ elif hasattr(loss, "__class__"):
201
+ result["loss_function"] = type(loss).__name__
202
+ except Exception as e:
203
+ logger.debug(f"Error getting loss function: {e}")
204
+
205
+ # Metrics (only for compiled models)
206
+ try:
207
+ if hasattr(model, "metrics_names") and model.metrics_names:
208
+ result["metrics"] = list(model.metrics_names)
209
+ elif hasattr(model, "metrics") and model.metrics:
210
+ result["metrics"] = [m.name if hasattr(m, "name") else str(m) for m in model.metrics]
211
+ except Exception as e:
212
+ logger.debug(f"Error getting metrics: {e}")
213
+
214
+ # Check if model is compiled
215
+ with contextlib.suppress(Exception):
216
+ result["is_compiled"] = hasattr(model, "optimizer") and model.optimizer is not None
217
+
218
+ # Check if model is built
219
+ with contextlib.suppress(Exception):
220
+ if hasattr(model, "built"):
221
+ result["is_built"] = model.built
222
+
223
+ return result
224
+
225
+ @staticmethod
226
+ def extract_pytorch_info(model: Any) -> dict[str, Any]:
227
+ """Extract metadata from a PyTorch model.
228
+
229
+ This method is robust and handles:
230
+ - nn.Module models
231
+ - Custom modules
232
+ - Pretrained models from torchvision, transformers, etc.
233
+ - Models in eval or train mode
234
+ """
235
+ result = {"framework": "pytorch", "_auto_extracted": True}
236
+
237
+ # Model class name
238
+ try:
239
+ result["model_class"] = type(model).__name__
240
+ result["architecture"] = type(model).__name__
241
+ except Exception:
242
+ pass
243
+
244
+ # Parameter count
245
+ try:
246
+ if hasattr(model, "parameters"):
247
+ params = list(model.parameters())
248
+ if params:
249
+ total_params = sum(p.numel() for p in params)
250
+ trainable_params = sum(p.numel() for p in params if p.requires_grad)
251
+ result["parameters"] = total_params
252
+ result["trainable_parameters"] = trainable_params
253
+ result["frozen_parameters"] = total_params - trainable_params
254
+ except Exception as e:
255
+ logger.debug(f"Error getting PyTorch param count: {e}")
256
+
257
+ # Layer info from modules
258
+ try:
259
+ if hasattr(model, "modules"):
260
+ modules = list(model.modules())
261
+ if modules:
262
+ # Skip the first module (the model itself)
263
+ layer_modules = modules[1:] if len(modules) > 1 else modules
264
+ result["num_layers"] = len(layer_modules)
265
+
266
+ # Get unique layer types
267
+ layer_types = set()
268
+ for m in layer_modules:
269
+ layer_type = type(m).__name__
270
+ # Skip container modules
271
+ if layer_type not in ("Sequential", "ModuleList", "ModuleDict"):
272
+ layer_types.add(layer_type)
273
+ result["layer_types"] = list(layer_types)
274
+ except Exception as e:
275
+ logger.debug(f"Error getting PyTorch layer info: {e}")
276
+
277
+ # Named modules for architecture insights
278
+ try:
279
+ if hasattr(model, "named_modules"):
280
+ named = dict(model.named_modules())
281
+ if named:
282
+ result["num_named_modules"] = len(named)
283
+ except Exception:
284
+ pass
285
+
286
+ # State dict info
287
+ try:
288
+ if hasattr(model, "state_dict"):
289
+ state_dict = model.state_dict()
290
+ if state_dict:
291
+ result["num_tensors"] = len(state_dict)
292
+ # Get tensor shapes for key layers
293
+ tensor_shapes = {}
294
+ for name, tensor in list(state_dict.items())[:10]: # First 10 only
295
+ tensor_shapes[name] = list(tensor.shape)
296
+ result["tensor_shapes_sample"] = tensor_shapes
297
+ except Exception as e:
298
+ logger.debug(f"Error getting PyTorch state dict: {e}")
299
+
300
+ # Training mode
301
+ try:
302
+ if hasattr(model, "training"):
303
+ result["training_mode"] = model.training
304
+ except Exception:
305
+ pass
306
+
307
+ # Device info
308
+ try:
309
+ if hasattr(model, "parameters"):
310
+ first_param = next(model.parameters(), None)
311
+ if first_param is not None:
312
+ result["device"] = str(first_param.device)
313
+ result["dtype"] = str(first_param.dtype)
314
+ except Exception:
315
+ pass
316
+
317
+ # Check for common PyTorch model attributes
318
+ try:
319
+ # Input features (common in many models)
320
+ for attr in ["in_features", "in_channels", "input_size", "num_features"]:
321
+ if hasattr(model, attr):
322
+ val = getattr(model, attr, None)
323
+ if val is not None:
324
+ result[attr] = val
325
+ break
326
+
327
+ # Output features
328
+ for attr in ["out_features", "out_channels", "output_size", "num_classes"]:
329
+ if hasattr(model, attr):
330
+ val = getattr(model, attr, None)
331
+ if val is not None:
332
+ result[attr] = val
333
+ break
334
+ except Exception:
335
+ pass
336
+
337
+ return result
338
+
339
+ @staticmethod
340
+ def extract_sklearn_info(model: Any) -> dict[str, Any]:
341
+ """Extract metadata from a scikit-learn model.
342
+
343
+ Handles all sklearn estimators including:
344
+ - Classifiers, regressors, transformers
345
+ - Ensemble methods (RandomForest, GradientBoosting, etc.)
346
+ - Linear models
347
+ - Pipelines
348
+ """
349
+ result = {"framework": "sklearn", "_auto_extracted": True}
350
+
351
+ # Model type
352
+ try:
353
+ result["model_class"] = type(model).__name__
354
+ result["architecture"] = type(model).__name__
355
+ except Exception:
356
+ pass
357
+
358
+ # Get parameters (safe extraction)
359
+ try:
360
+ if hasattr(model, "get_params"):
361
+ params = model.get_params(deep=False) # Shallow to avoid recursion
362
+ # Filter to serializable values
363
+ filtered_params = {}
364
+ for k, v in params.items():
365
+ if v is None:
366
+ continue
367
+ if isinstance(v, (str, int, float, bool)):
368
+ filtered_params[k] = v
369
+ elif isinstance(v, (list, tuple)) and len(v) < 10:
370
+ # Small lists/tuples of primitives
371
+ if all(isinstance(x, (str, int, float, bool, type(None))) for x in v):
372
+ filtered_params[k] = list(v)
373
+ if filtered_params:
374
+ result["hyperparameters"] = filtered_params
375
+ except Exception as e:
376
+ logger.debug(f"Error getting sklearn params: {e}")
377
+
378
+ # Feature importance (tree-based models)
379
+ with contextlib.suppress(Exception):
380
+ if hasattr(model, "feature_importances_"):
381
+ importances = model.feature_importances_
382
+ result["has_feature_importances"] = True
383
+ result["num_features"] = len(importances)
384
+ # Store top 10 feature importances
385
+ if len(importances) <= 20:
386
+ result["feature_importances"] = list(importances)
387
+
388
+ # Coefficients (linear models)
389
+ with contextlib.suppress(Exception):
390
+ if hasattr(model, "coef_"):
391
+ coef = model.coef_
392
+ if hasattr(coef, "shape"):
393
+ result["coef_shape"] = str(coef.shape)
394
+ if hasattr(coef, "size") and coef.size <= 100:
395
+ result["num_coefficients"] = int(coef.size)
396
+
397
+ # Intercept
398
+ with contextlib.suppress(Exception):
399
+ if hasattr(model, "intercept_"):
400
+ intercept = model.intercept_
401
+ if hasattr(intercept, "tolist"):
402
+ intercept = intercept.tolist()
403
+ if isinstance(intercept, (int, float)) or (isinstance(intercept, list) and len(intercept) <= 10):
404
+ result["intercept"] = intercept
405
+
406
+ # Classes (classifiers)
407
+ try:
408
+ if hasattr(model, "classes_"):
409
+ classes = model.classes_
410
+ result["num_classes"] = len(classes)
411
+ if len(classes) <= 20:
412
+ # Convert to list if numpy array
413
+ if hasattr(classes, "tolist"):
414
+ classes = classes.tolist()
415
+ result["classes"] = list(classes)
416
+ except Exception:
417
+ pass
418
+
419
+ # Number of estimators (ensemble models)
420
+ try:
421
+ if hasattr(model, "n_estimators"):
422
+ result["n_estimators"] = model.n_estimators
423
+ if hasattr(model, "estimators_"):
424
+ result["num_estimators_fitted"] = len(model.estimators_)
425
+ except Exception:
426
+ pass
427
+
428
+ # Tree-specific attributes
429
+ try:
430
+ if hasattr(model, "max_depth"):
431
+ result["max_depth"] = model.max_depth
432
+ if hasattr(model, "n_features_in_"):
433
+ result["n_features_in"] = model.n_features_in_
434
+ except Exception:
435
+ pass
436
+
437
+ # Check if fitted
438
+ try:
439
+ # Common sklearn pattern for checking if fitted
440
+ from sklearn.utils.validation import check_is_fitted
441
+
442
+ check_is_fitted(model)
443
+ result["is_fitted"] = True
444
+ except Exception:
445
+ result["is_fitted"] = False
446
+
447
+ return result
448
+
449
+ @staticmethod
450
+ def extract_xgboost_info(model: Any) -> dict[str, Any]:
451
+ """Extract metadata from an XGBoost model."""
452
+ result = {"framework": "xgboost", "_auto_extracted": True}
453
+
454
+ try:
455
+ result["model_class"] = type(model).__name__
456
+ result["architecture"] = "XGBoost"
457
+ except Exception:
458
+ pass
459
+
460
+ # Get hyperparameters
461
+ try:
462
+ if hasattr(model, "get_params"):
463
+ params = model.get_params()
464
+ result["hyperparameters"] = {
465
+ k: v for k, v in params.items() if v is not None and isinstance(v, (str, int, float, bool))
466
+ }
467
+ except Exception as e:
468
+ logger.debug(f"Error getting XGBoost params: {e}")
469
+
470
+ # Booster info
471
+ try:
472
+ if hasattr(model, "get_booster"):
473
+ booster = model.get_booster()
474
+ if hasattr(booster, "num_trees"):
475
+ result["num_trees"] = booster.num_trees()
476
+ if hasattr(booster, "num_features"):
477
+ result["num_features"] = booster.num_features()
478
+ elif hasattr(model, "n_estimators"):
479
+ result["n_estimators"] = model.n_estimators
480
+ except Exception:
481
+ pass
482
+
483
+ # Feature importance
484
+ try:
485
+ if hasattr(model, "feature_importances_"):
486
+ result["num_features"] = len(model.feature_importances_)
487
+ result["has_feature_importances"] = True
488
+ except Exception:
489
+ pass
490
+
491
+ # Best iteration (for early stopping)
492
+ try:
493
+ if hasattr(model, "best_iteration"):
494
+ result["best_iteration"] = model.best_iteration
495
+ if hasattr(model, "best_score"):
496
+ result["best_score"] = model.best_score
497
+ except Exception:
498
+ pass
499
+
500
+ return result
501
+
502
+ @staticmethod
503
+ def extract_lightgbm_info(model: Any) -> dict[str, Any]:
504
+ """Extract metadata from a LightGBM model."""
505
+ result = {"framework": "lightgbm", "_auto_extracted": True}
506
+
507
+ try:
508
+ result["model_class"] = type(model).__name__
509
+ result["architecture"] = "LightGBM"
510
+ except Exception:
511
+ pass
512
+
513
+ # Get hyperparameters
514
+ try:
515
+ if hasattr(model, "get_params"):
516
+ params = model.get_params()
517
+ result["hyperparameters"] = {
518
+ k: v for k, v in params.items() if v is not None and isinstance(v, (str, int, float, bool))
519
+ }
520
+ except Exception:
521
+ pass
522
+
523
+ # Booster info
524
+ try:
525
+ if hasattr(model, "booster_"):
526
+ booster = model.booster_
527
+ if hasattr(booster, "num_trees"):
528
+ result["num_trees"] = booster.num_trees()
529
+ elif hasattr(model, "n_estimators"):
530
+ result["n_estimators"] = model.n_estimators
531
+ except Exception:
532
+ pass
533
+
534
+ # Feature importance
535
+ try:
536
+ if hasattr(model, "feature_importances_"):
537
+ result["num_features"] = len(model.feature_importances_)
538
+ result["has_feature_importances"] = True
539
+ except Exception:
540
+ pass
541
+
542
+ # Best iteration
543
+ try:
544
+ if hasattr(model, "best_iteration_"):
545
+ result["best_iteration"] = model.best_iteration_
546
+ if hasattr(model, "best_score_"):
547
+ result["best_score"] = model.best_score_
548
+ except Exception:
549
+ pass
550
+
551
+ return result
552
+
553
+ @staticmethod
554
+ def extract_huggingface_info(model: Any) -> dict[str, Any]:
555
+ """Extract metadata from a Hugging Face Transformers model."""
556
+ result = {"framework": "huggingface", "_auto_extracted": True}
557
+
558
+ try:
559
+ result["model_class"] = type(model).__name__
560
+ result["architecture"] = type(model).__name__
561
+ except Exception:
562
+ pass
563
+
564
+ # Config info
565
+ try:
566
+ if hasattr(model, "config"):
567
+ config = model.config
568
+ result["model_type"] = getattr(config, "model_type", None)
569
+ result["hidden_size"] = getattr(config, "hidden_size", None)
570
+ result["num_attention_heads"] = getattr(config, "num_attention_heads", None)
571
+ result["num_hidden_layers"] = getattr(config, "num_hidden_layers", None)
572
+ result["vocab_size"] = getattr(config, "vocab_size", None)
573
+ # Clean up None values
574
+ result = {k: v for k, v in result.items() if v is not None}
575
+ except Exception:
576
+ pass
577
+
578
+ # Parameter count
579
+ try:
580
+ if hasattr(model, "num_parameters"):
581
+ result["parameters"] = model.num_parameters()
582
+ elif hasattr(model, "parameters"):
583
+ result["parameters"] = sum(p.numel() for p in model.parameters())
584
+ except Exception:
585
+ pass
586
+
587
+ # Device
588
+ try:
589
+ if hasattr(model, "device"):
590
+ result["device"] = str(model.device)
591
+ except Exception:
592
+ pass
593
+
594
+ return result
595
+
596
+ @staticmethod
597
+ def extract_generic_info(model: Any) -> dict[str, Any]:
598
+ """Extract basic metadata from any model type.
599
+
600
+ This is the fallback for unknown/custom models.
601
+ """
602
+ result = {"framework": "custom", "_auto_extracted": True}
603
+
604
+ try:
605
+ result["model_class"] = type(model).__name__
606
+ result["module"] = type(model).__module__
607
+ except Exception:
608
+ pass
609
+
610
+ # Check for common model attributes
611
+ try:
612
+ # Fit/predict API (sklearn-like)
613
+ result["has_fit"] = hasattr(model, "fit")
614
+ result["has_predict"] = hasattr(model, "predict")
615
+ result["has_transform"] = hasattr(model, "transform")
616
+
617
+ # Parameters
618
+ if hasattr(model, "get_params"):
619
+ result["has_get_params"] = True
620
+
621
+ # State dict (PyTorch-like)
622
+ if hasattr(model, "state_dict"):
623
+ result["has_state_dict"] = True
624
+ except Exception:
625
+ pass
626
+
627
+ return result
628
+
629
+ @staticmethod
630
+ def extract_info(model: Any) -> dict[str, Any]:
631
+ """Auto-detect framework and extract model metadata.
632
+
633
+ This method is the main entry point for model metadata extraction.
634
+ It is designed to NEVER fail - if extraction fails for any reason,
635
+ it returns a minimal result with whatever could be extracted.
636
+
637
+ Args:
638
+ model: Any model object from any framework
639
+
640
+ Returns:
641
+ Dict with extracted metadata. Always includes:
642
+ - framework: Detected framework name
643
+ - _auto_extracted: Whether extraction succeeded
644
+ - model_class: Class name of the model (if available)
645
+ """
646
+ # Handle None model
647
+ if model is None:
648
+ return {"framework": None, "_auto_extracted": False, "error": "Model is None"}
649
+
650
+ # Detect framework
651
+ framework = ModelInspector.detect_framework(model)
652
+
653
+ # Extract based on framework
654
+ try:
655
+ if framework == "keras":
656
+ return ModelInspector.extract_keras_info(model)
657
+ elif framework == "tensorflow":
658
+ # TensorFlow models that aren't Keras
659
+ result = ModelInspector.extract_keras_info(model)
660
+ result["framework"] = "tensorflow"
661
+ return result
662
+ elif framework == "pytorch":
663
+ return ModelInspector.extract_pytorch_info(model)
664
+ elif framework == "sklearn":
665
+ return ModelInspector.extract_sklearn_info(model)
666
+ elif framework == "xgboost":
667
+ return ModelInspector.extract_xgboost_info(model)
668
+ elif framework == "lightgbm":
669
+ return ModelInspector.extract_lightgbm_info(model)
670
+ elif framework == "catboost":
671
+ # CatBoost is similar to XGBoost
672
+ result = ModelInspector.extract_xgboost_info(model)
673
+ result["framework"] = "catboost"
674
+ return result
675
+ elif framework == "huggingface":
676
+ return ModelInspector.extract_huggingface_info(model)
677
+ elif framework in ("jax", "onnx"):
678
+ # Basic extraction for JAX/ONNX
679
+ result = ModelInspector.extract_generic_info(model)
680
+ result["framework"] = framework
681
+ return result
682
+ else:
683
+ # Unknown or custom framework - use generic extraction
684
+ return ModelInspector.extract_generic_info(model)
685
+
686
+ except Exception as e:
687
+ # If all else fails, return minimal info
688
+ logger.debug(f"Error extracting model info: {e}")
689
+ return {
690
+ "framework": framework or "unknown",
691
+ "model_class": type(model).__name__,
692
+ "_auto_extracted": False,
693
+ "_extraction_error": str(e),
694
+ }
695
+
696
+ @staticmethod
697
+ def extract_training_history_from_callback(callback: Any) -> dict[str, list] | None:
698
+ """Extract training history from a FlowyML callback or Keras History object."""
699
+ try:
700
+ # FlowyML callback
701
+ if hasattr(callback, "get_training_history"):
702
+ return callback.get_training_history()
703
+
704
+ # Keras History object
705
+ if hasattr(callback, "history"):
706
+ history = callback.history
707
+ if isinstance(history, dict):
708
+ # Add epochs if not present
709
+ if "epochs" not in history and history:
710
+ first_key = next(iter(history.keys()))
711
+ history["epochs"] = list(range(1, len(history[first_key]) + 1))
712
+ return history
713
+
714
+ # Training history dict directly
715
+ if isinstance(callback, dict) and any(isinstance(v, list) for v in callback.values()):
716
+ return callback
717
+
718
+ except Exception as e:
719
+ logger.debug(f"Error extracting training history: {e}")
720
+
721
+ return None
722
+
6
723
 
7
724
  class Model(Asset):
8
- """Model asset with training metadata and lineage.
725
+ """Model asset with automatic metadata extraction and training history.
726
+
727
+ The Model class automatically extracts metadata from various ML frameworks,
728
+ reducing boilerplate code and improving UX. It also captures training history
729
+ for visualization in the FlowyML dashboard.
730
+
731
+ Supported frameworks:
732
+ - Keras/TensorFlow: Auto-extracts layers, parameters, optimizer, loss
733
+ - PyTorch: Auto-extracts modules, parameters, training mode
734
+ - Scikit-learn: Auto-extracts hyperparameters, feature importance
735
+ - XGBoost/LightGBM: Auto-extracts trees, hyperparameters
9
736
 
10
737
  Example:
11
- >>> model = Model(
12
- ... name="resnet50_v1",
13
- ... version="v1.0.0",
14
- ... data=trained_model,
15
- ... architecture="resnet50",
16
- ... framework="pytorch",
17
- ... properties={"params": 25_557_032},
738
+ >>> # Minimal usage - properties auto-extracted!
739
+ >>> model_asset = Model.create(
740
+ ... data=trained_keras_model,
741
+ ... name="my_model",
742
+ ... )
743
+ >>> print(model_asset.parameters) # Auto-extracted
744
+ >>> print(model_asset.framework) # Auto-detected
745
+
746
+ >>> # With FlowyML callback - training history auto-captured
747
+ >>> callback = FlowymlKerasCallback(experiment_name="demo")
748
+ >>> model.fit(X, y, callbacks=[callback])
749
+ >>> model_asset = Model.create(
750
+ ... data=model,
751
+ ... name="trained_model",
752
+ ... flowyml_callback=callback, # Auto-extracts training history!
18
753
  ... )
19
754
  """
20
755
 
@@ -31,35 +766,302 @@ class Model(Asset):
31
766
  parent: Asset | None = None,
32
767
  tags: dict[str, str] | None = None,
33
768
  properties: dict[str, Any] | None = None,
769
+ training_history: dict[str, list] | None = None,
770
+ auto_extract: bool = True,
34
771
  ):
772
+ """Initialize Model with automatic metadata extraction.
773
+
774
+ Args:
775
+ name: Model name
776
+ version: Version string
777
+ data: The model object (Keras, PyTorch, sklearn, etc.)
778
+ architecture: Architecture name (auto-detected if not provided)
779
+ framework: Framework name (auto-detected if not provided)
780
+ input_shape: Input shape (auto-detected for Keras)
781
+ output_shape: Output shape (auto-detected for Keras)
782
+ trained_on: Dataset this model was trained on
783
+ parent: Parent asset for lineage
784
+ tags: Metadata tags
785
+ properties: Additional properties (merged with auto-extracted)
786
+ training_history: Training metrics per epoch
787
+ auto_extract: Whether to auto-extract model metadata
788
+ """
789
+ # Initialize properties
790
+ final_properties = properties.copy() if properties else {}
791
+
792
+ # Auto-extract model metadata if enabled
793
+ if auto_extract and data is not None:
794
+ extracted = ModelInspector.extract_info(data)
795
+ # Merge - user-provided values take precedence
796
+ for key, value in extracted.items():
797
+ if key not in final_properties:
798
+ final_properties[key] = value
799
+
800
+ # Set framework from extracted if not provided
801
+ if framework is None and "framework" in extracted:
802
+ framework = extracted["framework"]
803
+
804
+ # Set architecture from extracted if not provided
805
+ if architecture is None and "architecture" in extracted:
806
+ architecture = extracted["architecture"]
807
+
35
808
  super().__init__(
36
809
  name=name,
37
810
  version=version,
38
811
  data=data,
39
812
  parent=parent,
40
813
  tags=tags,
41
- properties=properties,
814
+ properties=final_properties,
42
815
  )
43
816
 
44
817
  self.architecture = architecture
45
818
  self.framework = framework
46
819
  self.input_shape = input_shape
47
820
  self.output_shape = output_shape
821
+ self.training_history = training_history
48
822
 
49
823
  # Track training dataset
50
824
  if trained_on:
51
825
  self.parents.append(trained_on)
52
826
  trained_on.children.append(self)
53
827
 
54
- # Add model-specific properties
828
+ # Add model-specific properties (explicit ones override extracted)
55
829
  if architecture:
56
830
  self.metadata.properties["architecture"] = architecture
57
831
  if framework:
58
832
  self.metadata.properties["framework"] = framework
59
833
  if input_shape:
60
- self.metadata.properties["input_shape"] = input_shape
834
+ self.metadata.properties["input_shape"] = str(input_shape)
61
835
  if output_shape:
62
- self.metadata.properties["output_shape"] = output_shape
836
+ self.metadata.properties["output_shape"] = str(output_shape)
837
+
838
+ @classmethod
839
+ def create(
840
+ cls,
841
+ data: Any,
842
+ name: str | None = None,
843
+ version: str | None = None,
844
+ parent: "Asset | None" = None,
845
+ flowyml_callback: Any = None,
846
+ keras_history: Any = None,
847
+ auto_extract: bool = True,
848
+ **kwargs: Any,
849
+ ) -> "Model":
850
+ """Create a Model asset with automatic metadata extraction.
851
+
852
+ This is the preferred way to create Model objects. Metadata is
853
+ automatically extracted from the model, and training history can
854
+ be captured from FlowyML callbacks.
855
+
856
+ Args:
857
+ data: The model object (Keras, PyTorch, sklearn, etc.)
858
+ name: Asset name (auto-generated if not provided)
859
+ version: Asset version
860
+ parent: Parent asset for lineage
861
+ flowyml_callback: FlowymlKerasCallback for auto-capturing training history
862
+ keras_history: Keras History object from model.fit()
863
+ auto_extract: Whether to auto-extract model metadata
864
+ **kwargs: Additional parameters including:
865
+ - training_history: Dict of training metrics per epoch
866
+ - architecture: Model architecture name
867
+ - framework: ML framework (keras, pytorch, etc.)
868
+ - properties: Additional properties
869
+ - tags: Metadata tags
870
+
871
+ Returns:
872
+ New Model instance with auto-extracted metadata
873
+
874
+ Example:
875
+ >>> # Simple usage - everything auto-extracted
876
+ >>> model_asset = Model.create(data=model, name="my_model")
877
+
878
+ >>> # With FlowyML callback
879
+ >>> callback = FlowymlKerasCallback(experiment_name="demo")
880
+ >>> model.fit(X, y, callbacks=[callback])
881
+ >>> model_asset = Model.create(
882
+ ... data=model,
883
+ ... name="trained_model",
884
+ ... flowyml_callback=callback,
885
+ ... )
886
+
887
+ >>> # With Keras History
888
+ >>> history = model.fit(X, y)
889
+ >>> model_asset = Model.create(
890
+ ... data=model,
891
+ ... name="trained_model",
892
+ ... keras_history=history,
893
+ ... )
894
+ """
895
+ from datetime import datetime
896
+
897
+ asset_name = name or f"Model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
898
+
899
+ # Extract Model-specific parameters
900
+ training_history = kwargs.pop("training_history", None)
901
+ architecture = kwargs.pop("architecture", None)
902
+ framework = kwargs.pop("framework", None)
903
+ input_shape = kwargs.pop("input_shape", None)
904
+ output_shape = kwargs.pop("output_shape", None)
905
+ trained_on = kwargs.pop("trained_on", None)
906
+
907
+ # Auto-extract training history from callback or history object
908
+ if training_history is None:
909
+ if flowyml_callback is not None:
910
+ training_history = ModelInspector.extract_training_history_from_callback(
911
+ flowyml_callback,
912
+ )
913
+ elif keras_history is not None:
914
+ training_history = ModelInspector.extract_training_history_from_callback(
915
+ keras_history,
916
+ )
917
+
918
+ # Extract tags and properties
919
+ tags = kwargs.pop("tags", {})
920
+ props = kwargs.pop("properties", {})
921
+ # Merge remaining kwargs into properties
922
+ props.update(kwargs)
923
+
924
+ return cls(
925
+ name=asset_name,
926
+ version=version,
927
+ data=data,
928
+ architecture=architecture,
929
+ framework=framework,
930
+ input_shape=input_shape,
931
+ output_shape=output_shape,
932
+ trained_on=trained_on,
933
+ parent=parent,
934
+ tags=tags,
935
+ properties=props,
936
+ training_history=training_history,
937
+ auto_extract=auto_extract,
938
+ )
939
+
940
+ @classmethod
941
+ def from_keras(
942
+ cls,
943
+ model: Any,
944
+ name: str | None = None,
945
+ callback: Any = None,
946
+ history: Any = None,
947
+ **kwargs: Any,
948
+ ) -> "Model":
949
+ """Create a Model asset from a Keras model with full auto-extraction.
950
+
951
+ Args:
952
+ model: Keras model object
953
+ name: Asset name
954
+ callback: FlowymlKerasCallback for training history
955
+ history: Keras History object from model.fit()
956
+ **kwargs: Additional properties
957
+
958
+ Returns:
959
+ Model asset with auto-extracted Keras metadata
960
+ """
961
+ return cls.create(
962
+ data=model,
963
+ name=name,
964
+ framework="keras",
965
+ flowyml_callback=callback,
966
+ keras_history=history,
967
+ **kwargs,
968
+ )
969
+
970
+ @classmethod
971
+ def from_pytorch(
972
+ cls,
973
+ model: Any,
974
+ name: str | None = None,
975
+ training_history: dict | None = None,
976
+ **kwargs: Any,
977
+ ) -> "Model":
978
+ """Create a Model asset from a PyTorch model with full auto-extraction.
979
+
980
+ Args:
981
+ model: PyTorch model object (nn.Module)
982
+ name: Asset name
983
+ training_history: Training metrics dict
984
+ **kwargs: Additional properties
985
+
986
+ Returns:
987
+ Model asset with auto-extracted PyTorch metadata
988
+ """
989
+ return cls.create(
990
+ data=model,
991
+ name=name,
992
+ framework="pytorch",
993
+ training_history=training_history,
994
+ **kwargs,
995
+ )
996
+
997
+ @classmethod
998
+ def from_sklearn(
999
+ cls,
1000
+ model: Any,
1001
+ name: str | None = None,
1002
+ **kwargs: Any,
1003
+ ) -> "Model":
1004
+ """Create a Model asset from a scikit-learn model with full auto-extraction.
1005
+
1006
+ Args:
1007
+ model: Scikit-learn model object
1008
+ name: Asset name
1009
+ **kwargs: Additional properties
1010
+
1011
+ Returns:
1012
+ Model asset with auto-extracted sklearn metadata
1013
+ """
1014
+ return cls.create(
1015
+ data=model,
1016
+ name=name,
1017
+ framework="sklearn",
1018
+ **kwargs,
1019
+ )
1020
+
1021
+ @property
1022
+ def parameters(self) -> int | None:
1023
+ """Get number of model parameters (auto-extracted)."""
1024
+ return self.metadata.properties.get("parameters") or self.metadata.properties.get("params")
1025
+
1026
+ @property
1027
+ def trainable_parameters(self) -> int | None:
1028
+ """Get number of trainable parameters (auto-extracted)."""
1029
+ return self.metadata.properties.get("trainable_parameters")
1030
+
1031
+ @property
1032
+ def num_layers(self) -> int | None:
1033
+ """Get number of layers (auto-extracted)."""
1034
+ return self.metadata.properties.get("num_layers")
1035
+
1036
+ @property
1037
+ def layer_types(self) -> list[str] | None:
1038
+ """Get list of layer types (auto-extracted)."""
1039
+ return self.metadata.properties.get("layer_types")
1040
+
1041
+ @property
1042
+ def optimizer(self) -> str | None:
1043
+ """Get optimizer name (auto-extracted from Keras)."""
1044
+ return self.metadata.properties.get("optimizer")
1045
+
1046
+ @property
1047
+ def learning_rate(self) -> float | None:
1048
+ """Get learning rate (auto-extracted from Keras)."""
1049
+ return self.metadata.properties.get("learning_rate")
1050
+
1051
+ @property
1052
+ def loss_function(self) -> str | None:
1053
+ """Get loss function (auto-extracted from Keras)."""
1054
+ return self.metadata.properties.get("loss_function")
1055
+
1056
+ @property
1057
+ def metrics(self) -> list[str] | None:
1058
+ """Get metrics (auto-extracted from Keras)."""
1059
+ return self.metadata.properties.get("metrics")
1060
+
1061
+ @property
1062
+ def hyperparameters(self) -> dict | None:
1063
+ """Get hyperparameters (auto-extracted from sklearn/xgboost)."""
1064
+ return self.metadata.properties.get("hyperparameters")
63
1065
 
64
1066
  def get_training_datasets(self):
65
1067
  """Get all datasets this model was trained on."""
@@ -69,7 +1071,7 @@ class Model(Asset):
69
1071
 
70
1072
  def get_parameters_count(self) -> int | None:
71
1073
  """Get number of model parameters if available."""
72
- return self.metadata.properties.get("params") or self.metadata.properties.get("parameters")
1074
+ return self.parameters
73
1075
 
74
1076
  def get_architecture_info(self) -> dict[str, Any]:
75
1077
  """Get architecture information."""
@@ -78,5 +1080,40 @@ class Model(Asset):
78
1080
  "framework": self.framework,
79
1081
  "input_shape": self.input_shape,
80
1082
  "output_shape": self.output_shape,
81
- "parameters": self.get_parameters_count(),
1083
+ "parameters": self.parameters,
1084
+ "trainable_parameters": self.trainable_parameters,
1085
+ "num_layers": self.num_layers,
1086
+ "layer_types": self.layer_types,
1087
+ }
1088
+
1089
+ def get_training_info(self) -> dict[str, Any]:
1090
+ """Get training information."""
1091
+ result = {
1092
+ "optimizer": self.optimizer,
1093
+ "learning_rate": self.learning_rate,
1094
+ "loss_function": self.loss_function,
1095
+ "metrics": self.metrics,
82
1096
  }
1097
+
1098
+ if self.training_history:
1099
+ epochs = self.training_history.get("epochs", [])
1100
+ result["epochs_trained"] = len(epochs)
1101
+
1102
+ # Get final metrics
1103
+ for key, values in self.training_history.items():
1104
+ if key != "epochs" and values:
1105
+ result[f"final_{key}"] = values[-1]
1106
+
1107
+ return {k: v for k, v in result.items() if v is not None}
1108
+
1109
+ def __repr__(self) -> str:
1110
+ """String representation with key info."""
1111
+ parts = [f"Model(name='{self.name}'"]
1112
+ if self.framework:
1113
+ parts.append(f"framework='{self.framework}'")
1114
+ if self.parameters:
1115
+ parts.append(f"params={self.parameters:,}")
1116
+ if self.training_history:
1117
+ epochs = len(self.training_history.get("epochs", []))
1118
+ parts.append(f"epochs={epochs}")
1119
+ return ", ".join(parts) + ")"