ob-metaflow-extensions 1.1.142__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow_extensions/outerbounds/__init__.py +1 -1
- metaflow_extensions/outerbounds/plugins/__init__.py +26 -5
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
- metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +44 -4
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/remote_config.py +27 -3
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +87 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
- metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
- ob_metaflow_extensions-1.1.142.dist-info/RECORD +0 -64
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
|
@@ -1,47 +1,163 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from .
|
|
6
|
-
from
|
|
1
|
+
import sys
|
|
2
|
+
import time
|
|
3
|
+
import requests
|
|
4
|
+
import sqlite3
|
|
5
|
+
from urllib3.util.retry import Retry
|
|
6
|
+
from requests.adapters import HTTPAdapter
|
|
7
|
+
from typing import Dict, Optional, Any
|
|
8
|
+
from .utils import get_ngc_response, get_storage_path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def nvcf_submit_helper(
|
|
12
|
+
url: str,
|
|
13
|
+
payload: Dict[str, Any],
|
|
14
|
+
headers: Optional[Dict[str, str]] = None,
|
|
15
|
+
timeout: int = 30,
|
|
16
|
+
max_retries: int = 300,
|
|
17
|
+
backoff_factor: float = 0.3,
|
|
18
|
+
request_delay: float = 1.1,
|
|
19
|
+
log_callback: Optional[callable] = None,
|
|
20
|
+
) -> Dict[str, Any]:
|
|
21
|
+
def _log_error(start_time: float, status_code: int, poll_count: int):
|
|
22
|
+
if log_callback:
|
|
23
|
+
end_time = time.time()
|
|
24
|
+
try:
|
|
25
|
+
log_callback({}, end_time - start_time, status_code, poll_count)
|
|
26
|
+
except Exception as log_error:
|
|
27
|
+
print(f"Warning: Logging callback failed: {log_error}")
|
|
28
|
+
|
|
29
|
+
# use default headers
|
|
30
|
+
if not headers:
|
|
31
|
+
headers = {"accept": "application/json", "content-type": "application/json"}
|
|
32
|
+
print(f"Using Default Headers: {headers}")
|
|
33
|
+
|
|
34
|
+
# Configure session with retry strategy
|
|
35
|
+
session = requests.Session()
|
|
36
|
+
status_forcelist = [429, 500, 502, 503, 504, 404]
|
|
37
|
+
retry_strategy = Retry(
|
|
38
|
+
total=max_retries,
|
|
39
|
+
backoff_factor=backoff_factor,
|
|
40
|
+
status_forcelist=status_forcelist,
|
|
41
|
+
allowed_methods=["GET", "POST"],
|
|
42
|
+
)
|
|
43
|
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
|
44
|
+
session.mount("http://", adapter)
|
|
45
|
+
session.mount("https://", adapter)
|
|
46
|
+
|
|
47
|
+
# Add artificial delay if specified
|
|
48
|
+
time.sleep(request_delay)
|
|
49
|
+
|
|
50
|
+
start_time = time.time()
|
|
51
|
+
poll_count = 0
|
|
52
|
+
status_code = 0
|
|
53
|
+
response_data = {}
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
# Make initial request
|
|
57
|
+
response = session.post(url, json=payload, headers=headers, timeout=timeout)
|
|
58
|
+
time.sleep(request_delay)
|
|
59
|
+
|
|
60
|
+
# Handle initial response
|
|
61
|
+
response.raise_for_status()
|
|
62
|
+
request_id = response.headers.get("NVCF-REQID")
|
|
63
|
+
polling_url = f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/{request_id}"
|
|
64
|
+
|
|
65
|
+
print(f"Polling NVCF Request ID: {request_id}")
|
|
66
|
+
|
|
67
|
+
# Initial response status
|
|
68
|
+
status_code = response.status_code
|
|
69
|
+
print(f"Initial response status: {status_code}")
|
|
70
|
+
|
|
71
|
+
# Create a variable to store the final response
|
|
72
|
+
final_response = response
|
|
73
|
+
|
|
74
|
+
# Continue polling while we get 202 (Accepted/Processing)
|
|
75
|
+
while status_code == 202:
|
|
76
|
+
poll_count += 1
|
|
77
|
+
print(f"Polling attempt #{poll_count} to {polling_url}")
|
|
78
|
+
|
|
79
|
+
# Wait before next poll
|
|
80
|
+
time.sleep(request_delay)
|
|
81
|
+
|
|
82
|
+
# Make a new poll request
|
|
83
|
+
poll_response = session.get(polling_url, headers=headers, timeout=timeout)
|
|
84
|
+
status_code = poll_response.status_code
|
|
85
|
+
print(f"Poll #{poll_count} status: {status_code}")
|
|
86
|
+
|
|
87
|
+
# Check for errors
|
|
88
|
+
try:
|
|
89
|
+
poll_response.raise_for_status()
|
|
90
|
+
except requests.exceptions.HTTPError as e:
|
|
91
|
+
print(f"Poll request failed: {str(e)}")
|
|
92
|
+
poll_response.close()
|
|
93
|
+
# Log the error before re-raising
|
|
94
|
+
_log_error(start_time, poll_response.status_code, poll_count)
|
|
95
|
+
raise
|
|
96
|
+
|
|
97
|
+
# If status is 200, the job is complete
|
|
98
|
+
if status_code == 200:
|
|
99
|
+
print("Polling complete - job finished successfully")
|
|
100
|
+
# Update our final response to be this poll response
|
|
101
|
+
final_response = poll_response
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
# Close this poll response if we're going to loop again
|
|
105
|
+
if status_code == 202:
|
|
106
|
+
poll_response.close()
|
|
107
|
+
|
|
108
|
+
# If we exited the loop without a 200 status, something went wrong
|
|
109
|
+
if status_code != 200:
|
|
110
|
+
print(f"Polling ended with unexpected status: {status_code}")
|
|
111
|
+
# Log the error before raising
|
|
112
|
+
_log_error(start_time, status_code, poll_count)
|
|
113
|
+
raise Exception(f"Unexpected status code after polling: {status_code}")
|
|
114
|
+
|
|
115
|
+
# Get the response data for logging
|
|
116
|
+
response_data = final_response.json()
|
|
117
|
+
|
|
118
|
+
except requests.exceptions.HTTPError as e:
|
|
119
|
+
# Handle HTTP errors (4xx, 5xx status codes)
|
|
120
|
+
status_code = e.response.status_code if e.response else 0
|
|
121
|
+
print(f"HTTP Error: {str(e)}", file=sys.stderr)
|
|
122
|
+
# Log the error
|
|
123
|
+
_log_error(start_time, status_code, poll_count)
|
|
124
|
+
raise
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
# Handle other errors (connection errors, timeouts, etc.)
|
|
128
|
+
print(f"Request Error: {str(e)}", file=sys.stderr)
|
|
129
|
+
# Log the error with status_code 0 to indicate non-HTTP error
|
|
130
|
+
_log_error(start_time, 0, poll_count)
|
|
131
|
+
raise
|
|
132
|
+
|
|
133
|
+
# Calculate final duration and log successful requests
|
|
134
|
+
end_time = time.time()
|
|
135
|
+
duration = end_time - start_time
|
|
136
|
+
|
|
137
|
+
# Call the logging callback if provided
|
|
138
|
+
if log_callback:
|
|
139
|
+
try:
|
|
140
|
+
log_callback(response_data, duration, status_code, poll_count)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
print(f"Warning: Logging callback failed: {e}")
|
|
7
143
|
|
|
144
|
+
# Log metrics
|
|
145
|
+
print(
|
|
146
|
+
f"Request completed: duration={duration:.2f}s, polls={poll_count}, "
|
|
147
|
+
f"status={status_code}, size={len(final_response.content)} bytes"
|
|
148
|
+
)
|
|
8
149
|
|
|
9
|
-
|
|
10
|
-
NVCF_SUBMIT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/functions"
|
|
11
|
-
NVCF_RESULT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/status"
|
|
12
|
-
NVCF_POLL_INTERVAL_SECONDS = 1
|
|
13
|
-
COMMON_HEADERS = {
|
|
14
|
-
"accept": "application/json",
|
|
15
|
-
"Content-Type": "application/json",
|
|
16
|
-
"nvcf-feature-enable-gateway-timeout": "true",
|
|
17
|
-
}
|
|
150
|
+
return response_data
|
|
18
151
|
|
|
19
152
|
|
|
20
153
|
class NimMetadata(object):
|
|
21
154
|
def __init__(self):
|
|
22
155
|
self._nvcf_chat_completion_models = []
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
conf = init_config()
|
|
26
|
-
|
|
27
|
-
if "OBP_AUTH_SERVER" in conf:
|
|
28
|
-
auth_host = conf["OBP_AUTH_SERVER"]
|
|
29
|
-
else:
|
|
30
|
-
auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
|
|
31
|
-
|
|
32
|
-
nim_info_url = "https://" + auth_host + "/generate/nim"
|
|
156
|
+
ngc_response = get_ngc_response()
|
|
33
157
|
|
|
34
|
-
|
|
35
|
-
headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
|
|
36
|
-
res = requests.get(nim_info_url, headers=headers)
|
|
37
|
-
else:
|
|
38
|
-
headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
|
|
39
|
-
res = requests.get(nim_info_url, headers=headers)
|
|
40
|
-
|
|
41
|
-
res.raise_for_status()
|
|
42
|
-
self._ngc_api_key = res.json()["nvcf"]["api_key"]
|
|
158
|
+
self.ngc_api_key = ngc_response["nvcf"]["api_key"]
|
|
43
159
|
|
|
44
|
-
for model in
|
|
160
|
+
for model in ngc_response["nvcf"]["functions"]:
|
|
45
161
|
self._nvcf_chat_completion_models.append(
|
|
46
162
|
{
|
|
47
163
|
"name": model["model_key"],
|
|
@@ -49,64 +165,93 @@ class NimMetadata(object):
|
|
|
49
165
|
"version-id": model["version"],
|
|
50
166
|
}
|
|
51
167
|
)
|
|
52
|
-
for model in res.json()["coreweave"]["containers"]:
|
|
53
|
-
self._coreweave_chat_completion_models.append(
|
|
54
|
-
{"name": model["nim_name"], "ip-address": model["ip_addr"]}
|
|
55
|
-
)
|
|
56
168
|
|
|
57
169
|
def get_nvcf_chat_completion_models(self):
|
|
58
170
|
return self._nvcf_chat_completion_models
|
|
59
171
|
|
|
60
172
|
def get_headers_for_nvcf_request(self):
|
|
61
|
-
return {
|
|
173
|
+
return {
|
|
174
|
+
"accept": "application/json",
|
|
175
|
+
"content-type": "application/json",
|
|
176
|
+
"Authorization": f"Bearer {self.ngc_api_key}",
|
|
177
|
+
"NVCF-POLL-SECONDS": "5",
|
|
178
|
+
}
|
|
62
179
|
|
|
63
180
|
|
|
64
181
|
class NimManager(object):
|
|
65
|
-
def __init__(self, models,
|
|
66
|
-
|
|
182
|
+
def __init__(self, models, flow, step_name, monitor):
|
|
67
183
|
nim_metadata = NimMetadata()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
184
|
+
nvcf_models = [
|
|
185
|
+
m["name"] for m in nim_metadata.get_nvcf_chat_completion_models()
|
|
186
|
+
]
|
|
187
|
+
self.models = {}
|
|
188
|
+
|
|
189
|
+
# Convert models to a standard format
|
|
190
|
+
standardized_models = []
|
|
191
|
+
# If models is a single string, convert it to a list with a dict
|
|
192
|
+
if isinstance(models, str):
|
|
193
|
+
standardized_models = [{"name": models}]
|
|
194
|
+
# If models is a list, process each item
|
|
195
|
+
elif isinstance(models, list):
|
|
196
|
+
for model_item in models:
|
|
197
|
+
# If the item is a string, convert it to a dict
|
|
198
|
+
if isinstance(model_item, str):
|
|
199
|
+
standardized_models.append({"name": model_item})
|
|
200
|
+
# If it's already a dict, use it as is
|
|
201
|
+
elif isinstance(model_item, dict):
|
|
202
|
+
standardized_models.append(model_item)
|
|
203
|
+
else:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"Model specification must be a string or dictionary, got {type(model_item)}"
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Models must be a string or a list of strings/dictionaries, got {type(models)}"
|
|
210
|
+
)
|
|
72
211
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
212
|
+
# Process each standardized model
|
|
213
|
+
for each_model_dict in standardized_models:
|
|
214
|
+
model_name = each_model_dict.get("name", "")
|
|
215
|
+
nvcf_id = each_model_dict.get("nvcf_id", "")
|
|
216
|
+
nvcf_version = each_model_dict.get("nvcf_version", "")
|
|
217
|
+
|
|
218
|
+
if model_name and not (nvcf_id and nvcf_version):
|
|
219
|
+
if model_name in nvcf_models:
|
|
220
|
+
self.models[model_name] = NimChatCompletion(
|
|
221
|
+
model=model_name,
|
|
222
|
+
nvcf_id=nvcf_id,
|
|
223
|
+
nvcf_version=nvcf_version,
|
|
79
224
|
nim_metadata=nim_metadata,
|
|
80
225
|
monitor=monitor,
|
|
81
|
-
queue_timeout=queue_timeout,
|
|
82
226
|
)
|
|
83
227
|
else:
|
|
84
228
|
raise ValueError(
|
|
85
|
-
f"Model {
|
|
229
|
+
f"Model {model_name} not supported by the Outerbounds @nim offering."
|
|
86
230
|
f"\nYou can choose from these options: {nvcf_models}\n\n"
|
|
87
231
|
"Reach out to Outerbounds if there are other models you'd like supported."
|
|
88
232
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
233
|
+
elif nvcf_id and nvcf_version:
|
|
234
|
+
self.models[model_name] = NimChatCompletion(
|
|
235
|
+
model=model_name,
|
|
236
|
+
nvcf_id=nvcf_id,
|
|
237
|
+
nvcf_version=nvcf_version,
|
|
238
|
+
nim_metadata=nim_metadata,
|
|
239
|
+
monitor=monitor,
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"You must provide either a valid 'name' or a custom 'name' along with both 'nvcf_id' and 'nvcf_version'."
|
|
244
|
+
)
|
|
100
245
|
|
|
101
246
|
|
|
102
247
|
class NimChatCompletion(object):
|
|
103
248
|
def __init__(
|
|
104
249
|
self,
|
|
105
|
-
model="meta/llama3-8b-instruct",
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
250
|
+
model: str = "meta/llama3-8b-instruct",
|
|
251
|
+
nvcf_id: str = "",
|
|
252
|
+
nvcf_version: str = "",
|
|
253
|
+
nim_metadata: NimMetadata = None,
|
|
254
|
+
monitor: bool = False,
|
|
110
255
|
**kwargs,
|
|
111
256
|
):
|
|
112
257
|
if nim_metadata is None:
|
|
@@ -114,79 +259,70 @@ class NimChatCompletion(object):
|
|
|
114
259
|
"NimMetadata object is required to initialize NimChatCompletion object."
|
|
115
260
|
)
|
|
116
261
|
|
|
117
|
-
self.
|
|
118
|
-
self.
|
|
119
|
-
self.invocations = []
|
|
120
|
-
self.max_request_retries = int(
|
|
121
|
-
os.environ.get("METAFLOW_EXT_HTTP_MAX_RETRIES", "10")
|
|
122
|
-
)
|
|
262
|
+
self.model_name = model
|
|
263
|
+
self.nim_metadata = nim_metadata
|
|
123
264
|
self.monitor = monitor
|
|
265
|
+
all_nvcf_models = self.nim_metadata.get_nvcf_chat_completion_models()
|
|
124
266
|
|
|
125
|
-
if
|
|
126
|
-
|
|
127
|
-
m
|
|
267
|
+
if nvcf_id and nvcf_version:
|
|
268
|
+
matching_models = [
|
|
269
|
+
m
|
|
270
|
+
for m in all_nvcf_models
|
|
271
|
+
if m["function-id"] == nvcf_id and m["version-id"] == nvcf_version
|
|
128
272
|
]
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
273
|
+
if matching_models:
|
|
274
|
+
self.model = matching_models[0]
|
|
275
|
+
self.function_id = self.model["function-id"]
|
|
276
|
+
self.version_id = self.model["version-id"]
|
|
277
|
+
self.model_name = self.model["name"]
|
|
278
|
+
else:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Function {self.function_id} with version {self.version_id} not found on NVCF"
|
|
281
|
+
)
|
|
136
282
|
else:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
283
|
+
all_nvcf_model_names = [m["name"] for m in all_nvcf_models]
|
|
284
|
+
|
|
285
|
+
if self.model_name not in all_nvcf_model_names:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"Model {self.model_name} not found in available NVCF models"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
self.model = all_nvcf_models[all_nvcf_model_names.index(self.model_name)]
|
|
291
|
+
self.function_id = self.model["function-id"]
|
|
292
|
+
self.version_id = self.model["version-id"]
|
|
140
293
|
|
|
141
|
-
# to know whether to set file_name
|
|
142
294
|
self.first_request = True
|
|
143
295
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def _log_stats(self, response, e2e_time):
|
|
168
|
-
stats = {}
|
|
169
|
-
if response.status_code == 200:
|
|
170
|
-
stats["success"] = 1
|
|
171
|
-
stats["error"] = 0
|
|
296
|
+
def log_stats(self, response_data, duration, status_code, poll_count):
|
|
297
|
+
if not self.monitor:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
stats = {
|
|
301
|
+
"status_code": status_code,
|
|
302
|
+
"success": 1 if status_code == 200 else 0,
|
|
303
|
+
"error": 0 if status_code == 200 else 1,
|
|
304
|
+
"e2e_time": duration,
|
|
305
|
+
"model": self.model_name,
|
|
306
|
+
"poll_count": poll_count,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
if status_code == 200 and response_data:
|
|
310
|
+
try:
|
|
311
|
+
stats["prompt_tokens"] = response_data["usage"]["prompt_tokens"]
|
|
312
|
+
except (KeyError, TypeError):
|
|
313
|
+
stats["prompt_tokens"] = None
|
|
314
|
+
|
|
315
|
+
try:
|
|
316
|
+
stats["completion_tokens"] = response_data["usage"]["completion_tokens"]
|
|
317
|
+
except (KeyError, TypeError):
|
|
318
|
+
stats["completion_tokens"] = None
|
|
172
319
|
else:
|
|
173
|
-
stats["success"] = 0
|
|
174
|
-
stats["error"] = 1
|
|
175
|
-
stats["status_code"] = response.status_code
|
|
176
|
-
try:
|
|
177
|
-
stats["prompt_tokens"] = response.json()["usage"]["prompt_tokens"]
|
|
178
|
-
except KeyError:
|
|
179
320
|
stats["prompt_tokens"] = None
|
|
180
|
-
try:
|
|
181
|
-
stats["completion_tokens"] = response.json()["usage"]["completion_tokens"]
|
|
182
|
-
except KeyError:
|
|
183
321
|
stats["completion_tokens"] = None
|
|
184
|
-
stats["e2e_time"] = e2e_time
|
|
185
|
-
stats["provider"] = self.compute_provider
|
|
186
|
-
stats["model"] = self.model
|
|
187
322
|
|
|
188
323
|
conn = sqlite3.connect(self.file_name)
|
|
189
324
|
cursor = conn.cursor()
|
|
325
|
+
|
|
190
326
|
try:
|
|
191
327
|
cursor.execute(
|
|
192
328
|
"""
|
|
@@ -207,112 +343,37 @@ class NimChatCompletion(object):
|
|
|
207
343
|
finally:
|
|
208
344
|
conn.close()
|
|
209
345
|
|
|
210
|
-
@retry_on_status(status_codes=[500], max_retries=3, delay=5)
|
|
211
|
-
@retry_on_status(status_codes=[504])
|
|
212
346
|
def __call__(self, **kwargs):
|
|
213
|
-
|
|
214
347
|
if self.first_request:
|
|
215
|
-
# Put here to guarantee self.file_name is set after task_id exists.
|
|
216
348
|
from metaflow import current
|
|
217
349
|
|
|
218
350
|
self.file_name = get_storage_path(current.task_id)
|
|
351
|
+
self.first_request = False
|
|
219
352
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
"[@nim ERROR] The OpenAI-compatible returned a 400 status code. "
|
|
248
|
-
+ "Known causes include improper requests or prompts with too many tokens for the selected model. "
|
|
249
|
-
+ "Please contact Outerbounds if you need assistance resolving the issue."
|
|
250
|
-
)
|
|
251
|
-
print(msg, file=sys.stderr)
|
|
252
|
-
self._result = {"ERROR": msg}
|
|
253
|
-
return self._result
|
|
254
|
-
except (
|
|
255
|
-
requests.exceptions.ConnectionError,
|
|
256
|
-
requests.exceptions.ReadTimeout,
|
|
257
|
-
) as e:
|
|
258
|
-
# ConnectionErrors are generally temporary errors like DNS resolution failures,
|
|
259
|
-
# timeouts etc.
|
|
260
|
-
print(
|
|
261
|
-
"received error of type {}. Retrying...".format(type(e)),
|
|
262
|
-
e,
|
|
263
|
-
file=sys.stderr,
|
|
264
|
-
)
|
|
265
|
-
time.sleep(retry_delay)
|
|
266
|
-
retry_delay *= 2 # Double the delay for the next attempt
|
|
267
|
-
retry_delay += random.uniform(0, 1) # Add jitter
|
|
268
|
-
retry_delay = min(retry_delay, 10)
|
|
269
|
-
|
|
270
|
-
def _poll():
|
|
271
|
-
poll_request_url = f"{NVCF_RESULT_ENDPOINT}/{invocation_id}"
|
|
272
|
-
attempts = 0
|
|
273
|
-
retry_delay = 1
|
|
274
|
-
while attempts < self.max_request_retries:
|
|
275
|
-
try:
|
|
276
|
-
attempts += 1
|
|
277
|
-
poll_response = requests.get(
|
|
278
|
-
poll_request_url,
|
|
279
|
-
headers=self._nim_metadata.get_headers_for_nvcf_request(),
|
|
280
|
-
)
|
|
281
|
-
if poll_response.status_code == 200:
|
|
282
|
-
tf = time.time()
|
|
283
|
-
self._log_stats(response, tf - t0)
|
|
284
|
-
self._status = JobStatus.SUCCESSFUL
|
|
285
|
-
self._result = poll_response.json()
|
|
286
|
-
return self._result
|
|
287
|
-
elif poll_response.status_code == 202:
|
|
288
|
-
self._status = JobStatus.SUBMITTED
|
|
289
|
-
return 202
|
|
290
|
-
elif poll_response.status_code == 400:
|
|
291
|
-
self._status = JobStatus.FAILED
|
|
292
|
-
msg = (
|
|
293
|
-
"[@nim ERROR] The OpenAI-compatible API returned a 400 status code. "
|
|
294
|
-
+ "Known causes include improper requests or prompts with too many tokens for the selected model. "
|
|
295
|
-
+ "Please contact Outerbounds if you need assistance resolving the issue."
|
|
296
|
-
)
|
|
297
|
-
print(msg, file=sys.stderr)
|
|
298
|
-
self._result = {"@nim ERROR": msg}
|
|
299
|
-
return self._result
|
|
300
|
-
except (
|
|
301
|
-
requests.exceptions.ConnectionError,
|
|
302
|
-
requests.exceptions.ReadTimeout,
|
|
303
|
-
) as e:
|
|
304
|
-
print(
|
|
305
|
-
"received error of type {}. Retrying...".format(type(e)),
|
|
306
|
-
e,
|
|
307
|
-
file=sys.stderr,
|
|
308
|
-
)
|
|
309
|
-
time.sleep(retry_delay)
|
|
310
|
-
retry_delay *= 2 # Double the delay for the next attempt
|
|
311
|
-
retry_delay += random.uniform(0, 1) # Add jitter
|
|
312
|
-
retry_delay = min(retry_delay, 10)
|
|
313
|
-
|
|
314
|
-
while True:
|
|
315
|
-
data = _poll()
|
|
316
|
-
if data and data != 202:
|
|
317
|
-
return data
|
|
318
|
-
time.sleep(NVCF_POLL_INTERVAL_SECONDS)
|
|
353
|
+
# Create log callback if monitoring is enabled
|
|
354
|
+
log_callback = self.log_stats if self.monitor else None
|
|
355
|
+
|
|
356
|
+
request_data = {"model": self.model_name, **kwargs}
|
|
357
|
+
request_url = (
|
|
358
|
+
f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/{self.function_id}"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
response_data = nvcf_submit_helper(
|
|
363
|
+
url=request_url,
|
|
364
|
+
payload=request_data,
|
|
365
|
+
headers=self.nim_metadata.get_headers_for_nvcf_request(),
|
|
366
|
+
log_callback=log_callback,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
return response_data
|
|
370
|
+
|
|
371
|
+
except requests.exceptions.HTTPError as e:
|
|
372
|
+
error_msg = f"[@nim ERROR] NVCF API request failed: {str(e)}"
|
|
373
|
+
print(error_msg, file=sys.stderr)
|
|
374
|
+
raise
|
|
375
|
+
|
|
376
|
+
except Exception as e:
|
|
377
|
+
error_msg = f"[@nim ERROR] Unexpected error: {str(e)}"
|
|
378
|
+
print(error_msg, file=sys.stderr)
|
|
379
|
+
raise
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import requests
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
from metaflow.metaflow_config import SERVICE_URL
|
|
6
|
+
from metaflow.metaflow_config_funcs import init_config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
NIM_MONITOR_LOCAL_STORAGE_ROOT = ".nim-monitor"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_storage_path(task_id):
|
|
13
|
+
return f"{NIM_MONITOR_LOCAL_STORAGE_ROOT}/" + task_id + ".sqlite"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_ngc_response():
|
|
17
|
+
conf = init_config()
|
|
18
|
+
if "OBP_AUTH_SERVER" in conf:
|
|
19
|
+
auth_host = conf["OBP_AUTH_SERVER"]
|
|
20
|
+
else:
|
|
21
|
+
auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
|
|
22
|
+
|
|
23
|
+
# NOTE: reusing the same auth_host as the one used in NimMetadata,
|
|
24
|
+
# however, user should not need to use nim container to use @nvct.
|
|
25
|
+
# May want to refactor this to a common endpoint.
|
|
26
|
+
nim_info_url = "https://" + auth_host + "/generate/nim"
|
|
27
|
+
|
|
28
|
+
if "METAFLOW_SERVICE_AUTH_KEY" in conf:
|
|
29
|
+
headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
|
|
30
|
+
res = requests.get(nim_info_url, headers=headers)
|
|
31
|
+
else:
|
|
32
|
+
headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
|
|
33
|
+
res = requests.get(nim_info_url, headers=headers)
|
|
34
|
+
|
|
35
|
+
res.raise_for_status()
|
|
36
|
+
return res.json()
|
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
SUPPORTABLE_GPU_TYPES = ["L40", "L40S", "L40G", "H100"]
|
|
1
|
+
SUPPORTABLE_GPU_TYPES = ["L40", "L40S", "L40G", "H100", "NEBIUS_H100"]
|
|
2
2
|
DEFAULT_GPU_TYPE = "H100"
|
|
3
|
-
MAX_N_GPU_BY_TYPE = {"L40": 1, "L40S": 1, "L40G": 1, "H100": 4}
|
|
3
|
+
MAX_N_GPU_BY_TYPE = {"L40": 1, "L40S": 1, "L40G": 1, "H100": 4, "NEBIUS_H100": 8}
|