flowyml 1.7.2__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowyml/assets/base.py +15 -0
- flowyml/assets/metrics.py +5 -0
- flowyml/cli/main.py +709 -0
- flowyml/cli/stack_cli.py +138 -25
- flowyml/core/__init__.py +17 -0
- flowyml/core/executor.py +161 -26
- flowyml/core/image_builder.py +129 -0
- flowyml/core/log_streamer.py +227 -0
- flowyml/core/orchestrator.py +22 -2
- flowyml/core/pipeline.py +34 -10
- flowyml/core/routing.py +558 -0
- flowyml/core/step.py +9 -1
- flowyml/core/step_grouping.py +49 -35
- flowyml/core/types.py +407 -0
- flowyml/monitoring/alerts.py +10 -0
- flowyml/monitoring/notifications.py +104 -25
- flowyml/monitoring/slack_blocks.py +323 -0
- flowyml/plugins/__init__.py +251 -0
- flowyml/plugins/alerters/__init__.py +1 -0
- flowyml/plugins/alerters/slack.py +168 -0
- flowyml/plugins/base.py +752 -0
- flowyml/plugins/config.py +478 -0
- flowyml/plugins/deployers/__init__.py +22 -0
- flowyml/plugins/deployers/gcp_cloud_run.py +200 -0
- flowyml/plugins/deployers/sagemaker.py +306 -0
- flowyml/plugins/deployers/vertex.py +290 -0
- flowyml/plugins/integration.py +369 -0
- flowyml/plugins/manager.py +510 -0
- flowyml/plugins/model_registries/__init__.py +22 -0
- flowyml/plugins/model_registries/mlflow.py +159 -0
- flowyml/plugins/model_registries/sagemaker.py +489 -0
- flowyml/plugins/model_registries/vertex.py +386 -0
- flowyml/plugins/orchestrators/__init__.py +13 -0
- flowyml/plugins/orchestrators/sagemaker.py +443 -0
- flowyml/plugins/orchestrators/vertex_ai.py +461 -0
- flowyml/plugins/registries/__init__.py +13 -0
- flowyml/plugins/registries/ecr.py +321 -0
- flowyml/plugins/registries/gcr.py +313 -0
- flowyml/plugins/registry.py +454 -0
- flowyml/plugins/stack.py +494 -0
- flowyml/plugins/stack_config.py +537 -0
- flowyml/plugins/stores/__init__.py +13 -0
- flowyml/plugins/stores/gcs.py +460 -0
- flowyml/plugins/stores/s3.py +453 -0
- flowyml/plugins/trackers/__init__.py +11 -0
- flowyml/plugins/trackers/mlflow.py +316 -0
- flowyml/plugins/validators/__init__.py +3 -0
- flowyml/plugins/validators/deepchecks.py +119 -0
- flowyml/registry/__init__.py +2 -1
- flowyml/registry/model_environment.py +109 -0
- flowyml/registry/model_registry.py +241 -96
- flowyml/serving/__init__.py +17 -0
- flowyml/serving/model_server.py +628 -0
- flowyml/stacks/__init__.py +60 -0
- flowyml/stacks/aws.py +93 -0
- flowyml/stacks/base.py +62 -0
- flowyml/stacks/components.py +12 -0
- flowyml/stacks/gcp.py +44 -9
- flowyml/stacks/plugins.py +115 -0
- flowyml/stacks/registry.py +2 -1
- flowyml/storage/sql.py +401 -12
- flowyml/tracking/experiment.py +8 -5
- flowyml/ui/backend/Dockerfile +87 -16
- flowyml/ui/backend/auth.py +12 -2
- flowyml/ui/backend/main.py +149 -5
- flowyml/ui/backend/routers/ai_context.py +226 -0
- flowyml/ui/backend/routers/assets.py +23 -4
- flowyml/ui/backend/routers/auth.py +96 -0
- flowyml/ui/backend/routers/deployments.py +660 -0
- flowyml/ui/backend/routers/model_explorer.py +597 -0
- flowyml/ui/backend/routers/plugins.py +103 -51
- flowyml/ui/backend/routers/projects.py +91 -8
- flowyml/ui/backend/routers/runs.py +20 -1
- flowyml/ui/backend/routers/schedules.py +22 -17
- flowyml/ui/backend/routers/templates.py +319 -0
- flowyml/ui/backend/routers/websocket.py +2 -2
- flowyml/ui/frontend/Dockerfile +55 -6
- flowyml/ui/frontend/dist/assets/index-B5AsPTSz.css +1 -0
- flowyml/ui/frontend/dist/assets/index-dFbZ8wD8.js +753 -0
- flowyml/ui/frontend/dist/index.html +2 -2
- flowyml/ui/frontend/dist/logo.png +0 -0
- flowyml/ui/frontend/nginx.conf +65 -4
- flowyml/ui/frontend/package-lock.json +1404 -74
- flowyml/ui/frontend/package.json +3 -0
- flowyml/ui/frontend/public/logo.png +0 -0
- flowyml/ui/frontend/src/App.jsx +10 -7
- flowyml/ui/frontend/src/app/auth/Login.jsx +90 -0
- flowyml/ui/frontend/src/app/dashboard/page.jsx +8 -8
- flowyml/ui/frontend/src/app/deployments/page.jsx +786 -0
- flowyml/ui/frontend/src/app/model-explorer/page.jsx +1031 -0
- flowyml/ui/frontend/src/app/pipelines/page.jsx +12 -2
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectExperimentsList.jsx +19 -6
- flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +36 -24
- flowyml/ui/frontend/src/app/runs/page.jsx +8 -2
- flowyml/ui/frontend/src/app/settings/page.jsx +267 -253
- flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +29 -7
- flowyml/ui/frontend/src/components/Layout.jsx +6 -0
- flowyml/ui/frontend/src/components/PipelineGraph.jsx +79 -29
- flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +36 -6
- flowyml/ui/frontend/src/components/RunMetaPanel.jsx +113 -0
- flowyml/ui/frontend/src/components/ai/AIAssistantButton.jsx +71 -0
- flowyml/ui/frontend/src/components/ai/AIAssistantPanel.jsx +420 -0
- flowyml/ui/frontend/src/components/header/Header.jsx +22 -0
- flowyml/ui/frontend/src/components/plugins/PluginManager.jsx +4 -4
- flowyml/ui/frontend/src/components/plugins/{ZenMLIntegration.jsx → StackImport.jsx} +38 -12
- flowyml/ui/frontend/src/components/sidebar/Sidebar.jsx +36 -13
- flowyml/ui/frontend/src/contexts/AIAssistantContext.jsx +245 -0
- flowyml/ui/frontend/src/contexts/AuthContext.jsx +108 -0
- flowyml/ui/frontend/src/hooks/useAIContext.js +156 -0
- flowyml/ui/frontend/src/hooks/useWebGPU.js +54 -0
- flowyml/ui/frontend/src/layouts/MainLayout.jsx +6 -0
- flowyml/ui/frontend/src/router/index.jsx +47 -20
- flowyml/ui/frontend/src/services/pluginService.js +3 -1
- flowyml/ui/server_manager.py +5 -5
- flowyml/ui/utils.py +157 -39
- flowyml/utils/config.py +37 -15
- flowyml/utils/model_introspection.py +123 -0
- flowyml/utils/observability.py +30 -0
- flowyml-1.8.0.dist-info/METADATA +174 -0
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/RECORD +123 -65
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/WHEEL +1 -1
- flowyml/ui/frontend/dist/assets/index-B40RsQDq.css +0 -1
- flowyml/ui/frontend/dist/assets/index-CjI0zKCn.js +0 -685
- flowyml-1.7.2.dist-info/METADATA +0 -477
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/entry_points.txt +0 -0
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,628 @@
|
|
|
1
|
+
"""Model server implementation for deploying ML models as API endpoints.
|
|
2
|
+
|
|
3
|
+
This module provides real model loading and prediction functionality for
|
|
4
|
+
Keras, PyTorch, sklearn, TensorFlow, and other frameworks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import subprocess
|
|
9
|
+
import contextlib
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
from collections import deque
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ServerConfig:
|
|
21
|
+
"""Configuration for a model server."""
|
|
22
|
+
|
|
23
|
+
port: int
|
|
24
|
+
api_token: str
|
|
25
|
+
rate_limit: int = 100
|
|
26
|
+
timeout_seconds: int = 30
|
|
27
|
+
max_batch_size: int = 1
|
|
28
|
+
enable_cors: bool = True
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class ModelServer:
|
|
33
|
+
"""Represents a running model server process."""
|
|
34
|
+
|
|
35
|
+
deployment_id: str
|
|
36
|
+
model_artifact_id: str
|
|
37
|
+
model_path: str
|
|
38
|
+
framework: str
|
|
39
|
+
config: ServerConfig
|
|
40
|
+
process: subprocess.Popen | None = None
|
|
41
|
+
log_buffer: deque = field(default_factory=lambda: deque(maxlen=1000))
|
|
42
|
+
started_at: datetime | None = None
|
|
43
|
+
model: Any = None
|
|
44
|
+
|
|
45
|
+
def is_running(self) -> bool:
|
|
46
|
+
"""Check if server process is running."""
|
|
47
|
+
return self.process is not None and self.process.poll() is None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# Global server registry
|
|
51
|
+
_servers: dict[str, ModelServer] = {}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _detect_framework(artifact_path: str) -> str:
|
|
55
|
+
"""Detect the ML framework from the artifact path or structure."""
|
|
56
|
+
path = Path(artifact_path)
|
|
57
|
+
|
|
58
|
+
# Check directory contents for framework hints
|
|
59
|
+
if path.is_dir():
|
|
60
|
+
contents = list(path.iterdir())
|
|
61
|
+
content_names = [c.name for c in contents]
|
|
62
|
+
|
|
63
|
+
# TensorFlow SavedModel format
|
|
64
|
+
if "saved_model.pb" in content_names:
|
|
65
|
+
return "tensorflow"
|
|
66
|
+
# Keras H5 format
|
|
67
|
+
if any(c.suffix == ".h5" for c in contents):
|
|
68
|
+
return "keras"
|
|
69
|
+
# PyTorch
|
|
70
|
+
if any(c.suffix in [".pt", ".pth"] for c in contents):
|
|
71
|
+
return "pytorch"
|
|
72
|
+
|
|
73
|
+
# Check file extension
|
|
74
|
+
suffix = path.suffix.lower()
|
|
75
|
+
if suffix in [".h5", ".keras"]:
|
|
76
|
+
return "keras"
|
|
77
|
+
elif suffix in [".pt", ".pth"]:
|
|
78
|
+
return "pytorch"
|
|
79
|
+
elif suffix in [".pkl", ".joblib", ".pickle"]:
|
|
80
|
+
return "sklearn"
|
|
81
|
+
elif suffix == ".onnx":
|
|
82
|
+
return "onnx"
|
|
83
|
+
|
|
84
|
+
return "unknown"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _load_model_by_framework(model_path: str, framework: str) -> Any:
|
|
88
|
+
"""Load a model based on its framework.
|
|
89
|
+
|
|
90
|
+
Always tries pickle/joblib first as the universal fallback,
|
|
91
|
+
then attempts framework-specific loading if that fails.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model_path: Path to the model file/directory
|
|
95
|
+
framework: The ML framework (keras, pytorch, sklearn, tensorflow)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Loaded model object
|
|
99
|
+
"""
|
|
100
|
+
path = Path(model_path)
|
|
101
|
+
errors = []
|
|
102
|
+
|
|
103
|
+
# First, always try pickle/joblib - works for most serialized models
|
|
104
|
+
try:
|
|
105
|
+
import joblib
|
|
106
|
+
|
|
107
|
+
model = joblib.load(str(path))
|
|
108
|
+
logger.info(f"Successfully loaded model with joblib from {path}")
|
|
109
|
+
return model
|
|
110
|
+
except Exception as e:
|
|
111
|
+
errors.append(f"joblib: {e}")
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
import pickle
|
|
115
|
+
|
|
116
|
+
with open(str(path), "rb") as f:
|
|
117
|
+
model = pickle.load(f)
|
|
118
|
+
logger.info(f"Successfully loaded model with pickle from {path}")
|
|
119
|
+
return model
|
|
120
|
+
except Exception as e:
|
|
121
|
+
errors.append(f"pickle: {e}")
|
|
122
|
+
|
|
123
|
+
# If pickle/joblib failed, try framework-specific loaders
|
|
124
|
+
if framework == "keras":
|
|
125
|
+
try:
|
|
126
|
+
import keras
|
|
127
|
+
|
|
128
|
+
if path.is_dir():
|
|
129
|
+
for ext in [".keras", ".h5"]:
|
|
130
|
+
candidates = list(path.glob(f"*{ext}"))
|
|
131
|
+
if candidates:
|
|
132
|
+
return keras.models.load_model(str(candidates[0]))
|
|
133
|
+
return keras.models.load_model(str(path))
|
|
134
|
+
except ImportError:
|
|
135
|
+
errors.append("keras: module not installed")
|
|
136
|
+
except Exception as e:
|
|
137
|
+
errors.append(f"keras: {e}")
|
|
138
|
+
|
|
139
|
+
elif framework == "tensorflow":
|
|
140
|
+
try:
|
|
141
|
+
import tensorflow as tf
|
|
142
|
+
|
|
143
|
+
return tf.saved_model.load(str(path))
|
|
144
|
+
except ImportError:
|
|
145
|
+
errors.append("tensorflow: module not installed")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
errors.append(f"tensorflow: {e}")
|
|
148
|
+
|
|
149
|
+
elif framework == "pytorch":
|
|
150
|
+
try:
|
|
151
|
+
import torch
|
|
152
|
+
|
|
153
|
+
return torch.load(str(path), map_location=torch.device("cpu"))
|
|
154
|
+
except ImportError:
|
|
155
|
+
errors.append("pytorch: module not installed")
|
|
156
|
+
except Exception as e:
|
|
157
|
+
errors.append(f"pytorch: {e}")
|
|
158
|
+
|
|
159
|
+
elif framework == "onnx":
|
|
160
|
+
try:
|
|
161
|
+
import onnxruntime as ort
|
|
162
|
+
|
|
163
|
+
return ort.InferenceSession(str(path))
|
|
164
|
+
except ImportError:
|
|
165
|
+
errors.append("onnxruntime: module not installed")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
errors.append(f"onnx: {e}")
|
|
168
|
+
|
|
169
|
+
# If all loading attempts failed, raise with detailed error
|
|
170
|
+
raise RuntimeError(f"Failed to load model from {path}. Attempted methods: {'; '.join(errors)}")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _predict_with_model(model: Any, data: dict, framework: str) -> dict:
|
|
174
|
+
"""Run prediction using the loaded model.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model: Loaded model object
|
|
178
|
+
data: Input data dictionary
|
|
179
|
+
framework: The ML framework
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Prediction result dictionary
|
|
183
|
+
"""
|
|
184
|
+
import numpy as np
|
|
185
|
+
|
|
186
|
+
def extract_numeric_values(obj, values=None):
|
|
187
|
+
"""Recursively extract numeric values from nested structures."""
|
|
188
|
+
if values is None:
|
|
189
|
+
values = []
|
|
190
|
+
|
|
191
|
+
if isinstance(obj, (int, float)):
|
|
192
|
+
values.append(float(obj))
|
|
193
|
+
elif isinstance(obj, (list, tuple)):
|
|
194
|
+
for item in obj:
|
|
195
|
+
extract_numeric_values(item, values)
|
|
196
|
+
elif isinstance(obj, dict):
|
|
197
|
+
for v in obj.values():
|
|
198
|
+
extract_numeric_values(v, values)
|
|
199
|
+
elif isinstance(obj, str):
|
|
200
|
+
# Try to parse as number
|
|
201
|
+
with contextlib.suppress(ValueError):
|
|
202
|
+
values.append(float(obj))
|
|
203
|
+
return values
|
|
204
|
+
|
|
205
|
+
# Extract input from data - handle various input formats
|
|
206
|
+
input_data = data.get("input") or data.get("data") or data.get("X") or data
|
|
207
|
+
|
|
208
|
+
# Remove non-feature keys that might be in the dict
|
|
209
|
+
if isinstance(input_data, dict):
|
|
210
|
+
input_data = {k: v for k, v in input_data.items() if k not in ["deployment_id", "model_artifact_id", "inputs"]}
|
|
211
|
+
|
|
212
|
+
# Convert to numpy array
|
|
213
|
+
try:
|
|
214
|
+
if isinstance(input_data, list):
|
|
215
|
+
# Direct list input
|
|
216
|
+
flat_values = extract_numeric_values(input_data)
|
|
217
|
+
if flat_values:
|
|
218
|
+
input_array = np.array(flat_values, dtype=np.float32)
|
|
219
|
+
else:
|
|
220
|
+
input_array = np.array(input_data, dtype=np.float32)
|
|
221
|
+
elif isinstance(input_data, dict):
|
|
222
|
+
# Dictionary input - extract numeric values
|
|
223
|
+
flat_values = extract_numeric_values(input_data)
|
|
224
|
+
if flat_values:
|
|
225
|
+
input_array = np.array(flat_values, dtype=np.float32)
|
|
226
|
+
else:
|
|
227
|
+
# Try to get values as-is
|
|
228
|
+
input_array = np.array(list(input_data.values()), dtype=np.float32)
|
|
229
|
+
elif isinstance(input_data, (int, float)):
|
|
230
|
+
input_array = np.array([[float(input_data)]], dtype=np.float32)
|
|
231
|
+
else:
|
|
232
|
+
input_array = np.array([[float(input_data)]], dtype=np.float32)
|
|
233
|
+
except (ValueError, TypeError) as e:
|
|
234
|
+
raise ValueError(f"Failed to convert input to numeric array: {e}. Input: {input_data}")
|
|
235
|
+
|
|
236
|
+
# Ensure 2D array for batch processing
|
|
237
|
+
if input_array.ndim == 1:
|
|
238
|
+
input_array = input_array.reshape(1, -1)
|
|
239
|
+
|
|
240
|
+
# Ensure float32 for all frameworks
|
|
241
|
+
input_array = input_array.astype(np.float32)
|
|
242
|
+
|
|
243
|
+
if framework in ["keras", "tensorflow"]:
|
|
244
|
+
# Introspect Keras model for input shape and names
|
|
245
|
+
expected_shape = None
|
|
246
|
+
input_names = []
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
# Try to get input specification from Keras model
|
|
250
|
+
if hasattr(model, "input_shape"):
|
|
251
|
+
expected_shape = model.input_shape
|
|
252
|
+
if hasattr(model, "input_names") and model.input_names:
|
|
253
|
+
input_names = model.input_names
|
|
254
|
+
elif hasattr(model, "input") and hasattr(model.input, "name"):
|
|
255
|
+
input_names = [model.input.name.split(":")[0]]
|
|
256
|
+
|
|
257
|
+
# Handle multi-input models
|
|
258
|
+
if hasattr(model, "inputs") and len(model.inputs) > 1:
|
|
259
|
+
# Multi-input model - need dict of arrays
|
|
260
|
+
model_inputs = {}
|
|
261
|
+
for inp in model.inputs:
|
|
262
|
+
inp_name = inp.name.split(":")[0] if ":" in inp.name else inp.name
|
|
263
|
+
|
|
264
|
+
if isinstance(data, dict) and inp_name in data:
|
|
265
|
+
val = data[inp_name]
|
|
266
|
+
elif isinstance(input_data, dict) and inp_name in input_data:
|
|
267
|
+
val = input_data[inp_name]
|
|
268
|
+
else:
|
|
269
|
+
# Try to use input_array sliced appropriately
|
|
270
|
+
val = input_array
|
|
271
|
+
|
|
272
|
+
if isinstance(val, (int, float)):
|
|
273
|
+
val = np.array([[val]], dtype=np.float32)
|
|
274
|
+
elif isinstance(val, list):
|
|
275
|
+
val = np.array(val, dtype=np.float32)
|
|
276
|
+
if val.ndim == 1:
|
|
277
|
+
val = val.reshape(1, -1)
|
|
278
|
+
model_inputs[inp_name] = val.astype(np.float32)
|
|
279
|
+
|
|
280
|
+
prediction = model.predict(model_inputs)
|
|
281
|
+
else:
|
|
282
|
+
# Single input - check expected shape
|
|
283
|
+
if expected_shape and len(expected_shape) > 1:
|
|
284
|
+
expected_features = expected_shape[-1]
|
|
285
|
+
if expected_features and input_array.shape[-1] != expected_features:
|
|
286
|
+
# Reshape or pad to match expected features
|
|
287
|
+
if input_array.size >= expected_features:
|
|
288
|
+
input_array = input_array.flatten()[:expected_features].reshape(1, -1)
|
|
289
|
+
else:
|
|
290
|
+
# Pad with zeros if not enough features
|
|
291
|
+
padded = np.zeros((1, expected_features), dtype=np.float32)
|
|
292
|
+
padded[0, : input_array.size] = input_array.flatten()
|
|
293
|
+
input_array = padded
|
|
294
|
+
|
|
295
|
+
prediction = model.predict(input_array)
|
|
296
|
+
except Exception as e:
|
|
297
|
+
# Fallback to direct prediction
|
|
298
|
+
logger.warning(f"Model introspection failed, using direct input: {e}")
|
|
299
|
+
prediction = model.predict(input_array)
|
|
300
|
+
|
|
301
|
+
result = prediction.tolist() if hasattr(prediction, "tolist") else prediction
|
|
302
|
+
|
|
303
|
+
# Format output nicely
|
|
304
|
+
output = {"prediction": result}
|
|
305
|
+
if hasattr(prediction, "shape"):
|
|
306
|
+
output["shape"] = list(prediction.shape)
|
|
307
|
+
if input_names:
|
|
308
|
+
output["input_names"] = input_names
|
|
309
|
+
if expected_shape:
|
|
310
|
+
output["expected_input_shape"] = [s if s else "?" for s in expected_shape]
|
|
311
|
+
|
|
312
|
+
return output
|
|
313
|
+
|
|
314
|
+
elif framework == "pytorch":
|
|
315
|
+
import torch
|
|
316
|
+
|
|
317
|
+
with torch.no_grad():
|
|
318
|
+
tensor_input = torch.tensor(input_array, dtype=torch.float32)
|
|
319
|
+
prediction = model(tensor_input)
|
|
320
|
+
result = prediction.numpy().tolist() if hasattr(prediction, "numpy") else prediction.tolist()
|
|
321
|
+
return {"prediction": result}
|
|
322
|
+
|
|
323
|
+
elif framework == "sklearn":
|
|
324
|
+
prediction = model.predict(input_array)
|
|
325
|
+
result = prediction.tolist() if hasattr(prediction, "tolist") else list(prediction)
|
|
326
|
+
|
|
327
|
+
# Try to get probability if available
|
|
328
|
+
proba = None
|
|
329
|
+
if hasattr(model, "predict_proba"):
|
|
330
|
+
with contextlib.suppress(Exception):
|
|
331
|
+
proba = model.predict_proba(input_array).tolist()
|
|
332
|
+
|
|
333
|
+
return {
|
|
334
|
+
"prediction": result,
|
|
335
|
+
"probabilities": proba,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
elif framework == "onnx":
|
|
339
|
+
input_name = model.get_inputs()[0].name
|
|
340
|
+
output_name = model.get_outputs()[0].name
|
|
341
|
+
prediction = model.run([output_name], {input_name: input_array.astype(np.float32)})
|
|
342
|
+
return {"prediction": prediction[0].tolist()}
|
|
343
|
+
|
|
344
|
+
else:
|
|
345
|
+
# Try generic predict
|
|
346
|
+
if hasattr(model, "predict"):
|
|
347
|
+
prediction = model.predict(input_array)
|
|
348
|
+
result = prediction.tolist() if hasattr(prediction, "tolist") else prediction
|
|
349
|
+
return {"prediction": result}
|
|
350
|
+
elif callable(model):
|
|
351
|
+
prediction = model(input_array)
|
|
352
|
+
return {"prediction": str(prediction)}
|
|
353
|
+
else:
|
|
354
|
+
raise RuntimeError("Model does not have a predict method")
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def load_and_predict(
|
|
358
|
+
model_artifact_id: str,
|
|
359
|
+
input_data: dict,
|
|
360
|
+
cached_model: Any = None,
|
|
361
|
+
framework: str | None = None,
|
|
362
|
+
) -> tuple[dict, Any]:
|
|
363
|
+
"""Load a model and run prediction.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
model_artifact_id: ID of the model artifact
|
|
367
|
+
input_data: Input data for prediction
|
|
368
|
+
cached_model: Previously loaded model to reuse
|
|
369
|
+
framework: Framework hint
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
Tuple of (prediction_result, loaded_model)
|
|
373
|
+
"""
|
|
374
|
+
from flowyml.ui.backend.dependencies import get_store
|
|
375
|
+
import time
|
|
376
|
+
|
|
377
|
+
start_time = time.time()
|
|
378
|
+
store = get_store()
|
|
379
|
+
|
|
380
|
+
# Get artifact path
|
|
381
|
+
artifacts = store.list_assets()
|
|
382
|
+
artifact = next(
|
|
383
|
+
(
|
|
384
|
+
a
|
|
385
|
+
for a in artifacts
|
|
386
|
+
if a.get("artifact_id") == model_artifact_id
|
|
387
|
+
or f"{a.get('run_id')}_{a.get('step')}_{a.get('name')}" == model_artifact_id
|
|
388
|
+
),
|
|
389
|
+
None,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if not artifact:
|
|
393
|
+
raise ValueError(f"Model artifact not found: {model_artifact_id}")
|
|
394
|
+
|
|
395
|
+
model_path = artifact.get("path") or artifact.get("uri") or artifact.get("storage_path")
|
|
396
|
+
if not model_path:
|
|
397
|
+
raise ValueError(f"No path found for artifact: {model_artifact_id}")
|
|
398
|
+
|
|
399
|
+
# Detect or use provided framework
|
|
400
|
+
if not framework:
|
|
401
|
+
framework = (artifact.get("type") or artifact.get("asset_type") or "").lower()
|
|
402
|
+
if "keras" in framework:
|
|
403
|
+
framework = "keras"
|
|
404
|
+
elif "pytorch" in framework or "torch" in framework:
|
|
405
|
+
framework = "pytorch"
|
|
406
|
+
elif "sklearn" in framework or "scikit" in framework:
|
|
407
|
+
framework = "sklearn"
|
|
408
|
+
elif "tensorflow" in framework or "tf" in framework:
|
|
409
|
+
framework = "tensorflow"
|
|
410
|
+
else:
|
|
411
|
+
framework = _detect_framework(model_path)
|
|
412
|
+
|
|
413
|
+
# Load model if not cached
|
|
414
|
+
model = cached_model
|
|
415
|
+
if model is None:
|
|
416
|
+
model = _load_model_by_framework(model_path, framework)
|
|
417
|
+
|
|
418
|
+
# Run prediction
|
|
419
|
+
prediction = _predict_with_model(model, input_data, framework)
|
|
420
|
+
|
|
421
|
+
prediction["latency_ms"] = (time.time() - start_time) * 1000
|
|
422
|
+
prediction["framework"] = framework
|
|
423
|
+
|
|
424
|
+
return prediction, model
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def start_model_server(
|
|
428
|
+
deployment_id: str,
|
|
429
|
+
model_artifact_id: str,
|
|
430
|
+
config: ServerConfig,
|
|
431
|
+
) -> ModelServer:
|
|
432
|
+
"""Start a model server for the given deployment.
|
|
433
|
+
|
|
434
|
+
This loads the model and stores it in memory for fast predictions.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
deployment_id: Unique deployment identifier
|
|
438
|
+
model_artifact_id: ID of the model artifact
|
|
439
|
+
config: Server configuration
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
ModelServer instance
|
|
443
|
+
"""
|
|
444
|
+
from flowyml.ui.backend.dependencies import get_store
|
|
445
|
+
import os
|
|
446
|
+
|
|
447
|
+
store = get_store()
|
|
448
|
+
|
|
449
|
+
# Get artifact info
|
|
450
|
+
artifacts = store.list_assets()
|
|
451
|
+
artifact = next(
|
|
452
|
+
(
|
|
453
|
+
a
|
|
454
|
+
for a in artifacts
|
|
455
|
+
if a.get("artifact_id") == model_artifact_id
|
|
456
|
+
or f"{a.get('run_id')}_{a.get('step')}_{a.get('name')}" == model_artifact_id
|
|
457
|
+
),
|
|
458
|
+
None,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if not artifact:
|
|
462
|
+
raise ValueError(f"Model artifact not found: {model_artifact_id}")
|
|
463
|
+
|
|
464
|
+
# Get path and normalize it
|
|
465
|
+
relative_path = artifact.get("path") or artifact.get("uri") or artifact.get("storage_path")
|
|
466
|
+
if not relative_path:
|
|
467
|
+
raise ValueError(f"No path found for artifact: {model_artifact_id}")
|
|
468
|
+
|
|
469
|
+
# Container paths are relative to /app/artifacts
|
|
470
|
+
model_path = os.path.join("/app/artifacts", relative_path)
|
|
471
|
+
|
|
472
|
+
# Check if file/directory exists
|
|
473
|
+
if not os.path.exists(model_path):
|
|
474
|
+
raise ValueError(f"Model file not found at: {model_path}")
|
|
475
|
+
|
|
476
|
+
# Detect framework from type or path
|
|
477
|
+
framework = (artifact.get("type") or artifact.get("asset_type") or "").lower()
|
|
478
|
+
if "keras" in framework:
|
|
479
|
+
framework = "keras"
|
|
480
|
+
elif "pytorch" in framework or "torch" in framework:
|
|
481
|
+
framework = "pytorch"
|
|
482
|
+
elif "sklearn" in framework or "scikit" in framework:
|
|
483
|
+
framework = "sklearn"
|
|
484
|
+
elif "tensorflow" in framework or "tf" in framework:
|
|
485
|
+
framework = "tensorflow"
|
|
486
|
+
else:
|
|
487
|
+
framework = _detect_framework(model_path)
|
|
488
|
+
|
|
489
|
+
# Create server instance
|
|
490
|
+
server = ModelServer(
|
|
491
|
+
deployment_id=deployment_id,
|
|
492
|
+
model_artifact_id=model_artifact_id,
|
|
493
|
+
model_path=model_path,
|
|
494
|
+
framework=framework,
|
|
495
|
+
config=config,
|
|
496
|
+
started_at=datetime.now(),
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Load the model
|
|
500
|
+
try:
|
|
501
|
+
server.model = _load_model_by_framework(model_path, framework)
|
|
502
|
+
server.log_buffer.append(
|
|
503
|
+
{
|
|
504
|
+
"timestamp": datetime.now().isoformat(),
|
|
505
|
+
"level": "INFO",
|
|
506
|
+
"message": f"Model loaded successfully from {model_path} (framework: {framework})",
|
|
507
|
+
},
|
|
508
|
+
)
|
|
509
|
+
except Exception as e:
|
|
510
|
+
server.log_buffer.append(
|
|
511
|
+
{
|
|
512
|
+
"timestamp": datetime.now().isoformat(),
|
|
513
|
+
"level": "ERROR",
|
|
514
|
+
"message": f"Failed to load model: {str(e)}",
|
|
515
|
+
},
|
|
516
|
+
)
|
|
517
|
+
raise
|
|
518
|
+
|
|
519
|
+
# Store in registry
|
|
520
|
+
_servers[deployment_id] = server
|
|
521
|
+
|
|
522
|
+
logger.info(f"Started model server for deployment {deployment_id} on port {config.port}")
|
|
523
|
+
|
|
524
|
+
return server
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def stop_model_server(deployment_id: str) -> bool:
|
|
528
|
+
"""Stop a running model server.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
deployment_id: ID of the deployment to stop
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
True if stopped successfully, False if not found
|
|
535
|
+
"""
|
|
536
|
+
if deployment_id not in _servers:
|
|
537
|
+
return False
|
|
538
|
+
|
|
539
|
+
server = _servers[deployment_id]
|
|
540
|
+
|
|
541
|
+
# Clean up model from memory
|
|
542
|
+
if server.model is not None:
|
|
543
|
+
del server.model
|
|
544
|
+
server.model = None
|
|
545
|
+
|
|
546
|
+
server.log_buffer.append(
|
|
547
|
+
{
|
|
548
|
+
"timestamp": datetime.now().isoformat(),
|
|
549
|
+
"level": "INFO",
|
|
550
|
+
"message": "Server stopped",
|
|
551
|
+
},
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Remove from registry
|
|
555
|
+
del _servers[deployment_id]
|
|
556
|
+
|
|
557
|
+
logger.info(f"Stopped model server for deployment {deployment_id}")
|
|
558
|
+
|
|
559
|
+
return True
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def get_server(deployment_id: str) -> ModelServer | None:
|
|
563
|
+
"""Get a server by deployment ID."""
|
|
564
|
+
return _servers.get(deployment_id)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def get_server_logs(deployment_id: str, lines: int = 100) -> list[dict]:
|
|
568
|
+
"""Get logs from a server.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
deployment_id: ID of the deployment
|
|
572
|
+
lines: Number of log lines to return
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
List of log entries
|
|
576
|
+
"""
|
|
577
|
+
server = _servers.get(deployment_id)
|
|
578
|
+
if not server:
|
|
579
|
+
return []
|
|
580
|
+
|
|
581
|
+
return list(server.log_buffer)[-lines:]
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def predict(deployment_id: str, input_data: dict) -> dict:
|
|
585
|
+
"""Run prediction on a deployed model.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
deployment_id: ID of the deployment
|
|
589
|
+
input_data: Input data for prediction
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
Prediction result
|
|
593
|
+
"""
|
|
594
|
+
server = _servers.get(deployment_id)
|
|
595
|
+
if not server:
|
|
596
|
+
raise ValueError(f"Deployment not found: {deployment_id}")
|
|
597
|
+
|
|
598
|
+
if server.model is None:
|
|
599
|
+
raise RuntimeError("Model not loaded")
|
|
600
|
+
|
|
601
|
+
import time
|
|
602
|
+
|
|
603
|
+
start_time = time.time()
|
|
604
|
+
|
|
605
|
+
try:
|
|
606
|
+
result = _predict_with_model(server.model, input_data, server.framework)
|
|
607
|
+
latency = (time.time() - start_time) * 1000
|
|
608
|
+
|
|
609
|
+
server.log_buffer.append(
|
|
610
|
+
{
|
|
611
|
+
"timestamp": datetime.now().isoformat(),
|
|
612
|
+
"level": "INFO",
|
|
613
|
+
"message": f"Prediction completed in {latency:.2f}ms",
|
|
614
|
+
},
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
result["latency_ms"] = latency
|
|
618
|
+
return result
|
|
619
|
+
|
|
620
|
+
except Exception as e:
|
|
621
|
+
server.log_buffer.append(
|
|
622
|
+
{
|
|
623
|
+
"timestamp": datetime.now().isoformat(),
|
|
624
|
+
"level": "ERROR",
|
|
625
|
+
"message": f"Prediction failed: {str(e)}",
|
|
626
|
+
},
|
|
627
|
+
)
|
|
628
|
+
raise
|
flowyml/stacks/__init__.py
CHANGED
|
@@ -13,6 +13,59 @@ from flowyml.stacks.components import (
|
|
|
13
13
|
ContainerRegistry,
|
|
14
14
|
)
|
|
15
15
|
from flowyml.stacks.registry import StackRegistry, get_registry, get_active_stack, set_active_stack
|
|
16
|
+
from flowyml.stacks.plugins import (
|
|
17
|
+
get_component_registry,
|
|
18
|
+
register_component,
|
|
19
|
+
load_component,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ZenML integration - lazy imports to avoid errors when ZenML is not installed
|
|
24
|
+
# NOTE: These functions are deprecated. Use native FlowyML plugins instead.
|
|
25
|
+
# See: https://docs.flowyml.ai/plugins/native-plugins/
|
|
26
|
+
def get_zenml_bridge():
|
|
27
|
+
"""Get a ZenML bridge for importing ZenML components.
|
|
28
|
+
|
|
29
|
+
.. deprecated::
|
|
30
|
+
ZenML integration is deprecated. Use native FlowyML plugins instead.
|
|
31
|
+
See the Native Plugins documentation for the recommended approach.
|
|
32
|
+
"""
|
|
33
|
+
import warnings
|
|
34
|
+
|
|
35
|
+
warnings.warn(
|
|
36
|
+
"get_zenml_bridge() is deprecated. Use native FlowyML plugins instead. "
|
|
37
|
+
"See: https://docs.flowyml.ai/plugins/native-plugins/",
|
|
38
|
+
DeprecationWarning,
|
|
39
|
+
stacklevel=2,
|
|
40
|
+
)
|
|
41
|
+
from flowyml.stacks.zenml_bridge import ZenMLBridge
|
|
42
|
+
|
|
43
|
+
return ZenMLBridge()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def import_all_zenml():
|
|
47
|
+
"""Import all components from all installed ZenML integrations.
|
|
48
|
+
|
|
49
|
+
.. deprecated::
|
|
50
|
+
ZenML integration is deprecated. Use native FlowyML plugins instead.
|
|
51
|
+
See the Native Plugins documentation for the recommended approach.
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> from flowyml.stacks import import_all_zenml
|
|
55
|
+
>>> components = import_all_zenml()
|
|
56
|
+
"""
|
|
57
|
+
import warnings
|
|
58
|
+
|
|
59
|
+
warnings.warn(
|
|
60
|
+
"import_all_zenml() is deprecated. Use native FlowyML plugins instead. "
|
|
61
|
+
"See: https://docs.flowyml.ai/plugins/native-plugins/",
|
|
62
|
+
DeprecationWarning,
|
|
63
|
+
stacklevel=2,
|
|
64
|
+
)
|
|
65
|
+
from flowyml.stacks.zenml_bridge import import_all_zenml as _import_all
|
|
66
|
+
|
|
67
|
+
return _import_all()
|
|
68
|
+
|
|
16
69
|
|
|
17
70
|
__all__ = [
|
|
18
71
|
"Stack",
|
|
@@ -39,4 +92,11 @@ __all__ = [
|
|
|
39
92
|
"get_registry",
|
|
40
93
|
"get_active_stack",
|
|
41
94
|
"set_active_stack",
|
|
95
|
+
# Plugin system
|
|
96
|
+
"get_component_registry",
|
|
97
|
+
"register_component",
|
|
98
|
+
"load_component",
|
|
99
|
+
# ZenML integration
|
|
100
|
+
"get_zenml_bridge",
|
|
101
|
+
"import_all_zenml",
|
|
42
102
|
]
|