flowyml 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowyml/assets/dataset.py +570 -17
- flowyml/assets/model.py +1052 -15
- flowyml/core/executor.py +70 -11
- flowyml/core/orchestrator.py +37 -2
- flowyml/core/pipeline.py +32 -4
- flowyml/core/scheduler.py +88 -5
- flowyml/integrations/keras.py +247 -82
- flowyml/storage/sql.py +24 -6
- flowyml/ui/backend/routers/runs.py +112 -0
- flowyml/ui/backend/routers/schedules.py +35 -15
- flowyml/ui/frontend/dist/assets/index-B40RsQDq.css +1 -0
- flowyml/ui/frontend/dist/assets/index-CjI0zKCn.js +685 -0
- flowyml/ui/frontend/dist/index.html +2 -2
- flowyml/ui/frontend/package-lock.json +11 -0
- flowyml/ui/frontend/package.json +1 -0
- flowyml/ui/frontend/src/app/assets/page.jsx +890 -321
- flowyml/ui/frontend/src/app/dashboard/page.jsx +1 -1
- flowyml/ui/frontend/src/app/experiments/[experimentId]/page.jsx +1 -1
- flowyml/ui/frontend/src/app/leaderboard/page.jsx +1 -1
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectMetricsPanel.jsx +1 -1
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRunsList.jsx +3 -3
- flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +590 -102
- flowyml/ui/frontend/src/components/ArtifactViewer.jsx +62 -2
- flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +401 -28
- flowyml/ui/frontend/src/components/AssetTreeHierarchy.jsx +119 -11
- flowyml/ui/frontend/src/components/DatasetViewer.jsx +753 -0
- flowyml/ui/frontend/src/components/TrainingHistoryChart.jsx +514 -0
- flowyml/ui/frontend/src/components/TrainingMetricsPanel.jsx +175 -0
- {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/METADATA +1 -1
- {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/RECORD +33 -30
- flowyml/ui/frontend/dist/assets/index-By4trVyv.css +0 -1
- flowyml/ui/frontend/dist/assets/index-CX5RV2C9.js +0 -630
- {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/WHEEL +0 -0
- {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/entry_points.txt +0 -0
- {flowyml-1.7.0.dist-info → flowyml-1.7.2.dist-info}/licenses/LICENSE +0 -0
flowyml/assets/model.py
CHANGED
|
@@ -1,20 +1,755 @@
|
|
|
1
|
-
"""Model Asset - Represents ML models with metadata
|
|
1
|
+
"""Model Asset - Represents ML models with automatic metadata extraction."""
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
|
+
import contextlib
|
|
5
|
+
import logging
|
|
6
|
+
|
|
4
7
|
from flowyml.assets.base import Asset
|
|
5
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelInspector:
|
|
13
|
+
"""Utility class for extracting model metadata from various frameworks.
|
|
14
|
+
|
|
15
|
+
This class provides robust, fail-safe extraction of model metadata.
|
|
16
|
+
It will never raise exceptions - if extraction fails, it returns
|
|
17
|
+
partial results with whatever could be extracted.
|
|
18
|
+
|
|
19
|
+
Supported frameworks with rich extraction:
|
|
20
|
+
- Keras/TensorFlow: Full architecture, optimizer, loss, metrics
|
|
21
|
+
- PyTorch: Architecture, parameters, layer info
|
|
22
|
+
- Scikit-learn: Hyperparameters, feature importance
|
|
23
|
+
- XGBoost/LightGBM/CatBoost: Trees, hyperparameters
|
|
24
|
+
|
|
25
|
+
All other model types are supported with basic metadata (type name, etc.)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def detect_framework(model: Any) -> str | None:
|
|
30
|
+
"""Detect the ML framework of a model.
|
|
31
|
+
|
|
32
|
+
Returns one of: 'keras', 'tensorflow', 'pytorch', 'sklearn', 'xgboost',
|
|
33
|
+
'lightgbm', 'catboost', 'huggingface', 'onnx', 'custom', or None
|
|
34
|
+
|
|
35
|
+
This method is safe to call on any object and will never raise.
|
|
36
|
+
"""
|
|
37
|
+
if model is None:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
type_name = type(model).__name__
|
|
42
|
+
module_name = type(model).__module__
|
|
43
|
+
|
|
44
|
+
# Keras/TensorFlow - check multiple indicators
|
|
45
|
+
if any(x in module_name.lower() for x in ["keras", "tf.keras"]):
|
|
46
|
+
return "keras"
|
|
47
|
+
if type_name in ("Sequential", "Functional", "Model") and "tensorflow" in module_name.lower():
|
|
48
|
+
return "keras"
|
|
49
|
+
if "tensorflow" in module_name.lower():
|
|
50
|
+
return "tensorflow"
|
|
51
|
+
|
|
52
|
+
# PyTorch - check for nn.Module inheritance (handles user-defined models)
|
|
53
|
+
# Check if any base class is from torch
|
|
54
|
+
try:
|
|
55
|
+
for base in type(model).__mro__:
|
|
56
|
+
base_module = base.__module__
|
|
57
|
+
if "torch" in base_module.lower():
|
|
58
|
+
return "pytorch"
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
# Also check module name directly (for torch tensors, etc.)
|
|
63
|
+
if "torch" in module_name.lower():
|
|
64
|
+
return "pytorch"
|
|
65
|
+
|
|
66
|
+
# Check for PyTorch-specific attributes
|
|
67
|
+
if hasattr(model, "forward") and hasattr(model, "parameters") and hasattr(model, "state_dict"):
|
|
68
|
+
# Likely a PyTorch model
|
|
69
|
+
return "pytorch"
|
|
70
|
+
|
|
71
|
+
# Scikit-learn - check for common base classes
|
|
72
|
+
if "sklearn" in module_name.lower():
|
|
73
|
+
return "sklearn"
|
|
74
|
+
|
|
75
|
+
# XGBoost
|
|
76
|
+
if "xgboost" in module_name.lower() or type_name.startswith("XGB"):
|
|
77
|
+
return "xgboost"
|
|
78
|
+
|
|
79
|
+
# LightGBM
|
|
80
|
+
if "lightgbm" in module_name.lower() or type_name.startswith("LGB"):
|
|
81
|
+
return "lightgbm"
|
|
82
|
+
|
|
83
|
+
# CatBoost
|
|
84
|
+
if "catboost" in module_name.lower():
|
|
85
|
+
return "catboost"
|
|
86
|
+
|
|
87
|
+
# Hugging Face Transformers
|
|
88
|
+
if "transformers" in module_name.lower():
|
|
89
|
+
return "huggingface"
|
|
90
|
+
|
|
91
|
+
# ONNX
|
|
92
|
+
if "onnx" in module_name.lower():
|
|
93
|
+
return "onnx"
|
|
94
|
+
|
|
95
|
+
# JAX/Flax
|
|
96
|
+
if "flax" in module_name.lower() or "jax" in module_name.lower():
|
|
97
|
+
return "jax"
|
|
98
|
+
|
|
99
|
+
# Check for common ML model attributes
|
|
100
|
+
if hasattr(model, "predict") and hasattr(model, "fit"):
|
|
101
|
+
return "sklearn" # Sklearn-like API
|
|
102
|
+
|
|
103
|
+
return "custom"
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.debug(f"Error detecting framework: {e}")
|
|
107
|
+
return "unknown"
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def extract_keras_info(model: Any) -> dict[str, Any]:
|
|
111
|
+
"""Extract metadata from a Keras/TensorFlow model.
|
|
112
|
+
|
|
113
|
+
This method is robust and will extract as much as possible,
|
|
114
|
+
even if some attributes are not available (e.g., uncompiled model).
|
|
115
|
+
"""
|
|
116
|
+
result = {"framework": "keras", "_auto_extracted": True}
|
|
117
|
+
|
|
118
|
+
# Parameter count - handle both built and unbuilt models
|
|
119
|
+
try:
|
|
120
|
+
if hasattr(model, "count_params"):
|
|
121
|
+
with contextlib.suppress(Exception):
|
|
122
|
+
result["parameters"] = model.count_params()
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.debug(f"Error getting param count: {e}")
|
|
125
|
+
|
|
126
|
+
# Trainable parameters
|
|
127
|
+
try:
|
|
128
|
+
if hasattr(model, "trainable_weights") and model.trainable_weights:
|
|
129
|
+
trainable = sum(int(w.numpy().size) if hasattr(w, "numpy") else 0 for w in model.trainable_weights)
|
|
130
|
+
if trainable > 0:
|
|
131
|
+
result["trainable_parameters"] = trainable
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.debug(f"Error getting trainable params: {e}")
|
|
134
|
+
|
|
135
|
+
# Architecture name
|
|
136
|
+
try:
|
|
137
|
+
if hasattr(model, "name") and model.name:
|
|
138
|
+
result["architecture"] = model.name
|
|
139
|
+
result["model_class"] = type(model).__name__
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.debug(f"Error getting architecture: {e}")
|
|
142
|
+
|
|
143
|
+
# Layer info - handle models without layers attribute
|
|
144
|
+
try:
|
|
145
|
+
if hasattr(model, "layers"):
|
|
146
|
+
layers = model.layers
|
|
147
|
+
if layers:
|
|
148
|
+
result["num_layers"] = len(layers)
|
|
149
|
+
result["layer_types"] = list({type(layer).__name__ for layer in layers})
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.debug(f"Error getting layer info: {e}")
|
|
152
|
+
|
|
153
|
+
# Input/Output shapes - handle unbuilt models
|
|
154
|
+
with contextlib.suppress(Exception):
|
|
155
|
+
if hasattr(model, "input_shape"):
|
|
156
|
+
input_shape = model.input_shape
|
|
157
|
+
if input_shape is not None:
|
|
158
|
+
result["input_shape"] = str(input_shape)
|
|
159
|
+
|
|
160
|
+
with contextlib.suppress(Exception):
|
|
161
|
+
if hasattr(model, "output_shape"):
|
|
162
|
+
output_shape = model.output_shape
|
|
163
|
+
if output_shape is not None:
|
|
164
|
+
result["output_shape"] = str(output_shape)
|
|
165
|
+
|
|
166
|
+
# Optimizer info (only for compiled models)
|
|
167
|
+
try:
|
|
168
|
+
if hasattr(model, "optimizer") and model.optimizer is not None:
|
|
169
|
+
opt = model.optimizer
|
|
170
|
+
result["optimizer"] = type(opt).__name__
|
|
171
|
+
|
|
172
|
+
# Get learning rate - handle different Keras versions
|
|
173
|
+
if hasattr(opt, "learning_rate"):
|
|
174
|
+
lr = opt.learning_rate
|
|
175
|
+
if hasattr(lr, "numpy"):
|
|
176
|
+
lr = float(lr.numpy())
|
|
177
|
+
elif callable(lr):
|
|
178
|
+
# Learning rate schedule
|
|
179
|
+
result["lr_schedule"] = type(lr).__name__
|
|
180
|
+
else:
|
|
181
|
+
lr = float(lr)
|
|
182
|
+
if isinstance(lr, (int, float)):
|
|
183
|
+
result["learning_rate"] = lr
|
|
184
|
+
elif hasattr(opt, "lr"):
|
|
185
|
+
# Older Keras versions
|
|
186
|
+
result["learning_rate"] = float(opt.lr)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.debug(f"Error getting optimizer info: {e}")
|
|
189
|
+
|
|
190
|
+
# Loss function (only for compiled models)
|
|
191
|
+
try:
|
|
192
|
+
if hasattr(model, "loss") and model.loss is not None:
|
|
193
|
+
loss = model.loss
|
|
194
|
+
if isinstance(loss, str):
|
|
195
|
+
result["loss_function"] = loss
|
|
196
|
+
elif hasattr(loss, "__name__"):
|
|
197
|
+
result["loss_function"] = loss.__name__
|
|
198
|
+
elif hasattr(loss, "name"):
|
|
199
|
+
result["loss_function"] = loss.name
|
|
200
|
+
elif hasattr(loss, "__class__"):
|
|
201
|
+
result["loss_function"] = type(loss).__name__
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.debug(f"Error getting loss function: {e}")
|
|
204
|
+
|
|
205
|
+
# Metrics (only for compiled models)
|
|
206
|
+
try:
|
|
207
|
+
if hasattr(model, "metrics_names") and model.metrics_names:
|
|
208
|
+
result["metrics"] = list(model.metrics_names)
|
|
209
|
+
elif hasattr(model, "metrics") and model.metrics:
|
|
210
|
+
result["metrics"] = [m.name if hasattr(m, "name") else str(m) for m in model.metrics]
|
|
211
|
+
except Exception as e:
|
|
212
|
+
logger.debug(f"Error getting metrics: {e}")
|
|
213
|
+
|
|
214
|
+
# Check if model is compiled
|
|
215
|
+
with contextlib.suppress(Exception):
|
|
216
|
+
result["is_compiled"] = hasattr(model, "optimizer") and model.optimizer is not None
|
|
217
|
+
|
|
218
|
+
# Check if model is built
|
|
219
|
+
with contextlib.suppress(Exception):
|
|
220
|
+
if hasattr(model, "built"):
|
|
221
|
+
result["is_built"] = model.built
|
|
222
|
+
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def extract_pytorch_info(model: Any) -> dict[str, Any]:
|
|
227
|
+
"""Extract metadata from a PyTorch model.
|
|
228
|
+
|
|
229
|
+
This method is robust and handles:
|
|
230
|
+
- nn.Module models
|
|
231
|
+
- Custom modules
|
|
232
|
+
- Pretrained models from torchvision, transformers, etc.
|
|
233
|
+
- Models in eval or train mode
|
|
234
|
+
"""
|
|
235
|
+
result = {"framework": "pytorch", "_auto_extracted": True}
|
|
236
|
+
|
|
237
|
+
# Model class name
|
|
238
|
+
try:
|
|
239
|
+
result["model_class"] = type(model).__name__
|
|
240
|
+
result["architecture"] = type(model).__name__
|
|
241
|
+
except Exception:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
# Parameter count
|
|
245
|
+
try:
|
|
246
|
+
if hasattr(model, "parameters"):
|
|
247
|
+
params = list(model.parameters())
|
|
248
|
+
if params:
|
|
249
|
+
total_params = sum(p.numel() for p in params)
|
|
250
|
+
trainable_params = sum(p.numel() for p in params if p.requires_grad)
|
|
251
|
+
result["parameters"] = total_params
|
|
252
|
+
result["trainable_parameters"] = trainable_params
|
|
253
|
+
result["frozen_parameters"] = total_params - trainable_params
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.debug(f"Error getting PyTorch param count: {e}")
|
|
256
|
+
|
|
257
|
+
# Layer info from modules
|
|
258
|
+
try:
|
|
259
|
+
if hasattr(model, "modules"):
|
|
260
|
+
modules = list(model.modules())
|
|
261
|
+
if modules:
|
|
262
|
+
# Skip the first module (the model itself)
|
|
263
|
+
layer_modules = modules[1:] if len(modules) > 1 else modules
|
|
264
|
+
result["num_layers"] = len(layer_modules)
|
|
265
|
+
|
|
266
|
+
# Get unique layer types
|
|
267
|
+
layer_types = set()
|
|
268
|
+
for m in layer_modules:
|
|
269
|
+
layer_type = type(m).__name__
|
|
270
|
+
# Skip container modules
|
|
271
|
+
if layer_type not in ("Sequential", "ModuleList", "ModuleDict"):
|
|
272
|
+
layer_types.add(layer_type)
|
|
273
|
+
result["layer_types"] = list(layer_types)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.debug(f"Error getting PyTorch layer info: {e}")
|
|
276
|
+
|
|
277
|
+
# Named modules for architecture insights
|
|
278
|
+
try:
|
|
279
|
+
if hasattr(model, "named_modules"):
|
|
280
|
+
named = dict(model.named_modules())
|
|
281
|
+
if named:
|
|
282
|
+
result["num_named_modules"] = len(named)
|
|
283
|
+
except Exception:
|
|
284
|
+
pass
|
|
285
|
+
|
|
286
|
+
# State dict info
|
|
287
|
+
try:
|
|
288
|
+
if hasattr(model, "state_dict"):
|
|
289
|
+
state_dict = model.state_dict()
|
|
290
|
+
if state_dict:
|
|
291
|
+
result["num_tensors"] = len(state_dict)
|
|
292
|
+
# Get tensor shapes for key layers
|
|
293
|
+
tensor_shapes = {}
|
|
294
|
+
for name, tensor in list(state_dict.items())[:10]: # First 10 only
|
|
295
|
+
tensor_shapes[name] = list(tensor.shape)
|
|
296
|
+
result["tensor_shapes_sample"] = tensor_shapes
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.debug(f"Error getting PyTorch state dict: {e}")
|
|
299
|
+
|
|
300
|
+
# Training mode
|
|
301
|
+
try:
|
|
302
|
+
if hasattr(model, "training"):
|
|
303
|
+
result["training_mode"] = model.training
|
|
304
|
+
except Exception:
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
# Device info
|
|
308
|
+
try:
|
|
309
|
+
if hasattr(model, "parameters"):
|
|
310
|
+
first_param = next(model.parameters(), None)
|
|
311
|
+
if first_param is not None:
|
|
312
|
+
result["device"] = str(first_param.device)
|
|
313
|
+
result["dtype"] = str(first_param.dtype)
|
|
314
|
+
except Exception:
|
|
315
|
+
pass
|
|
316
|
+
|
|
317
|
+
# Check for common PyTorch model attributes
|
|
318
|
+
try:
|
|
319
|
+
# Input features (common in many models)
|
|
320
|
+
for attr in ["in_features", "in_channels", "input_size", "num_features"]:
|
|
321
|
+
if hasattr(model, attr):
|
|
322
|
+
val = getattr(model, attr, None)
|
|
323
|
+
if val is not None:
|
|
324
|
+
result[attr] = val
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
# Output features
|
|
328
|
+
for attr in ["out_features", "out_channels", "output_size", "num_classes"]:
|
|
329
|
+
if hasattr(model, attr):
|
|
330
|
+
val = getattr(model, attr, None)
|
|
331
|
+
if val is not None:
|
|
332
|
+
result[attr] = val
|
|
333
|
+
break
|
|
334
|
+
except Exception:
|
|
335
|
+
pass
|
|
336
|
+
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
@staticmethod
|
|
340
|
+
def extract_sklearn_info(model: Any) -> dict[str, Any]:
|
|
341
|
+
"""Extract metadata from a scikit-learn model.
|
|
342
|
+
|
|
343
|
+
Handles all sklearn estimators including:
|
|
344
|
+
- Classifiers, regressors, transformers
|
|
345
|
+
- Ensemble methods (RandomForest, GradientBoosting, etc.)
|
|
346
|
+
- Linear models
|
|
347
|
+
- Pipelines
|
|
348
|
+
"""
|
|
349
|
+
result = {"framework": "sklearn", "_auto_extracted": True}
|
|
350
|
+
|
|
351
|
+
# Model type
|
|
352
|
+
try:
|
|
353
|
+
result["model_class"] = type(model).__name__
|
|
354
|
+
result["architecture"] = type(model).__name__
|
|
355
|
+
except Exception:
|
|
356
|
+
pass
|
|
357
|
+
|
|
358
|
+
# Get parameters (safe extraction)
|
|
359
|
+
try:
|
|
360
|
+
if hasattr(model, "get_params"):
|
|
361
|
+
params = model.get_params(deep=False) # Shallow to avoid recursion
|
|
362
|
+
# Filter to serializable values
|
|
363
|
+
filtered_params = {}
|
|
364
|
+
for k, v in params.items():
|
|
365
|
+
if v is None:
|
|
366
|
+
continue
|
|
367
|
+
if isinstance(v, (str, int, float, bool)):
|
|
368
|
+
filtered_params[k] = v
|
|
369
|
+
elif isinstance(v, (list, tuple)) and len(v) < 10:
|
|
370
|
+
# Small lists/tuples of primitives
|
|
371
|
+
if all(isinstance(x, (str, int, float, bool, type(None))) for x in v):
|
|
372
|
+
filtered_params[k] = list(v)
|
|
373
|
+
if filtered_params:
|
|
374
|
+
result["hyperparameters"] = filtered_params
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.debug(f"Error getting sklearn params: {e}")
|
|
377
|
+
|
|
378
|
+
# Feature importance (tree-based models)
|
|
379
|
+
with contextlib.suppress(Exception):
|
|
380
|
+
if hasattr(model, "feature_importances_"):
|
|
381
|
+
importances = model.feature_importances_
|
|
382
|
+
result["has_feature_importances"] = True
|
|
383
|
+
result["num_features"] = len(importances)
|
|
384
|
+
# Store top 10 feature importances
|
|
385
|
+
if len(importances) <= 20:
|
|
386
|
+
result["feature_importances"] = list(importances)
|
|
387
|
+
|
|
388
|
+
# Coefficients (linear models)
|
|
389
|
+
with contextlib.suppress(Exception):
|
|
390
|
+
if hasattr(model, "coef_"):
|
|
391
|
+
coef = model.coef_
|
|
392
|
+
if hasattr(coef, "shape"):
|
|
393
|
+
result["coef_shape"] = str(coef.shape)
|
|
394
|
+
if hasattr(coef, "size") and coef.size <= 100:
|
|
395
|
+
result["num_coefficients"] = int(coef.size)
|
|
396
|
+
|
|
397
|
+
# Intercept
|
|
398
|
+
with contextlib.suppress(Exception):
|
|
399
|
+
if hasattr(model, "intercept_"):
|
|
400
|
+
intercept = model.intercept_
|
|
401
|
+
if hasattr(intercept, "tolist"):
|
|
402
|
+
intercept = intercept.tolist()
|
|
403
|
+
if isinstance(intercept, (int, float)) or (isinstance(intercept, list) and len(intercept) <= 10):
|
|
404
|
+
result["intercept"] = intercept
|
|
405
|
+
|
|
406
|
+
# Classes (classifiers)
|
|
407
|
+
try:
|
|
408
|
+
if hasattr(model, "classes_"):
|
|
409
|
+
classes = model.classes_
|
|
410
|
+
result["num_classes"] = len(classes)
|
|
411
|
+
if len(classes) <= 20:
|
|
412
|
+
# Convert to list if numpy array
|
|
413
|
+
if hasattr(classes, "tolist"):
|
|
414
|
+
classes = classes.tolist()
|
|
415
|
+
result["classes"] = list(classes)
|
|
416
|
+
except Exception:
|
|
417
|
+
pass
|
|
418
|
+
|
|
419
|
+
# Number of estimators (ensemble models)
|
|
420
|
+
try:
|
|
421
|
+
if hasattr(model, "n_estimators"):
|
|
422
|
+
result["n_estimators"] = model.n_estimators
|
|
423
|
+
if hasattr(model, "estimators_"):
|
|
424
|
+
result["num_estimators_fitted"] = len(model.estimators_)
|
|
425
|
+
except Exception:
|
|
426
|
+
pass
|
|
427
|
+
|
|
428
|
+
# Tree-specific attributes
|
|
429
|
+
try:
|
|
430
|
+
if hasattr(model, "max_depth"):
|
|
431
|
+
result["max_depth"] = model.max_depth
|
|
432
|
+
if hasattr(model, "n_features_in_"):
|
|
433
|
+
result["n_features_in"] = model.n_features_in_
|
|
434
|
+
except Exception:
|
|
435
|
+
pass
|
|
436
|
+
|
|
437
|
+
# Check if fitted
|
|
438
|
+
try:
|
|
439
|
+
# Common sklearn pattern for checking if fitted
|
|
440
|
+
from sklearn.utils.validation import check_is_fitted
|
|
441
|
+
|
|
442
|
+
check_is_fitted(model)
|
|
443
|
+
result["is_fitted"] = True
|
|
444
|
+
except Exception:
|
|
445
|
+
result["is_fitted"] = False
|
|
446
|
+
|
|
447
|
+
return result
|
|
448
|
+
|
|
449
|
+
@staticmethod
|
|
450
|
+
def extract_xgboost_info(model: Any) -> dict[str, Any]:
|
|
451
|
+
"""Extract metadata from an XGBoost model."""
|
|
452
|
+
result = {"framework": "xgboost", "_auto_extracted": True}
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
result["model_class"] = type(model).__name__
|
|
456
|
+
result["architecture"] = "XGBoost"
|
|
457
|
+
except Exception:
|
|
458
|
+
pass
|
|
459
|
+
|
|
460
|
+
# Get hyperparameters
|
|
461
|
+
try:
|
|
462
|
+
if hasattr(model, "get_params"):
|
|
463
|
+
params = model.get_params()
|
|
464
|
+
result["hyperparameters"] = {
|
|
465
|
+
k: v for k, v in params.items() if v is not None and isinstance(v, (str, int, float, bool))
|
|
466
|
+
}
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.debug(f"Error getting XGBoost params: {e}")
|
|
469
|
+
|
|
470
|
+
# Booster info
|
|
471
|
+
try:
|
|
472
|
+
if hasattr(model, "get_booster"):
|
|
473
|
+
booster = model.get_booster()
|
|
474
|
+
if hasattr(booster, "num_trees"):
|
|
475
|
+
result["num_trees"] = booster.num_trees()
|
|
476
|
+
if hasattr(booster, "num_features"):
|
|
477
|
+
result["num_features"] = booster.num_features()
|
|
478
|
+
elif hasattr(model, "n_estimators"):
|
|
479
|
+
result["n_estimators"] = model.n_estimators
|
|
480
|
+
except Exception:
|
|
481
|
+
pass
|
|
482
|
+
|
|
483
|
+
# Feature importance
|
|
484
|
+
try:
|
|
485
|
+
if hasattr(model, "feature_importances_"):
|
|
486
|
+
result["num_features"] = len(model.feature_importances_)
|
|
487
|
+
result["has_feature_importances"] = True
|
|
488
|
+
except Exception:
|
|
489
|
+
pass
|
|
490
|
+
|
|
491
|
+
# Best iteration (for early stopping)
|
|
492
|
+
try:
|
|
493
|
+
if hasattr(model, "best_iteration"):
|
|
494
|
+
result["best_iteration"] = model.best_iteration
|
|
495
|
+
if hasattr(model, "best_score"):
|
|
496
|
+
result["best_score"] = model.best_score
|
|
497
|
+
except Exception:
|
|
498
|
+
pass
|
|
499
|
+
|
|
500
|
+
return result
|
|
501
|
+
|
|
502
|
+
@staticmethod
|
|
503
|
+
def extract_lightgbm_info(model: Any) -> dict[str, Any]:
|
|
504
|
+
"""Extract metadata from a LightGBM model."""
|
|
505
|
+
result = {"framework": "lightgbm", "_auto_extracted": True}
|
|
506
|
+
|
|
507
|
+
try:
|
|
508
|
+
result["model_class"] = type(model).__name__
|
|
509
|
+
result["architecture"] = "LightGBM"
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
512
|
+
|
|
513
|
+
# Get hyperparameters
|
|
514
|
+
try:
|
|
515
|
+
if hasattr(model, "get_params"):
|
|
516
|
+
params = model.get_params()
|
|
517
|
+
result["hyperparameters"] = {
|
|
518
|
+
k: v for k, v in params.items() if v is not None and isinstance(v, (str, int, float, bool))
|
|
519
|
+
}
|
|
520
|
+
except Exception:
|
|
521
|
+
pass
|
|
522
|
+
|
|
523
|
+
# Booster info
|
|
524
|
+
try:
|
|
525
|
+
if hasattr(model, "booster_"):
|
|
526
|
+
booster = model.booster_
|
|
527
|
+
if hasattr(booster, "num_trees"):
|
|
528
|
+
result["num_trees"] = booster.num_trees()
|
|
529
|
+
elif hasattr(model, "n_estimators"):
|
|
530
|
+
result["n_estimators"] = model.n_estimators
|
|
531
|
+
except Exception:
|
|
532
|
+
pass
|
|
533
|
+
|
|
534
|
+
# Feature importance
|
|
535
|
+
try:
|
|
536
|
+
if hasattr(model, "feature_importances_"):
|
|
537
|
+
result["num_features"] = len(model.feature_importances_)
|
|
538
|
+
result["has_feature_importances"] = True
|
|
539
|
+
except Exception:
|
|
540
|
+
pass
|
|
541
|
+
|
|
542
|
+
# Best iteration
|
|
543
|
+
try:
|
|
544
|
+
if hasattr(model, "best_iteration_"):
|
|
545
|
+
result["best_iteration"] = model.best_iteration_
|
|
546
|
+
if hasattr(model, "best_score_"):
|
|
547
|
+
result["best_score"] = model.best_score_
|
|
548
|
+
except Exception:
|
|
549
|
+
pass
|
|
550
|
+
|
|
551
|
+
return result
|
|
552
|
+
|
|
553
|
+
@staticmethod
|
|
554
|
+
def extract_huggingface_info(model: Any) -> dict[str, Any]:
|
|
555
|
+
"""Extract metadata from a Hugging Face Transformers model."""
|
|
556
|
+
result = {"framework": "huggingface", "_auto_extracted": True}
|
|
557
|
+
|
|
558
|
+
try:
|
|
559
|
+
result["model_class"] = type(model).__name__
|
|
560
|
+
result["architecture"] = type(model).__name__
|
|
561
|
+
except Exception:
|
|
562
|
+
pass
|
|
563
|
+
|
|
564
|
+
# Config info
|
|
565
|
+
try:
|
|
566
|
+
if hasattr(model, "config"):
|
|
567
|
+
config = model.config
|
|
568
|
+
result["model_type"] = getattr(config, "model_type", None)
|
|
569
|
+
result["hidden_size"] = getattr(config, "hidden_size", None)
|
|
570
|
+
result["num_attention_heads"] = getattr(config, "num_attention_heads", None)
|
|
571
|
+
result["num_hidden_layers"] = getattr(config, "num_hidden_layers", None)
|
|
572
|
+
result["vocab_size"] = getattr(config, "vocab_size", None)
|
|
573
|
+
# Clean up None values
|
|
574
|
+
result = {k: v for k, v in result.items() if v is not None}
|
|
575
|
+
except Exception:
|
|
576
|
+
pass
|
|
577
|
+
|
|
578
|
+
# Parameter count
|
|
579
|
+
try:
|
|
580
|
+
if hasattr(model, "num_parameters"):
|
|
581
|
+
result["parameters"] = model.num_parameters()
|
|
582
|
+
elif hasattr(model, "parameters"):
|
|
583
|
+
result["parameters"] = sum(p.numel() for p in model.parameters())
|
|
584
|
+
except Exception:
|
|
585
|
+
pass
|
|
586
|
+
|
|
587
|
+
# Device
|
|
588
|
+
try:
|
|
589
|
+
if hasattr(model, "device"):
|
|
590
|
+
result["device"] = str(model.device)
|
|
591
|
+
except Exception:
|
|
592
|
+
pass
|
|
593
|
+
|
|
594
|
+
return result
|
|
595
|
+
|
|
596
|
+
@staticmethod
|
|
597
|
+
def extract_generic_info(model: Any) -> dict[str, Any]:
|
|
598
|
+
"""Extract basic metadata from any model type.
|
|
599
|
+
|
|
600
|
+
This is the fallback for unknown/custom models.
|
|
601
|
+
"""
|
|
602
|
+
result = {"framework": "custom", "_auto_extracted": True}
|
|
603
|
+
|
|
604
|
+
try:
|
|
605
|
+
result["model_class"] = type(model).__name__
|
|
606
|
+
result["module"] = type(model).__module__
|
|
607
|
+
except Exception:
|
|
608
|
+
pass
|
|
609
|
+
|
|
610
|
+
# Check for common model attributes
|
|
611
|
+
try:
|
|
612
|
+
# Fit/predict API (sklearn-like)
|
|
613
|
+
result["has_fit"] = hasattr(model, "fit")
|
|
614
|
+
result["has_predict"] = hasattr(model, "predict")
|
|
615
|
+
result["has_transform"] = hasattr(model, "transform")
|
|
616
|
+
|
|
617
|
+
# Parameters
|
|
618
|
+
if hasattr(model, "get_params"):
|
|
619
|
+
result["has_get_params"] = True
|
|
620
|
+
|
|
621
|
+
# State dict (PyTorch-like)
|
|
622
|
+
if hasattr(model, "state_dict"):
|
|
623
|
+
result["has_state_dict"] = True
|
|
624
|
+
except Exception:
|
|
625
|
+
pass
|
|
626
|
+
|
|
627
|
+
return result
|
|
628
|
+
|
|
629
|
+
@staticmethod
|
|
630
|
+
def extract_info(model: Any) -> dict[str, Any]:
|
|
631
|
+
"""Auto-detect framework and extract model metadata.
|
|
632
|
+
|
|
633
|
+
This method is the main entry point for model metadata extraction.
|
|
634
|
+
It is designed to NEVER fail - if extraction fails for any reason,
|
|
635
|
+
it returns a minimal result with whatever could be extracted.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
model: Any model object from any framework
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
Dict with extracted metadata. Always includes:
|
|
642
|
+
- framework: Detected framework name
|
|
643
|
+
- _auto_extracted: Whether extraction succeeded
|
|
644
|
+
- model_class: Class name of the model (if available)
|
|
645
|
+
"""
|
|
646
|
+
# Handle None model
|
|
647
|
+
if model is None:
|
|
648
|
+
return {"framework": None, "_auto_extracted": False, "error": "Model is None"}
|
|
649
|
+
|
|
650
|
+
# Detect framework
|
|
651
|
+
framework = ModelInspector.detect_framework(model)
|
|
652
|
+
|
|
653
|
+
# Extract based on framework
|
|
654
|
+
try:
|
|
655
|
+
if framework == "keras":
|
|
656
|
+
return ModelInspector.extract_keras_info(model)
|
|
657
|
+
elif framework == "tensorflow":
|
|
658
|
+
# TensorFlow models that aren't Keras
|
|
659
|
+
result = ModelInspector.extract_keras_info(model)
|
|
660
|
+
result["framework"] = "tensorflow"
|
|
661
|
+
return result
|
|
662
|
+
elif framework == "pytorch":
|
|
663
|
+
return ModelInspector.extract_pytorch_info(model)
|
|
664
|
+
elif framework == "sklearn":
|
|
665
|
+
return ModelInspector.extract_sklearn_info(model)
|
|
666
|
+
elif framework == "xgboost":
|
|
667
|
+
return ModelInspector.extract_xgboost_info(model)
|
|
668
|
+
elif framework == "lightgbm":
|
|
669
|
+
return ModelInspector.extract_lightgbm_info(model)
|
|
670
|
+
elif framework == "catboost":
|
|
671
|
+
# CatBoost is similar to XGBoost
|
|
672
|
+
result = ModelInspector.extract_xgboost_info(model)
|
|
673
|
+
result["framework"] = "catboost"
|
|
674
|
+
return result
|
|
675
|
+
elif framework == "huggingface":
|
|
676
|
+
return ModelInspector.extract_huggingface_info(model)
|
|
677
|
+
elif framework in ("jax", "onnx"):
|
|
678
|
+
# Basic extraction for JAX/ONNX
|
|
679
|
+
result = ModelInspector.extract_generic_info(model)
|
|
680
|
+
result["framework"] = framework
|
|
681
|
+
return result
|
|
682
|
+
else:
|
|
683
|
+
# Unknown or custom framework - use generic extraction
|
|
684
|
+
return ModelInspector.extract_generic_info(model)
|
|
685
|
+
|
|
686
|
+
except Exception as e:
|
|
687
|
+
# If all else fails, return minimal info
|
|
688
|
+
logger.debug(f"Error extracting model info: {e}")
|
|
689
|
+
return {
|
|
690
|
+
"framework": framework or "unknown",
|
|
691
|
+
"model_class": type(model).__name__,
|
|
692
|
+
"_auto_extracted": False,
|
|
693
|
+
"_extraction_error": str(e),
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
@staticmethod
|
|
697
|
+
def extract_training_history_from_callback(callback: Any) -> dict[str, list] | None:
|
|
698
|
+
"""Extract training history from a FlowyML callback or Keras History object."""
|
|
699
|
+
try:
|
|
700
|
+
# FlowyML callback
|
|
701
|
+
if hasattr(callback, "get_training_history"):
|
|
702
|
+
return callback.get_training_history()
|
|
703
|
+
|
|
704
|
+
# Keras History object
|
|
705
|
+
if hasattr(callback, "history"):
|
|
706
|
+
history = callback.history
|
|
707
|
+
if isinstance(history, dict):
|
|
708
|
+
# Add epochs if not present
|
|
709
|
+
if "epochs" not in history and history:
|
|
710
|
+
first_key = next(iter(history.keys()))
|
|
711
|
+
history["epochs"] = list(range(1, len(history[first_key]) + 1))
|
|
712
|
+
return history
|
|
713
|
+
|
|
714
|
+
# Training history dict directly
|
|
715
|
+
if isinstance(callback, dict) and any(isinstance(v, list) for v in callback.values()):
|
|
716
|
+
return callback
|
|
717
|
+
|
|
718
|
+
except Exception as e:
|
|
719
|
+
logger.debug(f"Error extracting training history: {e}")
|
|
720
|
+
|
|
721
|
+
return None
|
|
722
|
+
|
|
6
723
|
|
|
7
724
|
class Model(Asset):
|
|
8
|
-
"""Model asset with
|
|
725
|
+
"""Model asset with automatic metadata extraction and training history.
|
|
726
|
+
|
|
727
|
+
The Model class automatically extracts metadata from various ML frameworks,
|
|
728
|
+
reducing boilerplate code and improving UX. It also captures training history
|
|
729
|
+
for visualization in the FlowyML dashboard.
|
|
730
|
+
|
|
731
|
+
Supported frameworks:
|
|
732
|
+
- Keras/TensorFlow: Auto-extracts layers, parameters, optimizer, loss
|
|
733
|
+
- PyTorch: Auto-extracts modules, parameters, training mode
|
|
734
|
+
- Scikit-learn: Auto-extracts hyperparameters, feature importance
|
|
735
|
+
- XGBoost/LightGBM: Auto-extracts trees, hyperparameters
|
|
9
736
|
|
|
10
737
|
Example:
|
|
11
|
-
>>>
|
|
12
|
-
|
|
13
|
-
...
|
|
14
|
-
...
|
|
15
|
-
...
|
|
16
|
-
|
|
17
|
-
|
|
738
|
+
>>> # Minimal usage - properties auto-extracted!
|
|
739
|
+
>>> model_asset = Model.create(
|
|
740
|
+
... data=trained_keras_model,
|
|
741
|
+
... name="my_model",
|
|
742
|
+
... )
|
|
743
|
+
>>> print(model_asset.parameters) # Auto-extracted
|
|
744
|
+
>>> print(model_asset.framework) # Auto-detected
|
|
745
|
+
|
|
746
|
+
>>> # With FlowyML callback - training history auto-captured
|
|
747
|
+
>>> callback = FlowymlKerasCallback(experiment_name="demo")
|
|
748
|
+
>>> model.fit(X, y, callbacks=[callback])
|
|
749
|
+
>>> model_asset = Model.create(
|
|
750
|
+
... data=model,
|
|
751
|
+
... name="trained_model",
|
|
752
|
+
... flowyml_callback=callback, # Auto-extracts training history!
|
|
18
753
|
... )
|
|
19
754
|
"""
|
|
20
755
|
|
|
@@ -31,35 +766,302 @@ class Model(Asset):
|
|
|
31
766
|
parent: Asset | None = None,
|
|
32
767
|
tags: dict[str, str] | None = None,
|
|
33
768
|
properties: dict[str, Any] | None = None,
|
|
769
|
+
training_history: dict[str, list] | None = None,
|
|
770
|
+
auto_extract: bool = True,
|
|
34
771
|
):
|
|
772
|
+
"""Initialize Model with automatic metadata extraction.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
name: Model name
|
|
776
|
+
version: Version string
|
|
777
|
+
data: The model object (Keras, PyTorch, sklearn, etc.)
|
|
778
|
+
architecture: Architecture name (auto-detected if not provided)
|
|
779
|
+
framework: Framework name (auto-detected if not provided)
|
|
780
|
+
input_shape: Input shape (auto-detected for Keras)
|
|
781
|
+
output_shape: Output shape (auto-detected for Keras)
|
|
782
|
+
trained_on: Dataset this model was trained on
|
|
783
|
+
parent: Parent asset for lineage
|
|
784
|
+
tags: Metadata tags
|
|
785
|
+
properties: Additional properties (merged with auto-extracted)
|
|
786
|
+
training_history: Training metrics per epoch
|
|
787
|
+
auto_extract: Whether to auto-extract model metadata
|
|
788
|
+
"""
|
|
789
|
+
# Initialize properties
|
|
790
|
+
final_properties = properties.copy() if properties else {}
|
|
791
|
+
|
|
792
|
+
# Auto-extract model metadata if enabled
|
|
793
|
+
if auto_extract and data is not None:
|
|
794
|
+
extracted = ModelInspector.extract_info(data)
|
|
795
|
+
# Merge - user-provided values take precedence
|
|
796
|
+
for key, value in extracted.items():
|
|
797
|
+
if key not in final_properties:
|
|
798
|
+
final_properties[key] = value
|
|
799
|
+
|
|
800
|
+
# Set framework from extracted if not provided
|
|
801
|
+
if framework is None and "framework" in extracted:
|
|
802
|
+
framework = extracted["framework"]
|
|
803
|
+
|
|
804
|
+
# Set architecture from extracted if not provided
|
|
805
|
+
if architecture is None and "architecture" in extracted:
|
|
806
|
+
architecture = extracted["architecture"]
|
|
807
|
+
|
|
35
808
|
super().__init__(
|
|
36
809
|
name=name,
|
|
37
810
|
version=version,
|
|
38
811
|
data=data,
|
|
39
812
|
parent=parent,
|
|
40
813
|
tags=tags,
|
|
41
|
-
properties=
|
|
814
|
+
properties=final_properties,
|
|
42
815
|
)
|
|
43
816
|
|
|
44
817
|
self.architecture = architecture
|
|
45
818
|
self.framework = framework
|
|
46
819
|
self.input_shape = input_shape
|
|
47
820
|
self.output_shape = output_shape
|
|
821
|
+
self.training_history = training_history
|
|
48
822
|
|
|
49
823
|
# Track training dataset
|
|
50
824
|
if trained_on:
|
|
51
825
|
self.parents.append(trained_on)
|
|
52
826
|
trained_on.children.append(self)
|
|
53
827
|
|
|
54
|
-
# Add model-specific properties
|
|
828
|
+
# Add model-specific properties (explicit ones override extracted)
|
|
55
829
|
if architecture:
|
|
56
830
|
self.metadata.properties["architecture"] = architecture
|
|
57
831
|
if framework:
|
|
58
832
|
self.metadata.properties["framework"] = framework
|
|
59
833
|
if input_shape:
|
|
60
|
-
self.metadata.properties["input_shape"] = input_shape
|
|
834
|
+
self.metadata.properties["input_shape"] = str(input_shape)
|
|
61
835
|
if output_shape:
|
|
62
|
-
self.metadata.properties["output_shape"] = output_shape
|
|
836
|
+
self.metadata.properties["output_shape"] = str(output_shape)
|
|
837
|
+
|
|
838
|
+
@classmethod
|
|
839
|
+
def create(
|
|
840
|
+
cls,
|
|
841
|
+
data: Any,
|
|
842
|
+
name: str | None = None,
|
|
843
|
+
version: str | None = None,
|
|
844
|
+
parent: "Asset | None" = None,
|
|
845
|
+
flowyml_callback: Any = None,
|
|
846
|
+
keras_history: Any = None,
|
|
847
|
+
auto_extract: bool = True,
|
|
848
|
+
**kwargs: Any,
|
|
849
|
+
) -> "Model":
|
|
850
|
+
"""Create a Model asset with automatic metadata extraction.
|
|
851
|
+
|
|
852
|
+
This is the preferred way to create Model objects. Metadata is
|
|
853
|
+
automatically extracted from the model, and training history can
|
|
854
|
+
be captured from FlowyML callbacks.
|
|
855
|
+
|
|
856
|
+
Args:
|
|
857
|
+
data: The model object (Keras, PyTorch, sklearn, etc.)
|
|
858
|
+
name: Asset name (auto-generated if not provided)
|
|
859
|
+
version: Asset version
|
|
860
|
+
parent: Parent asset for lineage
|
|
861
|
+
flowyml_callback: FlowymlKerasCallback for auto-capturing training history
|
|
862
|
+
keras_history: Keras History object from model.fit()
|
|
863
|
+
auto_extract: Whether to auto-extract model metadata
|
|
864
|
+
**kwargs: Additional parameters including:
|
|
865
|
+
- training_history: Dict of training metrics per epoch
|
|
866
|
+
- architecture: Model architecture name
|
|
867
|
+
- framework: ML framework (keras, pytorch, etc.)
|
|
868
|
+
- properties: Additional properties
|
|
869
|
+
- tags: Metadata tags
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
New Model instance with auto-extracted metadata
|
|
873
|
+
|
|
874
|
+
Example:
|
|
875
|
+
>>> # Simple usage - everything auto-extracted
|
|
876
|
+
>>> model_asset = Model.create(data=model, name="my_model")
|
|
877
|
+
|
|
878
|
+
>>> # With FlowyML callback
|
|
879
|
+
>>> callback = FlowymlKerasCallback(experiment_name="demo")
|
|
880
|
+
>>> model.fit(X, y, callbacks=[callback])
|
|
881
|
+
>>> model_asset = Model.create(
|
|
882
|
+
... data=model,
|
|
883
|
+
... name="trained_model",
|
|
884
|
+
... flowyml_callback=callback,
|
|
885
|
+
... )
|
|
886
|
+
|
|
887
|
+
>>> # With Keras History
|
|
888
|
+
>>> history = model.fit(X, y)
|
|
889
|
+
>>> model_asset = Model.create(
|
|
890
|
+
... data=model,
|
|
891
|
+
... name="trained_model",
|
|
892
|
+
... keras_history=history,
|
|
893
|
+
... )
|
|
894
|
+
"""
|
|
895
|
+
from datetime import datetime
|
|
896
|
+
|
|
897
|
+
asset_name = name or f"Model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
898
|
+
|
|
899
|
+
# Extract Model-specific parameters
|
|
900
|
+
training_history = kwargs.pop("training_history", None)
|
|
901
|
+
architecture = kwargs.pop("architecture", None)
|
|
902
|
+
framework = kwargs.pop("framework", None)
|
|
903
|
+
input_shape = kwargs.pop("input_shape", None)
|
|
904
|
+
output_shape = kwargs.pop("output_shape", None)
|
|
905
|
+
trained_on = kwargs.pop("trained_on", None)
|
|
906
|
+
|
|
907
|
+
# Auto-extract training history from callback or history object
|
|
908
|
+
if training_history is None:
|
|
909
|
+
if flowyml_callback is not None:
|
|
910
|
+
training_history = ModelInspector.extract_training_history_from_callback(
|
|
911
|
+
flowyml_callback,
|
|
912
|
+
)
|
|
913
|
+
elif keras_history is not None:
|
|
914
|
+
training_history = ModelInspector.extract_training_history_from_callback(
|
|
915
|
+
keras_history,
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# Extract tags and properties
|
|
919
|
+
tags = kwargs.pop("tags", {})
|
|
920
|
+
props = kwargs.pop("properties", {})
|
|
921
|
+
# Merge remaining kwargs into properties
|
|
922
|
+
props.update(kwargs)
|
|
923
|
+
|
|
924
|
+
return cls(
|
|
925
|
+
name=asset_name,
|
|
926
|
+
version=version,
|
|
927
|
+
data=data,
|
|
928
|
+
architecture=architecture,
|
|
929
|
+
framework=framework,
|
|
930
|
+
input_shape=input_shape,
|
|
931
|
+
output_shape=output_shape,
|
|
932
|
+
trained_on=trained_on,
|
|
933
|
+
parent=parent,
|
|
934
|
+
tags=tags,
|
|
935
|
+
properties=props,
|
|
936
|
+
training_history=training_history,
|
|
937
|
+
auto_extract=auto_extract,
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
@classmethod
|
|
941
|
+
def from_keras(
|
|
942
|
+
cls,
|
|
943
|
+
model: Any,
|
|
944
|
+
name: str | None = None,
|
|
945
|
+
callback: Any = None,
|
|
946
|
+
history: Any = None,
|
|
947
|
+
**kwargs: Any,
|
|
948
|
+
) -> "Model":
|
|
949
|
+
"""Create a Model asset from a Keras model with full auto-extraction.
|
|
950
|
+
|
|
951
|
+
Args:
|
|
952
|
+
model: Keras model object
|
|
953
|
+
name: Asset name
|
|
954
|
+
callback: FlowymlKerasCallback for training history
|
|
955
|
+
history: Keras History object from model.fit()
|
|
956
|
+
**kwargs: Additional properties
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
Model asset with auto-extracted Keras metadata
|
|
960
|
+
"""
|
|
961
|
+
return cls.create(
|
|
962
|
+
data=model,
|
|
963
|
+
name=name,
|
|
964
|
+
framework="keras",
|
|
965
|
+
flowyml_callback=callback,
|
|
966
|
+
keras_history=history,
|
|
967
|
+
**kwargs,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
@classmethod
|
|
971
|
+
def from_pytorch(
|
|
972
|
+
cls,
|
|
973
|
+
model: Any,
|
|
974
|
+
name: str | None = None,
|
|
975
|
+
training_history: dict | None = None,
|
|
976
|
+
**kwargs: Any,
|
|
977
|
+
) -> "Model":
|
|
978
|
+
"""Create a Model asset from a PyTorch model with full auto-extraction.
|
|
979
|
+
|
|
980
|
+
Args:
|
|
981
|
+
model: PyTorch model object (nn.Module)
|
|
982
|
+
name: Asset name
|
|
983
|
+
training_history: Training metrics dict
|
|
984
|
+
**kwargs: Additional properties
|
|
985
|
+
|
|
986
|
+
Returns:
|
|
987
|
+
Model asset with auto-extracted PyTorch metadata
|
|
988
|
+
"""
|
|
989
|
+
return cls.create(
|
|
990
|
+
data=model,
|
|
991
|
+
name=name,
|
|
992
|
+
framework="pytorch",
|
|
993
|
+
training_history=training_history,
|
|
994
|
+
**kwargs,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
@classmethod
|
|
998
|
+
def from_sklearn(
|
|
999
|
+
cls,
|
|
1000
|
+
model: Any,
|
|
1001
|
+
name: str | None = None,
|
|
1002
|
+
**kwargs: Any,
|
|
1003
|
+
) -> "Model":
|
|
1004
|
+
"""Create a Model asset from a scikit-learn model with full auto-extraction.
|
|
1005
|
+
|
|
1006
|
+
Args:
|
|
1007
|
+
model: Scikit-learn model object
|
|
1008
|
+
name: Asset name
|
|
1009
|
+
**kwargs: Additional properties
|
|
1010
|
+
|
|
1011
|
+
Returns:
|
|
1012
|
+
Model asset with auto-extracted sklearn metadata
|
|
1013
|
+
"""
|
|
1014
|
+
return cls.create(
|
|
1015
|
+
data=model,
|
|
1016
|
+
name=name,
|
|
1017
|
+
framework="sklearn",
|
|
1018
|
+
**kwargs,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
@property
|
|
1022
|
+
def parameters(self) -> int | None:
|
|
1023
|
+
"""Get number of model parameters (auto-extracted)."""
|
|
1024
|
+
return self.metadata.properties.get("parameters") or self.metadata.properties.get("params")
|
|
1025
|
+
|
|
1026
|
+
@property
|
|
1027
|
+
def trainable_parameters(self) -> int | None:
|
|
1028
|
+
"""Get number of trainable parameters (auto-extracted)."""
|
|
1029
|
+
return self.metadata.properties.get("trainable_parameters")
|
|
1030
|
+
|
|
1031
|
+
@property
|
|
1032
|
+
def num_layers(self) -> int | None:
|
|
1033
|
+
"""Get number of layers (auto-extracted)."""
|
|
1034
|
+
return self.metadata.properties.get("num_layers")
|
|
1035
|
+
|
|
1036
|
+
@property
|
|
1037
|
+
def layer_types(self) -> list[str] | None:
|
|
1038
|
+
"""Get list of layer types (auto-extracted)."""
|
|
1039
|
+
return self.metadata.properties.get("layer_types")
|
|
1040
|
+
|
|
1041
|
+
@property
|
|
1042
|
+
def optimizer(self) -> str | None:
|
|
1043
|
+
"""Get optimizer name (auto-extracted from Keras)."""
|
|
1044
|
+
return self.metadata.properties.get("optimizer")
|
|
1045
|
+
|
|
1046
|
+
@property
|
|
1047
|
+
def learning_rate(self) -> float | None:
|
|
1048
|
+
"""Get learning rate (auto-extracted from Keras)."""
|
|
1049
|
+
return self.metadata.properties.get("learning_rate")
|
|
1050
|
+
|
|
1051
|
+
@property
|
|
1052
|
+
def loss_function(self) -> str | None:
|
|
1053
|
+
"""Get loss function (auto-extracted from Keras)."""
|
|
1054
|
+
return self.metadata.properties.get("loss_function")
|
|
1055
|
+
|
|
1056
|
+
@property
|
|
1057
|
+
def metrics(self) -> list[str] | None:
|
|
1058
|
+
"""Get metrics (auto-extracted from Keras)."""
|
|
1059
|
+
return self.metadata.properties.get("metrics")
|
|
1060
|
+
|
|
1061
|
+
@property
|
|
1062
|
+
def hyperparameters(self) -> dict | None:
|
|
1063
|
+
"""Get hyperparameters (auto-extracted from sklearn/xgboost)."""
|
|
1064
|
+
return self.metadata.properties.get("hyperparameters")
|
|
63
1065
|
|
|
64
1066
|
def get_training_datasets(self):
|
|
65
1067
|
"""Get all datasets this model was trained on."""
|
|
@@ -69,7 +1071,7 @@ class Model(Asset):
|
|
|
69
1071
|
|
|
70
1072
|
def get_parameters_count(self) -> int | None:
|
|
71
1073
|
"""Get number of model parameters if available."""
|
|
72
|
-
return self.
|
|
1074
|
+
return self.parameters
|
|
73
1075
|
|
|
74
1076
|
def get_architecture_info(self) -> dict[str, Any]:
|
|
75
1077
|
"""Get architecture information."""
|
|
@@ -78,5 +1080,40 @@ class Model(Asset):
|
|
|
78
1080
|
"framework": self.framework,
|
|
79
1081
|
"input_shape": self.input_shape,
|
|
80
1082
|
"output_shape": self.output_shape,
|
|
81
|
-
"parameters": self.
|
|
1083
|
+
"parameters": self.parameters,
|
|
1084
|
+
"trainable_parameters": self.trainable_parameters,
|
|
1085
|
+
"num_layers": self.num_layers,
|
|
1086
|
+
"layer_types": self.layer_types,
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
def get_training_info(self) -> dict[str, Any]:
|
|
1090
|
+
"""Get training information."""
|
|
1091
|
+
result = {
|
|
1092
|
+
"optimizer": self.optimizer,
|
|
1093
|
+
"learning_rate": self.learning_rate,
|
|
1094
|
+
"loss_function": self.loss_function,
|
|
1095
|
+
"metrics": self.metrics,
|
|
82
1096
|
}
|
|
1097
|
+
|
|
1098
|
+
if self.training_history:
|
|
1099
|
+
epochs = self.training_history.get("epochs", [])
|
|
1100
|
+
result["epochs_trained"] = len(epochs)
|
|
1101
|
+
|
|
1102
|
+
# Get final metrics
|
|
1103
|
+
for key, values in self.training_history.items():
|
|
1104
|
+
if key != "epochs" and values:
|
|
1105
|
+
result[f"final_{key}"] = values[-1]
|
|
1106
|
+
|
|
1107
|
+
return {k: v for k, v in result.items() if v is not None}
|
|
1108
|
+
|
|
1109
|
+
def __repr__(self) -> str:
|
|
1110
|
+
"""String representation with key info."""
|
|
1111
|
+
parts = [f"Model(name='{self.name}'"]
|
|
1112
|
+
if self.framework:
|
|
1113
|
+
parts.append(f"framework='{self.framework}'")
|
|
1114
|
+
if self.parameters:
|
|
1115
|
+
parts.append(f"params={self.parameters:,}")
|
|
1116
|
+
if self.training_history:
|
|
1117
|
+
epochs = len(self.training_history.get("epochs", []))
|
|
1118
|
+
parts.append(f"epochs={epochs}")
|
|
1119
|
+
return ", ".join(parts) + ")"
|