ob-metaflow-extensions 1.1.142__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow_extensions/outerbounds/__init__.py +1 -1
- metaflow_extensions/outerbounds/plugins/__init__.py +26 -5
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
- metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +44 -4
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/remote_config.py +27 -3
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +87 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
- metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
- ob_metaflow_extensions-1.1.142.dist-info/RECORD +0 -64
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
|
@@ -7,9 +7,10 @@ from metaflow.metaflow_config import SERVICE_URL
|
|
|
7
7
|
from metaflow.metaflow_config_funcs import init_config
|
|
8
8
|
from typing import Dict
|
|
9
9
|
from os import environ
|
|
10
|
-
|
|
10
|
+
import sys
|
|
11
11
|
import json
|
|
12
12
|
import requests
|
|
13
|
+
import random
|
|
13
14
|
import time
|
|
14
15
|
|
|
15
16
|
|
|
@@ -75,7 +76,9 @@ def get_snowflake_token(user: str = "", role: str = "", integration: str = "") -
|
|
|
75
76
|
}
|
|
76
77
|
json_payload = json.dumps(payload)
|
|
77
78
|
headers = provisioner.get_service_auth_header()
|
|
78
|
-
response =
|
|
79
|
+
response = _api_server_get(
|
|
80
|
+
snowflake_token_url, data=json_payload, headers=headers, conn_error_retries=5
|
|
81
|
+
)
|
|
79
82
|
response.raise_for_status()
|
|
80
83
|
return response.json()["token"]
|
|
81
84
|
|
|
@@ -150,6 +153,39 @@ def connect(user: str = "", role: str = "", integration: str = "", **kwargs):
|
|
|
150
153
|
)
|
|
151
154
|
|
|
152
155
|
|
|
156
|
+
def _api_server_get(*args, conn_error_retries=2, **kwargs):
|
|
157
|
+
"""
|
|
158
|
+
There are two categories of errors that we need to handle when dealing with any API server.
|
|
159
|
+
1. HTTP errors. These are are errors that are returned from the API server.
|
|
160
|
+
- How to handle retries for this case will be application specific.
|
|
161
|
+
2. Errors when the API server may not be reachable (DNS resolution / network issues)
|
|
162
|
+
- In this scenario, we know that something external to the API server is going wrong causing the issue.
|
|
163
|
+
- Failing pre-maturely in the case might not be the best course of action since critical user jobs might crash on intermittent issues.
|
|
164
|
+
- So in this case, we can just planely retry the request.
|
|
165
|
+
|
|
166
|
+
This function handles the second case. It's a simple wrapper to handle the retry logic for connection errors.
|
|
167
|
+
If this function is provided a `conn_error_retries` of 5, then the last retry will have waited 32 seconds.
|
|
168
|
+
Generally this is a safe enough number of retries after which we can assume that something is really broken. Until then,
|
|
169
|
+
there can be intermittent issues that would resolve themselves if we retry gracefully.
|
|
170
|
+
"""
|
|
171
|
+
_num_retries = 0
|
|
172
|
+
noise = random.uniform(-0.5, 0.5)
|
|
173
|
+
while _num_retries < conn_error_retries:
|
|
174
|
+
try:
|
|
175
|
+
return requests.get(*args, **kwargs)
|
|
176
|
+
except requests.exceptions.ConnectionError:
|
|
177
|
+
if _num_retries <= conn_error_retries - 1:
|
|
178
|
+
# Exponential backoff with 2^(_num_retries+1) seconds
|
|
179
|
+
time.sleep((2 ** (_num_retries + 1)) + noise)
|
|
180
|
+
_num_retries += 1
|
|
181
|
+
else:
|
|
182
|
+
print(
|
|
183
|
+
"[@snowflake] Failed to connect to the API server. ",
|
|
184
|
+
file=sys.stderr,
|
|
185
|
+
)
|
|
186
|
+
raise
|
|
187
|
+
|
|
188
|
+
|
|
153
189
|
class Snowflake:
|
|
154
190
|
def __init__(
|
|
155
191
|
self, user: str = "", role: str = "", integration: str = "", **kwargs
|
|
@@ -273,7 +309,9 @@ class SnowflakeIntegrationProvisioner:
|
|
|
273
309
|
retryable_status_codes = [409]
|
|
274
310
|
json_payload = json.dumps(payload)
|
|
275
311
|
for attempt in range(2): # 0 = initial attempt, 1-2 = retries
|
|
276
|
-
response =
|
|
312
|
+
response = _api_server_get(
|
|
313
|
+
url, data=json_payload, headers=request_headers, conn_error_retries=5
|
|
314
|
+
)
|
|
277
315
|
if response.status_code not in retryable_status_codes:
|
|
278
316
|
break
|
|
279
317
|
|
|
@@ -281,7 +319,9 @@ class SnowflakeIntegrationProvisioner:
|
|
|
281
319
|
sleep_time = 0.5 * (attempt + 1)
|
|
282
320
|
time.sleep(sleep_time)
|
|
283
321
|
|
|
284
|
-
response =
|
|
322
|
+
response = _api_server_get(
|
|
323
|
+
url, data=json_payload, headers=request_headers, conn_error_retries=5
|
|
324
|
+
)
|
|
285
325
|
self._handle_error_response(response)
|
|
286
326
|
return response.json()
|
|
287
327
|
|
|
@@ -27,9 +27,12 @@ class SnowparkClient(object):
|
|
|
27
27
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
28
28
|
raise SnowflakeException(
|
|
29
29
|
"Could not import module 'snowflake'.\n\nInstall Snowflake "
|
|
30
|
-
"Python
|
|
31
|
-
"
|
|
32
|
-
"
|
|
30
|
+
"Python packages first:\n"
|
|
31
|
+
" snowflake==1.8.0\n"
|
|
32
|
+
" snowflake-connector-python==3.18.0\n"
|
|
33
|
+
" snowflake-snowpark-python==1.40.0\n\n"
|
|
34
|
+
"You can install them by executing:\n"
|
|
35
|
+
"%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
|
|
33
36
|
"or equivalent through your favorite Python package manager."
|
|
34
37
|
% sys.executable
|
|
35
38
|
)
|
|
@@ -42,10 +42,13 @@ class Snowflake(object):
|
|
|
42
42
|
return session
|
|
43
43
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
44
44
|
raise SnowflakeException(
|
|
45
|
-
"Could not import module 'snowflake'.\n\
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
45
|
+
"Could not import module 'snowflake'.\n\n"
|
|
46
|
+
"Install required Snowflake packages using the @pypi decorator:\n"
|
|
47
|
+
"@pypi(packages={\n"
|
|
48
|
+
" 'snowflake': '1.8.0',\n"
|
|
49
|
+
" 'snowflake-connector-python': '3.18.0',\n"
|
|
50
|
+
" 'snowflake-snowpark-python': '1.40.0'\n"
|
|
51
|
+
"})\n"
|
|
49
52
|
)
|
|
50
53
|
|
|
51
54
|
|
|
@@ -143,9 +146,12 @@ class SnowparkDecorator(StepDecorator):
|
|
|
143
146
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
144
147
|
raise SnowflakeException(
|
|
145
148
|
"Could not import module 'snowflake'.\n\nInstall Snowflake "
|
|
146
|
-
"Python
|
|
147
|
-
"
|
|
148
|
-
"
|
|
149
|
+
"Python packages first:\n"
|
|
150
|
+
" snowflake==1.8.0\n"
|
|
151
|
+
" snowflake-connector-python==3.18.0\n"
|
|
152
|
+
" snowflake-snowpark-python==1.40.0\n\n"
|
|
153
|
+
"You can install them by executing:\n"
|
|
154
|
+
"%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
|
|
149
155
|
"or equivalent through your favorite Python package manager."
|
|
150
156
|
% sys.executable
|
|
151
157
|
)
|
|
@@ -199,11 +199,17 @@ class RunningJob(object):
|
|
|
199
199
|
|
|
200
200
|
@property
|
|
201
201
|
def status(self):
|
|
202
|
-
|
|
202
|
+
status_list = self.status_obj()
|
|
203
|
+
if not status_list:
|
|
204
|
+
return "UNKNOWN"
|
|
205
|
+
return status_list[0].get("status", "UNKNOWN")
|
|
203
206
|
|
|
204
207
|
@property
|
|
205
208
|
def message(self):
|
|
206
|
-
|
|
209
|
+
status_list = self.status_obj()
|
|
210
|
+
if not status_list:
|
|
211
|
+
return None
|
|
212
|
+
return status_list[0].get("message")
|
|
207
213
|
|
|
208
214
|
@property
|
|
209
215
|
def is_waiting(self):
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from queue import Queue, Empty
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
from typing import Optional, List, Dict
|
|
4
|
+
import subprocess
|
|
5
|
+
import shutil
|
|
6
|
+
import sys
|
|
7
|
+
from metaflow import current
|
|
8
|
+
|
|
9
|
+
__mf_promote_submodules__ = ["plugins.torchtune"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TorchTune:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
use_multi_node_config: bool = False,
|
|
16
|
+
config_overrides: Optional[Dict] = None,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the Tune launcher.
|
|
20
|
+
|
|
21
|
+
:param use_multi_node_config: If True, attempt to build a distributed configuration
|
|
22
|
+
from current.torch.torchrun_args.
|
|
23
|
+
:param config_overrides: Optional dictionary of config overrides for tune run.
|
|
24
|
+
"""
|
|
25
|
+
self.multi_node_config = {}
|
|
26
|
+
if use_multi_node_config:
|
|
27
|
+
if getattr(current, "torch", None):
|
|
28
|
+
print(
|
|
29
|
+
"[Metaflow Tune] Since @torchrun is used, multi-node config can be used to launch the job."
|
|
30
|
+
)
|
|
31
|
+
# For distributed torchtune launches, we use similar parameters as torchrun.
|
|
32
|
+
# (You might need to adjust the keys according to your environment.)
|
|
33
|
+
self.multi_node_config = {
|
|
34
|
+
"nnodes": current.torch.torchrun_args["nnodes"],
|
|
35
|
+
"master_addr": current.torch.torchrun_args["master_addr"],
|
|
36
|
+
"master_port": int(current.torch.torchrun_args["master_port"]),
|
|
37
|
+
"node_rank": current.torch.torchrun_args["node_rank"],
|
|
38
|
+
"nproc_per_node": current.torch.torchrun_args["nproc_per_node"],
|
|
39
|
+
"num_processes": current.torch.torchrun_args["nproc_per_node"]
|
|
40
|
+
* current.torch.torchrun_args["nnodes"],
|
|
41
|
+
}
|
|
42
|
+
if config_overrides:
|
|
43
|
+
self.multi_node_config.update(config_overrides)
|
|
44
|
+
print(
|
|
45
|
+
f"[Metaflow Tune] Discovered multi-node config for torchrun: {self.multi_node_config}"
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
print(
|
|
49
|
+
"[Metaflow Tune] Since @torchrun is not used, default multi-node config cannot be used to launch the job."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def run(
|
|
53
|
+
self,
|
|
54
|
+
recipe: str,
|
|
55
|
+
config_dict: Dict,
|
|
56
|
+
additional_cli_options: Optional[List[str]] = None,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Launch the torchtune job via its CLI.
|
|
60
|
+
|
|
61
|
+
:param recipe: The path to the recipe (or name of the recipe) to run.
|
|
62
|
+
:param config_dict: Optional dictionary that will be dumped to a YAML file and passed via --config.
|
|
63
|
+
:param additional_cli_options: Optional list of additional CLI options.
|
|
64
|
+
:raises: subprocess.CalledProcessError if the subprocess returns a nonzero exit code.
|
|
65
|
+
"""
|
|
66
|
+
import yaml
|
|
67
|
+
import tempfile
|
|
68
|
+
import os
|
|
69
|
+
|
|
70
|
+
_temp_dir = tempfile.mkdtemp()
|
|
71
|
+
try:
|
|
72
|
+
config_path = os.path.join(_temp_dir, "config.yaml")
|
|
73
|
+
with open(config_path, "w") as f:
|
|
74
|
+
yaml.dump(config_dict, f)
|
|
75
|
+
|
|
76
|
+
additional_options = (
|
|
77
|
+
additional_cli_options if additional_cli_options else []
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Build the command. Here we use "tune run" as the base command.
|
|
81
|
+
cmd = ["tune", "run"]
|
|
82
|
+
|
|
83
|
+
# If distributed configuration is present, add torchrun–style flags.
|
|
84
|
+
if self.multi_node_config:
|
|
85
|
+
cmd.extend(
|
|
86
|
+
[
|
|
87
|
+
"--nnodes",
|
|
88
|
+
str(self.multi_node_config.get("nnodes")),
|
|
89
|
+
"--nproc-per-node",
|
|
90
|
+
str(self.multi_node_config.get("nproc_per_node")),
|
|
91
|
+
# "--rdzv_conf", f"rdzv_endpoint={self.multi_node_config.get('master_addr')}:{self.multi_node_config.get('master_port')}"
|
|
92
|
+
"--rdzv-backend",
|
|
93
|
+
"c10d",
|
|
94
|
+
"--rdzv-endpoint",
|
|
95
|
+
f"{self.multi_node_config.get('master_addr')}:{self.multi_node_config.get('master_port')}",
|
|
96
|
+
"--rdzv-id",
|
|
97
|
+
"1234567890",
|
|
98
|
+
"--node-rank",
|
|
99
|
+
str(self.multi_node_config.get("node_rank")),
|
|
100
|
+
# TODO: should there be a masterip/port here ?
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
cmd.extend(additional_options)
|
|
105
|
+
|
|
106
|
+
cmd.append(recipe)
|
|
107
|
+
# If a recipe configuration was provided, pass it via the --config flag.
|
|
108
|
+
cmd.extend(["--config", config_path])
|
|
109
|
+
|
|
110
|
+
# Append any additional CLI options.
|
|
111
|
+
|
|
112
|
+
# Launch the subprocess.
|
|
113
|
+
print(f"[Metaflow tune] {' '.join(cmd)}")
|
|
114
|
+
process = subprocess.Popen(
|
|
115
|
+
cmd,
|
|
116
|
+
stdout=subprocess.PIPE,
|
|
117
|
+
stderr=subprocess.PIPE,
|
|
118
|
+
universal_newlines=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Stream the output in real-time.
|
|
122
|
+
for out_line, err_line in read_popen_pipes(process):
|
|
123
|
+
print(out_line, end="", flush=True)
|
|
124
|
+
print(err_line, end="", file=sys.stderr, flush=True)
|
|
125
|
+
|
|
126
|
+
process.wait()
|
|
127
|
+
if process.returncode != 0:
|
|
128
|
+
raise subprocess.CalledProcessError(process.returncode, cmd)
|
|
129
|
+
finally:
|
|
130
|
+
shutil.rmtree(_temp_dir)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def enqueue_output(file, queue):
|
|
134
|
+
for line in iter(file.readline, ""):
|
|
135
|
+
queue.put(line)
|
|
136
|
+
file.close()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def read_popen_pipes(p):
|
|
140
|
+
|
|
141
|
+
with ThreadPoolExecutor(2) as pool:
|
|
142
|
+
q_stdout, q_stderr = Queue(), Queue()
|
|
143
|
+
|
|
144
|
+
pool.submit(enqueue_output, p.stdout, q_stdout)
|
|
145
|
+
pool.submit(enqueue_output, p.stderr, q_stderr)
|
|
146
|
+
|
|
147
|
+
while True:
|
|
148
|
+
|
|
149
|
+
if p.poll() is not None and q_stdout.empty() and q_stderr.empty():
|
|
150
|
+
break
|
|
151
|
+
|
|
152
|
+
out_line = err_line = ""
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
out_line = q_stdout.get_nowait()
|
|
156
|
+
except Empty:
|
|
157
|
+
pass
|
|
158
|
+
try:
|
|
159
|
+
err_line = q_stderr.get_nowait()
|
|
160
|
+
except Empty:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
yield (out_line, err_line)
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from metaflow.decorators import StepDecorator
|
|
2
|
+
from metaflow import current
|
|
3
|
+
import functools
|
|
4
|
+
from enum import Enum
|
|
5
|
+
import threading
|
|
6
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
|
7
|
+
from metaflow.metaflow_config import from_conf
|
|
8
|
+
|
|
9
|
+
from .vllm_manager import VLLMOpenAIManager, VLLMPyManager
|
|
10
|
+
from .status_card import VLLMStatusCard, CardDecoratorInjector
|
|
11
|
+
|
|
12
|
+
__mf_promote_submodules__ = ["plugins.vllm"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
### The following classes are used to store the vLLM information in the current environment.
|
|
16
|
+
# Then, Metaflow users can access the vLLM information through the current environment.
|
|
17
|
+
class OpenAIAPIInfo:
|
|
18
|
+
def __init__(self, local_endpoint, local_api_key):
|
|
19
|
+
self.local_endpoint = local_endpoint
|
|
20
|
+
self.local_api_key = local_api_key
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class VLLM:
|
|
24
|
+
def __init__(self, llm):
|
|
25
|
+
self.llm = llm
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class VLLMDecorator(StepDecorator, CardDecoratorInjector):
|
|
29
|
+
"""
|
|
30
|
+
This decorator is used to run vllm APIs as Metaflow task sidecars.
|
|
31
|
+
|
|
32
|
+
User code call
|
|
33
|
+
--------------
|
|
34
|
+
@vllm(
|
|
35
|
+
model="...",
|
|
36
|
+
...
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
Valid backend options
|
|
40
|
+
---------------------
|
|
41
|
+
- 'local': Run as a separate process on the local task machine.
|
|
42
|
+
|
|
43
|
+
Valid model options
|
|
44
|
+
-------------------
|
|
45
|
+
Any HuggingFace model identifier, e.g. 'meta-llama/Llama-3.2-1B'
|
|
46
|
+
|
|
47
|
+
NOTE: vLLM's OpenAI-compatible server serves ONE model per server instance.
|
|
48
|
+
If you need multiple models, you must create multiple @vllm decorators.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
model: str
|
|
53
|
+
HuggingFace model identifier to be served by vLLM.
|
|
54
|
+
backend: str
|
|
55
|
+
Determines where and how to run the vLLM process.
|
|
56
|
+
openai_api_server: bool
|
|
57
|
+
Whether to use OpenAI-compatible API server mode (subprocess) instead of native engine.
|
|
58
|
+
Default is False (uses native engine).
|
|
59
|
+
Set to True for backward compatibility with existing code.
|
|
60
|
+
debug: bool
|
|
61
|
+
Whether to turn on verbose debugging logs.
|
|
62
|
+
card_refresh_interval: int
|
|
63
|
+
Interval in seconds for refreshing the vLLM status card.
|
|
64
|
+
Only used when openai_api_server=True.
|
|
65
|
+
max_retries: int
|
|
66
|
+
Maximum number of retries checking for vLLM server startup.
|
|
67
|
+
Only used when openai_api_server=True.
|
|
68
|
+
retry_alert_frequency: int
|
|
69
|
+
Frequency of alert logs for vLLM server startup retries.
|
|
70
|
+
Only used when openai_api_server=True.
|
|
71
|
+
engine_args : dict
|
|
72
|
+
Additional keyword arguments to pass to the vLLM engine.
|
|
73
|
+
For example, `tensor_parallel_size=2`.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
name = "vllm"
|
|
77
|
+
defaults = {
|
|
78
|
+
"model": None,
|
|
79
|
+
"backend": "local",
|
|
80
|
+
"openai_api_server": False, # Default to native engine
|
|
81
|
+
"debug": False,
|
|
82
|
+
"stream_logs_to_card": False,
|
|
83
|
+
"card_refresh_interval": 10,
|
|
84
|
+
"max_retries": 60,
|
|
85
|
+
"retry_alert_frequency": 5,
|
|
86
|
+
"engine_args": {},
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
def step_init(
|
|
90
|
+
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
|
|
91
|
+
):
|
|
92
|
+
super().step_init(
|
|
93
|
+
flow, graph, step_name, decorators, environment, flow_datastore, logger
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Validate that a model is specified
|
|
97
|
+
if not self.attributes["model"]:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"@vllm decorator on step '{step_name}' requires a 'model' parameter. "
|
|
100
|
+
f"Example: @vllm(model='meta-llama/Llama-3.2-1B')"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Attach the vllm status card only for API server mode
|
|
104
|
+
if self.attributes["openai_api_server"]:
|
|
105
|
+
self.attach_card_decorator(
|
|
106
|
+
flow,
|
|
107
|
+
step_name,
|
|
108
|
+
"vllm_status",
|
|
109
|
+
"blank",
|
|
110
|
+
refresh_interval=self.attributes["card_refresh_interval"],
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def task_decorate(
|
|
114
|
+
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
|
|
115
|
+
):
|
|
116
|
+
@functools.wraps(step_func)
|
|
117
|
+
def vllm_wrapper():
|
|
118
|
+
# FIXME: Kind of ugly branch. Causing branching elsewhere.
|
|
119
|
+
# Other possibile code paths:
|
|
120
|
+
# - OpenAI batch API
|
|
121
|
+
# - Embedding
|
|
122
|
+
# - Special types of models
|
|
123
|
+
if self.attributes["openai_api_server"]:
|
|
124
|
+
# API Server mode (existing functionality)
|
|
125
|
+
self._run_api_server_mode(step_func)
|
|
126
|
+
else:
|
|
127
|
+
# Native engine mode (new functionality)
|
|
128
|
+
self._run_native_engine_mode(step_func)
|
|
129
|
+
|
|
130
|
+
return vllm_wrapper
|
|
131
|
+
|
|
132
|
+
def _run_api_server_mode(self, step_func):
|
|
133
|
+
"""Run vLLM in API server mode (subprocess, existing functionality)"""
|
|
134
|
+
self.vllm_manager = None
|
|
135
|
+
self.status_card = None
|
|
136
|
+
self.card_monitor_thread = None
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
self.status_card = VLLMStatusCard(
|
|
140
|
+
refresh_interval=self.attributes["card_refresh_interval"]
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def monitor_card():
|
|
144
|
+
try:
|
|
145
|
+
self.status_card.on_startup(current.card["vllm_status"])
|
|
146
|
+
|
|
147
|
+
while not getattr(self.card_monitor_thread, "_stop_event", False):
|
|
148
|
+
try:
|
|
149
|
+
self.status_card.on_update(
|
|
150
|
+
current.card["vllm_status"], None
|
|
151
|
+
)
|
|
152
|
+
import time
|
|
153
|
+
|
|
154
|
+
time.sleep(self.attributes["card_refresh_interval"])
|
|
155
|
+
except Exception as e:
|
|
156
|
+
if self.attributes["debug"]:
|
|
157
|
+
print(f"[@vllm] Card monitoring error: {e}")
|
|
158
|
+
break
|
|
159
|
+
except Exception as e:
|
|
160
|
+
if self.attributes["debug"]:
|
|
161
|
+
print(f"[@vllm] Card monitor thread error: {e}")
|
|
162
|
+
self.status_card.on_error(current.card["vllm_status"], str(e))
|
|
163
|
+
|
|
164
|
+
self.card_monitor_thread = threading.Thread(
|
|
165
|
+
target=monitor_card, daemon=True
|
|
166
|
+
)
|
|
167
|
+
self.card_monitor_thread._stop_event = False
|
|
168
|
+
self.card_monitor_thread.start()
|
|
169
|
+
self.vllm_manager = VLLMOpenAIManager(
|
|
170
|
+
model=self.attributes["model"],
|
|
171
|
+
backend=self.attributes["backend"],
|
|
172
|
+
debug=self.attributes["debug"],
|
|
173
|
+
status_card=self.status_card,
|
|
174
|
+
max_retries=self.attributes["max_retries"],
|
|
175
|
+
retry_alert_frequency=self.attributes["retry_alert_frequency"],
|
|
176
|
+
stream_logs_to_card=self.attributes["stream_logs_to_card"],
|
|
177
|
+
**self.attributes["engine_args"],
|
|
178
|
+
)
|
|
179
|
+
current._update_env(
|
|
180
|
+
dict(
|
|
181
|
+
vllm=OpenAIAPIInfo(
|
|
182
|
+
local_endpoint=f"http://127.0.0.1:{self.vllm_manager.port}/v1",
|
|
183
|
+
local_api_key="token123",
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if self.attributes["debug"]:
|
|
189
|
+
print("[@vllm] API server mode initialized.")
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
if self.status_card:
|
|
193
|
+
self.status_card.add_event("error", f"Initialization failed: {str(e)}")
|
|
194
|
+
try:
|
|
195
|
+
self.status_card.on_error(current.card["vllm_status"], str(e))
|
|
196
|
+
except:
|
|
197
|
+
pass
|
|
198
|
+
print(f"[@vllm] Error initializing API server mode: {e}")
|
|
199
|
+
raise
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
if self.status_card:
|
|
203
|
+
self.status_card.add_event("info", "Starting user step function")
|
|
204
|
+
step_func()
|
|
205
|
+
if self.status_card:
|
|
206
|
+
self.status_card.add_event(
|
|
207
|
+
"success", "User step function completed successfully"
|
|
208
|
+
)
|
|
209
|
+
finally:
|
|
210
|
+
if self.vllm_manager:
|
|
211
|
+
self.vllm_manager.terminate_models()
|
|
212
|
+
|
|
213
|
+
if self.card_monitor_thread and self.status_card:
|
|
214
|
+
import time
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
self.status_card.on_update(current.card["vllm_status"], None)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
if self.attributes["debug"]:
|
|
220
|
+
print(f"[@vllm] Final card update error: {e}")
|
|
221
|
+
time.sleep(2)
|
|
222
|
+
|
|
223
|
+
if self.card_monitor_thread:
|
|
224
|
+
self.card_monitor_thread._stop_event = True
|
|
225
|
+
self.card_monitor_thread.join(timeout=5)
|
|
226
|
+
if self.attributes["debug"]:
|
|
227
|
+
print("[@vllm] Card monitoring thread stopped.")
|
|
228
|
+
|
|
229
|
+
def _run_native_engine_mode(self, step_func):
|
|
230
|
+
"""Run vLLM in native engine mode (direct LLM API access)"""
|
|
231
|
+
self.vllm = None
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
if self.attributes["debug"]:
|
|
235
|
+
print("[@vllm] Initializing native engine mode")
|
|
236
|
+
|
|
237
|
+
self.vllm = VLLMPyManager(
|
|
238
|
+
model=self.attributes["model"],
|
|
239
|
+
debug=self.attributes["debug"],
|
|
240
|
+
**self.attributes["engine_args"],
|
|
241
|
+
)
|
|
242
|
+
current._update_env(dict(vllm=VLLM(llm=self.vllm.engine)))
|
|
243
|
+
|
|
244
|
+
if self.attributes["debug"]:
|
|
245
|
+
print("[@vllm] Native engine mode initialized.")
|
|
246
|
+
|
|
247
|
+
except Exception as e:
|
|
248
|
+
print(f"[@vllm] Error initializing native engine mode: {e}")
|
|
249
|
+
raise
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
step_func()
|
|
253
|
+
finally:
|
|
254
|
+
if self.vllm:
|
|
255
|
+
self.vllm.terminate_engine()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
VLLM_SUFFIX = "mf.vllm"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from metaflow.exception import MetaflowException
|