ob-metaflow-extensions 1.1.142__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow_extensions/outerbounds/__init__.py +1 -1
- metaflow_extensions/outerbounds/plugins/__init__.py +26 -5
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
- metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +44 -4
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/remote_config.py +27 -3
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +87 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
- metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
- ob_metaflow_extensions-1.1.142.dist-info/RECORD +0 -64
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import time
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import requests
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
# This click import is not used to construct any ob
|
|
9
|
+
# package cli. Its used only for printing stuff.
|
|
10
|
+
# So we can use the static metaflow._vendor import path
|
|
11
|
+
from metaflow._vendor import click
|
|
12
|
+
from .app_config import CAPSULE_DEBUG
|
|
13
|
+
import sys
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
import logging
|
|
17
|
+
import itertools
|
|
18
|
+
from typing import Union, Callable, Any, List
|
|
19
|
+
|
|
20
|
+
from ._vendor.spinner import (
|
|
21
|
+
Spinners,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MultiStepSpinner:
|
|
26
|
+
"""
|
|
27
|
+
A spinner that supports multi-step progress and configurable alignment.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
spinner : Spinners
|
|
32
|
+
Which spinner frames/interval to use.
|
|
33
|
+
text : str
|
|
34
|
+
Static text to display beside the spinner.
|
|
35
|
+
color : str, optional
|
|
36
|
+
Click color name.
|
|
37
|
+
align : {'left','right'}
|
|
38
|
+
Whether to render the spinner to the left (default) or right of the text.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
spinner: Spinners = Spinners.dots,
|
|
44
|
+
text: str = "",
|
|
45
|
+
color: Optional[str] = None,
|
|
46
|
+
align: str = "right",
|
|
47
|
+
file=sys.stdout,
|
|
48
|
+
):
|
|
49
|
+
cfg = spinner.value
|
|
50
|
+
self.frames = cfg["frames"]
|
|
51
|
+
self.interval = float(cfg["interval"]) / 1000.0 # type: ignore
|
|
52
|
+
self.text = text
|
|
53
|
+
self.color = color
|
|
54
|
+
if align not in ("left", "right"):
|
|
55
|
+
raise ValueError("align must be 'left' or 'right'")
|
|
56
|
+
self.align = align
|
|
57
|
+
self._write_file = file
|
|
58
|
+
# precompute clear length: max frame width + space + text length
|
|
59
|
+
max_frame = max(self.frames, key=lambda x: len(x)) # type: ignore
|
|
60
|
+
self.clear_len = len(self.main_text) + len(max_frame) + 1
|
|
61
|
+
|
|
62
|
+
self._stop_evt = threading.Event()
|
|
63
|
+
self._pause_evt = threading.Event()
|
|
64
|
+
self._thread = None
|
|
65
|
+
self._write_lock = threading.Lock()
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def main_text(self):
|
|
69
|
+
# if self.text is a callable then call it
|
|
70
|
+
if callable(self.text):
|
|
71
|
+
return self.text()
|
|
72
|
+
return self.text
|
|
73
|
+
|
|
74
|
+
def _spin(self):
|
|
75
|
+
for frame in itertools.cycle(self.frames):
|
|
76
|
+
if self._stop_evt.is_set():
|
|
77
|
+
break
|
|
78
|
+
if self._pause_evt.is_set():
|
|
79
|
+
time.sleep(0.05)
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
# ---- Core logging critical section ----
|
|
83
|
+
with self._write_lock:
|
|
84
|
+
symbol = click.style(frame, fg=self.color) if self.color else frame
|
|
85
|
+
if self.align == "left":
|
|
86
|
+
msg = f"{symbol} {self.main_text}"
|
|
87
|
+
else:
|
|
88
|
+
msg = f"{self.main_text} {symbol}"
|
|
89
|
+
|
|
90
|
+
click.echo(msg, nl=False, file=self._write_file)
|
|
91
|
+
click.echo("\r", nl=False, file=self._write_file)
|
|
92
|
+
self._write_file.flush()
|
|
93
|
+
# ---- End of critical section ----
|
|
94
|
+
time.sleep(self.interval)
|
|
95
|
+
# clear the line when done
|
|
96
|
+
self._clear_line()
|
|
97
|
+
|
|
98
|
+
def _clear_line(self):
|
|
99
|
+
with self._write_lock:
|
|
100
|
+
click.echo(" " * self.clear_len, nl=False, file=self._write_file)
|
|
101
|
+
click.echo("\r", nl=False, file=self._write_file)
|
|
102
|
+
self._write_file.flush()
|
|
103
|
+
|
|
104
|
+
def start(self):
|
|
105
|
+
if self._thread and self._thread.is_alive():
|
|
106
|
+
return
|
|
107
|
+
self._stop_evt.clear()
|
|
108
|
+
self._pause_evt.clear()
|
|
109
|
+
self._thread = threading.Thread(target=self._spin, daemon=True)
|
|
110
|
+
self._thread.start()
|
|
111
|
+
|
|
112
|
+
def stop(self):
|
|
113
|
+
self._stop_evt.set()
|
|
114
|
+
if self._thread:
|
|
115
|
+
self._thread.join()
|
|
116
|
+
|
|
117
|
+
def log(self, *messages: str):
|
|
118
|
+
"""Pause the spinner, emit a ✔ + message, then resume."""
|
|
119
|
+
self._pause_evt.set()
|
|
120
|
+
self._clear_line()
|
|
121
|
+
# ---- Core logging critical section ----
|
|
122
|
+
with self._write_lock:
|
|
123
|
+
self._write_file.flush()
|
|
124
|
+
for message in messages:
|
|
125
|
+
click.echo(f"{message}", file=self._write_file, nl=True)
|
|
126
|
+
self._write_file.flush()
|
|
127
|
+
# ---- End of critical section ----
|
|
128
|
+
self._pause_evt.clear()
|
|
129
|
+
|
|
130
|
+
def __enter__(self):
|
|
131
|
+
self.start()
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def __exit__(self, exc_type, exc, tb):
|
|
135
|
+
self.stop()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class SpinnerLogHandler(logging.Handler):
|
|
139
|
+
def __init__(self, spinner: MultiStepSpinner, *args, **kwargs):
|
|
140
|
+
super().__init__(*args, **kwargs)
|
|
141
|
+
self.spinner = spinner
|
|
142
|
+
|
|
143
|
+
def emit(self, record):
|
|
144
|
+
msg = self.format(record)
|
|
145
|
+
self.spinner.log(msg)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MaximumRetriesExceeded(Exception):
|
|
149
|
+
def __init__(self, url, method, status_code, text):
|
|
150
|
+
self.url = url
|
|
151
|
+
self.method = method
|
|
152
|
+
self.status_code = status_code
|
|
153
|
+
self.text = text
|
|
154
|
+
|
|
155
|
+
def __str__(self):
|
|
156
|
+
return f"Maximum retries exceeded for {self.url}[{self.method}] {self.status_code} {self.text}"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class TODOException(Exception):
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
requests_funcs = [
|
|
164
|
+
requests.get,
|
|
165
|
+
requests.post,
|
|
166
|
+
requests.put,
|
|
167
|
+
requests.delete,
|
|
168
|
+
requests.patch,
|
|
169
|
+
requests.head,
|
|
170
|
+
requests.options,
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def safe_requests_wrapper(
|
|
175
|
+
requests_module_fn,
|
|
176
|
+
*args,
|
|
177
|
+
conn_error_retries=2,
|
|
178
|
+
retryable_status_codes=[409],
|
|
179
|
+
logger_fn=None,
|
|
180
|
+
**kwargs,
|
|
181
|
+
):
|
|
182
|
+
"""
|
|
183
|
+
There are two categories of errors that we need to handle when dealing with any API server.
|
|
184
|
+
1. HTTP errors. These are are errors that are returned from the API server.
|
|
185
|
+
- How to handle retries for this case will be application specific.
|
|
186
|
+
2. Errors when the API server may not be reachable (DNS resolution / network issues)
|
|
187
|
+
- In this scenario, we know that something external to the API server is going wrong causing the issue.
|
|
188
|
+
- Failing prematurely in the case might not be the best course of action since critical user jobs might crash on intermittent issues.
|
|
189
|
+
- So in this case, we can just plainly retry the request.
|
|
190
|
+
|
|
191
|
+
This function handles the second case. It's a simple wrapper to handle the retry logic for connection errors.
|
|
192
|
+
If this function is provided a `conn_error_retries` of 5, then the last retry will have waited 32 seconds.
|
|
193
|
+
Generally this is a safe enough number of retries after which we can assume that something is really broken. Until then,
|
|
194
|
+
there can be intermittent issues that would resolve themselves if we retry gracefully.
|
|
195
|
+
"""
|
|
196
|
+
if requests_module_fn not in requests_funcs:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"safe_requests_wrapper doesn't support {requests_module_fn.__name__}. You can only use the following functions: {requests_funcs}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
_num_retries = 0
|
|
202
|
+
noise = random.uniform(-0.5, 0.5)
|
|
203
|
+
response = None
|
|
204
|
+
while _num_retries < conn_error_retries:
|
|
205
|
+
try:
|
|
206
|
+
response = requests_module_fn(*args, **kwargs)
|
|
207
|
+
if response.status_code not in retryable_status_codes:
|
|
208
|
+
return response
|
|
209
|
+
if CAPSULE_DEBUG:
|
|
210
|
+
if logger_fn:
|
|
211
|
+
logger_fn(
|
|
212
|
+
f"[outerbounds-debug] safe_requests_wrapper: {response.url}[{requests_module_fn.__name__}] {response.status_code} {response.text}",
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
print(
|
|
216
|
+
f"[outerbounds-debug] safe_requests_wrapper: {response.url}[{requests_module_fn.__name__}] {response.status_code} {response.text}",
|
|
217
|
+
file=sys.stderr,
|
|
218
|
+
)
|
|
219
|
+
_num_retries += 1
|
|
220
|
+
time.sleep((2 ** (_num_retries + 1)) + noise)
|
|
221
|
+
except requests.exceptions.ConnectionError:
|
|
222
|
+
if _num_retries <= conn_error_retries - 1:
|
|
223
|
+
# Exponential backoff with 2^(_num_retries+1) seconds
|
|
224
|
+
time.sleep((2 ** (_num_retries + 1)) + noise)
|
|
225
|
+
_num_retries += 1
|
|
226
|
+
else:
|
|
227
|
+
raise
|
|
228
|
+
raise MaximumRetriesExceeded(
|
|
229
|
+
response.url,
|
|
230
|
+
requests_module_fn.__name__,
|
|
231
|
+
response.status_code,
|
|
232
|
+
response.text,
|
|
233
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
from .app_config import AppConfig, AppConfigError
|
|
4
|
+
from .secrets import SecretRetriever, SecretNotFound
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def secrets_validator(secrets: List[str]):
|
|
8
|
+
secret_retriever = SecretRetriever()
|
|
9
|
+
for secret in secrets:
|
|
10
|
+
try:
|
|
11
|
+
secret_retriever.get_secret_as_dict(secret)
|
|
12
|
+
except SecretNotFound:
|
|
13
|
+
raise Exception(f"Secret named `{secret}` not found")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_validations(app_config: AppConfig):
|
|
17
|
+
pass
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from metaflow.user_decorators.mutable_flow import MutableFlow
|
|
2
|
+
from metaflow.user_decorators.mutable_step import MutableStep
|
|
3
|
+
from metaflow.user_decorators.user_flow_decorator import FlowMutator
|
|
4
|
+
from .assume_role import OBP_ASSUME_ROLE_ARN_ENV_VAR
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class assume_role(FlowMutator):
|
|
8
|
+
"""
|
|
9
|
+
Flow-level decorator for assuming AWS IAM roles.
|
|
10
|
+
|
|
11
|
+
When applied to a flow, all steps in the flow will automatically use the specified IAM role-arn
|
|
12
|
+
as their source principal.
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
------
|
|
16
|
+
@assume_role(role_arn="arn:aws:iam::123456789012:role/my-iam-role")
|
|
17
|
+
class MyFlow(FlowSpec):
|
|
18
|
+
@step
|
|
19
|
+
def start(self):
|
|
20
|
+
import boto3
|
|
21
|
+
client = boto3.client("dynamodb") # Automatically uses the role in the flow decorator
|
|
22
|
+
self.next(self.end)
|
|
23
|
+
|
|
24
|
+
@step
|
|
25
|
+
def end(self):
|
|
26
|
+
from metaflow import get_aws_client
|
|
27
|
+
client = get_aws_client("dynamodb") # Automatically uses the role in the flow decorator
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def init(self, *args, **kwargs):
|
|
31
|
+
self.role_arn = kwargs.get("role_arn", None)
|
|
32
|
+
|
|
33
|
+
if self.role_arn is None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"`role_arn` keyword argument is required for the assume_role decorator"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if not self.role_arn.startswith("arn:aws:iam::"):
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"`role_arn` must be a valid AWS IAM role ARN starting with 'arn:aws:iam::'"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def pre_mutate(self, mutable_flow: MutableFlow) -> None:
|
|
44
|
+
"""
|
|
45
|
+
This method is called by Metaflow to apply the decorator to the flow.
|
|
46
|
+
It sets up environment variables that will be used by the AWS client
|
|
47
|
+
to automatically assume the specified role.
|
|
48
|
+
"""
|
|
49
|
+
# Import environment decorator at runtime to avoid circular imports
|
|
50
|
+
from metaflow import environment
|
|
51
|
+
|
|
52
|
+
def _swap_environment_variables(step: MutableStep, role_arn: str) -> None:
|
|
53
|
+
_step_has_env_set = True
|
|
54
|
+
_env_kwargs = {OBP_ASSUME_ROLE_ARN_ENV_VAR: role_arn}
|
|
55
|
+
for d in step.decorator_specs:
|
|
56
|
+
name, _, _, deco_kwargs = d
|
|
57
|
+
if name == "environment":
|
|
58
|
+
_env_kwargs.update(deco_kwargs["vars"])
|
|
59
|
+
_step_has_env_set = True
|
|
60
|
+
|
|
61
|
+
if _step_has_env_set:
|
|
62
|
+
# remove the environment decorator
|
|
63
|
+
step.remove_decorator("environment")
|
|
64
|
+
|
|
65
|
+
# add the environment decorator
|
|
66
|
+
step.add_decorator(
|
|
67
|
+
environment,
|
|
68
|
+
deco_kwargs=dict(vars=_env_kwargs),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Set the role ARN as an environment variable that will be picked up
|
|
72
|
+
# by the get_aws_client function
|
|
73
|
+
def _setup_role_assumption(step: MutableStep) -> None:
|
|
74
|
+
_swap_environment_variables(step, self.role_arn)
|
|
75
|
+
|
|
76
|
+
# Apply the role assumption setup to all steps in the flow
|
|
77
|
+
for _, step in mutable_flow.steps:
|
|
78
|
+
_setup_role_assumption(step)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from metaflow.user_decorators.user_flow_decorator import FlowMutator
|
|
2
|
+
from metaflow.user_decorators.mutable_flow import MutableFlow
|
|
3
|
+
from metaflow.user_decorators.mutable_step import MutableStep
|
|
4
|
+
from .external_chckpt import _ExternalCheckpointFlowDeco
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class coreweave_checkpoints(_ExternalCheckpointFlowDeco):
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
This decorator is used for setting the coreweave object store as the artifact store for checkpoints/models created by the flow.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
secrets: list
|
|
17
|
+
A list of secrets to be added to the step. These secrets should contain any secrets that are required globally and the secret
|
|
18
|
+
for the coreweave object store. The secret should contain the following keys:
|
|
19
|
+
- COREWEAVE_ACCESS_KEY
|
|
20
|
+
- COREWEAVE_SECRET_KEY
|
|
21
|
+
|
|
22
|
+
bucket_path: str
|
|
23
|
+
The path to the bucket to store the checkpoints/models.
|
|
24
|
+
|
|
25
|
+
Usage
|
|
26
|
+
-----
|
|
27
|
+
```python
|
|
28
|
+
from metaflow import checkpoint, step, FlowSpec, coreweave_checkpoints
|
|
29
|
+
|
|
30
|
+
@coreweave_checkpoints(secrets=[], bucket_path=None)
|
|
31
|
+
class MyFlow(FlowSpec):
|
|
32
|
+
@checkpoint
|
|
33
|
+
@step
|
|
34
|
+
def start(self):
|
|
35
|
+
# Saves the checkpoint in the coreweave object store
|
|
36
|
+
current.checkpoint.save("./foo.txt")
|
|
37
|
+
|
|
38
|
+
@step
|
|
39
|
+
def end(self):
|
|
40
|
+
pass
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
|
|
47
|
+
def init(self, *args, **kwargs):
|
|
48
|
+
super().init(*args, **kwargs)
|
|
49
|
+
self.coreweave_endpoint_url = f"https://cwobject.com"
|
|
50
|
+
|
|
51
|
+
def pre_mutate(self, mutable_flow: MutableFlow) -> None:
|
|
52
|
+
from metaflow import (
|
|
53
|
+
with_artifact_store,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def _coreweave_config():
|
|
57
|
+
return {
|
|
58
|
+
"root": self.bucket_path,
|
|
59
|
+
"client_params": {
|
|
60
|
+
"aws_access_key_id": os.environ.get("COREWEAVE_ACCESS_KEY"),
|
|
61
|
+
"aws_secret_access_key": os.environ.get("COREWEAVE_SECRET_KEY"),
|
|
62
|
+
"endpoint_url": self.coreweave_endpoint_url,
|
|
63
|
+
"config": dict(s3={"addressing_style": "virtual"}),
|
|
64
|
+
},
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
mutable_flow.add_decorator(
|
|
68
|
+
with_artifact_store,
|
|
69
|
+
deco_kwargs=dict(type="coreweave", config=_coreweave_config),
|
|
70
|
+
)
|
|
71
|
+
self._swap_secrets(mutable_flow)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from metaflow.user_decorators.user_flow_decorator import FlowMutator
|
|
2
|
+
from metaflow.user_decorators.mutable_flow import MutableFlow
|
|
3
|
+
from metaflow.user_decorators.mutable_step import MutableStep
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _ExternalCheckpointFlowDeco(FlowMutator):
|
|
8
|
+
def init(self, *args, **kwargs):
|
|
9
|
+
self.bucket_path = kwargs.get("bucket_path", None)
|
|
10
|
+
|
|
11
|
+
self.secrets = kwargs.get("secrets", [])
|
|
12
|
+
if self.bucket_path is None:
|
|
13
|
+
raise ValueError(
|
|
14
|
+
"`bucket_path` keyword argument is required for the coreweave_datastore"
|
|
15
|
+
)
|
|
16
|
+
if not self.bucket_path.startswith("s3://"):
|
|
17
|
+
raise ValueError(
|
|
18
|
+
"`bucket_path` must start with `s3://` for the coreweave_datastore"
|
|
19
|
+
)
|
|
20
|
+
if self.secrets is None:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
"`secrets` keyword argument is required for the coreweave_datastore"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def _swap_secrets(self, mutable_flow: MutableFlow) -> None:
|
|
26
|
+
from metaflow import (
|
|
27
|
+
checkpoint,
|
|
28
|
+
model,
|
|
29
|
+
huggingface_hub,
|
|
30
|
+
secrets,
|
|
31
|
+
with_artifact_store,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _add_secrets(step: MutableStep) -> None:
|
|
35
|
+
decos_to_add = []
|
|
36
|
+
swapping_decos = {
|
|
37
|
+
"huggingface_hub": huggingface_hub,
|
|
38
|
+
"model": model,
|
|
39
|
+
"checkpoint": checkpoint,
|
|
40
|
+
}
|
|
41
|
+
already_has_secrets = False
|
|
42
|
+
secrets_present_in_deco = []
|
|
43
|
+
for d in step.decorator_specs:
|
|
44
|
+
name, _, _, deco_kwargs = d
|
|
45
|
+
if name in swapping_decos:
|
|
46
|
+
decos_to_add.append((name, deco_kwargs))
|
|
47
|
+
elif name == "secrets":
|
|
48
|
+
already_has_secrets = True
|
|
49
|
+
secrets_present_in_deco.extend(deco_kwargs["sources"])
|
|
50
|
+
|
|
51
|
+
# If the step aleady has secrets then take all the sources in
|
|
52
|
+
# the secrets and add the addtional secrets to the existing secrets
|
|
53
|
+
secrets_to_add = self.secrets
|
|
54
|
+
if already_has_secrets:
|
|
55
|
+
secrets_to_add.extend(secrets_present_in_deco)
|
|
56
|
+
|
|
57
|
+
secrets_to_add = list(set(secrets_to_add))
|
|
58
|
+
|
|
59
|
+
if len(decos_to_add) == 0:
|
|
60
|
+
if already_has_secrets:
|
|
61
|
+
step.remove_decorator("secrets")
|
|
62
|
+
|
|
63
|
+
step.add_decorator(
|
|
64
|
+
secrets,
|
|
65
|
+
deco_kwargs=dict(
|
|
66
|
+
sources=secrets_to_add,
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
for d, _ in decos_to_add:
|
|
72
|
+
step.remove_decorator(d)
|
|
73
|
+
|
|
74
|
+
step.add_decorator(
|
|
75
|
+
secrets,
|
|
76
|
+
deco_kwargs=dict(
|
|
77
|
+
sources=secrets_to_add,
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
for d, attrs in decos_to_add:
|
|
81
|
+
_deco_to_add = swapping_decos[d]
|
|
82
|
+
step.add_decorator(_deco_to_add, deco_kwargs=attrs)
|
|
83
|
+
|
|
84
|
+
for step_name, step in mutable_flow.steps:
|
|
85
|
+
_add_secrets(step)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from metaflow.user_decorators.mutable_flow import MutableFlow
|
|
2
|
+
from .external_chckpt import _ExternalCheckpointFlowDeco
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
NEBIUS_ENDPOINT_URL = "https://storage.eu-north1.nebius.cloud:443"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class nebius_checkpoints(_ExternalCheckpointFlowDeco):
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
This decorator is used for setting the nebius's S3 compatible object store as the artifact store for
|
|
13
|
+
checkpoints/models created by the flow.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
secrets: list
|
|
18
|
+
A list of secrets to be added to the step. These secrets should contain any secrets that are required globally and the secret
|
|
19
|
+
for the nebius object store. The secret should contain the following keys:
|
|
20
|
+
- NEBIUS_ACCESS_KEY
|
|
21
|
+
- NEBIUS_SECRET_KEY
|
|
22
|
+
|
|
23
|
+
bucket_path: str
|
|
24
|
+
The path to the bucket to store the checkpoints/models.
|
|
25
|
+
|
|
26
|
+
endpoint_url: str
|
|
27
|
+
The endpoint url for the nebius object store. Defaults to `https://storage.eu-north1.nebius.cloud:443`
|
|
28
|
+
|
|
29
|
+
Usage
|
|
30
|
+
-----
|
|
31
|
+
```python
|
|
32
|
+
from metaflow import checkpoint, step, FlowSpec, nebius_checkpoints
|
|
33
|
+
|
|
34
|
+
@nebius_checkpoints(secrets=[], bucket_path=None)
|
|
35
|
+
class MyFlow(FlowSpec):
|
|
36
|
+
@checkpoint
|
|
37
|
+
@step
|
|
38
|
+
def start(self):
|
|
39
|
+
# Saves the checkpoint in the nebius object store
|
|
40
|
+
current.checkpoint.save("./foo.txt")
|
|
41
|
+
|
|
42
|
+
@step
|
|
43
|
+
def end(self):
|
|
44
|
+
pass
|
|
45
|
+
```
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, *args, **kwargs):
|
|
49
|
+
super().__init__(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def init(self, *args, **kwargs):
|
|
52
|
+
super().init(*args, **kwargs)
|
|
53
|
+
self.nebius_endpoint_url = kwargs.get("endpoint_url", NEBIUS_ENDPOINT_URL)
|
|
54
|
+
|
|
55
|
+
def pre_mutate(self, mutable_flow: MutableFlow) -> None:
|
|
56
|
+
from metaflow import (
|
|
57
|
+
with_artifact_store,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _nebius_config():
|
|
61
|
+
return {
|
|
62
|
+
"root": self.bucket_path,
|
|
63
|
+
"client_params": {
|
|
64
|
+
"aws_access_key_id": os.environ.get("NEBIUS_ACCESS_KEY"),
|
|
65
|
+
"aws_secret_access_key": os.environ.get("NEBIUS_SECRET_KEY"),
|
|
66
|
+
"endpoint_url": self.nebius_endpoint_url,
|
|
67
|
+
},
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
mutable_flow.add_decorator(
|
|
71
|
+
with_artifact_store, deco_kwargs=dict(type="s3", config=_nebius_config)
|
|
72
|
+
)
|
|
73
|
+
self._swap_secrets(mutable_flow)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import time
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Dict, Optional, Any, Callable
|
|
5
|
+
from functools import partial
|
|
6
|
+
from metaflow.exception import MetaflowException
|
|
7
|
+
from metaflow.metaflow_config import FAST_BAKERY_URL
|
|
8
|
+
|
|
9
|
+
from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
|
|
10
|
+
from .docker_environment import cache_request
|
|
11
|
+
|
|
12
|
+
BAKERY_METAFILE = ".imagebakery-cache"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BakerException(MetaflowException):
|
|
16
|
+
headline = "Ran into an error while baking image"
|
|
17
|
+
|
|
18
|
+
def __init__(self, msg):
|
|
19
|
+
super(BakerException, self).__init__(msg)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def bake_image(
|
|
23
|
+
cache_file_path: str,
|
|
24
|
+
ref: Optional[str] = None,
|
|
25
|
+
python: Optional[str] = None,
|
|
26
|
+
pypi_packages: Optional[Dict[str, str]] = None,
|
|
27
|
+
conda_packages: Optional[Dict[str, str]] = None,
|
|
28
|
+
base_image: Optional[str] = None,
|
|
29
|
+
logger: Optional[Callable[[str], Any]] = None,
|
|
30
|
+
) -> FastBakeryApiResponse:
|
|
31
|
+
"""
|
|
32
|
+
Bakes a Docker image with the specified dependencies.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
cache_file_path: Path to the cache file
|
|
36
|
+
ref: Reference identifier for this bake (for logging purposes)
|
|
37
|
+
python: Python version to use
|
|
38
|
+
pypi_packages: Dictionary of PyPI packages and versions
|
|
39
|
+
conda_packages: Dictionary of Conda packages and versions
|
|
40
|
+
base_image: Base Docker image to use
|
|
41
|
+
logger: Optional logger function to output progress
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
FastBakeryApiResponse: The response from the bakery service
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
BakerException: If the baking process fails
|
|
48
|
+
"""
|
|
49
|
+
# Default logger if none provided
|
|
50
|
+
if logger is None:
|
|
51
|
+
logger = partial(print, file=sys.stderr)
|
|
52
|
+
|
|
53
|
+
# Thread lock for logging
|
|
54
|
+
logger_lock = threading.Lock()
|
|
55
|
+
images_baked = 0
|
|
56
|
+
|
|
57
|
+
@cache_request(cache_file_path)
|
|
58
|
+
def _cached_bake(
|
|
59
|
+
ref=None,
|
|
60
|
+
python=None,
|
|
61
|
+
pypi_packages=None,
|
|
62
|
+
conda_packages=None,
|
|
63
|
+
base_image=None,
|
|
64
|
+
):
|
|
65
|
+
try:
|
|
66
|
+
bakery = FastBakery(url=FAST_BAKERY_URL)
|
|
67
|
+
bakery._reset_payload()
|
|
68
|
+
bakery.python_version(python)
|
|
69
|
+
bakery.pypi_packages(pypi_packages)
|
|
70
|
+
bakery.conda_packages(conda_packages)
|
|
71
|
+
bakery.base_image(base_image)
|
|
72
|
+
# bakery.ignore_cache()
|
|
73
|
+
|
|
74
|
+
with logger_lock:
|
|
75
|
+
logger(f"🍳 Baking [{ref}] ...")
|
|
76
|
+
logger(f" 🐍 Python: {python}")
|
|
77
|
+
|
|
78
|
+
if pypi_packages:
|
|
79
|
+
logger(f" 📦 PyPI packages:")
|
|
80
|
+
for package, version in pypi_packages.items():
|
|
81
|
+
logger(f" 🔧 {package}: {version}")
|
|
82
|
+
|
|
83
|
+
if conda_packages:
|
|
84
|
+
logger(f" 📦 Conda packages:")
|
|
85
|
+
for package, version in conda_packages.items():
|
|
86
|
+
logger(f" 🔧 {package}: {version}")
|
|
87
|
+
|
|
88
|
+
logger(f" 🏗️ Base image: {base_image}")
|
|
89
|
+
|
|
90
|
+
start_time = time.time()
|
|
91
|
+
res = bakery.bake()
|
|
92
|
+
# TODO: Get actual bake time from bakery
|
|
93
|
+
bake_time = time.time() - start_time
|
|
94
|
+
|
|
95
|
+
with logger_lock:
|
|
96
|
+
logger(f"🏁 Baked [{ref}] in {bake_time:.2f} seconds!")
|
|
97
|
+
nonlocal images_baked
|
|
98
|
+
images_baked += 1
|
|
99
|
+
return res
|
|
100
|
+
except FastBakeryException as ex:
|
|
101
|
+
raise BakerException(f"Bake [{ref}] failed: {str(ex)}")
|
|
102
|
+
|
|
103
|
+
# Call the cached bake function with the provided parameters
|
|
104
|
+
return _cached_bake(
|
|
105
|
+
ref=ref,
|
|
106
|
+
python=python,
|
|
107
|
+
pypi_packages=pypi_packages,
|
|
108
|
+
conda_packages=conda_packages,
|
|
109
|
+
base_image=base_image,
|
|
110
|
+
)
|