ob-metaflow-extensions 1.1.84__tar.gz → 1.1.88__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (45) hide show
  1. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/PKG-INFO +1 -1
  2. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/config/__init__.py +7 -0
  3. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/__init__.py +125 -59
  4. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +92 -36
  5. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +14 -0
  6. ob-metaflow-extensions-1.1.88/metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +157 -0
  7. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +30 -6
  8. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +62 -1
  9. ob-metaflow-extensions-1.1.88/metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +53 -0
  10. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/PKG-INFO +1 -1
  11. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/SOURCES.txt +1 -0
  12. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/setup.py +1 -1
  13. ob-metaflow-extensions-1.1.84/metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +0 -10
  14. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/README.md +0 -0
  15. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/__init__.py +0 -0
  16. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/auth_server.py +0 -0
  17. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  18. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +0 -0
  19. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +0 -0
  20. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/kubernetes/__init__.py +0 -0
  21. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +0 -0
  22. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nim/__init__.py +0 -0
  23. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +0 -0
  24. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  25. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +0 -0
  26. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/perimeters.py +0 -0
  27. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  28. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +0 -0
  29. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +0 -0
  30. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +0 -0
  31. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +0 -0
  32. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +0 -0
  33. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +0 -0
  34. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +0 -0
  35. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/profilers/__init__.py +0 -0
  36. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/profilers/gpu.py +0 -0
  37. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/remote_config.py +0 -0
  38. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/__init__.py +0 -0
  39. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py +0 -0
  40. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py +0 -0
  41. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py +0 -0
  42. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/dependency_links.txt +0 -0
  43. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/requires.txt +0 -0
  44. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/top_level.txt +0 -0
  45. {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.84
3
+ Version: 1.1.88
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -14,6 +14,13 @@ DEFAULT_GCP_CLIENT_PROVIDER = "obp"
14
14
  FAST_BAKERY_URL = from_conf("FAST_BAKERY_URL", None)
15
15
 
16
16
 
17
+ ###
18
+ # NVCF configuration
19
+ ###
20
+ # Maximum number of consecutive heartbeats that can be missed.
21
+ NVIDIA_HEARTBEAT_THRESHOLD = from_conf("NVIDIA_HEARTBEAT_THRESHOLD", "3")
22
+
23
+
17
24
  ###
18
25
  # Snowpark configuration
19
26
  ###
@@ -32,6 +32,126 @@ def hide_access_keys(*args, **kwds):
32
32
  os.environ["AWS_SESSION_TOKEN"] = AWS_SESSION_TOKEN
33
33
 
34
34
 
35
+ # This is a special placeholder value that can be passed as role_arn to
36
+ # get_boto3_session() which makes it use the CSPR role, if its set.
37
+ USE_CSPR_ROLE_ARN_IF_SET = "__cspr__"
38
+
39
+
40
+ def get_boto3_session(role_arn=None, session_vars=None):
41
+ import boto3
42
+ import botocore
43
+ from metaflow_extensions.outerbounds.plugins.auth_server import get_token
44
+
45
+ from hashlib import sha256
46
+ from metaflow.util import get_username
47
+
48
+ user = get_username()
49
+
50
+ token_info = get_token("/generate/aws")
51
+
52
+ # Write token to a file. The file name is derived from the user name
53
+ # so it works with multiple users on the same machine.
54
+ #
55
+ # We hash the user name so we don't have to deal with special characters
56
+ # in the file name and the file name is not exposed to the user
57
+ # anyways, so it doesn't matter that its a little ugly.
58
+ token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
59
+
60
+ # Write to a temp file then rename to avoid a situation when someone
61
+ # tries to read the file after it was open for writing (and truncated)
62
+ # but before the token was written to it.
63
+ with tempfile.NamedTemporaryFile("w", delete=False) as f:
64
+ f.write(token_info["token"])
65
+ tmp_token_file = f.name
66
+ os.rename(tmp_token_file, token_file)
67
+
68
+ cspr_role = None
69
+ if token_info.get("cspr_role_arn"):
70
+ cspr_role = token_info["cspr_role_arn"]
71
+
72
+ if cspr_role:
73
+ # If CSPR role is set, we set it as the default role to assume
74
+ # for the AWS SDK. We do this by writing an AWS config file
75
+ # with two profiles. One to get credentials for the task role
76
+ # in exchange for the OIDC token, and second to assume the
77
+ # CSPR role using the task role credentials.
78
+ import configparser
79
+ from io import StringIO
80
+
81
+ aws_config = configparser.ConfigParser()
82
+
83
+ # Task role profile
84
+ aws_config["profile task"] = {
85
+ "role_arn": token_info["role_arn"],
86
+ "web_identity_token_file": token_file,
87
+ }
88
+
89
+ # CSPR role profile (default)
90
+ aws_config["profile cspr"] = {
91
+ "role_arn": cspr_role,
92
+ "source_profile": "task",
93
+ }
94
+
95
+ aws_config_string = StringIO()
96
+ aws_config.write(aws_config_string)
97
+ aws_config_file = (
98
+ "/tmp/aws_config." + sha256(user.encode("utf-8")).hexdigest()[:16]
99
+ )
100
+ with tempfile.NamedTemporaryFile(
101
+ "w", delete=False, dir=os.path.dirname(aws_config_file)
102
+ ) as f:
103
+ f.write(aws_config_string.getvalue())
104
+ tmp_aws_config_file = f.name
105
+ os.rename(tmp_aws_config_file, aws_config_file)
106
+ os.environ["AWS_CONFIG_FILE"] = aws_config_file
107
+ os.environ["AWS_PROFILE"] = "cspr"
108
+ else:
109
+ os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
110
+ os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
111
+
112
+ # Enable regional STS endpoints. This is the new recommended way
113
+ # by AWS [1] and is the more performant way.
114
+ # [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
115
+ os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
116
+ if token_info.get("region"):
117
+ os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
118
+
119
+ with hide_access_keys():
120
+ if cspr_role:
121
+ # The generated AWS config will be used here since we set the
122
+ # AWS_CONFIG_FILE environment variable above.
123
+ if role_arn == USE_CSPR_ROLE_ARN_IF_SET:
124
+ # Otherwise start from the default profile, assuming CSPR role
125
+ session = boto3.session.Session(profile_name="cspr")
126
+ else:
127
+ session = boto3.session.Session(profile_name="task")
128
+ else:
129
+ # Not using AWS config, just AWS_WEB_IDENTITY_TOKEN_FILE + AWS_ROLE_ARN
130
+ session = boto3.session.Session()
131
+
132
+ if role_arn and role_arn != USE_CSPR_ROLE_ARN_IF_SET:
133
+ # If the user provided a role_arn, we assume that role
134
+ # using the task role credentials. CSPR role is not used.
135
+ fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
136
+ client_creator=session._session.create_client,
137
+ source_credentials=session._session.get_credentials(),
138
+ role_arn=role_arn,
139
+ extra_args={},
140
+ )
141
+ creds = botocore.credentials.DeferredRefreshableCredentials(
142
+ method="assume-role", refresh_using=fetcher.fetch_credentials
143
+ )
144
+ botocore_session = botocore.session.Session(session_vars=session_vars)
145
+ botocore_session._credentials = creds
146
+ return boto3.session.Session(botocore_session=botocore_session)
147
+ else:
148
+ # If the user didn't provide a role_arn, or if the role_arn
149
+ # is set to USE_CSPR_ROLE_ARN_IF_SET, we return the default
150
+ # session which would use the CSPR role if it is set on the
151
+ # server, and the task role otherwise.
152
+ return session
153
+
154
+
35
155
  class ObpAuthProvider(object):
36
156
  name = "obp"
37
157
 
@@ -42,67 +162,13 @@ class ObpAuthProvider(object):
42
162
  if client_params is None:
43
163
  client_params = {}
44
164
 
45
- import boto3
46
- import botocore
47
165
  from botocore.exceptions import ClientError
48
- from metaflow_extensions.outerbounds.plugins.auth_server import get_token
49
-
50
- from hashlib import sha256
51
- from metaflow.util import get_username
52
-
53
- user = get_username()
54
-
55
- token_info = get_token("/generate/aws")
56
166
 
57
- # Write token to a file. The file name is derived from the user name
58
- # so it works with multiple users on the same machine.
59
- #
60
- # We hash the user name so we don't have to deal with special characters
61
- # in the file name and the file name is not exposed to the user
62
- # anyways, so it doesn't matter that its a little ugly.
63
- token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
64
-
65
- # Write to a temp file then rename to avoid a situation when someone
66
- # tries to read the file after it was open for writing (and truncated)
67
- # but before the token was written to it.
68
- with tempfile.NamedTemporaryFile("w", delete=False) as f:
69
- f.write(token_info["token"])
70
- tmp_token_file = f.name
71
- os.rename(tmp_token_file, token_file)
72
-
73
- os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
74
- os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
75
-
76
- # Enable regional STS endpoints. This is the new recommended way
77
- # by AWS [1] and is the more performant way.
78
- # [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
79
- os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
80
- if token_info.get("region"):
81
- os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
82
-
83
- with hide_access_keys():
84
- if role_arn:
85
- session = boto3.session.Session()
86
- fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
87
- client_creator=session._session.create_client,
88
- source_credentials=session._session.get_credentials(),
89
- role_arn=role_arn,
90
- extra_args={},
91
- )
92
- creds = botocore.credentials.DeferredRefreshableCredentials(
93
- method="assume-role", refresh_using=fetcher.fetch_credentials
94
- )
95
- botocore_session = botocore.session.Session(session_vars=session_vars)
96
- botocore_session._credentials = creds
97
- session = boto3.session.Session(botocore_session=botocore_session)
98
- if with_error:
99
- return session.client(module, **client_params), ClientError
100
- else:
101
- return session.client(module, **client_params)
102
- if with_error:
103
- return boto3.client(module, **client_params), ClientError
104
- else:
105
- return boto3.client(module, **client_params)
167
+ session = get_boto3_session(role_arn, session_vars)
168
+ if with_error:
169
+ return session.client(module, **client_params), ClientError
170
+ else:
171
+ return session.client(module, **client_params)
106
172
 
107
173
 
108
174
  AWS_CLIENT_PROVIDERS_DESC = [("obp", ".ObpAuthProvider")]
@@ -1,30 +1,30 @@
1
1
  import hashlib
2
2
  import json
3
3
  import os
4
-
4
+ import threading
5
+ import time
6
+ import uuid
5
7
  from concurrent.futures import ThreadPoolExecutor
6
8
  from typing import Dict
9
+
7
10
  from metaflow.exception import MetaflowException
8
- from metaflow.metaflow_config import (
9
- FAST_BAKERY_URL,
10
- get_pinned_conda_libs,
11
- )
11
+ from metaflow.metaflow_config import FAST_BAKERY_URL, get_pinned_conda_libs
12
12
  from metaflow.metaflow_environment import MetaflowEnvironment
13
- from metaflow.plugins.pypi.conda_environment import CondaEnvironment
14
- from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
15
13
  from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator
16
14
  from metaflow.plugins.kubernetes.kubernetes_decorator import KubernetesDecorator
17
15
  from metaflow.plugins.pypi.conda_decorator import CondaStepDecorator
16
+ from metaflow.plugins.pypi.conda_environment import CondaEnvironment
18
17
  from metaflow.plugins.pypi.pypi_decorator import PyPIStepDecorator
19
18
 
19
+ from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
20
+
20
21
  BAKERY_METAFILE = ".imagebakery-cache"
21
22
 
23
+ import fcntl
22
24
  import json
23
25
  import os
24
- import fcntl
25
- from functools import wraps
26
26
  from concurrent.futures import ThreadPoolExecutor
27
-
27
+ from functools import wraps
28
28
 
29
29
  # TODO - ensure that both @conda/@pypi are not assigned to the same step
30
30
 
@@ -36,6 +36,9 @@ def cache_request(cache_file):
36
36
  call_args = kwargs.copy()
37
37
  call_args.update(zip(func.__code__.co_varnames, args))
38
38
  call_args.pop("self", None)
39
+ call_args.pop("ref", None)
40
+ # invalidate cache when moving from one deployment to another
41
+ call_args.update({"fast_bakery_url": FAST_BAKERY_URL})
39
42
  cache_key = hashlib.md5(
40
43
  json.dumps(call_args, sort_keys=True).encode("utf-8")
41
44
  ).hexdigest()
@@ -79,7 +82,7 @@ def cache_request(cache_file):
79
82
 
80
83
 
81
84
  class DockerEnvironmentException(MetaflowException):
82
- headline = "Ran into an error while setting up the environment"
85
+ headline = "Ran into an error while baking image"
83
86
 
84
87
  def __init__(self, msg):
85
88
  super(DockerEnvironmentException, self).__init__(msg)
@@ -93,8 +96,8 @@ class DockerEnvironment(MetaflowEnvironment):
93
96
  self.skipped_steps = set()
94
97
  self.flow = flow
95
98
 
96
- self.bakery = FastBakery(url=FAST_BAKERY_URL)
97
99
  self.results = {}
100
+ self.images_baked = 0
98
101
 
99
102
  def set_local_root(self, local_root):
100
103
  self.local_root = local_root
@@ -102,15 +105,31 @@ class DockerEnvironment(MetaflowEnvironment):
102
105
  def decospecs(self):
103
106
  return ("conda", "fast_bakery_internal") + super().decospecs()
104
107
 
105
- def validate_environment(self, echo, datastore_type):
108
+ def validate_environment(self, logger, datastore_type):
106
109
  self.datastore_type = datastore_type
107
- self.echo = echo
110
+ self.logger = logger
108
111
 
109
112
  # Avoiding circular imports.
110
113
  from metaflow.plugins import DATASTORES
111
114
 
112
115
  self.datastore = [d for d in DATASTORES if d.TYPE == self.datastore_type][0]
113
116
 
117
+ # Mixing @pypi/@conda in a single step is not supported yet
118
+ for step in self.flow:
119
+ if (
120
+ sum(
121
+ 1
122
+ for deco in step.decorators
123
+ if isinstance(deco, (PyPIStepDecorator, CondaStepDecorator))
124
+ )
125
+ > 1
126
+ ):
127
+ raise MetaflowException(
128
+ "Mixing and matching PyPI packages and Conda packages within a\n"
129
+ "step is not yet supported. Use one of @pypi or @conda only for the *%s* step."
130
+ % step.name
131
+ )
132
+
114
133
  def init_environment(self, echo):
115
134
  self.skipped_steps = {
116
135
  step.name
@@ -125,14 +144,21 @@ class DockerEnvironment(MetaflowEnvironment):
125
144
  step for step in self.flow if step.name not in self.skipped_steps
126
145
  ]
127
146
  if steps_to_bake:
128
- echo("Baking container image(s) ...")
129
- self.results = self._bake(steps_to_bake, echo)
147
+ self.logger("🚀 Baking container image(s) ...")
148
+ start_time = time.time()
149
+ self.results = self._bake(steps_to_bake)
130
150
  for step in self.flow:
131
151
  for d in step.decorators:
132
152
  if isinstance(d, (BatchDecorator, KubernetesDecorator)):
133
153
  d.attributes["image"] = self.results[step.name].container_image
134
154
  d.attributes["executable"] = self.results[step.name].python_path
135
- echo("Container image(s) baked!")
155
+ if self.images_baked > 0:
156
+ bake_time = time.time() - start_time
157
+ self.logger(
158
+ f"🎉 All container image(s) baked in {bake_time:.2f} seconds!"
159
+ )
160
+ else:
161
+ self.logger("🎉 All container image(s) baked!")
136
162
 
137
163
  if self.skipped_steps:
138
164
  self.delegate = CondaEnvironment(self.flow)
@@ -140,29 +166,54 @@ class DockerEnvironment(MetaflowEnvironment):
140
166
  self.delegate.validate_environment(echo, self.datastore_type)
141
167
  self.delegate.init_environment(echo, self.skipped_steps)
142
168
 
143
- def _bake(self, steps, echo) -> Dict[str, FastBakeryApiResponse]:
169
+ def _bake(self, steps) -> Dict[str, FastBakeryApiResponse]:
144
170
  metafile_path = get_fastbakery_metafile_path(self.local_root, self.flow.name)
171
+ logger_lock = threading.Lock()
145
172
 
146
173
  @cache_request(metafile_path)
147
174
  def _cached_bake(
148
- python=None, pypi_packages=None, conda_packages=None, base_image=None
175
+ ref=None,
176
+ python=None,
177
+ pypi_packages=None,
178
+ conda_packages=None,
179
+ base_image=None,
149
180
  ):
150
- self.bakery._reset_payload()
151
- self.bakery.python_version(python)
152
- self.bakery.pypi_packages(pypi_packages)
153
- self.bakery.conda_packages(conda_packages)
154
- self.bakery.base_image(base_image)
155
- # self.bakery.ignore_cache()
181
+ bakery = FastBakery(url=FAST_BAKERY_URL)
182
+ bakery._reset_payload()
183
+ bakery.python_version(python)
184
+ bakery.pypi_packages(pypi_packages)
185
+ bakery.conda_packages(conda_packages)
186
+ bakery.base_image(base_image)
187
+ # bakery.ignore_cache()
188
+
189
+ with logger_lock:
190
+ self.logger(f"🍳 Baking [{ref}] ...")
191
+ self.logger(f" 🐍 Python: {python}")
192
+
193
+ if pypi_packages:
194
+ self.logger(f" 📦 PyPI packages:")
195
+ for package, version in pypi_packages.items():
196
+ self.logger(f" 🔧 {package}: {version}")
197
+
198
+ if conda_packages:
199
+ self.logger(f" 📦 Conda packages:")
200
+ for package, version in conda_packages.items():
201
+ self.logger(f" 🔧 {package}: {version}")
202
+
203
+ self.logger(f" 🏗️ Base image: {base_image}")
204
+
205
+ start_time = time.time()
156
206
  try:
157
- res = self.bakery.bake()
158
- if res.baking_stats:
159
- echo(
160
- "baked image in: %s milliseconds"
161
- % res.baking_stats.solver_stats.duration_ms
162
- )
207
+ res = bakery.bake()
208
+ # TODO: Get actual bake time from bakery
209
+ bake_time = time.time() - start_time
210
+
211
+ with logger_lock:
212
+ self.logger(f"🏁 Baked [{ref}] in {bake_time:.2f} seconds!")
213
+ self.images_baked += 1
163
214
  return res
164
215
  except FastBakeryException as ex:
165
- raise DockerEnvironmentException(str(ex))
216
+ raise DockerEnvironmentException(f"Bake [{ref}] failed: {str(ex)}")
166
217
 
167
218
  def prepare_step(step):
168
219
  base_image = next(
@@ -216,10 +267,15 @@ class DockerEnvironment(MetaflowEnvironment):
216
267
  }
217
268
 
218
269
  with ThreadPoolExecutor() as executor:
219
- return {
220
- step.name: _cached_bake(**args)
221
- for step, args in zip(steps, executor.map(prepare_step, steps))
222
- }
270
+ prepared_args = list(executor.map(prepare_step, steps))
271
+ for i, args in enumerate(prepared_args, 1):
272
+ args["ref"] = f"#{i:02d}"
273
+ futures = [executor.submit(_cached_bake, **args) for args in prepared_args]
274
+ results = {}
275
+ for step, future in zip(steps, futures):
276
+ results[step.name] = future.result()
277
+
278
+ return results
223
279
 
224
280
  def executable(self, step_name, default=None):
225
281
  if step_name in self.skipped_steps:
@@ -1,5 +1,6 @@
1
1
  from typing import Dict, Optional
2
2
  import requests
3
+ import time
3
4
 
4
5
 
5
6
  class FastBakeryException(Exception):
@@ -140,6 +141,19 @@ class FastBakery:
140
141
  headers = {**self.headers, **(SERVICE_HEADERS or {})}
141
142
  except ImportError:
142
143
  headers = self.headers
144
+
145
+ retryable_status_codes = [409]
146
+
147
+ for attempt in range(2): # 0 = initial attempt, 1-2 = retries
148
+ response = requests.post(self.url, json=payload, headers=headers)
149
+
150
+ if response.status_code not in retryable_status_codes:
151
+ break
152
+
153
+ if attempt < 2: # Don't sleep after the last attempt
154
+ sleep_time = 0.5 * (attempt + 1)
155
+ time.sleep(sleep_time)
156
+
143
157
  response = requests.post(self.url, json=payload, headers=headers)
144
158
  self._handle_error_response(response)
145
159
  return FastBakeryApiResponse(response.json())
@@ -0,0 +1,157 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ from io import BytesIO
6
+ from datetime import datetime, timezone
7
+
8
+ from metaflow.exception import MetaflowException
9
+
10
+
11
+ class HeartbeatStore(object):
12
+ def __init__(
13
+ self,
14
+ heartbeat_prefix,
15
+ main_pid=None,
16
+ storage_backend=None,
17
+ emit_frequency=30,
18
+ missed_heartbeat_timeout=60,
19
+ monitor_frequency=15,
20
+ max_missed_heartbeats=3,
21
+ ) -> None:
22
+ self.heartbeat_prefix = heartbeat_prefix
23
+ self.main_pid = main_pid
24
+ self.storage_backend = storage_backend
25
+ self.emit_frequency = emit_frequency
26
+ self.monitor_frequency = monitor_frequency
27
+ self.missed_heartbeat_timeout = missed_heartbeat_timeout
28
+ self.max_missed_heartbeats = max_missed_heartbeats
29
+ self.missed_heartbeats = 0
30
+
31
+ def emit_heartbeat(self, folder_name=None):
32
+ heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
33
+ if folder_name:
34
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
35
+
36
+ while True:
37
+ try:
38
+ epoch_string = str(datetime.now(timezone.utc).timestamp()).encode(
39
+ "utf-8"
40
+ )
41
+ self.storage_backend.save_bytes(
42
+ [(heartbeat_key, BytesIO(bytes(epoch_string)))], overwrite=True
43
+ )
44
+ except Exception as e:
45
+ print(f"Error writing heartbeat: {e}")
46
+ sys.exit(1)
47
+
48
+ time.sleep(self.emit_frequency)
49
+
50
+ def emit_tombstone(self, folder_name=None):
51
+ tombstone_key = f"{self.heartbeat_prefix}/tombstone"
52
+ if folder_name:
53
+ tombstone_key = f"{folder_name}/{tombstone_key}"
54
+
55
+ tombstone_string = "tombstone".encode("utf-8")
56
+ try:
57
+ self.storage_backend.save_bytes(
58
+ [(tombstone_key, BytesIO(bytes(tombstone_string)))], overwrite=True
59
+ )
60
+ except Exception as e:
61
+ print(f"Error writing tombstone: {e}")
62
+ sys.exit(1)
63
+
64
+ def __handle_tombstone(self, path):
65
+ if path is not None:
66
+ with open(path) as f:
67
+ contents = f.read()
68
+ if "tombstone" in contents:
69
+ print("[Outerbounds] Tombstone detected. Terminating the task..")
70
+ os.kill(self.main_pid, signal.SIGTERM)
71
+ sys.exit(1)
72
+
73
+ def __handle_heartbeat(self, path):
74
+ if path is not None:
75
+ with open(path) as f:
76
+ contents = f.read()
77
+ current_timestamp = datetime.now(timezone.utc).timestamp()
78
+ if current_timestamp - float(contents) > self.missed_heartbeat_timeout:
79
+ self.missed_heartbeats += 1
80
+ else:
81
+ self.missed_heartbeats = 0
82
+ else:
83
+ self.missed_heartbeats += 1
84
+
85
+ if self.missed_heartbeats > self.max_missed_heartbeats:
86
+ print(
87
+ f"[Outerbounds] Missed {self.max_missed_heartbeats} consecutive heartbeats. Terminating the task.."
88
+ )
89
+ os.kill(self.main_pid, signal.SIGTERM)
90
+ sys.exit(1)
91
+
92
+ def is_main_process_running(self):
93
+ try:
94
+ # Check if the process is running
95
+ os.kill(self.main_pid, 0)
96
+ except ProcessLookupError:
97
+ return False
98
+ return True
99
+
100
+ def monitor(self, folder_name=None):
101
+ heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
102
+ if folder_name:
103
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
104
+
105
+ tombstone_key = f"{self.heartbeat_prefix}/tombstone"
106
+ if folder_name:
107
+ tombstone_key = f"{folder_name}/{tombstone_key}"
108
+
109
+ while self.is_main_process_running():
110
+ with self.storage_backend.load_bytes(
111
+ [heartbeat_key, tombstone_key]
112
+ ) as results:
113
+ for key, path, _ in results:
114
+ if key == tombstone_key:
115
+ self.__handle_tombstone(path)
116
+ elif key == heartbeat_key:
117
+ self.__handle_heartbeat(path)
118
+
119
+ time.sleep(self.monitor_frequency)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from metaflow.plugins import DATASTORES
124
+ from metaflow.metaflow_config import NVIDIA_HEARTBEAT_THRESHOLD
125
+
126
+ if len(sys.argv) != 4:
127
+ print("Usage: heartbeat_store.py <main_pid> <datastore_type> <folder_name>")
128
+ sys.exit(1)
129
+ _, main_pid, datastore_type, folder_name = sys.argv
130
+
131
+ if datastore_type not in ("azure", "gs", "s3"):
132
+ print(f"Datastore unsupported for type: {datastore_type}")
133
+ sys.exit(1)
134
+
135
+ datastores = [d for d in DATASTORES if d.TYPE == datastore_type]
136
+ datastore_sysroot = datastores[0].get_datastore_root_from_config(
137
+ lambda *args, **kwargs: None
138
+ )
139
+ if datastore_sysroot is None:
140
+ raise MetaflowException(
141
+ msg="METAFLOW_DATASTORE_SYSROOT_{datastore_type} must be set!".format(
142
+ datastore_type=datastore_type.upper()
143
+ )
144
+ )
145
+
146
+ storage = datastores[0](datastore_sysroot)
147
+
148
+ heartbeat_prefix = f"{os.getenv('MF_PATHSPEC')}/{os.getenv('MF_ATTEMPT')}"
149
+
150
+ store = HeartbeatStore(
151
+ heartbeat_prefix=heartbeat_prefix,
152
+ main_pid=int(main_pid),
153
+ storage_backend=storage,
154
+ max_missed_heartbeats=int(NVIDIA_HEARTBEAT_THRESHOLD),
155
+ )
156
+
157
+ store.monitor(folder_name=folder_name)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import os
3
- import time
3
+ import threading
4
+ from urllib.parse import urlparse
4
5
  from urllib.request import HTTPError, Request, URLError, urlopen
5
6
 
6
7
  from metaflow import util
@@ -14,6 +15,7 @@ from metaflow.mflog import (
14
15
  )
15
16
  import requests
16
17
  from metaflow.metaflow_config_funcs import init_config
18
+ from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import HeartbeatStore
17
19
 
18
20
 
19
21
  class NvcfException(MetaflowException):
@@ -58,10 +60,11 @@ class Nvcf(object):
58
60
  code_package_url, code_package_ds
59
61
  )
60
62
  init_expr = " && ".join(init_cmds)
63
+ heartbeat_expr = f'python -m metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store "$MAIN_PID" {code_package_ds} nvcf_heartbeats 1>> $MFLOG_STDOUT 2>> $MFLOG_STDERR'
61
64
  step_expr = bash_capture_logs(
62
65
  " && ".join(
63
66
  self.environment.bootstrap_commands(step_name, code_package_ds)
64
- + [step_cli]
67
+ + [step_cli + " & MAIN_PID=$!; " + heartbeat_expr]
65
68
  )
66
69
  )
67
70
 
@@ -87,7 +90,9 @@ class Nvcf(object):
87
90
  '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
88
91
  % cmd_str
89
92
  )
90
- self.job = Job('bash -c "%s"' % cmd_str, env)
93
+ self.job = Job(
94
+ 'bash -c "%s"' % cmd_str, env, task_spec, self.datastore._storage_impl
95
+ )
91
96
  self.job.submit()
92
97
 
93
98
  def wait(self, stdout_location, stderr_location, echo=None):
@@ -137,8 +142,7 @@ result_endpoint = f"{nvcf_url}/v2/nvcf/pexec/status"
137
142
 
138
143
 
139
144
  class Job(object):
140
- def __init__(self, command, env):
141
-
145
+ def __init__(self, command, env, task_spec, backend):
142
146
  self._payload = {
143
147
  "command": command,
144
148
  "env": {k: v for k, v in env.items() if v is not None},
@@ -170,6 +174,27 @@ class Job(object):
170
174
  if f["model_key"] == "metaflow_task_executor":
171
175
  self._function_id = f["id"]
172
176
 
177
+ flow_name = task_spec.get("flow_name")
178
+ run_id = task_spec.get("run_id")
179
+ step_name = task_spec.get("step_name")
180
+ task_id = task_spec.get("task_id")
181
+ retry_count = task_spec.get("retry_count")
182
+
183
+ heartbeat_prefix = "/".join(
184
+ (flow_name, str(run_id), step_name, str(task_id), str(retry_count))
185
+ )
186
+
187
+ store = HeartbeatStore(
188
+ heartbeat_prefix=heartbeat_prefix,
189
+ main_pid=None,
190
+ storage_backend=backend,
191
+ )
192
+
193
+ self.heartbeat_thread = threading.Thread(
194
+ target=store.emit_heartbeat, args=("nvcf_heartbeats",), daemon=True
195
+ )
196
+ self.heartbeat_thread.start()
197
+
173
198
  def submit(self):
174
199
  try:
175
200
  headers = {
@@ -224,7 +249,6 @@ class Job(object):
224
249
 
225
250
  def _poll(self):
226
251
  try:
227
- invocation_id = self._invocation_id
228
252
  headers = {
229
253
  "Authorization": f"Bearer {self._ngc_api_key}",
230
254
  "Content-Type": "application/json",
@@ -4,7 +4,7 @@ import sys
4
4
  import time
5
5
  import traceback
6
6
 
7
- from metaflow import util
7
+ from metaflow import util, Run
8
8
  from metaflow._vendor import click
9
9
  from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
10
10
  from metaflow.metadata.util import sync_local_metadata_from_datastore
@@ -27,6 +27,7 @@ from metaflow.metaflow_config import (
27
27
  CARD_GSROOT,
28
28
  KUBERNETES_SANDBOX_INIT_SCRIPT,
29
29
  OTEL_ENDPOINT,
30
+ NVIDIA_HEARTBEAT_THRESHOLD,
30
31
  )
31
32
  from metaflow.mflog import TASK_LOG_SOURCE
32
33
 
@@ -43,6 +44,65 @@ def nvcf():
43
44
  pass
44
45
 
45
46
 
47
+ @nvcf.command(help="List steps / tasks running as an NVCF job.")
48
+ @click.argument("run-id")
49
+ @click.pass_context
50
+ def list(ctx, run_id):
51
+ flow_name = ctx.obj.flow.name
52
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
53
+ running_invocations = []
54
+ for each_step in run_obj:
55
+ if (
56
+ not each_step.task.finished
57
+ and "nvcf-function-id" in each_step.task.metadata_dict
58
+ ):
59
+ task_pathspec = each_step.task.pathspec
60
+ attempt = each_step.task.metadata_dict.get("attempt")
61
+ flow_name, run_id, step_name, task_id = task_pathspec.split("/")
62
+ running_invocations.append(
63
+ f"Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt}"
64
+ )
65
+
66
+ if running_invocations:
67
+ for each_invocation in running_invocations:
68
+ print(each_invocation)
69
+ else:
70
+ print("No running NVCF invocations for Run ID: %s" % run_id)
71
+
72
+
73
+ @nvcf.command(help="Kill steps / tasks running as an NVCF job.")
74
+ @click.argument("run-id")
75
+ @click.pass_context
76
+ def kill(ctx, run_id):
77
+ from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import (
78
+ HeartbeatStore,
79
+ )
80
+
81
+ flow_name = ctx.obj.flow.name
82
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
83
+
84
+ for each_step in run_obj:
85
+ if (
86
+ not each_step.task.finished
87
+ and "nvcf-function-id" in each_step.task.metadata_dict
88
+ ):
89
+ task_pathspec = each_step.task.pathspec
90
+ attempt = each_step.task.metadata_dict.get("attempt")
91
+ heartbeat_prefix = "{task_pathspec}/{attempt}".format(
92
+ task_pathspec=task_pathspec, attempt=attempt
93
+ )
94
+
95
+ datastore_root = ctx.obj.datastore_impl.datastore_root
96
+ store = HeartbeatStore(
97
+ heartbeat_prefix=heartbeat_prefix,
98
+ main_pid=None,
99
+ storage_backend=ctx.obj.datastore_impl(datastore_root),
100
+ )
101
+ store.emit_tombstone(folder_name="nvcf_heartbeats")
102
+ else:
103
+ print("No running NVCF invocations for Run ID: %s" % run_id)
104
+
105
+
46
106
  @nvcf.command(
47
107
  help="Execute a single task using NVCF. This command calls the "
48
108
  "top-level step command inside a NVCF job with the given options. "
@@ -139,6 +199,7 @@ def step(ctx, step_name, code_package_sha, code_package_url, **kwargs):
139
199
  "METAFLOW_CARD_GSROOT": CARD_GSROOT,
140
200
  "METAFLOW_INIT_SCRIPT": KUBERNETES_SANDBOX_INIT_SCRIPT,
141
201
  "METAFLOW_OTEL_ENDPOINT": OTEL_ENDPOINT,
202
+ "METAFLOW_NVIDIA_HEARTBEAT_THRESHOLD": str(NVIDIA_HEARTBEAT_THRESHOLD),
142
203
  }
143
204
 
144
205
  env_deco = [deco for deco in node.decorators if deco.name == "environment"]
@@ -0,0 +1,53 @@
1
+ # These two fields will show up within `metaflow_version` task metadata.
2
+ # Setting to major version of ob-metaflow-extensions only, so we don't keep trying
3
+ # (and failing) to keep this in sync with setup.py
4
+ # E.g. "2.7.22.1+ob(v1)"
5
+ __version__ = "v1"
6
+ __mf_extensions__ = "ob"
7
+
8
+
9
+ # Must match the signature of metaflow.plugins.aws.aws_client.get_aws_client
10
+ # This function is called by the "userland" code inside tasks. Metaflow internals
11
+ # will call the function in metaflow.plugins.aws.aws_client.get_aws_client directly.
12
+ #
13
+ # Unlike the original function, this wrapper will use the CSPR role if both of the following
14
+ # conditions are met:
15
+ #
16
+ # 1. CSPR is set
17
+ # 2. user didn't provide a role to assume explicitly.
18
+ #
19
+ def get_aws_client(
20
+ module, with_error=False, role_arn=None, session_vars=None, client_params=None
21
+ ):
22
+ import metaflow.plugins.aws.aws_client
23
+
24
+ from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
25
+
26
+ return metaflow.plugins.aws.aws_client.get_aws_client(
27
+ module,
28
+ with_error=with_error,
29
+ role_arn=role_arn or USE_CSPR_ROLE_ARN_IF_SET,
30
+ session_vars=session_vars,
31
+ client_params=client_params,
32
+ )
33
+
34
+
35
+ # This should match the signature of metaflow.plugins.datatools.s3.S3.
36
+ #
37
+ # This assumes that "userland" code inside tasks will call this, while Metaflow
38
+ # internals will call metaflow.plugins.datatools.s3.S3 directly.
39
+ #
40
+ # This wrapper will make S3() use the CSPR role if its set, and user didn't provide
41
+ # a role to assume explicitly.
42
+ def S3(*args, **kwargs):
43
+ import sys
44
+ import metaflow.plugins.datatools.s3
45
+ from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
46
+
47
+ if "role" not in kwargs or kwargs["role"] is None:
48
+ kwargs["role"] = USE_CSPR_ROLE_ARN_IF_SET
49
+
50
+ return metaflow.plugins.datatools.s3.S3(*args, **kwargs)
51
+
52
+
53
+ from .. import profilers
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.84
3
+ Version: 1.1.88
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -16,6 +16,7 @@ metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py
16
16
  metaflow_extensions/outerbounds/plugins/nim/__init__.py
17
17
  metaflow_extensions/outerbounds/plugins/nim/nim_manager.py
18
18
  metaflow_extensions/outerbounds/plugins/nvcf/__init__.py
19
+ metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py
19
20
  metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py
20
21
  metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py
21
22
  metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_namespace_packages
2
2
  from pathlib import Path
3
3
 
4
4
 
5
- version = "1.1.84"
5
+ version = "1.1.88"
6
6
  this_directory = Path(__file__).parent
7
7
  long_description = (this_directory / "README.md").read_text()
8
8
 
@@ -1,10 +0,0 @@
1
- # These two fields will show up within `metaflow_version` task metadata.
2
- # Setting to major version of ob-metaflow-extensions only, so we don't keep trying
3
- # (and failing) to keep this in sync with setup.py
4
- # E.g. "2.7.22.1+ob(v1)"
5
- __version__ = "v1"
6
- __mf_extensions__ = "ob"
7
-
8
- # To support "from metaflow import get_aws_client"
9
- from metaflow.plugins.aws.aws_client import get_aws_client
10
- from .. import profilers