ob-metaflow-extensions 1.1.67__py2.py3-none-any.whl → 1.1.69__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -240,3 +240,5 @@ class ObpGcpAuthProvider(object):
240
240
 
241
241
 
242
242
  GCP_CLIENT_PROVIDERS_DESC = [("obp", ".ObpGcpAuthProvider")]
243
+
244
+ FLOW_DECORATORS_DESC = [("nim", ".nim.NimDecorator")]
@@ -0,0 +1,52 @@
1
+ from functools import partial
2
+ from metaflow.decorators import FlowDecorator
3
+ from metaflow import current
4
+ from .nim_manager import NimManager
5
+
6
+
7
+ class NimDecorator(FlowDecorator):
8
+ """
9
+ This decorator is used to run NIM containers in Metaflow tasks as sidecars.
10
+
11
+ User code call
12
+ -----------
13
+ @nim(
14
+ models=['meta/llama3-8b-instruct', 'meta/llama3-70b-instruct'],
15
+ backend='managed'
16
+ )
17
+
18
+ Valid backend options
19
+ ---------------------
20
+ - 'managed': Outerbounds selects a compute provider based on the model.
21
+ - 🚧 'dataplane': Run in your account.
22
+
23
+ Valid model options
24
+ ----------------
25
+ - 'meta/llama3-8b-instruct': 8B parameter model
26
+ - 'meta/llama3-70b-instruct': 70B parameter model
27
+ - Upon request, any model here: https://nvcf.ngc.nvidia.com/functions?filter=nvidia-functions
28
+
29
+ Parameters
30
+ ----------
31
+ models: list[NIM]
32
+ List of NIM containers running models in sidecars.
33
+ backend: str
34
+ Compute provider to run the NIM container.
35
+ """
36
+
37
+ name = "nim"
38
+ defaults = {
39
+ "models": [],
40
+ "backend": "managed",
41
+ }
42
+
43
+ def flow_init(
44
+ self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
45
+ ):
46
+ current._update_env(
47
+ {
48
+ "nim": NimManager(
49
+ models=self.attributes["models"], backend=self.attributes["backend"]
50
+ )
51
+ }
52
+ )
@@ -0,0 +1,189 @@
1
+ import os
2
+ import time
3
+ import json
4
+ import requests
5
+ from urllib.parse import urlparse
6
+ from metaflow.metaflow_config import SERVICE_URL
7
+ from metaflow.metaflow_config_funcs import init_config
8
+
9
+ NVCF_URL = "https://api.nvcf.nvidia.com"
10
+ NVCF_SUBMIT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/functions"
11
+ NVCF_RESULT_ENDPOINT = f"{NVCF_URL}/v2/nvcf/pexec/status"
12
+
13
+ COMMON_HEADERS = {"accept": "application/json", "Content-Type": "application/json"}
14
+ POLL_INTERVAL = 1
15
+
16
+
17
+ class NimMetadata(object):
18
+ def __init__(self):
19
+ self._nvcf_chat_completion_models = []
20
+ self._coreweave_chat_completion_models = []
21
+
22
+ conf = init_config()
23
+
24
+ if "OBP_AUTH_SERVER" in conf:
25
+ auth_host = conf["OBP_AUTH_SERVER"]
26
+ else:
27
+ auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
28
+
29
+ nim_info_url = "https://" + auth_host + "/generate/nim"
30
+
31
+ if "METAFLOW_SERVICE_AUTH_KEY" in conf:
32
+ headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
33
+ res = requests.get(nim_info_url, headers=headers)
34
+ else:
35
+ headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
36
+ res = requests.get(nim_info_url, headers=headers)
37
+
38
+ res.raise_for_status()
39
+ self._ngc_api_key = res.json()["nvcf"]["api_key"]
40
+
41
+ for model in res.json()["nvcf"]["functions"]:
42
+ self._nvcf_chat_completion_models.append(
43
+ {
44
+ "name": model["model_key"],
45
+ "function-id": model["id"],
46
+ "version-id": model["version"],
47
+ }
48
+ )
49
+ for model in res.json()["coreweave"]["containers"]:
50
+ self._coreweave_chat_completion_models.append(
51
+ {"name": model["nim_name"], "ip-address": model["ip_addr"]}
52
+ )
53
+
54
+ def get_nvcf_chat_completion_models(self):
55
+ return self._nvcf_chat_completion_models
56
+
57
+ def get_coreweave_chat_completion_models(self):
58
+ return self._coreweave_chat_completion_models
59
+
60
+ def get_headers_for_nvcf_request(self):
61
+ return {**COMMON_HEADERS, "Authorization": f"Bearer {self._ngc_api_key}"}
62
+
63
+ def get_headers_for_coreweave_request(self):
64
+ return COMMON_HEADERS
65
+
66
+
67
+ class NimManager(object):
68
+ def __init__(self, models, backend):
69
+ nim_metadata = NimMetadata()
70
+ if backend == "managed":
71
+ nvcf_models = [
72
+ m["name"] for m in nim_metadata.get_nvcf_chat_completion_models()
73
+ ]
74
+ cw_models = [
75
+ m["name"] for m in nim_metadata.get_coreweave_chat_completion_models()
76
+ ]
77
+
78
+ self.models = {}
79
+ for m in models:
80
+ if m in nvcf_models:
81
+ self.models[m] = NimChatCompletion(
82
+ model=m, provider="NVCF", nim_metadata=nim_metadata
83
+ )
84
+ elif m in cw_models:
85
+ self.models[m] = NimChatCompletion(
86
+ model=m, provider="CoreWeave", nim_metadata=nim_metadata
87
+ )
88
+ else:
89
+ raise ValueError(
90
+ f"Model {m} not supported by the Outerbounds @nim offering."
91
+ f"\nYou can choose from these options: {nvcf_models + cw_models}\n\n"
92
+ "Reach out to Outerbounds if there are other models you'd like supported."
93
+ )
94
+ else:
95
+ raise ValueError(
96
+ f"Backend {backend} not supported by the Outerbounds @nim offering. Please reach out to Outerbounds."
97
+ )
98
+
99
+
100
+ class NimChatCompletion(object):
101
+ def __init__(
102
+ self,
103
+ model="meta/llama3-8b-instruct",
104
+ provider="CoreWeave",
105
+ nim_metadata=None,
106
+ **kwargs,
107
+ ):
108
+ if nim_metadata is None:
109
+ raise ValueError(
110
+ "NimMetadata object is required to initialize NimChatCompletion object."
111
+ )
112
+
113
+ self._nim_metadata = nim_metadata
114
+ self.compute_provider = provider
115
+ self.invocations = []
116
+
117
+ if self.compute_provider == "CoreWeave":
118
+ cw_model_names = [
119
+ m["name"]
120
+ for m in self._nim_metadata.get_coreweave_chat_completion_models()
121
+ ]
122
+ self.model = model
123
+ self.ip_address = self._nim_metadata.get_coreweave_chat_completion_models()[
124
+ cw_model_names.index(model)
125
+ ]["ip-address"]
126
+ self.endpoint = f"http://{self.ip_address}:8000/v1/chat/completions"
127
+
128
+ elif self.compute_provider == "NVCF":
129
+ nvcf_model_names = [
130
+ m["name"] for m in self._nim_metadata.get_nvcf_chat_completion_models()
131
+ ]
132
+ self.model = model
133
+ self.function_id = self._nim_metadata.get_nvcf_chat_completion_models()[
134
+ nvcf_model_names.index(model)
135
+ ]["function-id"]
136
+ self.version_id = self._nim_metadata.get_nvcf_chat_completion_models()[
137
+ nvcf_model_names.index(model)
138
+ ]["version-id"]
139
+
140
+ def __call__(self, **kwargs):
141
+
142
+ if self.compute_provider == "CoreWeave":
143
+ request_data = {"model": self.model, **kwargs}
144
+ response = requests.post(
145
+ self.endpoint,
146
+ headers=self._nim_metadata.get_headers_for_coreweave_request(),
147
+ json=request_data,
148
+ )
149
+ response.raise_for_status()
150
+ return response.json()
151
+
152
+ elif self.compute_provider == "NVCF":
153
+
154
+ request_data = {"model": self.model, **kwargs}
155
+ request_url = f"{NVCF_SUBMIT_ENDPOINT}/{self.function_id}"
156
+
157
+ response = requests.post(
158
+ request_url,
159
+ headers=self._nim_metadata.get_headers_for_nvcf_request(),
160
+ json=request_data,
161
+ )
162
+ response.raise_for_status()
163
+ if response.status_code == 202:
164
+ invocation_id = response.headers.get("NVCF-REQID")
165
+ self.invocations.append(invocation_id)
166
+ elif response.status_code == 200:
167
+ return response.json()
168
+
169
+ def _poll():
170
+ poll_request_url = f"{NVCF_RESULT_ENDPOINT}/{invocation_id}"
171
+ poll_response = requests.get(
172
+ poll_request_url,
173
+ headers=self._nim_metadata.get_headers_for_nvcf_request(),
174
+ )
175
+ poll_response.raise_for_status()
176
+ if poll_response.status_code == 200:
177
+ return poll_response.json()
178
+ elif poll_response.status_code == 202:
179
+ return 202
180
+ else:
181
+ raise Exception(
182
+ f"NVCF returned {poll_response.status_code} status code. Please contact Outerbounds."
183
+ )
184
+
185
+ while True:
186
+ data = _poll()
187
+ if data and data != 202:
188
+ return data
189
+ time.sleep(POLL_INTERVAL)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.67
3
+ Version: 1.1.69
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -1,11 +1,13 @@
1
1
  metaflow_extensions/outerbounds/__init__.py,sha256=TRGvIUMjkfneWtYUFSWoubu_Kf2ekAL4WLbV3IxOj9k,499
2
2
  metaflow_extensions/outerbounds/remote_config.py,sha256=HPFH4e3ZK3p-wS5HlS75fhR8_2avdD1AHQIZl2KnjeQ,4059
3
3
  metaflow_extensions/outerbounds/config/__init__.py,sha256=mYo95obHU1IE1wbPkeVz_pfTzNqlNabp1QBEMTGllbE,112
4
- metaflow_extensions/outerbounds/plugins/__init__.py,sha256=rK-EPg7107su6dY-yAx1IK3wAzyvQIIg6y4um_F_BXc,9343
4
+ metaflow_extensions/outerbounds/plugins/__init__.py,sha256=oR3krG5x3-W4g1sm5ygNPe9KVLBmxg7KEtzJsonQo_4,9398
5
5
  metaflow_extensions/outerbounds/plugins/auth_server.py,sha256=JhlMFcR7SPSfR1C9w6GlqJq-NYNhOfISmHl2PdkYUok,2212
6
6
  metaflow_extensions/outerbounds/plugins/perimeters.py,sha256=z8tSAkWtiITB-JtSQS7fkhlBwvxSxeTgEwFjahAzv-U,2238
7
7
  metaflow_extensions/outerbounds/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
8
8
  metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py,sha256=gj6Iaz26bGbZm3aQuNS18Mqh_80iJp5PgFwFSlJRcn8,1968
9
+ metaflow_extensions/outerbounds/plugins/nim/__init__.py,sha256=GVnvSTjqYVj5oG2yh8KJFt7iZ33cEadDD5HbdmC9hJ0,1457
10
+ metaflow_extensions/outerbounds/plugins/nim/nim_manager.py,sha256=l8WDfVtsMt7aZaOaeIPT5ySidxfxXU8gmwLoKUP3f04,7044
9
11
  metaflow_extensions/outerbounds/profilers/__init__.py,sha256=wa_jhnCBr82TBxoS0e8b6_6sLyZX0fdHicuGJZNTqKw,29
10
12
  metaflow_extensions/outerbounds/profilers/gpu.py,sha256=a5YZAepujuP0uDqG9UpXBlZS3wjUt4Yv8CjybXqeT2c,24342
11
13
  metaflow_extensions/outerbounds/toplevel/__init__.py,sha256=qWUJSv_r5hXJ7jV_On4nEasKIfUCm6_UjkjXWA_A1Ts,90
@@ -13,7 +15,7 @@ metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,
13
15
  metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py,sha256=WUuhz2YQfI4fz7nIcipwwWq781eaoHEk7n4GAn1npDg,63
14
16
  metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py,sha256=BbZiaH3uILlEZ6ntBLKeNyqn3If8nIXZFq_Apd7Dhco,70
15
17
  metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
16
- ob_metaflow_extensions-1.1.67.dist-info/METADATA,sha256=M02sZ0DJ0S99YYXMEZgfYLah2ETrN7zTqToU3QANTFM,519
17
- ob_metaflow_extensions-1.1.67.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
18
- ob_metaflow_extensions-1.1.67.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
19
- ob_metaflow_extensions-1.1.67.dist-info/RECORD,,
18
+ ob_metaflow_extensions-1.1.69.dist-info/METADATA,sha256=Xw60qLJNjigrBisO_T1GbcUQja8A2r637ntXJy6NpCA,519
19
+ ob_metaflow_extensions-1.1.69.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
20
+ ob_metaflow_extensions-1.1.69.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
21
+ ob_metaflow_extensions-1.1.69.dist-info/RECORD,,