ob-metaflow-extensions 1.1.166rc5__py2.py3-none-any.whl → 1.1.167__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -331,7 +331,6 @@ CLIS_DESC = [
331
331
  ("nvct", ".nvct.nvct_cli.cli"),
332
332
  ("fast-bakery", ".fast_bakery.fast_bakery_cli.cli"),
333
333
  ("snowpark", ".snowpark.snowpark_cli.cli"),
334
- ("app", ".apps.app_cli.cli"),
335
334
  ]
336
335
  STEP_DECORATORS_DESC = [
337
336
  ("nvidia", ".nvcf.nvcf_decorator.NvcfDecorator"),
@@ -345,6 +344,7 @@ STEP_DECORATORS_DESC = [
345
344
  ("gpu_profile", ".profilers.gpu_profile_decorator.GPUProfileDecorator"),
346
345
  ("nim", ".nim.nim_decorator.NimDecorator"),
347
346
  ("ollama", ".ollama.OllamaDecorator"),
347
+ ("vllm", ".vllm.VLLMDecorator"),
348
348
  ("app_deploy", ".apps.deploy_decorator.WorkstationAppDeployDecorator"),
349
349
  ]
350
350
 
@@ -0,0 +1,177 @@
1
+ from metaflow.decorators import StepDecorator
2
+ from metaflow import current
3
+ import functools
4
+ import os
5
+ import threading
6
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
7
+ from metaflow.metaflow_config import from_conf
8
+
9
+ from .vllm_manager import VLLMManager
10
+ from .status_card import VLLMStatusCard, CardDecoratorInjector
11
+
12
+ __mf_promote_submodules__ = ["plugins.vllm"]
13
+
14
+
15
+ class VLLMDecorator(StepDecorator, CardDecoratorInjector):
16
+ """
17
+ This decorator is used to run vllm APIs as Metaflow task sidecars.
18
+
19
+ User code call
20
+ --------------
21
+ @vllm(
22
+ model="...",
23
+ ...
24
+ )
25
+
26
+ Valid backend options
27
+ ---------------------
28
+ - 'local': Run as a separate process on the local task machine.
29
+
30
+ Valid model options
31
+ -------------------
32
+ Any HuggingFace model identifier, e.g. 'meta-llama/Llama-3.2-1B'
33
+
34
+ NOTE: vLLM's OpenAI-compatible server serves ONE model per server instance.
35
+ If you need multiple models, you must create multiple @vllm decorators.
36
+
37
+ Parameters
38
+ ----------
39
+ model: str
40
+ HuggingFace model identifier to be served by vLLM.
41
+ backend: str
42
+ Determines where and how to run the vLLM process.
43
+ debug: bool
44
+ Whether to turn on verbose debugging logs.
45
+ kwargs : Any
46
+ Any other keyword arguments are passed directly to the vLLM engine.
47
+ This allows for flexible configuration of vLLM server settings.
48
+ For example, `tensor_parallel_size=2`.
49
+ """
50
+
51
+ name = "vllm"
52
+ defaults = {
53
+ "model": None,
54
+ "backend": "local",
55
+ "debug": False,
56
+ "stream_logs_to_card": False,
57
+ "card_refresh_interval": 10,
58
+ "engine_args": {},
59
+ }
60
+
61
+ def step_init(
62
+ self, flow, graph, step_name, decorators, environment, flow_datastore, logger
63
+ ):
64
+ super().step_init(
65
+ flow, graph, step_name, decorators, environment, flow_datastore, logger
66
+ )
67
+
68
+ # Validate that a model is specified
69
+ if not self.attributes["model"]:
70
+ raise ValueError(
71
+ f"@vllm decorator on step '{step_name}' requires a 'model' parameter. "
72
+ f"Example: @vllm(model='meta-llama/Llama-3.2-1B')"
73
+ )
74
+
75
+ # Attach the vllm status card
76
+ self.attach_card_decorator(
77
+ flow,
78
+ step_name,
79
+ "vllm_status",
80
+ "blank",
81
+ refresh_interval=self.attributes["card_refresh_interval"],
82
+ )
83
+
84
+ def task_decorate(
85
+ self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
86
+ ):
87
+ @functools.wraps(step_func)
88
+ def vllm_wrapper():
89
+ self.vllm_manager = None
90
+ self.status_card = None
91
+ self.card_monitor_thread = None
92
+
93
+ try:
94
+ self.status_card = VLLMStatusCard(
95
+ refresh_interval=self.attributes["card_refresh_interval"]
96
+ )
97
+
98
+ def monitor_card():
99
+ try:
100
+ self.status_card.on_startup(current.card["vllm_status"])
101
+
102
+ while not getattr(
103
+ self.card_monitor_thread, "_stop_event", False
104
+ ):
105
+ try:
106
+ self.status_card.on_update(
107
+ current.card["vllm_status"], None
108
+ )
109
+ import time
110
+
111
+ time.sleep(self.attributes["card_refresh_interval"])
112
+ except Exception as e:
113
+ if self.attributes["debug"]:
114
+ print(f"[@vllm] Card monitoring error: {e}")
115
+ break
116
+ except Exception as e:
117
+ if self.attributes["debug"]:
118
+ print(f"[@vllm] Card monitor thread error: {e}")
119
+ self.status_card.on_error(current.card["vllm_status"], str(e))
120
+
121
+ self.card_monitor_thread = threading.Thread(
122
+ target=monitor_card, daemon=True
123
+ )
124
+ self.card_monitor_thread._stop_event = False
125
+ self.card_monitor_thread.start()
126
+ self.vllm_manager = VLLMManager(
127
+ model=self.attributes["model"],
128
+ backend=self.attributes["backend"],
129
+ debug=self.attributes["debug"],
130
+ status_card=self.status_card,
131
+ stream_logs_to_card=self.attributes["stream_logs_to_card"],
132
+ **self.attributes["engine_args"],
133
+ )
134
+ if self.attributes["debug"]:
135
+ print("[@vllm] VLLMManager initialized.")
136
+
137
+ except Exception as e:
138
+ if self.status_card:
139
+ self.status_card.add_event(
140
+ "error", f"Initialization failed: {str(e)}"
141
+ )
142
+ try:
143
+ self.status_card.on_error(current.card["vllm_status"], str(e))
144
+ except:
145
+ pass
146
+ print(f"[@vllm] Error initializing VLLMManager: {e}")
147
+ raise
148
+
149
+ try:
150
+ if self.status_card:
151
+ self.status_card.add_event("info", "Starting user step function")
152
+ step_func()
153
+ if self.status_card:
154
+ self.status_card.add_event(
155
+ "success", "User step function completed successfully"
156
+ )
157
+ finally:
158
+ if self.vllm_manager:
159
+ self.vllm_manager.terminate_models()
160
+
161
+ if self.card_monitor_thread and self.status_card:
162
+ import time
163
+
164
+ try:
165
+ self.status_card.on_update(current.card["vllm_status"], None)
166
+ except Exception as e:
167
+ if self.attributes["debug"]:
168
+ print(f"[@vllm] Final card update error: {e}")
169
+ time.sleep(2)
170
+
171
+ if self.card_monitor_thread:
172
+ self.card_monitor_thread._stop_event = True
173
+ self.card_monitor_thread.join(timeout=5)
174
+ if self.attributes["debug"]:
175
+ print("[@vllm] Card monitoring thread stopped.")
176
+
177
+ return vllm_wrapper
@@ -0,0 +1 @@
1
+ VLLM_SUFFIX = "mf.vllm"
@@ -0,0 +1 @@
1
+ from metaflow.exception import MetaflowException
@@ -0,0 +1,352 @@
1
+ from metaflow.cards import Markdown, Table, VegaChart
2
+ from metaflow.metaflow_current import current
3
+ from datetime import datetime
4
+ import threading
5
+ import time
6
+
7
+
8
+ from metaflow.exception import MetaflowException
9
+ from collections import defaultdict
10
+
11
+
12
+ class CardDecoratorInjector:
13
+ """
14
+ Mixin Useful for injecting @card decorators from other first class Metaflow decorators.
15
+ """
16
+
17
+ _first_time_init = defaultdict(dict)
18
+
19
+ @classmethod
20
+ def _get_first_time_init_cached_value(cls, step_name, card_id):
21
+ return cls._first_time_init.get(step_name, {}).get(card_id, None)
22
+
23
+ @classmethod
24
+ def _set_first_time_init_cached_value(cls, step_name, card_id, value):
25
+ cls._first_time_init[step_name][card_id] = value
26
+
27
+ def _card_deco_already_attached(self, step, card_id):
28
+ for decorator in step.decorators:
29
+ if decorator.name == "card":
30
+ if decorator.attributes["id"] and card_id == decorator.attributes["id"]:
31
+ return True
32
+ return False
33
+
34
+ def _get_step(self, flow, step_name):
35
+ for step in flow:
36
+ if step.name == step_name:
37
+ return step
38
+ return None
39
+
40
+ def _first_time_init_check(self, step_dag_node, card_id):
41
+ """ """
42
+ return not self._card_deco_already_attached(step_dag_node, card_id)
43
+
44
+ def attach_card_decorator(
45
+ self,
46
+ flow,
47
+ step_name,
48
+ card_id,
49
+ card_type,
50
+ refresh_interval=5,
51
+ ):
52
+ """
53
+ This method is called `step_init` in your StepDecorator code since
54
+ this class is used as a Mixin
55
+ """
56
+ from metaflow import decorators as _decorators
57
+
58
+ if not all([card_id, card_type]):
59
+ raise MetaflowException(
60
+ "`INJECTED_CARD_ID` and `INJECTED_CARD_TYPE` must be set in the `CardDecoratorInjector` Mixin"
61
+ )
62
+
63
+ step_dag_node = self._get_step(flow, step_name)
64
+ if (
65
+ self._get_first_time_init_cached_value(step_name, card_id) is None
66
+ ): # First check class level setting.
67
+ if self._first_time_init_check(step_dag_node, card_id):
68
+ self._set_first_time_init_cached_value(step_name, card_id, True)
69
+ _decorators._attach_decorators_to_step(
70
+ step_dag_node,
71
+ [
72
+ "card:type=%s,id=%s,refresh_interval=%s"
73
+ % (card_type, card_id, str(refresh_interval))
74
+ ],
75
+ )
76
+ else:
77
+ self._set_first_time_init_cached_value(step_name, card_id, False)
78
+
79
+
80
+ class CardRefresher:
81
+
82
+ CARD_ID = None
83
+
84
+ def on_startup(self, current_card):
85
+ raise NotImplementedError("make_card method must be implemented")
86
+
87
+ def on_error(self, current_card, error_message):
88
+ raise NotImplementedError("error_card method must be implemented")
89
+
90
+ def on_update(self, current_card, data_object):
91
+ raise NotImplementedError("update_card method must be implemented")
92
+
93
+ def sqlite_fetch_func(self, conn):
94
+ raise NotImplementedError("sqlite_fetch_func must be implemented")
95
+
96
+
97
+ class VLLMStatusCard(CardRefresher):
98
+ """
99
+ Real-time status card for vLLM system monitoring.
100
+ Shows server health, model status, and recent events.
101
+
102
+ Intended to be inherited from in a step decorator like this:
103
+ class VLLMDecorator(StepDecorator, VLLMStatusCard):
104
+ """
105
+
106
+ CARD_ID = "vllm_status"
107
+
108
+ def __init__(self, refresh_interval=10):
109
+ self.refresh_interval = refresh_interval
110
+ self.status_data = {
111
+ "server": {
112
+ "status": "Starting",
113
+ "uptime_start": None,
114
+ "last_health_check": None,
115
+ "health_status": "Unknown",
116
+ "models": [],
117
+ },
118
+ "models": {}, # model_name -> {status, load_time, etc}
119
+ "performance": {
120
+ "install_time": None,
121
+ "server_startup_time": None,
122
+ "total_initialization_time": None,
123
+ },
124
+ "versions": {
125
+ "vllm": "Detecting...",
126
+ },
127
+ "events": [], # Recent events log
128
+ "logs": [],
129
+ }
130
+ self._lock = threading.Lock()
131
+ self._already_rendered = False
132
+
133
+ def update_status(self, category, data):
134
+ """Thread-safe method to update status data"""
135
+ with self._lock:
136
+ if category in self.status_data:
137
+ self.status_data[category].update(data)
138
+
139
+ def add_log_line(self, log_line):
140
+ """Add a log line to the logs."""
141
+ with self._lock:
142
+ self.status_data["logs"].append(log_line)
143
+ # Keep only last 20 lines
144
+ self.status_data["logs"] = self.status_data["logs"][-20:]
145
+
146
+ def add_event(self, event_type, message, timestamp=None):
147
+ """Add an event to the timeline"""
148
+ if timestamp is None:
149
+ timestamp = datetime.now()
150
+
151
+ with self._lock:
152
+ self.status_data["events"].insert(
153
+ 0,
154
+ {
155
+ "type": event_type, # 'info', 'warning', 'error', 'success'
156
+ "message": message,
157
+ "timestamp": timestamp,
158
+ },
159
+ )
160
+ # Keep only last 10 events
161
+ self.status_data["events"] = self.status_data["events"][:10]
162
+
163
+ def get_circuit_breaker_emoji(self, state):
164
+ """Get status emoji for circuit breaker state"""
165
+ emoji_map = {"CLOSED": "🟢", "OPEN": "🔴", "HALF_OPEN": "🟡"}
166
+ return emoji_map.get(state, "⚪")
167
+
168
+ def get_uptime_string(self, start_time):
169
+ """Calculate uptime string"""
170
+ if not start_time:
171
+ return "Not started"
172
+
173
+ uptime = datetime.now() - start_time
174
+ hours, remainder = divmod(int(uptime.total_seconds()), 3600)
175
+ minutes, seconds = divmod(remainder, 60)
176
+
177
+ if hours > 0:
178
+ return f"{hours}h {minutes}m {seconds}s"
179
+ elif minutes > 0:
180
+ return f"{minutes}m {seconds}s"
181
+ else:
182
+ return f"{seconds}s"
183
+
184
+ def on_startup(self, current_card):
185
+ """Initialize the card when monitoring starts"""
186
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
187
+ current_card.append(Markdown("_Initializing vLLM system..._"))
188
+ current_card.refresh()
189
+
190
+ def render_card_fresh(self, current_card, data):
191
+ """Render the complete card with all status information"""
192
+ self._already_rendered = True
193
+ current_card.clear()
194
+
195
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
196
+
197
+ versions = data.get("versions", {})
198
+ vllm_version = versions.get("vllm", "Unknown")
199
+ current_card.append(Markdown(f"**vLLM Version:** `{vllm_version}`"))
200
+
201
+ current_card.append(
202
+ Markdown(f"_Last updated: {datetime.now().strftime('%H:%M:%S')}_")
203
+ )
204
+
205
+ server_data = data["server"]
206
+ uptime = self.get_uptime_string(server_data.get("uptime_start"))
207
+ server_status = server_data.get("status", "Unknown")
208
+ model = server_data.get("model", "Unknown")
209
+
210
+ # Determine status emoji
211
+ if server_status == "Running":
212
+ status_emoji = "🟢"
213
+ model_emoji = "✅"
214
+ elif server_status == "Failed":
215
+ status_emoji = "🔴"
216
+ model_emoji = "❌"
217
+ elif server_status == "Starting":
218
+ status_emoji = "🟡"
219
+ model_emoji = "⏳"
220
+ else: # Stopped, etc.
221
+ status_emoji = "⚫"
222
+ model_emoji = "⏹️"
223
+
224
+ # Main status section
225
+ current_card.append(
226
+ Markdown(f"## {status_emoji} Server Status: {server_status}")
227
+ )
228
+
229
+ if server_status == "Running" and uptime:
230
+ current_card.append(Markdown(f"**Uptime:** {uptime}"))
231
+
232
+ # Model information - only show detailed status if server is running
233
+ if server_status == "Running":
234
+ current_card.append(Markdown(f"## {model_emoji} Model: `{model}`"))
235
+
236
+ # Show model-specific status if available
237
+ models_data = data.get("models", {})
238
+ if models_data and model in models_data:
239
+ model_info = models_data[model]
240
+ model_status = model_info.get("status", "Unknown")
241
+ load_time = model_info.get("load_time")
242
+ location = model_info.get("location")
243
+
244
+ current_card.append(Markdown(f"**Status:** {model_status}"))
245
+ if location:
246
+ current_card.append(Markdown(f"**Location:** `{location}`"))
247
+ if load_time and isinstance(load_time, (int, float)):
248
+ current_card.append(Markdown(f"**Load Time:** {load_time:.1f}s"))
249
+ elif model != "Unknown":
250
+ current_card.append(
251
+ Markdown(f"## {model_emoji} Model: `{model}` (Server Stopped)")
252
+ )
253
+
254
+ # Simplified monitoring note
255
+ current_card.append(
256
+ Markdown(
257
+ "## 🔧 Monitoring\n**Advanced Features:** Disabled (Circuit Breaker, Request Interception)"
258
+ )
259
+ )
260
+
261
+ # Performance metrics
262
+ perf_data = data["performance"]
263
+ if any(v is not None for v in perf_data.values()):
264
+ current_card.append(Markdown("## ⚡ Performance"))
265
+
266
+ init_metrics = []
267
+ shutdown_metrics = []
268
+
269
+ for metric, value in perf_data.items():
270
+ if value is not None:
271
+ display_value = (
272
+ f"{value:.1f}s" if isinstance(value, (int, float)) else value
273
+ )
274
+ metric_display = metric.replace("_", " ").title()
275
+
276
+ if "shutdown" in metric.lower():
277
+ shutdown_metrics.append([metric_display, display_value])
278
+ elif metric in [
279
+ "install_time",
280
+ "server_startup_time",
281
+ "total_initialization_time",
282
+ ]:
283
+ init_metrics.append([metric_display, display_value])
284
+
285
+ if init_metrics:
286
+ current_card.append(Markdown("### Initialization"))
287
+ current_card.append(Table(init_metrics, headers=["Metric", "Duration"]))
288
+
289
+ if shutdown_metrics:
290
+ current_card.append(Markdown("### Shutdown"))
291
+ current_card.append(
292
+ Table(shutdown_metrics, headers=["Metric", "Value"])
293
+ )
294
+
295
+ # Recent events
296
+ events = data.get("events", [])
297
+ if events:
298
+ current_card.append(Markdown("## 📝 Recent Events"))
299
+ for event in events[:5]: # Show last 5 events
300
+ event_type = event.get("type", "info")
301
+ message = event.get("message", "")
302
+ timestamp = event.get("timestamp", datetime.now())
303
+
304
+ emoji_map = {
305
+ "info": "ℹ️",
306
+ "success": "✅",
307
+ "warning": "⚠️",
308
+ "error": "❌",
309
+ }
310
+ emoji = emoji_map.get(event_type, "ℹ️")
311
+
312
+ time_str = (
313
+ timestamp.strftime("%H:%M:%S")
314
+ if isinstance(timestamp, datetime)
315
+ else str(timestamp)
316
+ )
317
+ current_card.append(Markdown(f"- {emoji} `{time_str}` {message}"))
318
+
319
+ # Server Logs
320
+ logs = data.get("logs", [])
321
+ if logs:
322
+ current_card.append(Markdown("## 📜 Server Logs"))
323
+ # The logs are appended, so they are in chronological order.
324
+ log_content = "\n".join(logs)
325
+ current_card.append(Markdown(f"```\n{log_content}\n```"))
326
+
327
+ current_card.refresh()
328
+
329
+ def on_error(self, current_card, error_message):
330
+ """Handle errors in card rendering"""
331
+ if not self._already_rendered:
332
+ current_card.clear()
333
+ current_card.append(Markdown("# 🚀 `@vllm` Status Dashboard"))
334
+ current_card.append(Markdown(f"## ❌ Error: {str(error_message)}"))
335
+ current_card.refresh()
336
+
337
+ def on_update(self, current_card, data_object):
338
+ """Update the card with new data"""
339
+ with self._lock:
340
+ current_data = self.status_data.copy()
341
+
342
+ if not self._already_rendered:
343
+ self.render_card_fresh(current_card, current_data)
344
+ else:
345
+ # For frequent updates, we could implement incremental updates here
346
+ # For now, just re-render the whole card
347
+ self.render_card_fresh(current_card, current_data)
348
+
349
+ def sqlite_fetch_func(self, conn):
350
+ """Required by CardRefresher (which needs a refactor), but we use in-memory data instead"""
351
+ with self._lock:
352
+ return {"status": self.status_data}
@@ -0,0 +1,471 @@
1
+ import subprocess
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ import time
4
+ import socket
5
+ import sys
6
+ import os
7
+ import functools
8
+ import json
9
+ import requests
10
+ from enum import Enum
11
+ import threading
12
+ from datetime import datetime
13
+
14
+ from .constants import VLLM_SUFFIX
15
+
16
+
17
+ class ProcessStatus:
18
+ RUNNING = "RUNNING"
19
+ FAILED = "FAILED"
20
+ SUCCESSFUL = "SUCCESSFUL"
21
+
22
+
23
+ class VLLMManager:
24
+ """
25
+ A process manager for vLLM runtimes.
26
+ Implements interface @vllm(model=..., ...) to provide a local backend.
27
+ It wraps the vLLM OpenAI-compatible API server to make it easier to profile vLLM use on Outerbounds.
28
+
29
+ NOTE: vLLM's OpenAI-compatible server serves ONE model per server instance.
30
+ If you need multiple models, you must create multiple server instances.
31
+
32
+ Example usage:
33
+ from vllm import LLM
34
+ llm = LLM(model="meta-llama/Llama-3.2-1B")
35
+ llm.generate("Hello, world!")
36
+
37
+ Or via OpenAI-compatible API:
38
+ import openai
39
+ client = openai.OpenAI(
40
+ base_url="http://localhost:8000/v1",
41
+ api_key="token-abc123"
42
+ )
43
+ response = client.chat.completions.create(
44
+ model="meta-llama/Llama-3.2-1B",
45
+ messages=[{"role": "user", "content": "Hello"}]
46
+ )
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model,
52
+ backend="local",
53
+ debug=False,
54
+ status_card=None,
55
+ port=8000,
56
+ host="127.0.0.1",
57
+ stream_logs_to_card=False,
58
+ **vllm_args,
59
+ ):
60
+ # Validate that only a single model is provided
61
+ if isinstance(model, list):
62
+ if len(model) != 1:
63
+ raise ValueError(
64
+ f"vLLM server can only serve one model per instance. "
65
+ f"Got {len(model)} models: {model}. "
66
+ f"Please specify a single model or create multiple @vllm decorators."
67
+ )
68
+ self.model = model[0]
69
+ else:
70
+ self.model = model
71
+
72
+ self.processes = {}
73
+ self.debug = debug
74
+ self.stream_logs_to_card = stream_logs_to_card
75
+ self.stats = {}
76
+ self.port = port
77
+ self.host = host
78
+ self.vllm_url = f"http://{host}:{port}"
79
+ self.status_card = status_card
80
+ self.initialization_start = time.time()
81
+ self.server_process = None
82
+ self.vllm_args = vllm_args
83
+
84
+ if backend != "local":
85
+ raise ValueError(
86
+ "VLLMManager only supports the 'local' backend at this time."
87
+ )
88
+
89
+ self._log_event("info", "Starting vLLM initialization")
90
+ self._update_server_status("Initializing")
91
+
92
+ self._timeit(self._install_vllm, "install_vllm")
93
+ self._timeit(self._launch_vllm_server, "launch_server")
94
+ self._collect_version_info()
95
+
96
+ total_init_time = time.time() - self.initialization_start
97
+ self._update_performance("total_initialization_time", total_init_time)
98
+ self._log_event(
99
+ "success", f"vLLM initialization completed in {total_init_time:.1f}s"
100
+ )
101
+
102
+ def _log_event(self, event_type, message):
103
+ if self.status_card:
104
+ self.status_card.add_event(event_type, message)
105
+ if self.debug:
106
+ print(f"[@vllm] {event_type.upper()}: {message}")
107
+
108
+ def _update_server_status(self, status, **kwargs):
109
+ if self.status_card:
110
+ update_data = {"status": status}
111
+ update_data.update(kwargs)
112
+ self.status_card.update_status("server", update_data)
113
+
114
+ def _update_model_status(self, model_name, **kwargs):
115
+ if self.status_card:
116
+ current_models = self.status_card.status_data.get("models", {})
117
+ if model_name not in current_models:
118
+ current_models[model_name] = {}
119
+ current_models[model_name].update(kwargs)
120
+ self.status_card.update_status("models", current_models)
121
+
122
+ def _update_performance(self, metric, value):
123
+ if self.status_card:
124
+ self.status_card.update_status("performance", {metric: value})
125
+
126
+ def _timeit(self, f, name):
127
+ t0 = time.time()
128
+ f()
129
+ tf = time.time()
130
+ duration = tf - t0
131
+ self.stats[name] = {"process_runtime": duration}
132
+
133
+ if name == "install_vllm":
134
+ self._update_performance("install_time", duration)
135
+ elif name == "launch_server":
136
+ self._update_performance("server_startup_time", duration)
137
+
138
+ def _stream_output(self, stream, prefix):
139
+ """Reads and logs output from a stream."""
140
+ for line in iter(stream.readline, ""):
141
+ if line:
142
+ line = line.strip()
143
+ if self.stream_logs_to_card and self.status_card:
144
+ self.status_card.add_log_line(f"[{prefix}] {line}")
145
+ elif self.debug:
146
+ print(f"[{prefix}] {line}")
147
+ stream.close()
148
+
149
+ def _is_port_open(self, host, port, timeout=1):
150
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
151
+ sock.settimeout(timeout)
152
+ try:
153
+ sock.connect((host, port))
154
+ return True
155
+ except socket.error:
156
+ return False
157
+
158
+ def _install_vllm(self):
159
+ self._log_event("info", "Checking for existing vLLM installation")
160
+ try:
161
+ import vllm
162
+
163
+ self._log_event("success", f"vLLM {vllm.__version__} is already installed")
164
+ if self.debug:
165
+ print(f"[@vllm] vLLM {vllm.__version__} is already installed.")
166
+ return
167
+ except ImportError as e:
168
+ self._log_event(
169
+ "Error", "vLLM not installed. Please add it to your environment."
170
+ )
171
+ if self.debug:
172
+ print(
173
+ "[@vllm] vLLM not found. The user is responsible for installation."
174
+ )
175
+ raise e
176
+ # We are not installing it automatically to respect user's environment management.
177
+
178
+ def _launch_vllm_server(self):
179
+ self._update_server_status("Starting")
180
+ self._log_event("info", f"Starting vLLM server with model: {self.model}")
181
+
182
+ # Check if the model is cached
183
+ hf_home = os.environ.get("HF_HOME")
184
+ if hf_home:
185
+ # Construct the expected cache path for the model
186
+ model_path_id = f"models--{self.model.replace('/', '--')}"
187
+ model_cache_path = os.path.join(hf_home, model_path_id)
188
+ if os.path.exists(model_cache_path):
189
+ self._log_event("info", f"Found cached model at: {model_cache_path}")
190
+ self._update_model_status(
191
+ self.model, status="Found in cache", location=model_cache_path
192
+ )
193
+ else:
194
+ self._log_event(
195
+ "warning",
196
+ f"Cached model not found at {model_cache_path}. vLLM will attempt to download it.",
197
+ )
198
+ self._update_model_status(self.model, status="Downloading")
199
+ else:
200
+ self._log_event(
201
+ "warning",
202
+ "HF_HOME environment variable not set. vLLM will use default cache location and may re-download.",
203
+ )
204
+
205
+ if not self.model:
206
+ raise ValueError("At least one model must be specified for @vllm.")
207
+
208
+ try:
209
+ if self.debug:
210
+ print(
211
+ f"[@vllm] Starting vLLM OpenAI-compatible server for model: {self.model}"
212
+ )
213
+
214
+ cmd = [
215
+ sys.executable,
216
+ "-m",
217
+ "vllm.entrypoints.openai.api_server",
218
+ "--model",
219
+ self.model,
220
+ "--host",
221
+ self.host,
222
+ "--port",
223
+ str(self.port),
224
+ ]
225
+
226
+ vllm_args_copy = self.vllm_args.copy()
227
+ if self.debug or self.stream_logs_to_card:
228
+ # Note: This is an undocumented argument for the vLLM OpenAI server entrypoint.
229
+ vllm_args_copy.setdefault("uvicorn_log_level", "debug")
230
+
231
+ for key, value in vllm_args_copy.items():
232
+ arg_name = f"--{key.replace('_', '-')}"
233
+ if isinstance(value, bool):
234
+ if value:
235
+ cmd.append(arg_name)
236
+ elif value is not None:
237
+ cmd.append(arg_name)
238
+ cmd.append(str(value))
239
+
240
+ # For debugging, log the exact command being run to the status card
241
+ command_str = " ".join(cmd)
242
+ self._log_event("info", f"Launch Command: `{command_str}`")
243
+ if self.debug:
244
+ print(f"[@vllm] Launching vLLM with command: {command_str}")
245
+
246
+ process = subprocess.Popen(
247
+ cmd,
248
+ stdout=subprocess.PIPE,
249
+ stderr=subprocess.PIPE,
250
+ text=True,
251
+ bufsize=1, # Line-buffered
252
+ )
253
+
254
+ # Threads to stream subprocess output
255
+ if self.debug or self.stream_logs_to_card:
256
+ stdout_thread = threading.Thread(
257
+ target=self._stream_output,
258
+ args=(process.stdout, "@vllm-server-out"),
259
+ )
260
+ stderr_thread = threading.Thread(
261
+ target=self._stream_output,
262
+ args=(process.stderr, "@vllm-server-err"),
263
+ )
264
+ stdout_thread.daemon = True
265
+ stderr_thread.daemon = True
266
+ stdout_thread.start()
267
+ stderr_thread.start()
268
+
269
+ self.server_process = process
270
+ self.processes[process.pid] = {
271
+ "p": process,
272
+ "properties": {
273
+ "type": "vllm-server",
274
+ "model": self.model,
275
+ "error_details": None,
276
+ },
277
+ "status": ProcessStatus.RUNNING,
278
+ }
279
+
280
+ if self.debug:
281
+ print(f"[@vllm] Started vLLM server process with PID {process.pid}")
282
+
283
+ retries = 0
284
+ max_retries = 240
285
+ while (
286
+ not self._is_port_open(self.host, self.port, timeout=2)
287
+ and retries < max_retries
288
+ ):
289
+ if retries == 0:
290
+ print("[@vllm] Waiting for server to be ready...")
291
+ elif retries % 10 == 0:
292
+ print(
293
+ f"[@vllm] Still waiting for server... ({retries}/{max_retries})"
294
+ )
295
+
296
+ returncode = process.poll()
297
+ if returncode is not None:
298
+ if self.debug or self.stream_logs_to_card:
299
+ # Threads are handling output, can't use communicate.
300
+ # The error has already been printed to the log by the thread.
301
+ if self.stream_logs_to_card:
302
+ details_msg = "See card for logs."
303
+ else:
304
+ details_msg = "See logs from @vllm-server-err for details."
305
+ error_details = f"Return code: {returncode}. {details_msg}"
306
+ else:
307
+ # No threads, so we can and should use communicate to get stderr.
308
+ stdout, stderr = process.communicate()
309
+ error_details = f"Return code: {returncode}, stderr: {stderr}"
310
+
311
+ self.processes[process.pid]["properties"][
312
+ "error_details"
313
+ ] = error_details
314
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
315
+ self._update_server_status("Failed", error_details=error_details)
316
+ self._log_event(
317
+ "error", f"vLLM server failed to start: {error_details}"
318
+ )
319
+ raise RuntimeError(f"vLLM server failed to start: {error_details}")
320
+
321
+ time.sleep(2)
322
+ retries += 1
323
+
324
+ if not self._is_port_open(self.host, self.port, timeout=2):
325
+ error_details = f"vLLM server did not start listening on {self.host}:{self.port} after {max_retries*2}s"
326
+ self.processes[process.pid]["properties"][
327
+ "error_details"
328
+ ] = error_details
329
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
330
+ self._update_server_status("Failed", error_details=error_details)
331
+ self._log_event("error", f"Server startup timeout: {error_details}")
332
+ raise RuntimeError(f"vLLM server failed to start: {error_details}")
333
+
334
+ if not self._verify_server_health():
335
+ error_details = "vLLM server started but failed health check"
336
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
337
+ self._update_server_status("Failed", error_details=error_details)
338
+ self._log_event("error", error_details)
339
+ raise RuntimeError(error_details)
340
+
341
+ self._update_server_status(
342
+ "Running", uptime_start=datetime.now(), model=self.model
343
+ )
344
+ self._log_event("success", "vLLM server is ready and listening")
345
+
346
+ self._update_model_status(self.model, status="Ready")
347
+
348
+ if self.debug:
349
+ print("[@vllm] Server is ready.")
350
+
351
+ except Exception as e:
352
+ if process and process.pid in self.processes:
353
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
354
+ self.processes[process.pid]["properties"]["error_details"] = str(e)
355
+ self._update_server_status("Failed", error_details=str(e))
356
+ self._log_event("error", f"Error starting vLLM server: {str(e)}")
357
+ raise RuntimeError(f"Error starting vLLM server: {e}") from e
358
+
359
+ def _verify_server_health(self):
360
+ try:
361
+ response = requests.get(f"{self.vllm_url}/v1/models", timeout=10)
362
+ if response.status_code == 200:
363
+ if self.debug:
364
+ models_data = response.json()
365
+ available_models = [
366
+ m.get("id", "unknown") for m in models_data.get("data", [])
367
+ ]
368
+ print(
369
+ f"[@vllm] Health check OK. Available models: {available_models}"
370
+ )
371
+ return True
372
+ else:
373
+ if self.debug:
374
+ print(
375
+ f"[@vllm] Health check failed with status {response.status_code}"
376
+ )
377
+ return False
378
+ except Exception as e:
379
+ if self.debug:
380
+ print(f"[@vllm] Health check exception: {e}")
381
+ return False
382
+
383
+ def _collect_version_info(self):
384
+ version_info = {}
385
+ try:
386
+ import vllm
387
+
388
+ version_info["vllm"] = getattr(vllm, "__version__", "Unknown")
389
+ except ImportError:
390
+ version_info["vllm"] = "Not installed"
391
+ except Exception as e:
392
+ version_info["vllm"] = "Error detecting"
393
+ if self.debug:
394
+ print(f"[@vllm] Error getting vLLM version: {e}")
395
+
396
+ if self.status_card:
397
+ self.status_card.update_status("versions", version_info)
398
+ self._log_event(
399
+ "info", f"vLLM version: {version_info.get('vllm', 'Unknown')}"
400
+ )
401
+
402
+ def terminate_models(self):
403
+ shutdown_start_time = time.time()
404
+ self._log_event("info", "Starting vLLM shutdown sequence")
405
+ if self.debug:
406
+ print("[@vllm] Shutting down vLLM server...")
407
+
408
+ server_shutdown_cause = "graceful"
409
+
410
+ if self.server_process:
411
+ try:
412
+ self._update_server_status("Stopping")
413
+ self._log_event("info", "Stopping vLLM server")
414
+
415
+ # Clear model status since server is shutting down
416
+ self._update_model_status(self.model, status="Stopping")
417
+
418
+ self.server_process.terminate()
419
+ try:
420
+ self.server_process.wait(timeout=10)
421
+ if self.debug:
422
+ print("[@vllm] Server terminated gracefully")
423
+ except subprocess.TimeoutExpired:
424
+ server_shutdown_cause = "force_kill"
425
+ self._log_event(
426
+ "warning",
427
+ "vLLM server did not terminate gracefully, killing...",
428
+ )
429
+ if self.debug:
430
+ print("[@vllm] Server did not terminate, killing...")
431
+ self.server_process.kill()
432
+ self.server_process.wait()
433
+
434
+ if self.server_process.pid in self.processes:
435
+ self.processes[self.server_process.pid][
436
+ "status"
437
+ ] = ProcessStatus.SUCCESSFUL
438
+
439
+ self._update_server_status("Stopped")
440
+ if self.status_card:
441
+ self.status_card.update_status("models", {})
442
+
443
+ self._log_event(
444
+ "success", f"vLLM server stopped ({server_shutdown_cause})"
445
+ )
446
+
447
+ except Exception as e:
448
+ server_shutdown_cause = "failed"
449
+ if self.server_process.pid in self.processes:
450
+ self.processes[self.server_process.pid][
451
+ "status"
452
+ ] = ProcessStatus.FAILED
453
+ self.processes[self.server_process.pid]["properties"][
454
+ "error_details"
455
+ ] = str(e)
456
+ self._update_server_status("Failed to stop")
457
+ if self.status_card:
458
+ self.status_card.update_status("models", {})
459
+ self._log_event("error", f"vLLM server shutdown error: {str(e)}")
460
+ if self.debug:
461
+ print(f"[@vllm] Warning: Error terminating vLLM server: {e}")
462
+
463
+ total_shutdown_time = time.time() - shutdown_start_time
464
+ self._update_performance("total_shutdown_time", total_shutdown_time)
465
+ self._update_performance("shutdown_cause", server_shutdown_cause)
466
+
467
+ self._log_event(
468
+ "success", f"vLLM shutdown completed in {total_shutdown_time:.1f}s"
469
+ )
470
+ if self.debug:
471
+ print("[@vllm] vLLM server shutdown complete.")
@@ -75,4 +75,3 @@ from .. import profilers
75
75
  from ..plugins.snowflake import Snowflake
76
76
  from ..plugins.checkpoint_datastores import nebius_checkpoints, coreweave_checkpoints
77
77
  from ..plugins.aws import assume_role
78
- from . import ob_internal
@@ -0,0 +1 @@
1
+ __mf_promote_submodules__ = ["plugins.vllm"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.166rc5
3
+ Version: 1.1.167
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -1,11 +1,10 @@
1
1
  metaflow_extensions/outerbounds/__init__.py,sha256=Gb8u06s9ClQsA_vzxmkCzuMnigPy7kKcDnLfb7eB-64,514
2
2
  metaflow_extensions/outerbounds/remote_config.py,sha256=pEFJuKDYs98eoB_-ryPjVi9b_c4gpHMdBHE14ltoxIU,4672
3
3
  metaflow_extensions/outerbounds/config/__init__.py,sha256=JsQGRuGFz28fQWjUvxUgR8EKBLGRdLUIk_buPLJplJY,1225
4
- metaflow_extensions/outerbounds/plugins/__init__.py,sha256=HB298DqOlmM96-j_FjloDSgt1TZO9I-K2eJ7kevpHVQ,14092
4
+ metaflow_extensions/outerbounds/plugins/__init__.py,sha256=gQIy3VaDg5e8tbx5q5TflddxhS6N3B7TaCWZEtnqJbo,14095
5
5
  metaflow_extensions/outerbounds/plugins/auth_server.py,sha256=_Q9_2EL0Xy77bCRphkwT1aSu8gQXRDOH-Z-RxTUO8N4,2202
6
6
  metaflow_extensions/outerbounds/plugins/perimeters.py,sha256=QXh3SFP7GQbS-RAIxUOPbhPzQ7KDFVxZkTdKqFKgXjI,2697
7
7
  metaflow_extensions/outerbounds/plugins/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py,sha256=erfKCC6zKuwax0ye9j-tqIkAVdCg7MJfbRGhMhViSzU,575
9
8
  metaflow_extensions/outerbounds/plugins/apps/app_utils.py,sha256=sw9whU17lAzlD2K2kEDNjlk1Ib-2xE2UNhJkmzD8Qv8,8543
10
9
  metaflow_extensions/outerbounds/plugins/apps/consts.py,sha256=iHsyqbUg9k-rgswCs1Jxf5QZIxR1V-peCDRjgr9kdBM,177
11
10
  metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py,sha256=VkmiMdNYHhNdt-Qm9AVv7aE2LWFsIFEc16YcOYjwF6Q,8568
@@ -21,7 +20,6 @@ metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py,sha256
21
20
  metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py,sha256=_WzoOROFjoFa8TzsMNFp-r_1Zz7NUp-5ljn_kKlczXA,4534
22
21
  metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py,sha256=zgqDLFewCeF5jqh-hUNKmC_OAjld09ln0bb8Lkeqapc,4659
23
22
  metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py,sha256=ShE5omFBr83wkvEhL_ptRFvDNMs6wefg4BjaafQjTcM,3602
25
23
  metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py,sha256=Tl520HdBteg-aDOM7mnnJJpdDCZc49BmFFmLUc_vTi8,15018
26
24
  metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py,sha256=PE81ZB54OAMXkMGSB7JqgvgMg7N9kvoVclrWL-6jc2U,5626
27
25
  metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py,sha256=kqFyu2bJSnc9_9aYfBpz5xK6L6luWFZK_NMuh8f1eVk,1494
@@ -69,18 +67,23 @@ metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py,sha256=aQphxX6j
69
67
  metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py,sha256=AI_kcm1hZV3JRxJkookcH6twiGnAYjk9Dx-MeoYz60Y,8511
70
68
  metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py,sha256=9lUM4Cqi5RjrHBRfG6AQMRz8-R96eZC8Ih0KD2lv22Y,1858
71
69
  metaflow_extensions/outerbounds/plugins/torchtune/__init__.py,sha256=TOXNeyhcgd8VxplXO_oEuryFEsbk0tikn5GL0-44SU8,5853
70
+ metaflow_extensions/outerbounds/plugins/vllm/__init__.py,sha256=O04DPVoEdCZhPbvdldaE4ztoAxJNXU-ExosBCqe43v8,6463
71
+ metaflow_extensions/outerbounds/plugins/vllm/constants.py,sha256=ODX_uM5iYrzpVltsAdSf9Jo0DAOMiZ3647DcKdCnlS0,24
72
+ metaflow_extensions/outerbounds/plugins/vllm/exceptions.py,sha256=8m65k2L17zXgSkgU299DWqxr1wGUMsZgSJw0hBRizJ0,49
73
+ metaflow_extensions/outerbounds/plugins/vllm/status_card.py,sha256=ofTuBkhZgJq8cs-KwUOc85TB4hP9FT5qxagVEomY8jA,12931
74
+ metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py,sha256=E6SCBoanNNsFjm-IGGGf4VMVe1mueDiRWw0ZvCXizDQ,18496
72
75
  metaflow_extensions/outerbounds/profilers/__init__.py,sha256=wa_jhnCBr82TBxoS0e8b6_6sLyZX0fdHicuGJZNTqKw,29
73
76
  metaflow_extensions/outerbounds/profilers/gpu.py,sha256=3Er8uKQzfm_082uadg4yn_D4Y-iSCgzUfFmguYxZsz4,27485
74
77
  metaflow_extensions/outerbounds/toplevel/__init__.py,sha256=qWUJSv_r5hXJ7jV_On4nEasKIfUCm6_UjkjXWA_A1Ts,90
75
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=_fKWv_-O1k5Nk5A1q05Ioh-PSsFXGL-jiAt7zfl8pIE,2999
76
- metaflow_extensions/outerbounds/toplevel/ob_internal.py,sha256=53xM6d_UYT3uGFFA59UzxN23H5QMO5_F39pALpmGy04,51
78
+ metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=FS0ZKQJsJKlw9PgtLqVV4kH9o7_qwYcIVtyu2Kqwa_U,2973
77
79
  metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py,sha256=WUuhz2YQfI4fz7nIcipwwWq781eaoHEk7n4GAn1npDg,63
78
80
  metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py,sha256=BbZiaH3uILlEZ6ntBLKeNyqn3If8nIXZFq_Apd7Dhco,70
79
81
  metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
80
82
  metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py,sha256=GRSz2zwqkvlmFS6bcfYD_CX6CMko9DHQokMaH1iBshA,47
81
83
  metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py,sha256=LptpH-ziXHrednMYUjIaosS1SXD3sOtF_9_eRqd8SJw,50
82
84
  metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py,sha256=uTVkdSk3xZ7hEKYfdlyVteWj5KeDwaM1hU9WT-_YKfI,50
83
- ob_metaflow_extensions-1.1.166rc5.dist-info/METADATA,sha256=EuJI-hYrzPPZYZcyQhz8GP0PHTOpJ10pM1SQdH2vFPo,524
84
- ob_metaflow_extensions-1.1.166rc5.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
85
- ob_metaflow_extensions-1.1.166rc5.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
86
- ob_metaflow_extensions-1.1.166rc5.dist-info/RECORD,,
85
+ metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py,sha256=ekcgD3KVydf-a0xMI60P4uy6ePkSEoFHiGnDq1JM940,45
86
+ ob_metaflow_extensions-1.1.167.dist-info/METADATA,sha256=0WnYu65sqUQZp6jEd6aJrLMBEx6Nm1VIB7lFsfmQXaU,521
87
+ ob_metaflow_extensions-1.1.167.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
88
+ ob_metaflow_extensions-1.1.167.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
89
+ ob_metaflow_extensions-1.1.167.dist-info/RECORD,,
@@ -1,26 +0,0 @@
1
- from metaflow._vendor import click
2
-
3
- OUTERBOUNDS_APP_CLI_AVAILABLE = True
4
- try:
5
- import outerbounds.apps.app_cli as ob_apps_cli
6
- except ImportError:
7
- OUTERBOUNDS_APP_CLI_AVAILABLE = False
8
-
9
-
10
- if not OUTERBOUNDS_APP_CLI_AVAILABLE:
11
-
12
- @click.group()
13
- def _cli():
14
- pass
15
-
16
- @_cli.group(help="Dummy Group to append to CLI for Safety")
17
- def app():
18
- pass
19
-
20
- @app.command(help="Dummy Command to append to CLI for Safety")
21
- def cannot_deploy():
22
- raise Exception("Outerbounds App CLI not available")
23
-
24
- cli = _cli
25
- else:
26
- cli = ob_apps_cli.cli
@@ -1,110 +0,0 @@
1
- import threading
2
- import time
3
- import sys
4
- from typing import Dict, Optional, Any, Callable
5
- from functools import partial
6
- from metaflow.exception import MetaflowException
7
- from metaflow.metaflow_config import FAST_BAKERY_URL
8
-
9
- from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
10
- from .docker_environment import cache_request
11
-
12
- BAKERY_METAFILE = ".imagebakery-cache"
13
-
14
-
15
- class BakerException(MetaflowException):
16
- headline = "Ran into an error while baking image"
17
-
18
- def __init__(self, msg):
19
- super(BakerException, self).__init__(msg)
20
-
21
-
22
- def bake_image(
23
- cache_file_path: str,
24
- ref: Optional[str] = None,
25
- python: Optional[str] = None,
26
- pypi_packages: Optional[Dict[str, str]] = None,
27
- conda_packages: Optional[Dict[str, str]] = None,
28
- base_image: Optional[str] = None,
29
- logger: Optional[Callable[[str], Any]] = None,
30
- ) -> FastBakeryApiResponse:
31
- """
32
- Bakes a Docker image with the specified dependencies.
33
-
34
- Args:
35
- cache_file_path: Path to the cache file
36
- ref: Reference identifier for this bake (for logging purposes)
37
- python: Python version to use
38
- pypi_packages: Dictionary of PyPI packages and versions
39
- conda_packages: Dictionary of Conda packages and versions
40
- base_image: Base Docker image to use
41
- logger: Optional logger function to output progress
42
-
43
- Returns:
44
- FastBakeryApiResponse: The response from the bakery service
45
-
46
- Raises:
47
- BakerException: If the baking process fails
48
- """
49
- # Default logger if none provided
50
- if logger is None:
51
- logger = partial(print, file=sys.stderr)
52
-
53
- # Thread lock for logging
54
- logger_lock = threading.Lock()
55
- images_baked = 0
56
-
57
- @cache_request(cache_file_path)
58
- def _cached_bake(
59
- ref=None,
60
- python=None,
61
- pypi_packages=None,
62
- conda_packages=None,
63
- base_image=None,
64
- ):
65
- try:
66
- bakery = FastBakery(url=FAST_BAKERY_URL)
67
- bakery._reset_payload()
68
- bakery.python_version(python)
69
- bakery.pypi_packages(pypi_packages)
70
- bakery.conda_packages(conda_packages)
71
- bakery.base_image(base_image)
72
- # bakery.ignore_cache()
73
-
74
- with logger_lock:
75
- logger(f"🍳 Baking [{ref}] ...")
76
- logger(f" 🐍 Python: {python}")
77
-
78
- if pypi_packages:
79
- logger(f" 📦 PyPI packages:")
80
- for package, version in pypi_packages.items():
81
- logger(f" 🔧 {package}: {version}")
82
-
83
- if conda_packages:
84
- logger(f" 📦 Conda packages:")
85
- for package, version in conda_packages.items():
86
- logger(f" 🔧 {package}: {version}")
87
-
88
- logger(f" 🏗️ Base image: {base_image}")
89
-
90
- start_time = time.time()
91
- res = bakery.bake()
92
- # TODO: Get actual bake time from bakery
93
- bake_time = time.time() - start_time
94
-
95
- with logger_lock:
96
- logger(f"🏁 Baked [{ref}] in {bake_time:.2f} seconds!")
97
- nonlocal images_baked
98
- images_baked += 1
99
- return res
100
- except FastBakeryException as ex:
101
- raise BakerException(f"Bake [{ref}] failed: {str(ex)}")
102
-
103
- # Call the cached bake function with the provided parameters
104
- return _cached_bake(
105
- ref=ref,
106
- python=python,
107
- pypi_packages=pypi_packages,
108
- conda_packages=conda_packages,
109
- base_image=base_image,
110
- )
@@ -1 +0,0 @@
1
- from ..plugins.fast_bakery.baker import bake_image