ob-metaflow-extensions 1.1.170__py2.py3-none-any.whl → 1.4.35__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (65) hide show
  1. metaflow_extensions/outerbounds/plugins/__init__.py +6 -2
  2. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  3. metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
  4. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
  5. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  6. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  7. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  32. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +25 -12
  33. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +9 -77
  34. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  35. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +7 -78
  36. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  37. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +6 -2
  38. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
  39. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +8 -8
  40. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  41. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  42. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  43. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  44. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  45. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  46. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  47. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  48. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  49. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  50. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
  51. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
  52. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
  53. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +4 -0
  54. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +173 -95
  55. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +9 -9
  56. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +159 -9
  57. metaflow_extensions/outerbounds/remote_config.py +8 -3
  58. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +63 -1
  59. metaflow_extensions/outerbounds/toplevel/ob_internal.py +3 -0
  60. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  61. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  62. {ob_metaflow_extensions-1.1.170.dist-info → ob_metaflow_extensions-1.4.35.dist-info}/METADATA +2 -2
  63. {ob_metaflow_extensions-1.1.170.dist-info → ob_metaflow_extensions-1.4.35.dist-info}/RECORD +65 -21
  64. {ob_metaflow_extensions-1.1.170.dist-info → ob_metaflow_extensions-1.4.35.dist-info}/WHEEL +0 -0
  65. {ob_metaflow_extensions-1.1.170.dist-info → ob_metaflow_extensions-1.4.35.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,233 @@
1
+ import random
2
+ import time
3
+ import sys
4
+ import json
5
+ import requests
6
+ from typing import Optional
7
+
8
+ # This click import is not used to construct any ob
9
+ # package cli. Its used only for printing stuff.
10
+ # So we can use the static metaflow._vendor import path
11
+ from metaflow._vendor import click
12
+ from .app_config import CAPSULE_DEBUG
13
+ import sys
14
+ import threading
15
+ import time
16
+ import logging
17
+ import itertools
18
+ from typing import Union, Callable, Any, List
19
+
20
+ from ._vendor.spinner import (
21
+ Spinners,
22
+ )
23
+
24
+
25
+ class MultiStepSpinner:
26
+ """
27
+ A spinner that supports multi-step progress and configurable alignment.
28
+
29
+ Parameters
30
+ ----------
31
+ spinner : Spinners
32
+ Which spinner frames/interval to use.
33
+ text : str
34
+ Static text to display beside the spinner.
35
+ color : str, optional
36
+ Click color name.
37
+ align : {'left','right'}
38
+ Whether to render the spinner to the left (default) or right of the text.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ spinner: Spinners = Spinners.dots,
44
+ text: str = "",
45
+ color: Optional[str] = None,
46
+ align: str = "right",
47
+ file=sys.stdout,
48
+ ):
49
+ cfg = spinner.value
50
+ self.frames = cfg["frames"]
51
+ self.interval = float(cfg["interval"]) / 1000.0 # type: ignore
52
+ self.text = text
53
+ self.color = color
54
+ if align not in ("left", "right"):
55
+ raise ValueError("align must be 'left' or 'right'")
56
+ self.align = align
57
+ self._write_file = file
58
+ # precompute clear length: max frame width + space + text length
59
+ max_frame = max(self.frames, key=lambda x: len(x)) # type: ignore
60
+ self.clear_len = len(self.main_text) + len(max_frame) + 1
61
+
62
+ self._stop_evt = threading.Event()
63
+ self._pause_evt = threading.Event()
64
+ self._thread = None
65
+ self._write_lock = threading.Lock()
66
+
67
+ @property
68
+ def main_text(self):
69
+ # if self.text is a callable then call it
70
+ if callable(self.text):
71
+ return self.text()
72
+ return self.text
73
+
74
+ def _spin(self):
75
+ for frame in itertools.cycle(self.frames):
76
+ if self._stop_evt.is_set():
77
+ break
78
+ if self._pause_evt.is_set():
79
+ time.sleep(0.05)
80
+ continue
81
+
82
+ # ---- Core logging critical section ----
83
+ with self._write_lock:
84
+ symbol = click.style(frame, fg=self.color) if self.color else frame
85
+ if self.align == "left":
86
+ msg = f"{symbol} {self.main_text}"
87
+ else:
88
+ msg = f"{self.main_text} {symbol}"
89
+
90
+ click.echo(msg, nl=False, file=self._write_file)
91
+ click.echo("\r", nl=False, file=self._write_file)
92
+ self._write_file.flush()
93
+ # ---- End of critical section ----
94
+ time.sleep(self.interval)
95
+ # clear the line when done
96
+ self._clear_line()
97
+
98
+ def _clear_line(self):
99
+ with self._write_lock:
100
+ click.echo(" " * self.clear_len, nl=False, file=self._write_file)
101
+ click.echo("\r", nl=False, file=self._write_file)
102
+ self._write_file.flush()
103
+
104
+ def start(self):
105
+ if self._thread and self._thread.is_alive():
106
+ return
107
+ self._stop_evt.clear()
108
+ self._pause_evt.clear()
109
+ self._thread = threading.Thread(target=self._spin, daemon=True)
110
+ self._thread.start()
111
+
112
+ def stop(self):
113
+ self._stop_evt.set()
114
+ if self._thread:
115
+ self._thread.join()
116
+
117
+ def log(self, *messages: str):
118
+ """Pause the spinner, emit a ✔ + message, then resume."""
119
+ self._pause_evt.set()
120
+ self._clear_line()
121
+ # ---- Core logging critical section ----
122
+ with self._write_lock:
123
+ self._write_file.flush()
124
+ for message in messages:
125
+ click.echo(f"{message}", file=self._write_file, nl=True)
126
+ self._write_file.flush()
127
+ # ---- End of critical section ----
128
+ self._pause_evt.clear()
129
+
130
+ def __enter__(self):
131
+ self.start()
132
+ return self
133
+
134
+ def __exit__(self, exc_type, exc, tb):
135
+ self.stop()
136
+
137
+
138
+ class SpinnerLogHandler(logging.Handler):
139
+ def __init__(self, spinner: MultiStepSpinner, *args, **kwargs):
140
+ super().__init__(*args, **kwargs)
141
+ self.spinner = spinner
142
+
143
+ def emit(self, record):
144
+ msg = self.format(record)
145
+ self.spinner.log(msg)
146
+
147
+
148
+ class MaximumRetriesExceeded(Exception):
149
+ def __init__(self, url, method, status_code, text):
150
+ self.url = url
151
+ self.method = method
152
+ self.status_code = status_code
153
+ self.text = text
154
+
155
+ def __str__(self):
156
+ return f"Maximum retries exceeded for {self.url}[{self.method}] {self.status_code} {self.text}"
157
+
158
+
159
+ class TODOException(Exception):
160
+ pass
161
+
162
+
163
+ requests_funcs = [
164
+ requests.get,
165
+ requests.post,
166
+ requests.put,
167
+ requests.delete,
168
+ requests.patch,
169
+ requests.head,
170
+ requests.options,
171
+ ]
172
+
173
+
174
+ def safe_requests_wrapper(
175
+ requests_module_fn,
176
+ *args,
177
+ conn_error_retries=2,
178
+ retryable_status_codes=[409],
179
+ logger_fn=None,
180
+ **kwargs,
181
+ ):
182
+ """
183
+ There are two categories of errors that we need to handle when dealing with any API server.
184
+ 1. HTTP errors. These are are errors that are returned from the API server.
185
+ - How to handle retries for this case will be application specific.
186
+ 2. Errors when the API server may not be reachable (DNS resolution / network issues)
187
+ - In this scenario, we know that something external to the API server is going wrong causing the issue.
188
+ - Failing prematurely in the case might not be the best course of action since critical user jobs might crash on intermittent issues.
189
+ - So in this case, we can just plainly retry the request.
190
+
191
+ This function handles the second case. It's a simple wrapper to handle the retry logic for connection errors.
192
+ If this function is provided a `conn_error_retries` of 5, then the last retry will have waited 32 seconds.
193
+ Generally this is a safe enough number of retries after which we can assume that something is really broken. Until then,
194
+ there can be intermittent issues that would resolve themselves if we retry gracefully.
195
+ """
196
+ if requests_module_fn not in requests_funcs:
197
+ raise ValueError(
198
+ f"safe_requests_wrapper doesn't support {requests_module_fn.__name__}. You can only use the following functions: {requests_funcs}"
199
+ )
200
+
201
+ _num_retries = 0
202
+ noise = random.uniform(-0.5, 0.5)
203
+ response = None
204
+ while _num_retries < conn_error_retries:
205
+ try:
206
+ response = requests_module_fn(*args, **kwargs)
207
+ if response.status_code not in retryable_status_codes:
208
+ return response
209
+ if CAPSULE_DEBUG:
210
+ if logger_fn:
211
+ logger_fn(
212
+ f"[outerbounds-debug] safe_requests_wrapper: {response.url}[{requests_module_fn.__name__}] {response.status_code} {response.text}",
213
+ )
214
+ else:
215
+ print(
216
+ f"[outerbounds-debug] safe_requests_wrapper: {response.url}[{requests_module_fn.__name__}] {response.status_code} {response.text}",
217
+ file=sys.stderr,
218
+ )
219
+ _num_retries += 1
220
+ time.sleep((2 ** (_num_retries + 1)) + noise)
221
+ except requests.exceptions.ConnectionError:
222
+ if _num_retries <= conn_error_retries - 1:
223
+ # Exponential backoff with 2^(_num_retries+1) seconds
224
+ time.sleep((2 ** (_num_retries + 1)) + noise)
225
+ _num_retries += 1
226
+ else:
227
+ raise
228
+ raise MaximumRetriesExceeded(
229
+ response.url,
230
+ requests_module_fn.__name__,
231
+ response.status_code,
232
+ response.text,
233
+ )
@@ -0,0 +1,17 @@
1
+ import os
2
+ from typing import List
3
+ from .app_config import AppConfig, AppConfigError
4
+ from .secrets import SecretRetriever, SecretNotFound
5
+
6
+
7
+ def secrets_validator(secrets: List[str]):
8
+ secret_retriever = SecretRetriever()
9
+ for secret in secrets:
10
+ try:
11
+ secret_retriever.get_secret_as_dict(secret)
12
+ except SecretNotFound:
13
+ raise Exception(f"Secret named `{secret}` not found")
14
+
15
+
16
+ def run_validations(app_config: AppConfig):
17
+ pass
@@ -1,12 +1,10 @@
1
- from metaflow.user_configs.config_decorators import (
2
- MutableFlow,
3
- MutableStep,
4
- CustomFlowDecorator,
5
- )
1
+ from metaflow.user_decorators.mutable_flow import MutableFlow
2
+ from metaflow.user_decorators.mutable_step import MutableStep
3
+ from metaflow.user_decorators.user_flow_decorator import FlowMutator
6
4
  from .assume_role import OBP_ASSUME_ROLE_ARN_ENV_VAR
7
5
 
8
6
 
9
- class assume_role(CustomFlowDecorator):
7
+ class assume_role(FlowMutator):
10
8
  """
11
9
  Flow-level decorator for assuming AWS IAM roles.
12
10
 
@@ -42,7 +40,7 @@ class assume_role(CustomFlowDecorator):
42
40
  "`role_arn` must be a valid AWS IAM role ARN starting with 'arn:aws:iam::'"
43
41
  )
44
42
 
45
- def evaluate(self, mutable_flow: MutableFlow) -> None:
43
+ def pre_mutate(self, mutable_flow: MutableFlow) -> None:
46
44
  """
47
45
  This method is called by Metaflow to apply the decorator to the flow.
48
46
  It sets up environment variables that will be used by the AWS client
@@ -51,14 +49,29 @@ class assume_role(CustomFlowDecorator):
51
49
  # Import environment decorator at runtime to avoid circular imports
52
50
  from metaflow import environment
53
51
 
52
+ def _swap_environment_variables(step: MutableStep, role_arn: str) -> None:
53
+ _step_has_env_set = True
54
+ _env_kwargs = {OBP_ASSUME_ROLE_ARN_ENV_VAR: role_arn}
55
+ for d in step.decorator_specs:
56
+ name, _, _, deco_kwargs = d
57
+ if name == "environment":
58
+ _env_kwargs.update(deco_kwargs["vars"])
59
+ _step_has_env_set = True
60
+
61
+ if _step_has_env_set:
62
+ # remove the environment decorator
63
+ step.remove_decorator("environment")
64
+
65
+ # add the environment decorator
66
+ step.add_decorator(
67
+ environment,
68
+ deco_kwargs=dict(vars=_env_kwargs),
69
+ )
70
+
54
71
  # Set the role ARN as an environment variable that will be picked up
55
72
  # by the get_aws_client function
56
73
  def _setup_role_assumption(step: MutableStep) -> None:
57
- # We'll inject the role assumption by adding an environment decorator
58
- # The role will be available through an environment variable
59
- step.add_decorator(
60
- environment, vars={OBP_ASSUME_ROLE_ARN_ENV_VAR: self.role_arn}
61
- )
74
+ _swap_environment_variables(step, self.role_arn)
62
75
 
63
76
  # Apply the role assumption setup to all steps in the flow
64
77
  for _, step in mutable_flow.steps:
@@ -1,12 +1,11 @@
1
- from metaflow.user_configs.config_decorators import (
2
- MutableFlow,
3
- MutableStep,
4
- CustomFlowDecorator,
5
- )
1
+ from metaflow.user_decorators.user_flow_decorator import FlowMutator
2
+ from metaflow.user_decorators.mutable_flow import MutableFlow
3
+ from metaflow.user_decorators.mutable_step import MutableStep
4
+ from .external_chckpt import _ExternalCheckpointFlowDeco
6
5
  import os
7
6
 
8
7
 
9
- class coreweave_checkpoints(CustomFlowDecorator):
8
+ class coreweave_checkpoints(_ExternalCheckpointFlowDeco):
10
9
 
11
10
  """
12
11
 
@@ -46,78 +45,14 @@ class coreweave_checkpoints(CustomFlowDecorator):
46
45
  super().__init__(*args, **kwargs)
47
46
 
48
47
  def init(self, *args, **kwargs):
49
- self.bucket_path = kwargs.get("bucket_path", None)
50
-
51
- self.secrets = kwargs.get("secrets", [])
52
- if self.bucket_path is None:
53
- raise ValueError(
54
- "`bucket_path` keyword argument is required for the coreweave_datastore"
55
- )
56
- if not self.bucket_path.startswith("s3://"):
57
- raise ValueError(
58
- "`bucket_path` must start with `s3://` for the coreweave_datastore"
59
- )
60
-
48
+ super().init(*args, **kwargs)
61
49
  self.coreweave_endpoint_url = f"https://cwobject.com"
62
- if self.secrets is None:
63
- raise ValueError(
64
- "`secrets` keyword argument is required for the coreweave_datastore"
65
- )
66
50
 
67
- def evaluate(self, mutable_flow: MutableFlow) -> None:
51
+ def pre_mutate(self, mutable_flow: MutableFlow) -> None:
68
52
  from metaflow import (
69
- checkpoint,
70
- model,
71
- huggingface_hub,
72
- secrets,
73
53
  with_artifact_store,
74
54
  )
75
55
 
76
- def _add_secrets(step: MutableStep) -> None:
77
- decos_to_add = []
78
- swapping_decos = {
79
- "huggingface_hub": huggingface_hub,
80
- "model": model,
81
- "checkpoint": checkpoint,
82
- }
83
- already_has_secrets = False
84
- secrets_present_in_deco = []
85
- for d in step.decorators:
86
- if d.name in swapping_decos:
87
- decos_to_add.append((d.name, d.attributes))
88
- elif d.name == "secrets":
89
- already_has_secrets = True
90
- secrets_present_in_deco.extend(d.attributes["sources"])
91
-
92
- # If the step aleady has secrets then take all the sources in
93
- # the secrets and add the addtional secrets to the existing secrets
94
- secrets_to_add = self.secrets
95
- if already_has_secrets:
96
- secrets_to_add.extend(secrets_present_in_deco)
97
-
98
- secrets_to_add = list(set(secrets_to_add))
99
-
100
- if len(decos_to_add) == 0:
101
- if already_has_secrets:
102
- step.remove_decorator("secrets")
103
-
104
- step.add_decorator(
105
- secrets,
106
- sources=secrets_to_add,
107
- )
108
- return
109
-
110
- for d, _ in decos_to_add:
111
- step.remove_decorator(d)
112
-
113
- step.add_decorator(
114
- secrets,
115
- sources=secrets_to_add,
116
- )
117
- for d, attrs in decos_to_add:
118
- _deco_to_add = swapping_decos[d]
119
- step.add_decorator(_deco_to_add, **attrs)
120
-
121
56
  def _coreweave_config():
122
57
  return {
123
58
  "root": self.bucket_path,
@@ -131,9 +66,6 @@ class coreweave_checkpoints(CustomFlowDecorator):
131
66
 
132
67
  mutable_flow.add_decorator(
133
68
  with_artifact_store,
134
- type="coreweave",
135
- config=_coreweave_config,
69
+ deco_kwargs=dict(type="coreweave", config=_coreweave_config),
136
70
  )
137
-
138
- for step_name, step in mutable_flow.steps:
139
- _add_secrets(step)
71
+ self._swap_secrets(mutable_flow)
@@ -0,0 +1,85 @@
1
+ from metaflow.user_decorators.user_flow_decorator import FlowMutator
2
+ from metaflow.user_decorators.mutable_flow import MutableFlow
3
+ from metaflow.user_decorators.mutable_step import MutableStep
4
+ import os
5
+
6
+
7
+ class _ExternalCheckpointFlowDeco(FlowMutator):
8
+ def init(self, *args, **kwargs):
9
+ self.bucket_path = kwargs.get("bucket_path", None)
10
+
11
+ self.secrets = kwargs.get("secrets", [])
12
+ if self.bucket_path is None:
13
+ raise ValueError(
14
+ "`bucket_path` keyword argument is required for the coreweave_datastore"
15
+ )
16
+ if not self.bucket_path.startswith("s3://"):
17
+ raise ValueError(
18
+ "`bucket_path` must start with `s3://` for the coreweave_datastore"
19
+ )
20
+ if self.secrets is None:
21
+ raise ValueError(
22
+ "`secrets` keyword argument is required for the coreweave_datastore"
23
+ )
24
+
25
+ def _swap_secrets(self, mutable_flow: MutableFlow) -> None:
26
+ from metaflow import (
27
+ checkpoint,
28
+ model,
29
+ huggingface_hub,
30
+ secrets,
31
+ with_artifact_store,
32
+ )
33
+
34
+ def _add_secrets(step: MutableStep) -> None:
35
+ decos_to_add = []
36
+ swapping_decos = {
37
+ "huggingface_hub": huggingface_hub,
38
+ "model": model,
39
+ "checkpoint": checkpoint,
40
+ }
41
+ already_has_secrets = False
42
+ secrets_present_in_deco = []
43
+ for d in step.decorator_specs:
44
+ name, _, _, deco_kwargs = d
45
+ if name in swapping_decos:
46
+ decos_to_add.append((name, deco_kwargs))
47
+ elif name == "secrets":
48
+ already_has_secrets = True
49
+ secrets_present_in_deco.extend(deco_kwargs["sources"])
50
+
51
+ # If the step aleady has secrets then take all the sources in
52
+ # the secrets and add the addtional secrets to the existing secrets
53
+ secrets_to_add = self.secrets
54
+ if already_has_secrets:
55
+ secrets_to_add.extend(secrets_present_in_deco)
56
+
57
+ secrets_to_add = list(set(secrets_to_add))
58
+
59
+ if len(decos_to_add) == 0:
60
+ if already_has_secrets:
61
+ step.remove_decorator("secrets")
62
+
63
+ step.add_decorator(
64
+ secrets,
65
+ deco_kwargs=dict(
66
+ sources=secrets_to_add,
67
+ ),
68
+ )
69
+ return
70
+
71
+ for d, _ in decos_to_add:
72
+ step.remove_decorator(d)
73
+
74
+ step.add_decorator(
75
+ secrets,
76
+ deco_kwargs=dict(
77
+ sources=secrets_to_add,
78
+ ),
79
+ )
80
+ for d, attrs in decos_to_add:
81
+ _deco_to_add = swapping_decos[d]
82
+ step.add_decorator(_deco_to_add, deco_kwargs=attrs)
83
+
84
+ for step_name, step in mutable_flow.steps:
85
+ _add_secrets(step)
@@ -1,14 +1,11 @@
1
- from metaflow.user_configs.config_decorators import (
2
- MutableFlow,
3
- MutableStep,
4
- CustomFlowDecorator,
5
- )
1
+ from metaflow.user_decorators.mutable_flow import MutableFlow
2
+ from .external_chckpt import _ExternalCheckpointFlowDeco
6
3
  import os
7
4
 
8
5
  NEBIUS_ENDPOINT_URL = "https://storage.eu-north1.nebius.cloud:443"
9
6
 
10
7
 
11
- class nebius_checkpoints(CustomFlowDecorator):
8
+ class nebius_checkpoints(_ExternalCheckpointFlowDeco):
12
9
 
13
10
  """
14
11
 
@@ -52,78 +49,14 @@ class nebius_checkpoints(CustomFlowDecorator):
52
49
  super().__init__(*args, **kwargs)
53
50
 
54
51
  def init(self, *args, **kwargs):
55
- self.bucket_path = kwargs.get("bucket_path", None)
56
-
57
- self.secrets = kwargs.get("secrets", [])
58
- if self.bucket_path is None:
59
- raise ValueError(
60
- "`bucket_path` keyword argument is required for the coreweave_datastore"
61
- )
62
- if not self.bucket_path.startswith("s3://"):
63
- raise ValueError(
64
- "`bucket_path` must start with `s3://` for the coreweave_datastore"
65
- )
66
-
52
+ super().init(*args, **kwargs)
67
53
  self.nebius_endpoint_url = kwargs.get("endpoint_url", NEBIUS_ENDPOINT_URL)
68
- if self.secrets is None:
69
- raise ValueError(
70
- "`secrets` keyword argument is required for the coreweave_datastore"
71
- )
72
54
 
73
- def evaluate(self, mutable_flow: MutableFlow) -> None:
55
+ def pre_mutate(self, mutable_flow: MutableFlow) -> None:
74
56
  from metaflow import (
75
- checkpoint,
76
- model,
77
- huggingface_hub,
78
- secrets,
79
57
  with_artifact_store,
80
58
  )
81
59
 
82
- def _add_secrets(step: MutableStep) -> None:
83
- decos_to_add = []
84
- swapping_decos = {
85
- "huggingface_hub": huggingface_hub,
86
- "model": model,
87
- "checkpoint": checkpoint,
88
- }
89
- already_has_secrets = False
90
- secrets_present_in_deco = []
91
- for d in step.decorators:
92
- if d.name in swapping_decos:
93
- decos_to_add.append((d.name, d.attributes))
94
- elif d.name == "secrets":
95
- already_has_secrets = True
96
- secrets_present_in_deco.extend(d.attributes["sources"])
97
-
98
- # If the step aleady has secrets then take all the sources in
99
- # the secrets and add the addtional secrets to the existing secrets
100
- secrets_to_add = self.secrets
101
- if already_has_secrets:
102
- secrets_to_add.extend(secrets_present_in_deco)
103
-
104
- secrets_to_add = list(set(secrets_to_add))
105
-
106
- if len(decos_to_add) == 0:
107
- if already_has_secrets:
108
- step.remove_decorator("secrets")
109
-
110
- step.add_decorator(
111
- secrets,
112
- sources=secrets_to_add,
113
- )
114
- return
115
-
116
- for d, _ in decos_to_add:
117
- step.remove_decorator(d)
118
-
119
- step.add_decorator(
120
- secrets,
121
- sources=secrets_to_add,
122
- )
123
- for d, attrs in decos_to_add:
124
- _deco_to_add = swapping_decos[d]
125
- step.add_decorator(_deco_to_add, **attrs)
126
-
127
60
  def _nebius_config():
128
61
  return {
129
62
  "root": self.bucket_path,
@@ -135,10 +68,6 @@ class nebius_checkpoints(CustomFlowDecorator):
135
68
  }
136
69
 
137
70
  mutable_flow.add_decorator(
138
- with_artifact_store,
139
- type="s3",
140
- config=_nebius_config,
71
+ with_artifact_store, deco_kwargs=dict(type="s3", config=_nebius_config)
141
72
  )
142
-
143
- for step_name, step in mutable_flow.steps:
144
- _add_secrets(step)
73
+ self._swap_secrets(mutable_flow)