ob-metaflow-extensions 1.1.130__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- metaflow_extensions/outerbounds/__init__.py +1 -1
- metaflow_extensions/outerbounds/plugins/__init__.py +34 -4
- metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
- metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
- metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +1 -1
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +43 -9
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +12 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +2 -16
- metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +81 -11
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +18 -8
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +6 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +45 -18
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +18 -9
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +10 -4
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/remote_config.py +46 -9
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +94 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
- metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
- ob_metaflow_extensions-1.1.130.dist-info/RECORD +0 -56
- {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -7,9 +7,10 @@ from metaflow.metaflow_config import SERVICE_URL
|
|
|
7
7
|
from metaflow.metaflow_config_funcs import init_config
|
|
8
8
|
from typing import Dict
|
|
9
9
|
from os import environ
|
|
10
|
-
|
|
10
|
+
import sys
|
|
11
11
|
import json
|
|
12
12
|
import requests
|
|
13
|
+
import random
|
|
13
14
|
import time
|
|
14
15
|
|
|
15
16
|
|
|
@@ -75,27 +76,39 @@ def get_snowflake_token(user: str = "", role: str = "", integration: str = "") -
|
|
|
75
76
|
}
|
|
76
77
|
json_payload = json.dumps(payload)
|
|
77
78
|
headers = provisioner.get_service_auth_header()
|
|
78
|
-
response =
|
|
79
|
+
response = _api_server_get(
|
|
80
|
+
snowflake_token_url, data=json_payload, headers=headers, conn_error_retries=5
|
|
81
|
+
)
|
|
79
82
|
response.raise_for_status()
|
|
80
83
|
return response.json()["token"]
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
def
|
|
86
|
+
def get_oauth_connection_params(
|
|
87
|
+
user: str = "", role: str = "", integration: str = "", **kwargs
|
|
88
|
+
) -> Dict:
|
|
84
89
|
"""
|
|
85
|
-
|
|
90
|
+
Get OAuth connection parameters for Snowflake authentication using Outerbounds integration.
|
|
91
|
+
|
|
92
|
+
This is a helper function that returns connection parameters dict that can be used
|
|
93
|
+
with both snowflake-connector-python and snowflake-snowpark-python.
|
|
94
|
+
|
|
86
95
|
user: str
|
|
87
96
|
The user name used to authenticate with snowflake
|
|
88
97
|
role: str
|
|
89
|
-
The role to request when
|
|
98
|
+
The role to request when connecting with snowflake
|
|
90
99
|
integration: str
|
|
91
|
-
The name of the snowflake integration to use. If not set, an existing integration
|
|
100
|
+
The name of the snowflake integration to use. If not set, an existing integration
|
|
101
|
+
will be used provided that only one exists in the current perimeter.
|
|
92
102
|
kwargs: dict
|
|
93
|
-
Additional arguments to
|
|
103
|
+
Additional arguments to include in the connection parameters
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Dict with connection parameters including OAuth token
|
|
94
107
|
"""
|
|
95
108
|
# ensure password is not set
|
|
96
109
|
if "password" in kwargs:
|
|
97
110
|
raise OuterboundsSnowflakeConnectorException(
|
|
98
|
-
"Password should not be set when using Outerbounds
|
|
111
|
+
"Password should not be set when using Outerbounds OAuth authentication."
|
|
99
112
|
)
|
|
100
113
|
|
|
101
114
|
provisioner = SnowflakeIntegrationProvisioner(integration)
|
|
@@ -134,11 +147,31 @@ def connect(user: str = "", role: str = "", integration: str = "", **kwargs):
|
|
|
134
147
|
kwargs["role"] = role
|
|
135
148
|
kwargs["user"] = user
|
|
136
149
|
|
|
150
|
+
return kwargs
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def connect(user: str = "", role: str = "", integration: str = "", **kwargs):
|
|
154
|
+
"""
|
|
155
|
+
Connect to snowflake using the token minted by Outerbounds
|
|
156
|
+
user: str
|
|
157
|
+
The user name used to authenticate with snowflake
|
|
158
|
+
role: str
|
|
159
|
+
The role to request when connect with snowflake
|
|
160
|
+
integration: str
|
|
161
|
+
The name of the snowflake integration to use. If not set, an existing integration will be used provided that only one exists in the current perimeter. If integration is not set and more than one exists in the current perimeter, then we raise an exception.
|
|
162
|
+
kwargs: dict
|
|
163
|
+
Additional arguments to pass to the python snowflake connector
|
|
164
|
+
"""
|
|
165
|
+
# Get OAuth connection params using the helper
|
|
166
|
+
connection_params = get_oauth_connection_params(
|
|
167
|
+
user=user, role=role, integration=integration, **kwargs
|
|
168
|
+
)
|
|
169
|
+
|
|
137
170
|
# connect to snowflake
|
|
138
171
|
try:
|
|
139
172
|
from snowflake.connector import connect
|
|
140
173
|
|
|
141
|
-
cn = connect(**
|
|
174
|
+
cn = connect(**connection_params)
|
|
142
175
|
return cn
|
|
143
176
|
except ImportError as ie:
|
|
144
177
|
raise OuterboundsSnowflakeConnectorException(
|
|
@@ -150,6 +183,39 @@ def connect(user: str = "", role: str = "", integration: str = "", **kwargs):
|
|
|
150
183
|
)
|
|
151
184
|
|
|
152
185
|
|
|
186
|
+
def _api_server_get(*args, conn_error_retries=2, **kwargs):
|
|
187
|
+
"""
|
|
188
|
+
There are two categories of errors that we need to handle when dealing with any API server.
|
|
189
|
+
1. HTTP errors. These are are errors that are returned from the API server.
|
|
190
|
+
- How to handle retries for this case will be application specific.
|
|
191
|
+
2. Errors when the API server may not be reachable (DNS resolution / network issues)
|
|
192
|
+
- In this scenario, we know that something external to the API server is going wrong causing the issue.
|
|
193
|
+
- Failing pre-maturely in the case might not be the best course of action since critical user jobs might crash on intermittent issues.
|
|
194
|
+
- So in this case, we can just planely retry the request.
|
|
195
|
+
|
|
196
|
+
This function handles the second case. It's a simple wrapper to handle the retry logic for connection errors.
|
|
197
|
+
If this function is provided a `conn_error_retries` of 5, then the last retry will have waited 32 seconds.
|
|
198
|
+
Generally this is a safe enough number of retries after which we can assume that something is really broken. Until then,
|
|
199
|
+
there can be intermittent issues that would resolve themselves if we retry gracefully.
|
|
200
|
+
"""
|
|
201
|
+
_num_retries = 0
|
|
202
|
+
noise = random.uniform(-0.5, 0.5)
|
|
203
|
+
while _num_retries < conn_error_retries:
|
|
204
|
+
try:
|
|
205
|
+
return requests.get(*args, **kwargs)
|
|
206
|
+
except requests.exceptions.ConnectionError:
|
|
207
|
+
if _num_retries <= conn_error_retries - 1:
|
|
208
|
+
# Exponential backoff with 2^(_num_retries+1) seconds
|
|
209
|
+
time.sleep((2 ** (_num_retries + 1)) + noise)
|
|
210
|
+
_num_retries += 1
|
|
211
|
+
else:
|
|
212
|
+
print(
|
|
213
|
+
"[@snowflake] Failed to connect to the API server. ",
|
|
214
|
+
file=sys.stderr,
|
|
215
|
+
)
|
|
216
|
+
raise
|
|
217
|
+
|
|
218
|
+
|
|
153
219
|
class Snowflake:
|
|
154
220
|
def __init__(
|
|
155
221
|
self, user: str = "", role: str = "", integration: str = "", **kwargs
|
|
@@ -273,7 +339,9 @@ class SnowflakeIntegrationProvisioner:
|
|
|
273
339
|
retryable_status_codes = [409]
|
|
274
340
|
json_payload = json.dumps(payload)
|
|
275
341
|
for attempt in range(2): # 0 = initial attempt, 1-2 = retries
|
|
276
|
-
response =
|
|
342
|
+
response = _api_server_get(
|
|
343
|
+
url, data=json_payload, headers=request_headers, conn_error_retries=5
|
|
344
|
+
)
|
|
277
345
|
if response.status_code not in retryable_status_codes:
|
|
278
346
|
break
|
|
279
347
|
|
|
@@ -281,7 +349,9 @@ class SnowflakeIntegrationProvisioner:
|
|
|
281
349
|
sleep_time = 0.5 * (attempt + 1)
|
|
282
350
|
time.sleep(sleep_time)
|
|
283
351
|
|
|
284
|
-
response =
|
|
352
|
+
response = _api_server_get(
|
|
353
|
+
url, data=json_payload, headers=request_headers, conn_error_retries=5
|
|
354
|
+
)
|
|
285
355
|
self._handle_error_response(response)
|
|
286
356
|
return response.json()
|
|
287
357
|
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import re
|
|
2
3
|
import shlex
|
|
3
4
|
import atexit
|
|
4
5
|
import json
|
|
5
6
|
import math
|
|
6
7
|
import time
|
|
8
|
+
import hashlib
|
|
7
9
|
|
|
8
10
|
from metaflow import util
|
|
9
11
|
|
|
@@ -57,21 +59,29 @@ class Snowpark(object):
|
|
|
57
59
|
atexit.register(lambda: self.job.kill() if hasattr(self, "job") else None)
|
|
58
60
|
|
|
59
61
|
def _job_name(self, user, flow_name, run_id, step_name, task_id, retry_count):
|
|
60
|
-
|
|
61
|
-
user
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
62
|
+
unique_str = (
|
|
63
|
+
"{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}".format(
|
|
64
|
+
user=user,
|
|
65
|
+
flow_name=flow_name,
|
|
66
|
+
run_id=str(run_id) if run_id is not None else "",
|
|
67
|
+
step_name=step_name,
|
|
68
|
+
task_id=str(task_id) if task_id is not None else "",
|
|
69
|
+
retry_count=str(retry_count) if retry_count is not None else "",
|
|
70
|
+
)
|
|
67
71
|
)
|
|
72
|
+
unique_hash = hashlib.md5(unique_str.encode("utf-8")).hexdigest()[:8]
|
|
73
|
+
raw_prefix = f"{flow_name}-{step_name}"
|
|
74
|
+
safe_prefix = re.sub(r"[^a-z0-9]", "-", raw_prefix.lower())
|
|
75
|
+
safe_prefix = safe_prefix[:54]
|
|
76
|
+
safe_prefix = safe_prefix.lstrip("-")
|
|
77
|
+
return f"{safe_prefix}-{unique_hash}"
|
|
68
78
|
|
|
69
79
|
def _command(self, environment, code_package_url, step_name, step_cmds, task_spec):
|
|
70
80
|
mflog_expr = export_mflog_env_vars(
|
|
71
81
|
datastore_type=self.datastore.TYPE,
|
|
72
82
|
stdout_path=STDOUT_PATH,
|
|
73
83
|
stderr_path=STDERR_PATH,
|
|
74
|
-
**task_spec
|
|
84
|
+
**task_spec,
|
|
75
85
|
)
|
|
76
86
|
init_cmds = environment.get_package_commands(
|
|
77
87
|
code_package_url, self.datastore.TYPE
|
|
@@ -66,6 +66,10 @@ def snowpark():
|
|
|
66
66
|
"--schema",
|
|
67
67
|
help="Schema for Snowpark Container Services.",
|
|
68
68
|
)
|
|
69
|
+
@click.option(
|
|
70
|
+
"--integration",
|
|
71
|
+
help="Outerbounds OAuth integration name for Snowpark Container Services. When set, uses OAuth authentication instead of password.",
|
|
72
|
+
)
|
|
69
73
|
@click.option(
|
|
70
74
|
"--image",
|
|
71
75
|
help="Docker image requirement for Snowpark Container Services. In name:version format.",
|
|
@@ -119,6 +123,7 @@ def step(
|
|
|
119
123
|
database=None,
|
|
120
124
|
warehouse=None,
|
|
121
125
|
schema=None,
|
|
126
|
+
integration=None,
|
|
122
127
|
image=None,
|
|
123
128
|
stage=None,
|
|
124
129
|
compute_pool=None,
|
|
@@ -235,6 +240,7 @@ def step(
|
|
|
235
240
|
"database": database,
|
|
236
241
|
"warehouse": warehouse,
|
|
237
242
|
"schema": schema,
|
|
243
|
+
"integration": integration,
|
|
238
244
|
},
|
|
239
245
|
)
|
|
240
246
|
with ctx.obj.monitor.measure("metaflow.snowpark.launch_job"):
|
|
@@ -10,14 +10,15 @@ from metaflow.exception import MetaflowException
|
|
|
10
10
|
class SnowparkClient(object):
|
|
11
11
|
def __init__(
|
|
12
12
|
self,
|
|
13
|
-
account: str,
|
|
14
|
-
user: str,
|
|
15
|
-
password: str,
|
|
16
|
-
role: str,
|
|
17
|
-
database: str,
|
|
18
|
-
warehouse: str,
|
|
19
|
-
schema: str,
|
|
13
|
+
account: str = None,
|
|
14
|
+
user: str = None,
|
|
15
|
+
password: str = None,
|
|
16
|
+
role: str = None,
|
|
17
|
+
database: str = None,
|
|
18
|
+
warehouse: str = None,
|
|
19
|
+
schema: str = None,
|
|
20
20
|
autocommit: bool = True,
|
|
21
|
+
integration: str = None,
|
|
21
22
|
):
|
|
22
23
|
try:
|
|
23
24
|
from snowflake.core import Root
|
|
@@ -27,22 +28,48 @@ class SnowparkClient(object):
|
|
|
27
28
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
28
29
|
raise SnowflakeException(
|
|
29
30
|
"Could not import module 'snowflake'.\n\nInstall Snowflake "
|
|
30
|
-
"Python
|
|
31
|
-
"
|
|
32
|
-
"
|
|
31
|
+
"Python packages first:\n"
|
|
32
|
+
" snowflake==1.8.0\n"
|
|
33
|
+
" snowflake-connector-python==3.18.0\n"
|
|
34
|
+
" snowflake-snowpark-python==1.40.0\n\n"
|
|
35
|
+
"You can install them by executing:\n"
|
|
36
|
+
"%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
|
|
33
37
|
"or equivalent through your favorite Python package manager."
|
|
34
38
|
% sys.executable
|
|
35
39
|
)
|
|
36
40
|
|
|
41
|
+
if integration:
|
|
42
|
+
# Use OAuth authentication via Outerbounds integration
|
|
43
|
+
from metaflow_extensions.outerbounds.plugins.snowflake.snowflake import (
|
|
44
|
+
get_oauth_connection_params,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.connection_parameters = get_oauth_connection_params(
|
|
48
|
+
user=user or "",
|
|
49
|
+
role=role or "",
|
|
50
|
+
integration=integration,
|
|
51
|
+
schema=schema or "",
|
|
52
|
+
account=account,
|
|
53
|
+
warehouse=warehouse,
|
|
54
|
+
database=database,
|
|
55
|
+
)
|
|
56
|
+
self.connection_parameters["autocommit"] = autocommit
|
|
57
|
+
else:
|
|
58
|
+
# Password-based authentication
|
|
59
|
+
self.connection_parameters = {
|
|
60
|
+
"account": account,
|
|
61
|
+
"user": user,
|
|
62
|
+
"password": password,
|
|
63
|
+
"role": role,
|
|
64
|
+
"warehouse": warehouse,
|
|
65
|
+
"database": database,
|
|
66
|
+
"schema": schema,
|
|
67
|
+
"autocommit": autocommit,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Remove None values from connection parameters
|
|
37
71
|
self.connection_parameters = {
|
|
38
|
-
|
|
39
|
-
"user": user,
|
|
40
|
-
"password": password,
|
|
41
|
-
"role": role,
|
|
42
|
-
"warehouse": warehouse,
|
|
43
|
-
"database": database,
|
|
44
|
-
"schema": schema,
|
|
45
|
-
"autocommit": autocommit,
|
|
72
|
+
k: v for k, v in self.connection_parameters.items() if v is not None
|
|
46
73
|
}
|
|
47
74
|
|
|
48
75
|
try:
|
|
@@ -42,10 +42,13 @@ class Snowflake(object):
|
|
|
42
42
|
return session
|
|
43
43
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
44
44
|
raise SnowflakeException(
|
|
45
|
-
"Could not import module 'snowflake'.\n\
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
45
|
+
"Could not import module 'snowflake'.\n\n"
|
|
46
|
+
"Install required Snowflake packages using the @pypi decorator:\n"
|
|
47
|
+
"@pypi(packages={\n"
|
|
48
|
+
" 'snowflake': '1.8.0',\n"
|
|
49
|
+
" 'snowflake-connector-python': '3.18.0',\n"
|
|
50
|
+
" 'snowflake-snowpark-python': '1.40.0'\n"
|
|
51
|
+
"})\n"
|
|
49
52
|
)
|
|
50
53
|
|
|
51
54
|
|
|
@@ -68,6 +71,7 @@ class SnowparkDecorator(StepDecorator):
|
|
|
68
71
|
"cpu": None,
|
|
69
72
|
"gpu": None,
|
|
70
73
|
"memory": None,
|
|
74
|
+
"integration": None, # Outerbounds OAuth integration name
|
|
71
75
|
}
|
|
72
76
|
|
|
73
77
|
package_url = None
|
|
@@ -77,12 +81,11 @@ class SnowparkDecorator(StepDecorator):
|
|
|
77
81
|
def __init__(self, attributes=None, statically_defined=False):
|
|
78
82
|
super(SnowparkDecorator, self).__init__(attributes, statically_defined)
|
|
79
83
|
|
|
84
|
+
# Set defaults from config (user can override via decorator or integration)
|
|
80
85
|
if not self.attributes["account"]:
|
|
81
86
|
self.attributes["account"] = SNOWPARK_ACCOUNT
|
|
82
87
|
if not self.attributes["user"]:
|
|
83
88
|
self.attributes["user"] = SNOWPARK_USER
|
|
84
|
-
if not self.attributes["password"]:
|
|
85
|
-
self.attributes["password"] = SNOWPARK_PASSWORD
|
|
86
89
|
if not self.attributes["role"]:
|
|
87
90
|
self.attributes["role"] = SNOWPARK_ROLE
|
|
88
91
|
if not self.attributes["database"]:
|
|
@@ -91,6 +94,9 @@ class SnowparkDecorator(StepDecorator):
|
|
|
91
94
|
self.attributes["warehouse"] = SNOWPARK_WAREHOUSE
|
|
92
95
|
if not self.attributes["schema"]:
|
|
93
96
|
self.attributes["schema"] = SNOWPARK_SCHEMA
|
|
97
|
+
# Only use password from config if not using integration (OAuth)
|
|
98
|
+
if not self.attributes["integration"] and not self.attributes["password"]:
|
|
99
|
+
self.attributes["password"] = SNOWPARK_PASSWORD
|
|
94
100
|
|
|
95
101
|
# If no docker image is explicitly specified, impute a default image.
|
|
96
102
|
if not self.attributes["image"]:
|
|
@@ -143,9 +149,12 @@ class SnowparkDecorator(StepDecorator):
|
|
|
143
149
|
except (NameError, ImportError, ModuleNotFoundError):
|
|
144
150
|
raise SnowflakeException(
|
|
145
151
|
"Could not import module 'snowflake'.\n\nInstall Snowflake "
|
|
146
|
-
"Python
|
|
147
|
-
"
|
|
148
|
-
"
|
|
152
|
+
"Python packages first:\n"
|
|
153
|
+
" snowflake==1.8.0\n"
|
|
154
|
+
" snowflake-connector-python==3.18.0\n"
|
|
155
|
+
" snowflake-snowpark-python==1.40.0\n\n"
|
|
156
|
+
"You can install them by executing:\n"
|
|
157
|
+
"%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
|
|
149
158
|
"or equivalent through your favorite Python package manager."
|
|
150
159
|
% sys.executable
|
|
151
160
|
)
|
|
@@ -12,9 +12,9 @@ from .snowpark_exceptions import SnowparkException
|
|
|
12
12
|
mapping = str.maketrans("0123456789", "abcdefghij")
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
# keep only alpha numeric characters and
|
|
15
|
+
# keep only alpha numeric characters and dashes..
|
|
16
16
|
def sanitize_name(job_name: str):
|
|
17
|
-
return "".join(char for char in job_name if char.isalnum() or char == "
|
|
17
|
+
return "".join(char for char in job_name if char.isalnum() or char == "-")
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
# this is not a decorator since the exception imports need to be inside
|
|
@@ -199,11 +199,17 @@ class RunningJob(object):
|
|
|
199
199
|
|
|
200
200
|
@property
|
|
201
201
|
def status(self):
|
|
202
|
-
|
|
202
|
+
status_list = self.status_obj()
|
|
203
|
+
if not status_list:
|
|
204
|
+
return "UNKNOWN"
|
|
205
|
+
return status_list[0].get("status", "UNKNOWN")
|
|
203
206
|
|
|
204
207
|
@property
|
|
205
208
|
def message(self):
|
|
206
|
-
|
|
209
|
+
status_list = self.status_obj()
|
|
210
|
+
if not status_list:
|
|
211
|
+
return None
|
|
212
|
+
return status_list[0].get("message")
|
|
207
213
|
|
|
208
214
|
@property
|
|
209
215
|
def is_waiting(self):
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from queue import Queue, Empty
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
from typing import Optional, List, Dict
|
|
4
|
+
import subprocess
|
|
5
|
+
import shutil
|
|
6
|
+
import sys
|
|
7
|
+
from metaflow import current
|
|
8
|
+
|
|
9
|
+
__mf_promote_submodules__ = ["plugins.torchtune"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TorchTune:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
use_multi_node_config: bool = False,
|
|
16
|
+
config_overrides: Optional[Dict] = None,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the Tune launcher.
|
|
20
|
+
|
|
21
|
+
:param use_multi_node_config: If True, attempt to build a distributed configuration
|
|
22
|
+
from current.torch.torchrun_args.
|
|
23
|
+
:param config_overrides: Optional dictionary of config overrides for tune run.
|
|
24
|
+
"""
|
|
25
|
+
self.multi_node_config = {}
|
|
26
|
+
if use_multi_node_config:
|
|
27
|
+
if getattr(current, "torch", None):
|
|
28
|
+
print(
|
|
29
|
+
"[Metaflow Tune] Since @torchrun is used, multi-node config can be used to launch the job."
|
|
30
|
+
)
|
|
31
|
+
# For distributed torchtune launches, we use similar parameters as torchrun.
|
|
32
|
+
# (You might need to adjust the keys according to your environment.)
|
|
33
|
+
self.multi_node_config = {
|
|
34
|
+
"nnodes": current.torch.torchrun_args["nnodes"],
|
|
35
|
+
"master_addr": current.torch.torchrun_args["master_addr"],
|
|
36
|
+
"master_port": int(current.torch.torchrun_args["master_port"]),
|
|
37
|
+
"node_rank": current.torch.torchrun_args["node_rank"],
|
|
38
|
+
"nproc_per_node": current.torch.torchrun_args["nproc_per_node"],
|
|
39
|
+
"num_processes": current.torch.torchrun_args["nproc_per_node"]
|
|
40
|
+
* current.torch.torchrun_args["nnodes"],
|
|
41
|
+
}
|
|
42
|
+
if config_overrides:
|
|
43
|
+
self.multi_node_config.update(config_overrides)
|
|
44
|
+
print(
|
|
45
|
+
f"[Metaflow Tune] Discovered multi-node config for torchrun: {self.multi_node_config}"
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
print(
|
|
49
|
+
"[Metaflow Tune] Since @torchrun is not used, default multi-node config cannot be used to launch the job."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def run(
|
|
53
|
+
self,
|
|
54
|
+
recipe: str,
|
|
55
|
+
config_dict: Dict,
|
|
56
|
+
additional_cli_options: Optional[List[str]] = None,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Launch the torchtune job via its CLI.
|
|
60
|
+
|
|
61
|
+
:param recipe: The path to the recipe (or name of the recipe) to run.
|
|
62
|
+
:param config_dict: Optional dictionary that will be dumped to a YAML file and passed via --config.
|
|
63
|
+
:param additional_cli_options: Optional list of additional CLI options.
|
|
64
|
+
:raises: subprocess.CalledProcessError if the subprocess returns a nonzero exit code.
|
|
65
|
+
"""
|
|
66
|
+
import yaml
|
|
67
|
+
import tempfile
|
|
68
|
+
import os
|
|
69
|
+
|
|
70
|
+
_temp_dir = tempfile.mkdtemp()
|
|
71
|
+
try:
|
|
72
|
+
config_path = os.path.join(_temp_dir, "config.yaml")
|
|
73
|
+
with open(config_path, "w") as f:
|
|
74
|
+
yaml.dump(config_dict, f)
|
|
75
|
+
|
|
76
|
+
additional_options = (
|
|
77
|
+
additional_cli_options if additional_cli_options else []
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Build the command. Here we use "tune run" as the base command.
|
|
81
|
+
cmd = ["tune", "run"]
|
|
82
|
+
|
|
83
|
+
# If distributed configuration is present, add torchrun–style flags.
|
|
84
|
+
if self.multi_node_config:
|
|
85
|
+
cmd.extend(
|
|
86
|
+
[
|
|
87
|
+
"--nnodes",
|
|
88
|
+
str(self.multi_node_config.get("nnodes")),
|
|
89
|
+
"--nproc-per-node",
|
|
90
|
+
str(self.multi_node_config.get("nproc_per_node")),
|
|
91
|
+
# "--rdzv_conf", f"rdzv_endpoint={self.multi_node_config.get('master_addr')}:{self.multi_node_config.get('master_port')}"
|
|
92
|
+
"--rdzv-backend",
|
|
93
|
+
"c10d",
|
|
94
|
+
"--rdzv-endpoint",
|
|
95
|
+
f"{self.multi_node_config.get('master_addr')}:{self.multi_node_config.get('master_port')}",
|
|
96
|
+
"--rdzv-id",
|
|
97
|
+
"1234567890",
|
|
98
|
+
"--node-rank",
|
|
99
|
+
str(self.multi_node_config.get("node_rank")),
|
|
100
|
+
# TODO: should there be a masterip/port here ?
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
cmd.extend(additional_options)
|
|
105
|
+
|
|
106
|
+
cmd.append(recipe)
|
|
107
|
+
# If a recipe configuration was provided, pass it via the --config flag.
|
|
108
|
+
cmd.extend(["--config", config_path])
|
|
109
|
+
|
|
110
|
+
# Append any additional CLI options.
|
|
111
|
+
|
|
112
|
+
# Launch the subprocess.
|
|
113
|
+
print(f"[Metaflow tune] {' '.join(cmd)}")
|
|
114
|
+
process = subprocess.Popen(
|
|
115
|
+
cmd,
|
|
116
|
+
stdout=subprocess.PIPE,
|
|
117
|
+
stderr=subprocess.PIPE,
|
|
118
|
+
universal_newlines=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Stream the output in real-time.
|
|
122
|
+
for out_line, err_line in read_popen_pipes(process):
|
|
123
|
+
print(out_line, end="", flush=True)
|
|
124
|
+
print(err_line, end="", file=sys.stderr, flush=True)
|
|
125
|
+
|
|
126
|
+
process.wait()
|
|
127
|
+
if process.returncode != 0:
|
|
128
|
+
raise subprocess.CalledProcessError(process.returncode, cmd)
|
|
129
|
+
finally:
|
|
130
|
+
shutil.rmtree(_temp_dir)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def enqueue_output(file, queue):
|
|
134
|
+
for line in iter(file.readline, ""):
|
|
135
|
+
queue.put(line)
|
|
136
|
+
file.close()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def read_popen_pipes(p):
|
|
140
|
+
|
|
141
|
+
with ThreadPoolExecutor(2) as pool:
|
|
142
|
+
q_stdout, q_stderr = Queue(), Queue()
|
|
143
|
+
|
|
144
|
+
pool.submit(enqueue_output, p.stdout, q_stdout)
|
|
145
|
+
pool.submit(enqueue_output, p.stderr, q_stderr)
|
|
146
|
+
|
|
147
|
+
while True:
|
|
148
|
+
|
|
149
|
+
if p.poll() is not None and q_stdout.empty() and q_stderr.empty():
|
|
150
|
+
break
|
|
151
|
+
|
|
152
|
+
out_line = err_line = ""
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
out_line = q_stdout.get_nowait()
|
|
156
|
+
except Empty:
|
|
157
|
+
pass
|
|
158
|
+
try:
|
|
159
|
+
err_line = q_stderr.get_nowait()
|
|
160
|
+
except Empty:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
yield (out_line, err_line)
|