ob-metaflow-extensions 1.1.166rc6__py2.py3-none-any.whl → 1.1.168rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -345,6 +345,7 @@ STEP_DECORATORS_DESC = [
345
345
  ("gpu_profile", ".profilers.gpu_profile_decorator.GPUProfileDecorator"),
346
346
  ("nim", ".nim.nim_decorator.NimDecorator"),
347
347
  ("ollama", ".ollama.OllamaDecorator"),
348
+ ("vllm", ".vllm.VLLMDecorator"),
348
349
  ("app_deploy", ".apps.deploy_decorator.WorkstationAppDeployDecorator"),
349
350
  ]
350
351
 
@@ -0,0 +1,177 @@
1
+ from metaflow.decorators import StepDecorator
2
+ from metaflow import current
3
+ import functools
4
+ import os
5
+ import threading
6
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
7
+ from metaflow.metaflow_config import from_conf
8
+
9
+ from .vllm_manager import VLLMManager
10
+ from .status_card import VLLMStatusCard, CardDecoratorInjector
11
+
12
+ __mf_promote_submodules__ = ["plugins.vllm"]
13
+
14
+
15
+ class VLLMDecorator(StepDecorator, CardDecoratorInjector):
16
+ """
17
+ This decorator is used to run vllm APIs as Metaflow task sidecars.
18
+
19
+ User code call
20
+ --------------
21
+ @vllm(
22
+ model="...",
23
+ ...
24
+ )
25
+
26
+ Valid backend options
27
+ ---------------------
28
+ - 'local': Run as a separate process on the local task machine.
29
+
30
+ Valid model options
31
+ -------------------
32
+ Any HuggingFace model identifier, e.g. 'meta-llama/Llama-3.2-1B'
33
+
34
+ NOTE: vLLM's OpenAI-compatible server serves ONE model per server instance.
35
+ If you need multiple models, you must create multiple @vllm decorators.
36
+
37
+ Parameters
38
+ ----------
39
+ model: str
40
+ HuggingFace model identifier to be served by vLLM.
41
+ backend: str
42
+ Determines where and how to run the vLLM process.
43
+ debug: bool
44
+ Whether to turn on verbose debugging logs.
45
+ kwargs : Any
46
+ Any other keyword arguments are passed directly to the vLLM engine.
47
+ This allows for flexible configuration of vLLM server settings.
48
+ For example, `tensor_parallel_size=2`.
49
+ """
50
+
51
+ name = "vllm"
52
+ defaults = {
53
+ "model": None,
54
+ "backend": "local",
55
+ "debug": False,
56
+ "stream_logs_to_card": False,
57
+ "card_refresh_interval": 10,
58
+ "engine_args": {},
59
+ }
60
+
61
+ def step_init(
62
+ self, flow, graph, step_name, decorators, environment, flow_datastore, logger
63
+ ):
64
+ super().step_init(
65
+ flow, graph, step_name, decorators, environment, flow_datastore, logger
66
+ )
67
+
68
+ # Validate that a model is specified
69
+ if not self.attributes["model"]:
70
+ raise ValueError(
71
+ f"@vllm decorator on step '{step_name}' requires a 'model' parameter. "
72
+ f"Example: @vllm(model='meta-llama/Llama-3.2-1B')"
73
+ )
74
+
75
+ # Attach the vllm status card
76
+ self.attach_card_decorator(
77
+ flow,
78
+ step_name,
79
+ "vllm_status",
80
+ "blank",
81
+ refresh_interval=self.attributes["card_refresh_interval"],
82
+ )
83
+
84
+ def task_decorate(
85
+ self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
86
+ ):
87
+ @functools.wraps(step_func)
88
+ def vllm_wrapper():
89
+ self.vllm_manager = None
90
+ self.status_card = None
91
+ self.card_monitor_thread = None
92
+
93
+ try:
94
+ self.status_card = VLLMStatusCard(
95
+ refresh_interval=self.attributes["card_refresh_interval"]
96
+ )
97
+
98
+ def monitor_card():
99
+ try:
100
+ self.status_card.on_startup(current.card["vllm_status"])
101
+
102
+ while not getattr(
103
+ self.card_monitor_thread, "_stop_event", False
104
+ ):
105
+ try:
106
+ self.status_card.on_update(
107
+ current.card["vllm_status"], None
108
+ )
109
+ import time
110
+
111
+ time.sleep(self.attributes["card_refresh_interval"])
112
+ except Exception as e:
113
+ if self.attributes["debug"]:
114
+ print(f"[@vllm] Card monitoring error: {e}")
115
+ break
116
+ except Exception as e:
117
+ if self.attributes["debug"]:
118
+ print(f"[@vllm] Card monitor thread error: {e}")
119
+ self.status_card.on_error(current.card["vllm_status"], str(e))
120
+
121
+ self.card_monitor_thread = threading.Thread(
122
+ target=monitor_card, daemon=True
123
+ )
124
+ self.card_monitor_thread._stop_event = False
125
+ self.card_monitor_thread.start()
126
+ self.vllm_manager = VLLMManager(
127
+ model=self.attributes["model"],
128
+ backend=self.attributes["backend"],
129
+ debug=self.attributes["debug"],
130
+ status_card=self.status_card,
131
+ stream_logs_to_card=self.attributes["stream_logs_to_card"],
132
+ **self.attributes["engine_args"],
133
+ )
134
+ if self.attributes["debug"]:
135
+ print("[@vllm] VLLMManager initialized.")
136
+
137
+ except Exception as e:
138
+ if self.status_card:
139
+ self.status_card.add_event(
140
+ "error", f"Initialization failed: {str(e)}"
141
+ )
142
+ try:
143
+ self.status_card.on_error(current.card["vllm_status"], str(e))
144
+ except:
145
+ pass
146
+ print(f"[@vllm] Error initializing VLLMManager: {e}")
147
+ raise
148
+
149
+ try:
150
+ if self.status_card:
151
+ self.status_card.add_event("info", "Starting user step function")
152
+ step_func()
153
+ if self.status_card:
154
+ self.status_card.add_event(
155
+ "success", "User step function completed successfully"
156
+ )
157
+ finally:
158
+ if self.vllm_manager:
159
+ self.vllm_manager.terminate_models()
160
+
161
+ if self.card_monitor_thread and self.status_card:
162
+ import time
163
+
164
+ try:
165
+ self.status_card.on_update(current.card["vllm_status"], None)
166
+ except Exception as e:
167
+ if self.attributes["debug"]:
168
+ print(f"[@vllm] Final card update error: {e}")
169
+ time.sleep(2)
170
+
171
+ if self.card_monitor_thread:
172
+ self.card_monitor_thread._stop_event = True
173
+ self.card_monitor_thread.join(timeout=5)
174
+ if self.attributes["debug"]:
175
+ print("[@vllm] Card monitoring thread stopped.")
176
+
177
+ return vllm_wrapper
@@ -0,0 +1 @@
1
+ VLLM_SUFFIX = "mf.vllm"
@@ -0,0 +1 @@
1
+ from metaflow.exception import MetaflowException
@@ -0,0 +1,352 @@
1
+ from metaflow.cards import Markdown, Table, VegaChart
2
+ from metaflow.metaflow_current import current
3
+ from datetime import datetime
4
+ import threading
5
+ import time
6
+
7
+
8
+ from metaflow.exception import MetaflowException
9
+ from collections import defaultdict
10
+
11
+
12
+ class CardDecoratorInjector:
13
+ """
14
+ Mixin Useful for injecting @card decorators from other first class Metaflow decorators.
15
+ """
16
+
17
+ _first_time_init = defaultdict(dict)
18
+
19
+ @classmethod
20
+ def _get_first_time_init_cached_value(cls, step_name, card_id):
21
+ return cls._first_time_init.get(step_name, {}).get(card_id, None)
22
+
23
+ @classmethod
24
+ def _set_first_time_init_cached_value(cls, step_name, card_id, value):
25
+ cls._first_time_init[step_name][card_id] = value
26
+
27
+ def _card_deco_already_attached(self, step, card_id):
28
+ for decorator in step.decorators:
29
+ if decorator.name == "card":
30
+ if decorator.attributes["id"] and card_id == decorator.attributes["id"]:
31
+ return True
32
+ return False
33
+
34
+ def _get_step(self, flow, step_name):
35
+ for step in flow:
36
+ if step.name == step_name:
37
+ return step
38
+ return None
39
+
40
+ def _first_time_init_check(self, step_dag_node, card_id):
41
+ """ """
42
+ return not self._card_deco_already_attached(step_dag_node, card_id)
43
+
44
+ def attach_card_decorator(
45
+ self,
46
+ flow,
47
+ step_name,
48
+ card_id,
49
+ card_type,
50
+ refresh_interval=5,
51
+ ):
52
+ """
53
+ This method is called `step_init` in your StepDecorator code since
54
+ this class is used as a Mixin
55
+ """
56
+ from metaflow import decorators as _decorators
57
+
58
+ if not all([card_id, card_type]):
59
+ raise MetaflowException(
60
+ "`INJECTED_CARD_ID` and `INJECTED_CARD_TYPE` must be set in the `CardDecoratorInjector` Mixin"
61
+ )
62
+
63
+ step_dag_node = self._get_step(flow, step_name)
64
+ if (
65
+ self._get_first_time_init_cached_value(step_name, card_id) is None
66
+ ): # First check class level setting.
67
+ if self._first_time_init_check(step_dag_node, card_id):
68
+ self._set_first_time_init_cached_value(step_name, card_id, True)
69
+ _decorators._attach_decorators_to_step(
70
+ step_dag_node,
71
+ [
72
+ "card:type=%s,id=%s,refresh_interval=%s"
73
+ % (card_type, card_id, str(refresh_interval))
74
+ ],
75
+ )
76
+ else:
77
+ self._set_first_time_init_cached_value(step_name, card_id, False)
78
+
79
+
80
+ class CardRefresher:
81
+
82
+ CARD_ID = None
83
+
84
+ def on_startup(self, current_card):
85
+ raise NotImplementedError("make_card method must be implemented")
86
+
87
+ def on_error(self, current_card, error_message):
88
+ raise NotImplementedError("error_card method must be implemented")
89
+
90
+ def on_update(self, current_card, data_object):
91
+ raise NotImplementedError("update_card method must be implemented")
92
+
93
+ def sqlite_fetch_func(self, conn):
94
+ raise NotImplementedError("sqlite_fetch_func must be implemented")
95
+
96
+
97
+ class VLLMStatusCard(CardRefresher):
98
+ """
99
+ Real-time status card for vLLM system monitoring.
100
+ Shows server health, model status, and recent events.
101
+
102
+ Intended to be inherited from in a step decorator like this:
103
+ class VLLMDecorator(StepDecorator, VLLMStatusCard):
104
+ """
105
+
106
+ CARD_ID = "vllm_status"
107
+
108
+ def __init__(self, refresh_interval=10):
109
+ self.refresh_interval = refresh_interval
110
+ self.status_data = {
111
+ "server": {
112
+ "status": "Starting",
113
+ "uptime_start": None,
114
+ "last_health_check": None,
115
+ "health_status": "Unknown",
116
+ "models": [],
117
+ },
118
+ "models": {}, # model_name -> {status, load_time, etc}
119
+ "performance": {
120
+ "install_time": None,
121
+ "server_startup_time": None,
122
+ "total_initialization_time": None,
123
+ },
124
+ "versions": {
125
+ "vllm": "Detecting...",
126
+ },
127
+ "events": [], # Recent events log
128
+ "logs": [],
129
+ }
130
+ self._lock = threading.Lock()
131
+ self._already_rendered = False
132
+
133
+ def update_status(self, category, data):
134
+ """Thread-safe method to update status data"""
135
+ with self._lock:
136
+ if category in self.status_data:
137
+ self.status_data[category].update(data)
138
+
139
+ def add_log_line(self, log_line):
140
+ """Add a log line to the logs."""
141
+ with self._lock:
142
+ self.status_data["logs"].append(log_line)
143
+ # Keep only last 20 lines
144
+ self.status_data["logs"] = self.status_data["logs"][-20:]
145
+
146
+ def add_event(self, event_type, message, timestamp=None):
147
+ """Add an event to the timeline"""
148
+ if timestamp is None:
149
+ timestamp = datetime.now()
150
+
151
+ with self._lock:
152
+ self.status_data["events"].insert(
153
+ 0,
154
+ {
155
+ "type": event_type, # 'info', 'warning', 'error', 'success'
156
+ "message": message,
157
+ "timestamp": timestamp,
158
+ },
159
+ )
160
+ # Keep only last 10 events
161
+ self.status_data["events"] = self.status_data["events"][:10]
162
+
163
+ def get_circuit_breaker_emoji(self, state):
164
+ """Get status emoji for circuit breaker state"""
165
+ emoji_map = {"CLOSED": "🟢", "OPEN": "🔴", "HALF_OPEN": "🟡"}
166
+ return emoji_map.get(state, "⚪")
167
+
168
+ def get_uptime_string(self, start_time):
169
+ """Calculate uptime string"""
170
+ if not start_time:
171
+ return "Not started"
172
+
173
+ uptime = datetime.now() - start_time
174
+ hours, remainder = divmod(int(uptime.total_seconds()), 3600)
175
+ minutes, seconds = divmod(remainder, 60)
176
+
177
+ if hours > 0:
178
+ return f"{hours}h {minutes}m {seconds}s"
179
+ elif minutes > 0:
180
+ return f"{minutes}m {seconds}s"
181
+ else:
182
+ return f"{seconds}s"
183
+
184
+ def on_startup(self, current_card):
185
+ """Initialize the card when monitoring starts"""
186
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
187
+ current_card.append(Markdown("_Initializing vLLM system..._"))
188
+ current_card.refresh()
189
+
190
+ def render_card_fresh(self, current_card, data):
191
+ """Render the complete card with all status information"""
192
+ self._already_rendered = True
193
+ current_card.clear()
194
+
195
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
196
+
197
+ versions = data.get("versions", {})
198
+ vllm_version = versions.get("vllm", "Unknown")
199
+ current_card.append(Markdown(f"**vLLM Version:** `{vllm_version}`"))
200
+
201
+ current_card.append(
202
+ Markdown(f"_Last updated: {datetime.now().strftime('%H:%M:%S')}_")
203
+ )
204
+
205
+ server_data = data["server"]
206
+ uptime = self.get_uptime_string(server_data.get("uptime_start"))
207
+ server_status = server_data.get("status", "Unknown")
208
+ model = server_data.get("model", "Unknown")
209
+
210
+ # Determine status emoji
211
+ if server_status == "Running":
212
+ status_emoji = "🟢"
213
+ model_emoji = "✅"
214
+ elif server_status == "Failed":
215
+ status_emoji = "🔴"
216
+ model_emoji = "❌"
217
+ elif server_status == "Starting":
218
+ status_emoji = "🟡"
219
+ model_emoji = "⏳"
220
+ else: # Stopped, etc.
221
+ status_emoji = "⚫"
222
+ model_emoji = "⏹️"
223
+
224
+ # Main status section
225
+ current_card.append(
226
+ Markdown(f"## {status_emoji} Server Status: {server_status}")
227
+ )
228
+
229
+ if server_status == "Running" and uptime:
230
+ current_card.append(Markdown(f"**Uptime:** {uptime}"))
231
+
232
+ # Model information - only show detailed status if server is running
233
+ if server_status == "Running":
234
+ current_card.append(Markdown(f"## {model_emoji} Model: `{model}`"))
235
+
236
+ # Show model-specific status if available
237
+ models_data = data.get("models", {})
238
+ if models_data and model in models_data:
239
+ model_info = models_data[model]
240
+ model_status = model_info.get("status", "Unknown")
241
+ load_time = model_info.get("load_time")
242
+ location = model_info.get("location")
243
+
244
+ current_card.append(Markdown(f"**Status:** {model_status}"))
245
+ if location:
246
+ current_card.append(Markdown(f"**Location:** `{location}`"))
247
+ if load_time and isinstance(load_time, (int, float)):
248
+ current_card.append(Markdown(f"**Load Time:** {load_time:.1f}s"))
249
+ elif model != "Unknown":
250
+ current_card.append(
251
+ Markdown(f"## {model_emoji} Model: `{model}` (Server Stopped)")
252
+ )
253
+
254
+ # Simplified monitoring note
255
+ current_card.append(
256
+ Markdown(
257
+ "## 🔧 Monitoring\n**Advanced Features:** Disabled (Circuit Breaker, Request Interception)"
258
+ )
259
+ )
260
+
261
+ # Performance metrics
262
+ perf_data = data["performance"]
263
+ if any(v is not None for v in perf_data.values()):
264
+ current_card.append(Markdown("## ⚡ Performance"))
265
+
266
+ init_metrics = []
267
+ shutdown_metrics = []
268
+
269
+ for metric, value in perf_data.items():
270
+ if value is not None:
271
+ display_value = (
272
+ f"{value:.1f}s" if isinstance(value, (int, float)) else value
273
+ )
274
+ metric_display = metric.replace("_", " ").title()
275
+
276
+ if "shutdown" in metric.lower():
277
+ shutdown_metrics.append([metric_display, display_value])
278
+ elif metric in [
279
+ "install_time",
280
+ "server_startup_time",
281
+ "total_initialization_time",
282
+ ]:
283
+ init_metrics.append([metric_display, display_value])
284
+
285
+ if init_metrics:
286
+ current_card.append(Markdown("### Initialization"))
287
+ current_card.append(Table(init_metrics, headers=["Metric", "Duration"]))
288
+
289
+ if shutdown_metrics:
290
+ current_card.append(Markdown("### Shutdown"))
291
+ current_card.append(
292
+ Table(shutdown_metrics, headers=["Metric", "Value"])
293
+ )
294
+
295
+ # Recent events
296
+ events = data.get("events", [])
297
+ if events:
298
+ current_card.append(Markdown("## 📝 Recent Events"))
299
+ for event in events[:5]: # Show last 5 events
300
+ event_type = event.get("type", "info")
301
+ message = event.get("message", "")
302
+ timestamp = event.get("timestamp", datetime.now())
303
+
304
+ emoji_map = {
305
+ "info": "ℹ️",
306
+ "success": "✅",
307
+ "warning": "⚠️",
308
+ "error": "❌",
309
+ }
310
+ emoji = emoji_map.get(event_type, "ℹ️")
311
+
312
+ time_str = (
313
+ timestamp.strftime("%H:%M:%S")
314
+ if isinstance(timestamp, datetime)
315
+ else str(timestamp)
316
+ )
317
+ current_card.append(Markdown(f"- {emoji} `{time_str}` {message}"))
318
+
319
+ # Server Logs
320
+ logs = data.get("logs", [])
321
+ if logs:
322
+ current_card.append(Markdown("## 📜 Server Logs"))
323
+ # The logs are appended, so they are in chronological order.
324
+ log_content = "\n".join(logs)
325
+ current_card.append(Markdown(f"```\n{log_content}\n```"))
326
+
327
+ current_card.refresh()
328
+
329
+ def on_error(self, current_card, error_message):
330
+ """Handle errors in card rendering"""
331
+ if not self._already_rendered:
332
+ current_card.clear()
333
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
334
+ current_card.append(Markdown(f"## ❌ Error: {str(error_message)}"))
335
+ current_card.refresh()
336
+
337
+ def on_update(self, current_card, data_object):
338
+ """Update the card with new data"""
339
+ with self._lock:
340
+ current_data = self.status_data.copy()
341
+
342
+ if not self._already_rendered:
343
+ self.render_card_fresh(current_card, current_data)
344
+ else:
345
+ # For frequent updates, we could implement incremental updates here
346
+ # For now, just re-render the whole card
347
+ self.render_card_fresh(current_card, current_data)
348
+
349
+ def sqlite_fetch_func(self, conn):
350
+ """Required by CardRefresher (which needs a refactor), but we use in-memory data instead"""
351
+ with self._lock:
352
+ return {"status": self.status_data}
@@ -0,0 +1,471 @@
1
+ import subprocess
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ import time
4
+ import socket
5
+ import sys
6
+ import os
7
+ import functools
8
+ import json
9
+ import requests
10
+ from enum import Enum
11
+ import threading
12
+ from datetime import datetime
13
+
14
+ from .constants import VLLM_SUFFIX
15
+
16
+
17
+ class ProcessStatus:
18
+ RUNNING = "RUNNING"
19
+ FAILED = "FAILED"
20
+ SUCCESSFUL = "SUCCESSFUL"
21
+
22
+
23
+ class VLLMManager:
24
+ """
25
+ A process manager for vLLM runtimes.
26
+ Implements interface @vllm(model=..., ...) to provide a local backend.
27
+ It wraps the vLLM OpenAI-compatible API server to make it easier to profile vLLM use on Outerbounds.
28
+
29
+ NOTE: vLLM's OpenAI-compatible server serves ONE model per server instance.
30
+ If you need multiple models, you must create multiple server instances.
31
+
32
+ Example usage:
33
+ from vllm import LLM
34
+ llm = LLM(model="meta-llama/Llama-3.2-1B")
35
+ llm.generate("Hello, world!")
36
+
37
+ Or via OpenAI-compatible API:
38
+ import openai
39
+ client = openai.OpenAI(
40
+ base_url="http://localhost:8000/v1",
41
+ api_key="token-abc123"
42
+ )
43
+ response = client.chat.completions.create(
44
+ model="meta-llama/Llama-3.2-1B",
45
+ messages=[{"role": "user", "content": "Hello"}]
46
+ )
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model,
52
+ backend="local",
53
+ debug=False,
54
+ status_card=None,
55
+ port=8000,
56
+ host="127.0.0.1",
57
+ stream_logs_to_card=False,
58
+ **vllm_args,
59
+ ):
60
+ # Validate that only a single model is provided
61
+ if isinstance(model, list):
62
+ if len(model) != 1:
63
+ raise ValueError(
64
+ f"vLLM server can only serve one model per instance. "
65
+ f"Got {len(model)} models: {model}. "
66
+ f"Please specify a single model or create multiple @vllm decorators."
67
+ )
68
+ self.model = model[0]
69
+ else:
70
+ self.model = model
71
+
72
+ self.processes = {}
73
+ self.debug = debug
74
+ self.stream_logs_to_card = stream_logs_to_card
75
+ self.stats = {}
76
+ self.port = port
77
+ self.host = host
78
+ self.vllm_url = f"http://{host}:{port}"
79
+ self.status_card = status_card
80
+ self.initialization_start = time.time()
81
+ self.server_process = None
82
+ self.vllm_args = vllm_args
83
+
84
+ if backend != "local":
85
+ raise ValueError(
86
+ "VLLMManager only supports the 'local' backend at this time."
87
+ )
88
+
89
+ self._log_event("info", "Starting vLLM initialization")
90
+ self._update_server_status("Initializing")
91
+
92
+ self._timeit(self._install_vllm, "install_vllm")
93
+ self._timeit(self._launch_vllm_server, "launch_server")
94
+ self._collect_version_info()
95
+
96
+ total_init_time = time.time() - self.initialization_start
97
+ self._update_performance("total_initialization_time", total_init_time)
98
+ self._log_event(
99
+ "success", f"vLLM initialization completed in {total_init_time:.1f}s"
100
+ )
101
+
102
+ def _log_event(self, event_type, message):
103
+ if self.status_card:
104
+ self.status_card.add_event(event_type, message)
105
+ if self.debug:
106
+ print(f"[@vllm] {event_type.upper()}: {message}")
107
+
108
+ def _update_server_status(self, status, **kwargs):
109
+ if self.status_card:
110
+ update_data = {"status": status}
111
+ update_data.update(kwargs)
112
+ self.status_card.update_status("server", update_data)
113
+
114
+ def _update_model_status(self, model_name, **kwargs):
115
+ if self.status_card:
116
+ current_models = self.status_card.status_data.get("models", {})
117
+ if model_name not in current_models:
118
+ current_models[model_name] = {}
119
+ current_models[model_name].update(kwargs)
120
+ self.status_card.update_status("models", current_models)
121
+
122
+ def _update_performance(self, metric, value):
123
+ if self.status_card:
124
+ self.status_card.update_status("performance", {metric: value})
125
+
126
+ def _timeit(self, f, name):
127
+ t0 = time.time()
128
+ f()
129
+ tf = time.time()
130
+ duration = tf - t0
131
+ self.stats[name] = {"process_runtime": duration}
132
+
133
+ if name == "install_vllm":
134
+ self._update_performance("install_time", duration)
135
+ elif name == "launch_server":
136
+ self._update_performance("server_startup_time", duration)
137
+
138
+ def _stream_output(self, stream, prefix):
139
+ """Reads and logs output from a stream."""
140
+ for line in iter(stream.readline, ""):
141
+ if line:
142
+ line = line.strip()
143
+ if self.stream_logs_to_card and self.status_card:
144
+ self.status_card.add_log_line(f"[{prefix}] {line}")
145
+ elif self.debug:
146
+ print(f"[{prefix}] {line}")
147
+ stream.close()
148
+
149
+ def _is_port_open(self, host, port, timeout=1):
150
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
151
+ sock.settimeout(timeout)
152
+ try:
153
+ sock.connect((host, port))
154
+ return True
155
+ except socket.error:
156
+ return False
157
+
158
+ def _install_vllm(self):
159
+ self._log_event("info", "Checking for existing vLLM installation")
160
+ try:
161
+ import vllm
162
+
163
+ self._log_event("success", f"vLLM {vllm.__version__} is already installed")
164
+ if self.debug:
165
+ print(f"[@vllm] vLLM {vllm.__version__} is already installed.")
166
+ return
167
+ except ImportError as e:
168
+ self._log_event(
169
+ "Error", "vLLM not installed. Please add it to your environment."
170
+ )
171
+ if self.debug:
172
+ print(
173
+ "[@vllm] vLLM not found. The user is responsible for installation."
174
+ )
175
+ raise e
176
+ # We are not installing it automatically to respect user's environment management.
177
+
178
+ def _launch_vllm_server(self):
179
+ self._update_server_status("Starting")
180
+ self._log_event("info", f"Starting vLLM server with model: {self.model}")
181
+
182
+ # Check if the model is cached
183
+ hf_home = os.environ.get("HF_HOME")
184
+ if hf_home:
185
+ # Construct the expected cache path for the model
186
+ model_path_id = f"models--{self.model.replace('/', '--')}"
187
+ model_cache_path = os.path.join(hf_home, model_path_id)
188
+ if os.path.exists(model_cache_path):
189
+ self._log_event("info", f"Found cached model at: {model_cache_path}")
190
+ self._update_model_status(
191
+ self.model, status="Found in cache", location=model_cache_path
192
+ )
193
+ else:
194
+ self._log_event(
195
+ "warning",
196
+ f"Cached model not found at {model_cache_path}. vLLM will attempt to download it.",
197
+ )
198
+ self._update_model_status(self.model, status="Downloading")
199
+ else:
200
+ self._log_event(
201
+ "warning",
202
+ "HF_HOME environment variable not set. vLLM will use default cache location and may re-download.",
203
+ )
204
+
205
+ if not self.model:
206
+ raise ValueError("At least one model must be specified for @vllm.")
207
+
208
+ try:
209
+ if self.debug:
210
+ print(
211
+ f"[@vllm] Starting vLLM OpenAI-compatible server for model: {self.model}"
212
+ )
213
+
214
+ cmd = [
215
+ sys.executable,
216
+ "-m",
217
+ "vllm.entrypoints.openai.api_server",
218
+ "--model",
219
+ self.model,
220
+ "--host",
221
+ self.host,
222
+ "--port",
223
+ str(self.port),
224
+ ]
225
+
226
+ vllm_args_copy = self.vllm_args.copy()
227
+ if self.debug or self.stream_logs_to_card:
228
+ # Note: This is an undocumented argument for the vLLM OpenAI server entrypoint.
229
+ vllm_args_copy.setdefault("uvicorn_log_level", "debug")
230
+
231
+ for key, value in vllm_args_copy.items():
232
+ arg_name = f"--{key.replace('_', '-')}"
233
+ if isinstance(value, bool):
234
+ if value:
235
+ cmd.append(arg_name)
236
+ elif value is not None:
237
+ cmd.append(arg_name)
238
+ cmd.append(str(value))
239
+
240
+ # For debugging, log the exact command being run to the status card
241
+ command_str = " ".join(cmd)
242
+ self._log_event("info", f"Launch Command: `{command_str}`")
243
+ if self.debug:
244
+ print(f"[@vllm] Launching vLLM with command: {command_str}")
245
+
246
+ process = subprocess.Popen(
247
+ cmd,
248
+ stdout=subprocess.PIPE,
249
+ stderr=subprocess.PIPE,
250
+ text=True,
251
+ bufsize=1, # Line-buffered
252
+ )
253
+
254
+ # Threads to stream subprocess output
255
+ if self.debug or self.stream_logs_to_card:
256
+ stdout_thread = threading.Thread(
257
+ target=self._stream_output,
258
+ args=(process.stdout, "@vllm-server-out"),
259
+ )
260
+ stderr_thread = threading.Thread(
261
+ target=self._stream_output,
262
+ args=(process.stderr, "@vllm-server-err"),
263
+ )
264
+ stdout_thread.daemon = True
265
+ stderr_thread.daemon = True
266
+ stdout_thread.start()
267
+ stderr_thread.start()
268
+
269
+ self.server_process = process
270
+ self.processes[process.pid] = {
271
+ "p": process,
272
+ "properties": {
273
+ "type": "vllm-server",
274
+ "model": self.model,
275
+ "error_details": None,
276
+ },
277
+ "status": ProcessStatus.RUNNING,
278
+ }
279
+
280
+ if self.debug:
281
+ print(f"[@vllm] Started vLLM server process with PID {process.pid}")
282
+
283
+ retries = 0
284
+ max_retries = 240
285
+ while (
286
+ not self._is_port_open(self.host, self.port, timeout=2)
287
+ and retries < max_retries
288
+ ):
289
+ if retries == 0:
290
+ print("[@vllm] Waiting for server to be ready...")
291
+ elif retries % 10 == 0:
292
+ print(
293
+ f"[@vllm] Still waiting for server... ({retries}/{max_retries})"
294
+ )
295
+
296
+ returncode = process.poll()
297
+ if returncode is not None:
298
+ if self.debug or self.stream_logs_to_card:
299
+ # Threads are handling output, can't use communicate.
300
+ # The error has already been printed to the log by the thread.
301
+ if self.stream_logs_to_card:
302
+ details_msg = "See card for logs."
303
+ else:
304
+ details_msg = "See logs from @vllm-server-err for details."
305
+ error_details = f"Return code: {returncode}. {details_msg}"
306
+ else:
307
+ # No threads, so we can and should use communicate to get stderr.
308
+ stdout, stderr = process.communicate()
309
+ error_details = f"Return code: {returncode}, stderr: {stderr}"
310
+
311
+ self.processes[process.pid]["properties"][
312
+ "error_details"
313
+ ] = error_details
314
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
315
+ self._update_server_status("Failed", error_details=error_details)
316
+ self._log_event(
317
+ "error", f"vLLM server failed to start: {error_details}"
318
+ )
319
+ raise RuntimeError(f"vLLM server failed to start: {error_details}")
320
+
321
+ time.sleep(2)
322
+ retries += 1
323
+
324
+ if not self._is_port_open(self.host, self.port, timeout=2):
325
+ error_details = f"vLLM server did not start listening on {self.host}:{self.port} after {max_retries*2}s"
326
+ self.processes[process.pid]["properties"][
327
+ "error_details"
328
+ ] = error_details
329
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
330
+ self._update_server_status("Failed", error_details=error_details)
331
+ self._log_event("error", f"Server startup timeout: {error_details}")
332
+ raise RuntimeError(f"vLLM server failed to start: {error_details}")
333
+
334
+ if not self._verify_server_health():
335
+ error_details = "vLLM server started but failed health check"
336
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
337
+ self._update_server_status("Failed", error_details=error_details)
338
+ self._log_event("error", error_details)
339
+ raise RuntimeError(error_details)
340
+
341
+ self._update_server_status(
342
+ "Running", uptime_start=datetime.now(), model=self.model
343
+ )
344
+ self._log_event("success", "vLLM server is ready and listening")
345
+
346
+ self._update_model_status(self.model, status="Ready")
347
+
348
+ if self.debug:
349
+ print("[@vllm] Server is ready.")
350
+
351
+ except Exception as e:
352
+ if process and process.pid in self.processes:
353
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
354
+ self.processes[process.pid]["properties"]["error_details"] = str(e)
355
+ self._update_server_status("Failed", error_details=str(e))
356
+ self._log_event("error", f"Error starting vLLM server: {str(e)}")
357
+ raise RuntimeError(f"Error starting vLLM server: {e}") from e
358
+
359
+ def _verify_server_health(self):
360
+ try:
361
+ response = requests.get(f"{self.vllm_url}/v1/models", timeout=10)
362
+ if response.status_code == 200:
363
+ if self.debug:
364
+ models_data = response.json()
365
+ available_models = [
366
+ m.get("id", "unknown") for m in models_data.get("data", [])
367
+ ]
368
+ print(
369
+ f"[@vllm] Health check OK. Available models: {available_models}"
370
+ )
371
+ return True
372
+ else:
373
+ if self.debug:
374
+ print(
375
+ f"[@vllm] Health check failed with status {response.status_code}"
376
+ )
377
+ return False
378
+ except Exception as e:
379
+ if self.debug:
380
+ print(f"[@vllm] Health check exception: {e}")
381
+ return False
382
+
383
+ def _collect_version_info(self):
384
+ version_info = {}
385
+ try:
386
+ import vllm
387
+
388
+ version_info["vllm"] = getattr(vllm, "__version__", "Unknown")
389
+ except ImportError:
390
+ version_info["vllm"] = "Not installed"
391
+ except Exception as e:
392
+ version_info["vllm"] = "Error detecting"
393
+ if self.debug:
394
+ print(f"[@vllm] Error getting vLLM version: {e}")
395
+
396
+ if self.status_card:
397
+ self.status_card.update_status("versions", version_info)
398
+ self._log_event(
399
+ "info", f"vLLM version: {version_info.get('vllm', 'Unknown')}"
400
+ )
401
+
402
+ def terminate_models(self):
403
+ shutdown_start_time = time.time()
404
+ self._log_event("info", "Starting vLLM shutdown sequence")
405
+ if self.debug:
406
+ print("[@vllm] Shutting down vLLM server...")
407
+
408
+ server_shutdown_cause = "graceful"
409
+
410
+ if self.server_process:
411
+ try:
412
+ self._update_server_status("Stopping")
413
+ self._log_event("info", "Stopping vLLM server")
414
+
415
+ # Clear model status since server is shutting down
416
+ self._update_model_status(self.model, status="Stopping")
417
+
418
+ self.server_process.terminate()
419
+ try:
420
+ self.server_process.wait(timeout=10)
421
+ if self.debug:
422
+ print("[@vllm] Server terminated gracefully")
423
+ except subprocess.TimeoutExpired:
424
+ server_shutdown_cause = "force_kill"
425
+ self._log_event(
426
+ "warning",
427
+ "vLLM server did not terminate gracefully, killing...",
428
+ )
429
+ if self.debug:
430
+ print("[@vllm] Server did not terminate, killing...")
431
+ self.server_process.kill()
432
+ self.server_process.wait()
433
+
434
+ if self.server_process.pid in self.processes:
435
+ self.processes[self.server_process.pid][
436
+ "status"
437
+ ] = ProcessStatus.SUCCESSFUL
438
+
439
+ self._update_server_status("Stopped")
440
+ if self.status_card:
441
+ self.status_card.update_status("models", {})
442
+
443
+ self._log_event(
444
+ "success", f"vLLM server stopped ({server_shutdown_cause})"
445
+ )
446
+
447
+ except Exception as e:
448
+ server_shutdown_cause = "failed"
449
+ if self.server_process.pid in self.processes:
450
+ self.processes[self.server_process.pid][
451
+ "status"
452
+ ] = ProcessStatus.FAILED
453
+ self.processes[self.server_process.pid]["properties"][
454
+ "error_details"
455
+ ] = str(e)
456
+ self._update_server_status("Failed to stop")
457
+ if self.status_card:
458
+ self.status_card.update_status("models", {})
459
+ self._log_event("error", f"vLLM server shutdown error: {str(e)}")
460
+ if self.debug:
461
+ print(f"[@vllm] Warning: Error terminating vLLM server: {e}")
462
+
463
+ total_shutdown_time = time.time() - shutdown_start_time
464
+ self._update_performance("total_shutdown_time", total_shutdown_time)
465
+ self._update_performance("shutdown_cause", server_shutdown_cause)
466
+
467
+ self._log_event(
468
+ "success", f"vLLM shutdown completed in {total_shutdown_time:.1f}s"
469
+ )
470
+ if self.debug:
471
+ print("[@vllm] vLLM server shutdown complete.")
@@ -0,0 +1 @@
1
+ __mf_promote_submodules__ = ["plugins.vllm"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.166rc6
3
+ Version: 1.1.168rc0
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -1,7 +1,7 @@
1
1
  metaflow_extensions/outerbounds/__init__.py,sha256=Gb8u06s9ClQsA_vzxmkCzuMnigPy7kKcDnLfb7eB-64,514
2
2
  metaflow_extensions/outerbounds/remote_config.py,sha256=pEFJuKDYs98eoB_-ryPjVi9b_c4gpHMdBHE14ltoxIU,4672
3
3
  metaflow_extensions/outerbounds/config/__init__.py,sha256=JsQGRuGFz28fQWjUvxUgR8EKBLGRdLUIk_buPLJplJY,1225
4
- metaflow_extensions/outerbounds/plugins/__init__.py,sha256=HB298DqOlmM96-j_FjloDSgt1TZO9I-K2eJ7kevpHVQ,14092
4
+ metaflow_extensions/outerbounds/plugins/__init__.py,sha256=8NoHkOVUbd7YWK1S7GjAe7FzB_dTj-uPIZfla7-t-jE,14129
5
5
  metaflow_extensions/outerbounds/plugins/auth_server.py,sha256=_Q9_2EL0Xy77bCRphkwT1aSu8gQXRDOH-Z-RxTUO8N4,2202
6
6
  metaflow_extensions/outerbounds/plugins/perimeters.py,sha256=QXh3SFP7GQbS-RAIxUOPbhPzQ7KDFVxZkTdKqFKgXjI,2697
7
7
  metaflow_extensions/outerbounds/plugins/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -69,6 +69,11 @@ metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py,sha256=aQphxX6j
69
69
  metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py,sha256=AI_kcm1hZV3JRxJkookcH6twiGnAYjk9Dx-MeoYz60Y,8511
70
70
  metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py,sha256=9lUM4Cqi5RjrHBRfG6AQMRz8-R96eZC8Ih0KD2lv22Y,1858
71
71
  metaflow_extensions/outerbounds/plugins/torchtune/__init__.py,sha256=TOXNeyhcgd8VxplXO_oEuryFEsbk0tikn5GL0-44SU8,5853
72
+ metaflow_extensions/outerbounds/plugins/vllm/__init__.py,sha256=O04DPVoEdCZhPbvdldaE4ztoAxJNXU-ExosBCqe43v8,6463
73
+ metaflow_extensions/outerbounds/plugins/vllm/constants.py,sha256=ODX_uM5iYrzpVltsAdSf9Jo0DAOMiZ3647DcKdCnlS0,24
74
+ metaflow_extensions/outerbounds/plugins/vllm/exceptions.py,sha256=8m65k2L17zXgSkgU299DWqxr1wGUMsZgSJw0hBRizJ0,49
75
+ metaflow_extensions/outerbounds/plugins/vllm/status_card.py,sha256=ofTuBkhZgJq8cs-KwUOc85TB4hP9FT5qxagVEomY8jA,12931
76
+ metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py,sha256=E6SCBoanNNsFjm-IGGGf4VMVe1mueDiRWw0ZvCXizDQ,18496
72
77
  metaflow_extensions/outerbounds/profilers/__init__.py,sha256=wa_jhnCBr82TBxoS0e8b6_6sLyZX0fdHicuGJZNTqKw,29
73
78
  metaflow_extensions/outerbounds/profilers/gpu.py,sha256=3Er8uKQzfm_082uadg4yn_D4Y-iSCgzUfFmguYxZsz4,27485
74
79
  metaflow_extensions/outerbounds/toplevel/__init__.py,sha256=qWUJSv_r5hXJ7jV_On4nEasKIfUCm6_UjkjXWA_A1Ts,90
@@ -80,7 +85,8 @@ metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5
80
85
  metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py,sha256=GRSz2zwqkvlmFS6bcfYD_CX6CMko9DHQokMaH1iBshA,47
81
86
  metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py,sha256=LptpH-ziXHrednMYUjIaosS1SXD3sOtF_9_eRqd8SJw,50
82
87
  metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py,sha256=uTVkdSk3xZ7hEKYfdlyVteWj5KeDwaM1hU9WT-_YKfI,50
83
- ob_metaflow_extensions-1.1.166rc6.dist-info/METADATA,sha256=cuMhU0M1sJrUq2KaZu7BeeugszGQbe8bwEACJj4LytE,524
84
- ob_metaflow_extensions-1.1.166rc6.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
85
- ob_metaflow_extensions-1.1.166rc6.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
86
- ob_metaflow_extensions-1.1.166rc6.dist-info/RECORD,,
88
+ metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py,sha256=ekcgD3KVydf-a0xMI60P4uy6ePkSEoFHiGnDq1JM940,45
89
+ ob_metaflow_extensions-1.1.168rc0.dist-info/METADATA,sha256=X0qE3DgcUJ_xUa-sJRHRysv3NiAQiUHx2vCBiL8TlQM,524
90
+ ob_metaflow_extensions-1.1.168rc0.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
91
+ ob_metaflow_extensions-1.1.168rc0.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
92
+ ob_metaflow_extensions-1.1.168rc0.dist-info/RECORD,,