ob-metaflow-extensions 1.1.142__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -1
  2. metaflow_extensions/outerbounds/plugins/__init__.py +26 -5
  3. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  4. metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
  5. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
  6. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  7. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  33. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  34. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  35. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
  36. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  37. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  38. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  39. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  40. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  41. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
  42. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
  43. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
  44. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  45. metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
  46. metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
  47. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
  48. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  49. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
  50. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
  51. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
  52. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  53. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  54. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  55. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  56. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  57. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  58. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  59. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
  60. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  61. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  62. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
  63. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  64. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  65. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  66. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  67. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  68. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  69. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  70. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  71. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  72. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  73. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  74. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
  75. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +44 -4
  76. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
  77. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
  78. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
  79. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  80. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  81. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  82. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  83. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  84. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  85. metaflow_extensions/outerbounds/remote_config.py +27 -3
  86. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +87 -2
  87. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  88. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  89. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  90. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  91. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  92. {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
  93. ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
  94. metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
  95. ob_metaflow_extensions-1.1.142.dist-info/RECORD +0 -64
  96. {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
  97. {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
@@ -1,47 +1,163 @@
1
- import os, sys, time, json, random, requests, sqlite3
2
- from urllib.parse import urlparse
3
- from metaflow.metaflow_config import SERVICE_URL
4
- from metaflow.metaflow_config_funcs import init_config
5
- from .utilities import get_storage_path
6
- from ..nvcf.nvcf import retry_on_status
1
+ import sys
2
+ import time
3
+ import requests
4
+ import sqlite3
5
+ from urllib3.util.retry import Retry
6
+ from requests.adapters import HTTPAdapter
7
+ from typing import Dict, Optional, Any
8
+ from .utils import get_ngc_response, get_storage_path
9
+
10
+
11
+ def nvcf_submit_helper(
12
+ url: str,
13
+ payload: Dict[str, Any],
14
+ headers: Optional[Dict[str, str]] = None,
15
+ timeout: int = 30,
16
+ max_retries: int = 300,
17
+ backoff_factor: float = 0.3,
18
+ request_delay: float = 1.1,
19
+ log_callback: Optional[callable] = None,
20
+ ) -> Dict[str, Any]:
21
+ def _log_error(start_time: float, status_code: int, poll_count: int):
22
+ if log_callback:
23
+ end_time = time.time()
24
+ try:
25
+ log_callback({}, end_time - start_time, status_code, poll_count)
26
+ except Exception as log_error:
27
+ print(f"Warning: Logging callback failed: {log_error}")
28
+
29
+ # use default headers
30
+ if not headers:
31
+ headers = {"accept": "application/json", "content-type": "application/json"}
32
+ print(f"Using Default Headers: {headers}")
33
+
34
+ # Configure session with retry strategy
35
+ session = requests.Session()
36
+ status_forcelist = [429, 500, 502, 503, 504, 404]
37
+ retry_strategy = Retry(
38
+ total=max_retries,
39
+ backoff_factor=backoff_factor,
40
+ status_forcelist=status_forcelist,
41
+ allowed_methods=["GET", "POST"],
42
+ )
43
+ adapter = HTTPAdapter(max_retries=retry_strategy)
44
+ session.mount("http://", adapter)
45
+ session.mount("https://", adapter)
46
+
47
+ # Add artificial delay if specified
48
+ time.sleep(request_delay)
49
+
50
+ start_time = time.time()
51
+ poll_count = 0
52
+ status_code = 0
53
+ response_data = {}
54
+
55
+ try:
56
+ # Make initial request
57
+ response = session.post(url, json=payload, headers=headers, timeout=timeout)
58
+ time.sleep(request_delay)
59
+
60
+ # Handle initial response
61
+ response.raise_for_status()
62
+ request_id = response.headers.get("NVCF-REQID")
63
+ polling_url = f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/{request_id}"
64
+
65
+ print(f"Polling NVCF Request ID: {request_id}")
66
+
67
+ # Initial response status
68
+ status_code = response.status_code
69
+ print(f"Initial response status: {status_code}")
70
+
71
+ # Create a variable to store the final response
72
+ final_response = response
73
+
74
+ # Continue polling while we get 202 (Accepted/Processing)
75
+ while status_code == 202:
76
+ poll_count += 1
77
+ print(f"Polling attempt #{poll_count} to {polling_url}")
78
+
79
+ # Wait before next poll
80
+ time.sleep(request_delay)
81
+
82
+ # Make a new poll request
83
+ poll_response = session.get(polling_url, headers=headers, timeout=timeout)
84
+ status_code = poll_response.status_code
85
+ print(f"Poll #{poll_count} status: {status_code}")
86
+
87
+ # Check for errors
88
+ try:
89
+ poll_response.raise_for_status()
90
+ except requests.exceptions.HTTPError as e:
91
+ print(f"Poll request failed: {str(e)}")
92
+ poll_response.close()
93
+ # Log the error before re-raising
94
+ _log_error(start_time, poll_response.status_code, poll_count)
95
+ raise
96
+
97
+ # If status is 200, the job is complete
98
+ if status_code == 200:
99
+ print("Polling complete - job finished successfully")
100
+ # Update our final response to be this poll response
101
+ final_response = poll_response
102
+ break
103
+
104
+ # Close this poll response if we're going to loop again
105
+ if status_code == 202:
106
+ poll_response.close()
107
+
108
+ # If we exited the loop without a 200 status, something went wrong
109
+ if status_code != 200:
110
+ print(f"Polling ended with unexpected status: {status_code}")
111
+ # Log the error before raising
112
+ _log_error(start_time, status_code, poll_count)
113
+ raise Exception(f"Unexpected status code after polling: {status_code}")
114
+
115
+ # Get the response data for logging
116
+ response_data = final_response.json()
117
+
118
+ except requests.exceptions.HTTPError as e:
119
+ # Handle HTTP errors (4xx, 5xx status codes)
120
+ status_code = e.response.status_code if e.response else 0
121
+ print(f"HTTP Error: {str(e)}", file=sys.stderr)
122
+ # Log the error
123
+ _log_error(start_time, status_code, poll_count)
124
+ raise
125
+
126
+ except Exception as e:
127
+ # Handle other errors (connection errors, timeouts, etc.)
128
+ print(f"Request Error: {str(e)}", file=sys.stderr)
129
+ # Log the error with status_code 0 to indicate non-HTTP error
130
+ _log_error(start_time, 0, poll_count)
131
+ raise
132
+
133
+ # Calculate final duration and log successful requests
134
+ end_time = time.time()
135
+ duration = end_time - start_time
136
+
137
+ # Call the logging callback if provided
138
+ if log_callback:
139
+ try:
140
+ log_callback(response_data, duration, status_code, poll_count)
141
+ except Exception as e:
142
+ print(f"Warning: Logging callback failed: {e}")
7
143
 
144
+ # Log metrics
145
+ print(
146
+ f"Request completed: duration={duration:.2f}s, polls={poll_count}, "
147
+ f"status={status_code}, size={len(final_response.content)} bytes"
148
+ )
8
149
 
9
- NVCF_URL = "https://api.nvcf.nvidia.com"
10
- NVCF_SUBMIT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/functions"
11
- NVCF_RESULT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/status"
12
- NVCF_POLL_INTERVAL_SECONDS = 1
13
- COMMON_HEADERS = {
14
- "accept": "application/json",
15
- "Content-Type": "application/json",
16
- "nvcf-feature-enable-gateway-timeout": "true",
17
- }
150
+ return response_data
18
151
 
19
152
 
20
153
  class NimMetadata(object):
21
154
  def __init__(self):
22
155
  self._nvcf_chat_completion_models = []
23
- self._coreweave_chat_completion_models = []
24
-
25
- conf = init_config()
26
-
27
- if "OBP_AUTH_SERVER" in conf:
28
- auth_host = conf["OBP_AUTH_SERVER"]
29
- else:
30
- auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
31
-
32
- nim_info_url = "https://" + auth_host + "/generate/nim"
156
+ ngc_response = get_ngc_response()
33
157
 
34
- if "METAFLOW_SERVICE_AUTH_KEY" in conf:
35
- headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
36
- res = requests.get(nim_info_url, headers=headers)
37
- else:
38
- headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
39
- res = requests.get(nim_info_url, headers=headers)
40
-
41
- res.raise_for_status()
42
- self._ngc_api_key = res.json()["nvcf"]["api_key"]
158
+ self.ngc_api_key = ngc_response["nvcf"]["api_key"]
43
159
 
44
- for model in res.json()["nvcf"]["functions"]:
160
+ for model in ngc_response["nvcf"]["functions"]:
45
161
  self._nvcf_chat_completion_models.append(
46
162
  {
47
163
  "name": model["model_key"],
@@ -49,64 +165,93 @@ class NimMetadata(object):
49
165
  "version-id": model["version"],
50
166
  }
51
167
  )
52
- for model in res.json()["coreweave"]["containers"]:
53
- self._coreweave_chat_completion_models.append(
54
- {"name": model["nim_name"], "ip-address": model["ip_addr"]}
55
- )
56
168
 
57
169
  def get_nvcf_chat_completion_models(self):
58
170
  return self._nvcf_chat_completion_models
59
171
 
60
172
  def get_headers_for_nvcf_request(self):
61
- return {**COMMON_HEADERS, "Authorization": f"Bearer {self._ngc_api_key}"}
173
+ return {
174
+ "accept": "application/json",
175
+ "content-type": "application/json",
176
+ "Authorization": f"Bearer {self.ngc_api_key}",
177
+ "NVCF-POLL-SECONDS": "5",
178
+ }
62
179
 
63
180
 
64
181
  class NimManager(object):
65
- def __init__(self, models, backend, flow, step_name, monitor, queue_timeout):
66
-
182
+ def __init__(self, models, flow, step_name, monitor):
67
183
  nim_metadata = NimMetadata()
68
- if backend == "managed":
69
- nvcf_models = [
70
- m["name"] for m in nim_metadata.get_nvcf_chat_completion_models()
71
- ]
184
+ nvcf_models = [
185
+ m["name"] for m in nim_metadata.get_nvcf_chat_completion_models()
186
+ ]
187
+ self.models = {}
188
+
189
+ # Convert models to a standard format
190
+ standardized_models = []
191
+ # If models is a single string, convert it to a list with a dict
192
+ if isinstance(models, str):
193
+ standardized_models = [{"name": models}]
194
+ # If models is a list, process each item
195
+ elif isinstance(models, list):
196
+ for model_item in models:
197
+ # If the item is a string, convert it to a dict
198
+ if isinstance(model_item, str):
199
+ standardized_models.append({"name": model_item})
200
+ # If it's already a dict, use it as is
201
+ elif isinstance(model_item, dict):
202
+ standardized_models.append(model_item)
203
+ else:
204
+ raise ValueError(
205
+ f"Model specification must be a string or dictionary, got {type(model_item)}"
206
+ )
207
+ else:
208
+ raise ValueError(
209
+ f"Models must be a string or a list of strings/dictionaries, got {type(models)}"
210
+ )
72
211
 
73
- self.models = {}
74
- for m in models:
75
- if m in nvcf_models:
76
- self.models[m] = NimChatCompletion(
77
- model=m,
78
- provider="NVCF",
212
+ # Process each standardized model
213
+ for each_model_dict in standardized_models:
214
+ model_name = each_model_dict.get("name", "")
215
+ nvcf_id = each_model_dict.get("nvcf_id", "")
216
+ nvcf_version = each_model_dict.get("nvcf_version", "")
217
+
218
+ if model_name and not (nvcf_id and nvcf_version):
219
+ if model_name in nvcf_models:
220
+ self.models[model_name] = NimChatCompletion(
221
+ model=model_name,
222
+ nvcf_id=nvcf_id,
223
+ nvcf_version=nvcf_version,
79
224
  nim_metadata=nim_metadata,
80
225
  monitor=monitor,
81
- queue_timeout=queue_timeout,
82
226
  )
83
227
  else:
84
228
  raise ValueError(
85
- f"Model {m} not supported by the Outerbounds @nim offering."
229
+ f"Model {model_name} not supported by the Outerbounds @nim offering."
86
230
  f"\nYou can choose from these options: {nvcf_models}\n\n"
87
231
  "Reach out to Outerbounds if there are other models you'd like supported."
88
232
  )
89
- else:
90
- raise ValueError(
91
- f"Backend {backend} not supported by the Outerbounds @nim offering. Please reach out to Outerbounds."
92
- )
93
-
94
-
95
- class JobStatus(object):
96
- SUBMITTED = "SUBMITTED"
97
- RUNNING = "RUNNING"
98
- SUCCESSFUL = "SUCCESSFUL"
99
- FAILED = "FAILED"
233
+ elif nvcf_id and nvcf_version:
234
+ self.models[model_name] = NimChatCompletion(
235
+ model=model_name,
236
+ nvcf_id=nvcf_id,
237
+ nvcf_version=nvcf_version,
238
+ nim_metadata=nim_metadata,
239
+ monitor=monitor,
240
+ )
241
+ else:
242
+ raise ValueError(
243
+ "You must provide either a valid 'name' or a custom 'name' along with both 'nvcf_id' and 'nvcf_version'."
244
+ )
100
245
 
101
246
 
102
247
  class NimChatCompletion(object):
103
248
  def __init__(
104
249
  self,
105
- model="meta/llama3-8b-instruct",
106
- provider="NVCF",
107
- nim_metadata=None,
108
- monitor=False,
109
- queue_timeout=None,
250
+ model: str = "meta/llama3-8b-instruct",
251
+ nvcf_id: str = "",
252
+ nvcf_version: str = "",
253
+ nim_metadata: NimMetadata = None,
254
+ monitor: bool = False,
110
255
  **kwargs,
111
256
  ):
112
257
  if nim_metadata is None:
@@ -114,79 +259,70 @@ class NimChatCompletion(object):
114
259
  "NimMetadata object is required to initialize NimChatCompletion object."
115
260
  )
116
261
 
117
- self._nim_metadata = nim_metadata
118
- self.compute_provider = provider
119
- self.invocations = []
120
- self.max_request_retries = int(
121
- os.environ.get("METAFLOW_EXT_HTTP_MAX_RETRIES", "10")
122
- )
262
+ self.model_name = model
263
+ self.nim_metadata = nim_metadata
123
264
  self.monitor = monitor
265
+ all_nvcf_models = self.nim_metadata.get_nvcf_chat_completion_models()
124
266
 
125
- if self.compute_provider == "NVCF":
126
- nvcf_model_names = [
127
- m["name"] for m in self._nim_metadata.get_nvcf_chat_completion_models()
267
+ if nvcf_id and nvcf_version:
268
+ matching_models = [
269
+ m
270
+ for m in all_nvcf_models
271
+ if m["function-id"] == nvcf_id and m["version-id"] == nvcf_version
128
272
  ]
129
- self.model = model
130
- self.function_id = self._nim_metadata.get_nvcf_chat_completion_models()[
131
- nvcf_model_names.index(model)
132
- ]["function-id"]
133
- self.version_id = self._nim_metadata.get_nvcf_chat_completion_models()[
134
- nvcf_model_names.index(model)
135
- ]["version-id"]
273
+ if matching_models:
274
+ self.model = matching_models[0]
275
+ self.function_id = self.model["function-id"]
276
+ self.version_id = self.model["version-id"]
277
+ self.model_name = self.model["name"]
278
+ else:
279
+ raise ValueError(
280
+ f"Function {self.function_id} with version {self.version_id} not found on NVCF"
281
+ )
136
282
  else:
137
- raise ValueError(
138
- f"Backend compute provider {self.compute_provider} not yet supported for @nim."
139
- )
283
+ all_nvcf_model_names = [m["name"] for m in all_nvcf_models]
284
+
285
+ if self.model_name not in all_nvcf_model_names:
286
+ raise ValueError(
287
+ f"Model {self.model_name} not found in available NVCF models"
288
+ )
289
+
290
+ self.model = all_nvcf_models[all_nvcf_model_names.index(self.model_name)]
291
+ self.function_id = self.model["function-id"]
292
+ self.version_id = self.model["version-id"]
140
293
 
141
- # to know whether to set file_name
142
294
  self.first_request = True
143
295
 
144
- # TODO (Eddie) - this may make more sense in a base class.
145
- # @nim arch needs redesign if customers start using it in more creative ways.
146
- self._poll_seconds = "3600"
147
- self._queue_timeout = queue_timeout
148
- self._status = None
149
- self._result = {}
150
-
151
- @property
152
- def status(self):
153
- return self._status
154
-
155
- @property
156
- def has_failed(self):
157
- return self._status == JobStatus.FAILED
158
-
159
- @property
160
- def is_running(self):
161
- return self._status == JobStatus.SUBMITTED
162
-
163
- @property
164
- def result(self):
165
- return self._result
166
-
167
- def _log_stats(self, response, e2e_time):
168
- stats = {}
169
- if response.status_code == 200:
170
- stats["success"] = 1
171
- stats["error"] = 0
296
+ def log_stats(self, response_data, duration, status_code, poll_count):
297
+ if not self.monitor:
298
+ return
299
+
300
+ stats = {
301
+ "status_code": status_code,
302
+ "success": 1 if status_code == 200 else 0,
303
+ "error": 0 if status_code == 200 else 1,
304
+ "e2e_time": duration,
305
+ "model": self.model_name,
306
+ "poll_count": poll_count,
307
+ }
308
+
309
+ if status_code == 200 and response_data:
310
+ try:
311
+ stats["prompt_tokens"] = response_data["usage"]["prompt_tokens"]
312
+ except (KeyError, TypeError):
313
+ stats["prompt_tokens"] = None
314
+
315
+ try:
316
+ stats["completion_tokens"] = response_data["usage"]["completion_tokens"]
317
+ except (KeyError, TypeError):
318
+ stats["completion_tokens"] = None
172
319
  else:
173
- stats["success"] = 0
174
- stats["error"] = 1
175
- stats["status_code"] = response.status_code
176
- try:
177
- stats["prompt_tokens"] = response.json()["usage"]["prompt_tokens"]
178
- except KeyError:
179
320
  stats["prompt_tokens"] = None
180
- try:
181
- stats["completion_tokens"] = response.json()["usage"]["completion_tokens"]
182
- except KeyError:
183
321
  stats["completion_tokens"] = None
184
- stats["e2e_time"] = e2e_time
185
- stats["provider"] = self.compute_provider
186
- stats["model"] = self.model
187
322
 
188
323
  conn = sqlite3.connect(self.file_name)
189
324
  cursor = conn.cursor()
325
+
190
326
  try:
191
327
  cursor.execute(
192
328
  """
@@ -207,112 +343,37 @@ class NimChatCompletion(object):
207
343
  finally:
208
344
  conn.close()
209
345
 
210
- @retry_on_status(status_codes=[500], max_retries=3, delay=5)
211
- @retry_on_status(status_codes=[504])
212
346
  def __call__(self, **kwargs):
213
-
214
347
  if self.first_request:
215
- # Put here to guarantee self.file_name is set after task_id exists.
216
348
  from metaflow import current
217
349
 
218
350
  self.file_name = get_storage_path(current.task_id)
351
+ self.first_request = False
219
352
 
220
- request_data = {"model": self.model, **kwargs}
221
- request_url = f"{NVCF_SUBMIT_ENDPOINT}/{self.function_id}"
222
- retry_delay = 1
223
- attempts = 0
224
- t0 = time.time()
225
- while attempts < self.max_request_retries:
226
- try:
227
- attempts += 1
228
- response = requests.post(
229
- request_url,
230
- headers=self._nim_metadata.get_headers_for_nvcf_request(),
231
- json=request_data,
232
- )
233
- if response.status_code == 202:
234
- invocation_id = response.headers.get("NVCF-REQID")
235
- self.invocations.append(invocation_id)
236
- self._status = JobStatus.SUBMITTED
237
- elif response.status_code == 200:
238
- tf = time.time()
239
- if self.monitor:
240
- self._log_stats(response, tf - t0)
241
- self._status = JobStatus.SUCCESSFUL
242
- self._result = response.json()
243
- return self._result
244
- elif response.status_code == 400:
245
- self._status = JobStatus.FAILED
246
- msg = (
247
- "[@nim ERROR] The OpenAI-compatible returned a 400 status code. "
248
- + "Known causes include improper requests or prompts with too many tokens for the selected model. "
249
- + "Please contact Outerbounds if you need assistance resolving the issue."
250
- )
251
- print(msg, file=sys.stderr)
252
- self._result = {"ERROR": msg}
253
- return self._result
254
- except (
255
- requests.exceptions.ConnectionError,
256
- requests.exceptions.ReadTimeout,
257
- ) as e:
258
- # ConnectionErrors are generally temporary errors like DNS resolution failures,
259
- # timeouts etc.
260
- print(
261
- "received error of type {}. Retrying...".format(type(e)),
262
- e,
263
- file=sys.stderr,
264
- )
265
- time.sleep(retry_delay)
266
- retry_delay *= 2 # Double the delay for the next attempt
267
- retry_delay += random.uniform(0, 1) # Add jitter
268
- retry_delay = min(retry_delay, 10)
269
-
270
- def _poll():
271
- poll_request_url = f"{NVCF_RESULT_ENDPOINT}/{invocation_id}"
272
- attempts = 0
273
- retry_delay = 1
274
- while attempts < self.max_request_retries:
275
- try:
276
- attempts += 1
277
- poll_response = requests.get(
278
- poll_request_url,
279
- headers=self._nim_metadata.get_headers_for_nvcf_request(),
280
- )
281
- if poll_response.status_code == 200:
282
- tf = time.time()
283
- self._log_stats(response, tf - t0)
284
- self._status = JobStatus.SUCCESSFUL
285
- self._result = poll_response.json()
286
- return self._result
287
- elif poll_response.status_code == 202:
288
- self._status = JobStatus.SUBMITTED
289
- return 202
290
- elif poll_response.status_code == 400:
291
- self._status = JobStatus.FAILED
292
- msg = (
293
- "[@nim ERROR] The OpenAI-compatible API returned a 400 status code. "
294
- + "Known causes include improper requests or prompts with too many tokens for the selected model. "
295
- + "Please contact Outerbounds if you need assistance resolving the issue."
296
- )
297
- print(msg, file=sys.stderr)
298
- self._result = {"@nim ERROR": msg}
299
- return self._result
300
- except (
301
- requests.exceptions.ConnectionError,
302
- requests.exceptions.ReadTimeout,
303
- ) as e:
304
- print(
305
- "received error of type {}. Retrying...".format(type(e)),
306
- e,
307
- file=sys.stderr,
308
- )
309
- time.sleep(retry_delay)
310
- retry_delay *= 2 # Double the delay for the next attempt
311
- retry_delay += random.uniform(0, 1) # Add jitter
312
- retry_delay = min(retry_delay, 10)
313
-
314
- while True:
315
- data = _poll()
316
- if data and data != 202:
317
- return data
318
- time.sleep(NVCF_POLL_INTERVAL_SECONDS)
353
+ # Create log callback if monitoring is enabled
354
+ log_callback = self.log_stats if self.monitor else None
355
+
356
+ request_data = {"model": self.model_name, **kwargs}
357
+ request_url = (
358
+ f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/{self.function_id}"
359
+ )
360
+
361
+ try:
362
+ response_data = nvcf_submit_helper(
363
+ url=request_url,
364
+ payload=request_data,
365
+ headers=self.nim_metadata.get_headers_for_nvcf_request(),
366
+ log_callback=log_callback,
367
+ )
368
+
369
+ return response_data
370
+
371
+ except requests.exceptions.HTTPError as e:
372
+ error_msg = f"[@nim ERROR] NVCF API request failed: {str(e)}"
373
+ print(error_msg, file=sys.stderr)
374
+ raise
375
+
376
+ except Exception as e:
377
+ error_msg = f"[@nim ERROR] Unexpected error: {str(e)}"
378
+ print(error_msg, file=sys.stderr)
379
+ raise
@@ -0,0 +1,36 @@
1
+ import os
2
+ import json
3
+ import requests
4
+ from urllib.parse import urlparse
5
+ from metaflow.metaflow_config import SERVICE_URL
6
+ from metaflow.metaflow_config_funcs import init_config
7
+
8
+
9
+ NIM_MONITOR_LOCAL_STORAGE_ROOT = ".nim-monitor"
10
+
11
+
12
+ def get_storage_path(task_id):
13
+ return f"{NIM_MONITOR_LOCAL_STORAGE_ROOT}/" + task_id + ".sqlite"
14
+
15
+
16
+ def get_ngc_response():
17
+ conf = init_config()
18
+ if "OBP_AUTH_SERVER" in conf:
19
+ auth_host = conf["OBP_AUTH_SERVER"]
20
+ else:
21
+ auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
22
+
23
+ # NOTE: reusing the same auth_host as the one used in NimMetadata,
24
+ # however, user should not need to use nim container to use @nvct.
25
+ # May want to refactor this to a common endpoint.
26
+ nim_info_url = "https://" + auth_host + "/generate/nim"
27
+
28
+ if "METAFLOW_SERVICE_AUTH_KEY" in conf:
29
+ headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
30
+ res = requests.get(nim_info_url, headers=headers)
31
+ else:
32
+ headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
33
+ res = requests.get(nim_info_url, headers=headers)
34
+
35
+ res.raise_for_status()
36
+ return res.json()
@@ -1,3 +1,3 @@
1
- SUPPORTABLE_GPU_TYPES = ["L40", "L40S", "L40G", "H100"]
1
+ SUPPORTABLE_GPU_TYPES = ["L40", "L40S", "L40G", "H100", "NEBIUS_H100"]
2
2
  DEFAULT_GPU_TYPE = "H100"
3
- MAX_N_GPU_BY_TYPE = {"L40": 1, "L40S": 1, "L40G": 1, "H100": 4}
3
+ MAX_N_GPU_BY_TYPE = {"L40": 1, "L40S": 1, "L40G": 1, "H100": 4, "NEBIUS_H100": 8}