ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
1
+ import os
2
+ import time
3
+ from metaflow import current
4
+ from .utils import get_storage_path, NIM_MONITOR_LOCAL_STORAGE_ROOT
5
+ from .nim_manager import NimManager
6
+ from metaflow.decorators import StepDecorator
7
+ from .card import NimMetricsRefresher
8
+ from ..card_utilities.injector import CardDecoratorInjector
9
+ from ..card_utilities.async_cards import AsyncPeriodicRefresher
10
+
11
+
12
+ class NimDecorator(StepDecorator, CardDecoratorInjector):
13
+ name = "nim"
14
+
15
+ defaults = {
16
+ "models": [],
17
+ "monitor": True,
18
+ "persist_db": False,
19
+ }
20
+
21
+ # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
22
+ # to understand where these functions are invoked in the lifecycle of a
23
+ # Metaflow flow.
24
+ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
25
+ if self.attributes["monitor"]:
26
+ self.attach_card_decorator(
27
+ flow,
28
+ step,
29
+ NimMetricsRefresher.CARD_ID,
30
+ "blank",
31
+ refresh_interval=4.0,
32
+ )
33
+
34
+ current._update_env(
35
+ {
36
+ "nim": NimManager(
37
+ models=self.attributes["models"],
38
+ flow=flow,
39
+ step_name=step,
40
+ monitor=self.attributes["monitor"],
41
+ )
42
+ }
43
+ )
44
+
45
+ def task_decorate(
46
+ self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
47
+ ):
48
+ if self.attributes["monitor"]:
49
+ import sqlite3
50
+
51
+ file_path = get_storage_path(current.task_id)
52
+ if os.path.exists(file_path):
53
+ os.remove(file_path)
54
+ os.makedirs(NIM_MONITOR_LOCAL_STORAGE_ROOT, exist_ok=True)
55
+ conn = sqlite3.connect(file_path)
56
+
57
+ cursor = conn.cursor()
58
+ cursor.execute(
59
+ """
60
+ CREATE TABLE metrics (
61
+ error INTEGER,
62
+ success INTEGER,
63
+ status_code INTEGER,
64
+ prompt_tokens INTEGER,
65
+ completion_tokens INTEGER,
66
+ e2e_time NUMERIC,
67
+ model TEXT
68
+ )
69
+ """
70
+ )
71
+
72
+ def _wrapped_step_func(*args, **kwargs):
73
+ async_refresher_metrics = AsyncPeriodicRefresher(
74
+ NimMetricsRefresher(),
75
+ updater_interval=4.0,
76
+ collector_interval=2.0,
77
+ file_name=file_path,
78
+ )
79
+ try:
80
+ async_refresher_metrics.start()
81
+ return step_func(*args, **kwargs)
82
+ finally:
83
+ time.sleep(5.0) # buffer for the last update to synchronize
84
+ async_refresher_metrics.stop()
85
+
86
+ return _wrapped_step_func
87
+ else:
88
+ return step_func
89
+
90
+ def task_post_step(
91
+ self, step_name, flow, graph, retry_count, max_user_code_retries
92
+ ):
93
+ if not self.attributes["persist_db"]:
94
+ import shutil
95
+
96
+ file_path = get_storage_path(current.task_id)
97
+ if os.path.exists(file_path):
98
+ os.remove(file_path)
99
+ # if this task is the last one, delete the whole enchilada.
100
+ if not os.listdir(NIM_MONITOR_LOCAL_STORAGE_ROOT):
101
+ shutil.rmtree(NIM_MONITOR_LOCAL_STORAGE_ROOT, ignore_errors=True)
@@ -0,0 +1,379 @@
1
+ import sys
2
+ import time
3
+ import requests
4
+ import sqlite3
5
+ from urllib3.util.retry import Retry
6
+ from requests.adapters import HTTPAdapter
7
+ from typing import Dict, Optional, Any
8
+ from .utils import get_ngc_response, get_storage_path
9
+
10
+
11
+ def nvcf_submit_helper(
12
+ url: str,
13
+ payload: Dict[str, Any],
14
+ headers: Optional[Dict[str, str]] = None,
15
+ timeout: int = 30,
16
+ max_retries: int = 300,
17
+ backoff_factor: float = 0.3,
18
+ request_delay: float = 1.1,
19
+ log_callback: Optional[callable] = None,
20
+ ) -> Dict[str, Any]:
21
+ def _log_error(start_time: float, status_code: int, poll_count: int):
22
+ if log_callback:
23
+ end_time = time.time()
24
+ try:
25
+ log_callback({}, end_time - start_time, status_code, poll_count)
26
+ except Exception as log_error:
27
+ print(f"Warning: Logging callback failed: {log_error}")
28
+
29
+ # use default headers
30
+ if not headers:
31
+ headers = {"accept": "application/json", "content-type": "application/json"}
32
+ print(f"Using Default Headers: {headers}")
33
+
34
+ # Configure session with retry strategy
35
+ session = requests.Session()
36
+ status_forcelist = [429, 500, 502, 503, 504, 404]
37
+ retry_strategy = Retry(
38
+ total=max_retries,
39
+ backoff_factor=backoff_factor,
40
+ status_forcelist=status_forcelist,
41
+ allowed_methods=["GET", "POST"],
42
+ )
43
+ adapter = HTTPAdapter(max_retries=retry_strategy)
44
+ session.mount("http://", adapter)
45
+ session.mount("https://", adapter)
46
+
47
+ # Add artificial delay if specified
48
+ time.sleep(request_delay)
49
+
50
+ start_time = time.time()
51
+ poll_count = 0
52
+ status_code = 0
53
+ response_data = {}
54
+
55
+ try:
56
+ # Make initial request
57
+ response = session.post(url, json=payload, headers=headers, timeout=timeout)
58
+ time.sleep(request_delay)
59
+
60
+ # Handle initial response
61
+ response.raise_for_status()
62
+ request_id = response.headers.get("NVCF-REQID")
63
+ polling_url = f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/{request_id}"
64
+
65
+ print(f"Polling NVCF Request ID: {request_id}")
66
+
67
+ # Initial response status
68
+ status_code = response.status_code
69
+ print(f"Initial response status: {status_code}")
70
+
71
+ # Create a variable to store the final response
72
+ final_response = response
73
+
74
+ # Continue polling while we get 202 (Accepted/Processing)
75
+ while status_code == 202:
76
+ poll_count += 1
77
+ print(f"Polling attempt #{poll_count} to {polling_url}")
78
+
79
+ # Wait before next poll
80
+ time.sleep(request_delay)
81
+
82
+ # Make a new poll request
83
+ poll_response = session.get(polling_url, headers=headers, timeout=timeout)
84
+ status_code = poll_response.status_code
85
+ print(f"Poll #{poll_count} status: {status_code}")
86
+
87
+ # Check for errors
88
+ try:
89
+ poll_response.raise_for_status()
90
+ except requests.exceptions.HTTPError as e:
91
+ print(f"Poll request failed: {str(e)}")
92
+ poll_response.close()
93
+ # Log the error before re-raising
94
+ _log_error(start_time, poll_response.status_code, poll_count)
95
+ raise
96
+
97
+ # If status is 200, the job is complete
98
+ if status_code == 200:
99
+ print("Polling complete - job finished successfully")
100
+ # Update our final response to be this poll response
101
+ final_response = poll_response
102
+ break
103
+
104
+ # Close this poll response if we're going to loop again
105
+ if status_code == 202:
106
+ poll_response.close()
107
+
108
+ # If we exited the loop without a 200 status, something went wrong
109
+ if status_code != 200:
110
+ print(f"Polling ended with unexpected status: {status_code}")
111
+ # Log the error before raising
112
+ _log_error(start_time, status_code, poll_count)
113
+ raise Exception(f"Unexpected status code after polling: {status_code}")
114
+
115
+ # Get the response data for logging
116
+ response_data = final_response.json()
117
+
118
+ except requests.exceptions.HTTPError as e:
119
+ # Handle HTTP errors (4xx, 5xx status codes)
120
+ status_code = e.response.status_code if e.response else 0
121
+ print(f"HTTP Error: {str(e)}", file=sys.stderr)
122
+ # Log the error
123
+ _log_error(start_time, status_code, poll_count)
124
+ raise
125
+
126
+ except Exception as e:
127
+ # Handle other errors (connection errors, timeouts, etc.)
128
+ print(f"Request Error: {str(e)}", file=sys.stderr)
129
+ # Log the error with status_code 0 to indicate non-HTTP error
130
+ _log_error(start_time, 0, poll_count)
131
+ raise
132
+
133
+ # Calculate final duration and log successful requests
134
+ end_time = time.time()
135
+ duration = end_time - start_time
136
+
137
+ # Call the logging callback if provided
138
+ if log_callback:
139
+ try:
140
+ log_callback(response_data, duration, status_code, poll_count)
141
+ except Exception as e:
142
+ print(f"Warning: Logging callback failed: {e}")
143
+
144
+ # Log metrics
145
+ print(
146
+ f"Request completed: duration={duration:.2f}s, polls={poll_count}, "
147
+ f"status={status_code}, size={len(final_response.content)} bytes"
148
+ )
149
+
150
+ return response_data
151
+
152
+
153
+ class NimMetadata(object):
154
+ def __init__(self):
155
+ self._nvcf_chat_completion_models = []
156
+ ngc_response = get_ngc_response()
157
+
158
+ self.ngc_api_key = ngc_response["nvcf"]["api_key"]
159
+
160
+ for model in ngc_response["nvcf"]["functions"]:
161
+ self._nvcf_chat_completion_models.append(
162
+ {
163
+ "name": model["model_key"],
164
+ "function-id": model["id"],
165
+ "version-id": model["version"],
166
+ }
167
+ )
168
+
169
+ def get_nvcf_chat_completion_models(self):
170
+ return self._nvcf_chat_completion_models
171
+
172
+ def get_headers_for_nvcf_request(self):
173
+ return {
174
+ "accept": "application/json",
175
+ "content-type": "application/json",
176
+ "Authorization": f"Bearer {self.ngc_api_key}",
177
+ "NVCF-POLL-SECONDS": "5",
178
+ }
179
+
180
+
181
+ class NimManager(object):
182
+ def __init__(self, models, flow, step_name, monitor):
183
+ nim_metadata = NimMetadata()
184
+ nvcf_models = [
185
+ m["name"] for m in nim_metadata.get_nvcf_chat_completion_models()
186
+ ]
187
+ self.models = {}
188
+
189
+ # Convert models to a standard format
190
+ standardized_models = []
191
+ # If models is a single string, convert it to a list with a dict
192
+ if isinstance(models, str):
193
+ standardized_models = [{"name": models}]
194
+ # If models is a list, process each item
195
+ elif isinstance(models, list):
196
+ for model_item in models:
197
+ # If the item is a string, convert it to a dict
198
+ if isinstance(model_item, str):
199
+ standardized_models.append({"name": model_item})
200
+ # If it's already a dict, use it as is
201
+ elif isinstance(model_item, dict):
202
+ standardized_models.append(model_item)
203
+ else:
204
+ raise ValueError(
205
+ f"Model specification must be a string or dictionary, got {type(model_item)}"
206
+ )
207
+ else:
208
+ raise ValueError(
209
+ f"Models must be a string or a list of strings/dictionaries, got {type(models)}"
210
+ )
211
+
212
+ # Process each standardized model
213
+ for each_model_dict in standardized_models:
214
+ model_name = each_model_dict.get("name", "")
215
+ nvcf_id = each_model_dict.get("nvcf_id", "")
216
+ nvcf_version = each_model_dict.get("nvcf_version", "")
217
+
218
+ if model_name and not (nvcf_id and nvcf_version):
219
+ if model_name in nvcf_models:
220
+ self.models[model_name] = NimChatCompletion(
221
+ model=model_name,
222
+ nvcf_id=nvcf_id,
223
+ nvcf_version=nvcf_version,
224
+ nim_metadata=nim_metadata,
225
+ monitor=monitor,
226
+ )
227
+ else:
228
+ raise ValueError(
229
+ f"Model {model_name} not supported by the Outerbounds @nim offering."
230
+ f"\nYou can choose from these options: {nvcf_models}\n\n"
231
+ "Reach out to Outerbounds if there are other models you'd like supported."
232
+ )
233
+ elif nvcf_id and nvcf_version:
234
+ self.models[model_name] = NimChatCompletion(
235
+ model=model_name,
236
+ nvcf_id=nvcf_id,
237
+ nvcf_version=nvcf_version,
238
+ nim_metadata=nim_metadata,
239
+ monitor=monitor,
240
+ )
241
+ else:
242
+ raise ValueError(
243
+ "You must provide either a valid 'name' or a custom 'name' along with both 'nvcf_id' and 'nvcf_version'."
244
+ )
245
+
246
+
247
+ class NimChatCompletion(object):
248
+ def __init__(
249
+ self,
250
+ model: str = "meta/llama3-8b-instruct",
251
+ nvcf_id: str = "",
252
+ nvcf_version: str = "",
253
+ nim_metadata: NimMetadata = None,
254
+ monitor: bool = False,
255
+ **kwargs,
256
+ ):
257
+ if nim_metadata is None:
258
+ raise ValueError(
259
+ "NimMetadata object is required to initialize NimChatCompletion object."
260
+ )
261
+
262
+ self.model_name = model
263
+ self.nim_metadata = nim_metadata
264
+ self.monitor = monitor
265
+ all_nvcf_models = self.nim_metadata.get_nvcf_chat_completion_models()
266
+
267
+ if nvcf_id and nvcf_version:
268
+ matching_models = [
269
+ m
270
+ for m in all_nvcf_models
271
+ if m["function-id"] == nvcf_id and m["version-id"] == nvcf_version
272
+ ]
273
+ if matching_models:
274
+ self.model = matching_models[0]
275
+ self.function_id = self.model["function-id"]
276
+ self.version_id = self.model["version-id"]
277
+ self.model_name = self.model["name"]
278
+ else:
279
+ raise ValueError(
280
+ f"Function {self.function_id} with version {self.version_id} not found on NVCF"
281
+ )
282
+ else:
283
+ all_nvcf_model_names = [m["name"] for m in all_nvcf_models]
284
+
285
+ if self.model_name not in all_nvcf_model_names:
286
+ raise ValueError(
287
+ f"Model {self.model_name} not found in available NVCF models"
288
+ )
289
+
290
+ self.model = all_nvcf_models[all_nvcf_model_names.index(self.model_name)]
291
+ self.function_id = self.model["function-id"]
292
+ self.version_id = self.model["version-id"]
293
+
294
+ self.first_request = True
295
+
296
+ def log_stats(self, response_data, duration, status_code, poll_count):
297
+ if not self.monitor:
298
+ return
299
+
300
+ stats = {
301
+ "status_code": status_code,
302
+ "success": 1 if status_code == 200 else 0,
303
+ "error": 0 if status_code == 200 else 1,
304
+ "e2e_time": duration,
305
+ "model": self.model_name,
306
+ "poll_count": poll_count,
307
+ }
308
+
309
+ if status_code == 200 and response_data:
310
+ try:
311
+ stats["prompt_tokens"] = response_data["usage"]["prompt_tokens"]
312
+ except (KeyError, TypeError):
313
+ stats["prompt_tokens"] = None
314
+
315
+ try:
316
+ stats["completion_tokens"] = response_data["usage"]["completion_tokens"]
317
+ except (KeyError, TypeError):
318
+ stats["completion_tokens"] = None
319
+ else:
320
+ stats["prompt_tokens"] = None
321
+ stats["completion_tokens"] = None
322
+
323
+ conn = sqlite3.connect(self.file_name)
324
+ cursor = conn.cursor()
325
+
326
+ try:
327
+ cursor.execute(
328
+ """
329
+ INSERT INTO metrics (error, success, status_code, prompt_tokens, completion_tokens, e2e_time, model)
330
+ VALUES (?, ?, ?, ?, ?, ?, ?)
331
+ """,
332
+ (
333
+ stats["error"],
334
+ stats["success"],
335
+ stats["status_code"],
336
+ stats["prompt_tokens"],
337
+ stats["completion_tokens"],
338
+ stats["e2e_time"],
339
+ stats["model"],
340
+ ),
341
+ )
342
+ conn.commit()
343
+ finally:
344
+ conn.close()
345
+
346
+ def __call__(self, **kwargs):
347
+ if self.first_request:
348
+ from metaflow import current
349
+
350
+ self.file_name = get_storage_path(current.task_id)
351
+ self.first_request = False
352
+
353
+ # Create log callback if monitoring is enabled
354
+ log_callback = self.log_stats if self.monitor else None
355
+
356
+ request_data = {"model": self.model_name, **kwargs}
357
+ request_url = (
358
+ f"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/{self.function_id}"
359
+ )
360
+
361
+ try:
362
+ response_data = nvcf_submit_helper(
363
+ url=request_url,
364
+ payload=request_data,
365
+ headers=self.nim_metadata.get_headers_for_nvcf_request(),
366
+ log_callback=log_callback,
367
+ )
368
+
369
+ return response_data
370
+
371
+ except requests.exceptions.HTTPError as e:
372
+ error_msg = f"[@nim ERROR] NVCF API request failed: {str(e)}"
373
+ print(error_msg, file=sys.stderr)
374
+ raise
375
+
376
+ except Exception as e:
377
+ error_msg = f"[@nim ERROR] Unexpected error: {str(e)}"
378
+ print(error_msg, file=sys.stderr)
379
+ raise
@@ -0,0 +1,36 @@
1
+ import os
2
+ import json
3
+ import requests
4
+ from urllib.parse import urlparse
5
+ from metaflow.metaflow_config import SERVICE_URL
6
+ from metaflow.metaflow_config_funcs import init_config
7
+
8
+
9
+ NIM_MONITOR_LOCAL_STORAGE_ROOT = ".nim-monitor"
10
+
11
+
12
+ def get_storage_path(task_id):
13
+ return f"{NIM_MONITOR_LOCAL_STORAGE_ROOT}/" + task_id + ".sqlite"
14
+
15
+
16
+ def get_ngc_response():
17
+ conf = init_config()
18
+ if "OBP_AUTH_SERVER" in conf:
19
+ auth_host = conf["OBP_AUTH_SERVER"]
20
+ else:
21
+ auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
22
+
23
+ # NOTE: reusing the same auth_host as the one used in NimMetadata,
24
+ # however, user should not need to use nim container to use @nvct.
25
+ # May want to refactor this to a common endpoint.
26
+ nim_info_url = "https://" + auth_host + "/generate/nim"
27
+
28
+ if "METAFLOW_SERVICE_AUTH_KEY" in conf:
29
+ headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
30
+ res = requests.get(nim_info_url, headers=headers)
31
+ else:
32
+ headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
33
+ res = requests.get(nim_info_url, headers=headers)
34
+
35
+ res.raise_for_status()
36
+ return res.json()
@@ -0,0 +1,3 @@
1
+ SUPPORTABLE_GPU_TYPES = ["L40", "L40S", "L40G", "H100", "NEBIUS_H100"]
2
+ DEFAULT_GPU_TYPE = "H100"
3
+ MAX_N_GPU_BY_TYPE = {"L40": 1, "L40S": 1, "L40G": 1, "H100": 4, "NEBIUS_H100": 8}
@@ -0,0 +1,94 @@
1
+ from metaflow.exception import MetaflowException
2
+ from .constants import SUPPORTABLE_GPU_TYPES
3
+
4
+
5
+ class NvcfJobFailedException(MetaflowException):
6
+ headline = "[@nvidia] error"
7
+
8
+ def __init__(self, msg):
9
+ super(NvcfJobFailedException, self).__init__(msg)
10
+
11
+
12
+ class NvcfPollingConnectionError(MetaflowException):
13
+ headline = "[@nvidia] polling error."
14
+
15
+ def __init__(self, og_error_msg):
16
+ msg = (
17
+ "An error occurred while polling the job status. "
18
+ "\n\nOriginal error message: %s" % (og_error_msg)
19
+ )
20
+
21
+ super(NvcfPollingConnectionError, self).__init__(msg)
22
+
23
+
24
+ class RequestedGPUTypeUnavailableException(MetaflowException):
25
+ headline = "[@nvidia RequestedGPUTypeUnavailableException] GPU type unavailable."
26
+
27
+ def __init__(self, requested_gpu_type):
28
+ msg = (
29
+ f"The requested GPU type @nvidia(..., gpu_type='{requested_gpu_type}') is not available. "
30
+ f"Please choose from the following supported GPU types when using @nvidia: {SUPPORTABLE_GPU_TYPES}"
31
+ )
32
+ super(RequestedGPUTypeUnavailableException, self).__init__(msg)
33
+
34
+
35
+ class UnsupportedNvcfConfigurationException(MetaflowException):
36
+ headline = (
37
+ "[@nvidia UnsupportedNvcfConfigurationException] Unsupported GPU configuration"
38
+ )
39
+
40
+ def __init__(self, n_gpu, gpu_type, available_configurations, step):
41
+ msg = f"The requested configuration of @nvidia(gpu={n_gpu}, gpu_type='{gpu_type}') for @step {step} is not available."
42
+ if len(available_configurations) == 0:
43
+ msg += (
44
+ "\n\nNo configurations are available in your Outerbounds deployment."
45
+ " Please contact Outerbounds support if you wish to use @nvidia."
46
+ )
47
+ else:
48
+ msg += f"\n\nAvailable configurations for your deployment include: \n\t- {self._display(available_configurations)}"
49
+ msg += "\n\nPlease contact Outerbounds support if you wish to use a configuration not listed above."
50
+ super(UnsupportedNvcfConfigurationException, self).__init__(msg)
51
+
52
+ def _display(self, configs):
53
+ _available_decos = []
54
+ for cfg in configs:
55
+ n_gpu, gpu_type = cfg[0], cfg[1]
56
+ _available_decos.append(f"@nvidia(gpu={n_gpu}, gpu_type='{gpu_type}')")
57
+ return "\n\t- ".join(_available_decos)
58
+
59
+
60
+ class UnsupportedNvcfDatastoreException(MetaflowException):
61
+ headline = "[@nvidia UnsupportedNvcfDatastoreException] Unsupported datastore"
62
+
63
+ def __init__(self, ds_type):
64
+ msg = (
65
+ "The *@nvidia* decorator requires --datastore=s3 or --datastore=azure or --datastore=gs at the moment."
66
+ f"Current datastore type: {ds_type}."
67
+ )
68
+ super(UnsupportedNvcfDatastoreException, self).__init__(msg)
69
+
70
+
71
+ class NvcfTimeoutTooShortException(MetaflowException):
72
+ headline = "[@nvidia NvcfTimeoutTooShortException] Timeout too short"
73
+
74
+ def __init__(self, step):
75
+ msg = (
76
+ "The timeout for step *{step}* should be at least 60 seconds for "
77
+ "execution with @nvidia".format(step=step)
78
+ )
79
+ super(NvcfTimeoutTooShortException, self).__init__(msg)
80
+
81
+
82
+ class NvcfQueueTimeoutTooShortException(MetaflowException):
83
+ headline = "[@nvidia NvcfQueueTimeoutTooShortException] Queue Timeout too short"
84
+
85
+ def __init__(self, step):
86
+ msg = (
87
+ "The queue timeout for step *{step}* should be at least 60 seconds for "
88
+ "execution with @nvidia".format(step=step)
89
+ )
90
+ super(NvcfQueueTimeoutTooShortException, self).__init__(msg)
91
+
92
+
93
+ class NvcfKilledException(MetaflowException):
94
+ headline = "Nvidia job killed"