ob-metaflow-extensions 1.1.84__tar.gz → 1.1.88__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/PKG-INFO +1 -1
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/config/__init__.py +7 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/__init__.py +125 -59
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +92 -36
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +14 -0
- ob-metaflow-extensions-1.1.88/metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +157 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +30 -6
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +62 -1
- ob-metaflow-extensions-1.1.88/metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +53 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/PKG-INFO +1 -1
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/SOURCES.txt +1 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/setup.py +1 -1
- ob-metaflow-extensions-1.1.84/metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +0 -10
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/README.md +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/auth_server.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/kubernetes/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nim/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/perimeters.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/profilers/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/profilers/gpu.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/remote_config.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/dependency_links.txt +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/requires.txt +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/ob_metaflow_extensions.egg-info/top_level.txt +0 -0
- {ob-metaflow-extensions-1.1.84 → ob-metaflow-extensions-1.1.88}/setup.cfg +0 -0
|
@@ -14,6 +14,13 @@ DEFAULT_GCP_CLIENT_PROVIDER = "obp"
|
|
|
14
14
|
FAST_BAKERY_URL = from_conf("FAST_BAKERY_URL", None)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
###
|
|
18
|
+
# NVCF configuration
|
|
19
|
+
###
|
|
20
|
+
# Maximum number of consecutive heartbeats that can be missed.
|
|
21
|
+
NVIDIA_HEARTBEAT_THRESHOLD = from_conf("NVIDIA_HEARTBEAT_THRESHOLD", "3")
|
|
22
|
+
|
|
23
|
+
|
|
17
24
|
###
|
|
18
25
|
# Snowpark configuration
|
|
19
26
|
###
|
|
@@ -32,6 +32,126 @@ def hide_access_keys(*args, **kwds):
|
|
|
32
32
|
os.environ["AWS_SESSION_TOKEN"] = AWS_SESSION_TOKEN
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
# This is a special placeholder value that can be passed as role_arn to
|
|
36
|
+
# get_boto3_session() which makes it use the CSPR role, if its set.
|
|
37
|
+
USE_CSPR_ROLE_ARN_IF_SET = "__cspr__"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_boto3_session(role_arn=None, session_vars=None):
|
|
41
|
+
import boto3
|
|
42
|
+
import botocore
|
|
43
|
+
from metaflow_extensions.outerbounds.plugins.auth_server import get_token
|
|
44
|
+
|
|
45
|
+
from hashlib import sha256
|
|
46
|
+
from metaflow.util import get_username
|
|
47
|
+
|
|
48
|
+
user = get_username()
|
|
49
|
+
|
|
50
|
+
token_info = get_token("/generate/aws")
|
|
51
|
+
|
|
52
|
+
# Write token to a file. The file name is derived from the user name
|
|
53
|
+
# so it works with multiple users on the same machine.
|
|
54
|
+
#
|
|
55
|
+
# We hash the user name so we don't have to deal with special characters
|
|
56
|
+
# in the file name and the file name is not exposed to the user
|
|
57
|
+
# anyways, so it doesn't matter that its a little ugly.
|
|
58
|
+
token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
59
|
+
|
|
60
|
+
# Write to a temp file then rename to avoid a situation when someone
|
|
61
|
+
# tries to read the file after it was open for writing (and truncated)
|
|
62
|
+
# but before the token was written to it.
|
|
63
|
+
with tempfile.NamedTemporaryFile("w", delete=False) as f:
|
|
64
|
+
f.write(token_info["token"])
|
|
65
|
+
tmp_token_file = f.name
|
|
66
|
+
os.rename(tmp_token_file, token_file)
|
|
67
|
+
|
|
68
|
+
cspr_role = None
|
|
69
|
+
if token_info.get("cspr_role_arn"):
|
|
70
|
+
cspr_role = token_info["cspr_role_arn"]
|
|
71
|
+
|
|
72
|
+
if cspr_role:
|
|
73
|
+
# If CSPR role is set, we set it as the default role to assume
|
|
74
|
+
# for the AWS SDK. We do this by writing an AWS config file
|
|
75
|
+
# with two profiles. One to get credentials for the task role
|
|
76
|
+
# in exchange for the OIDC token, and second to assume the
|
|
77
|
+
# CSPR role using the task role credentials.
|
|
78
|
+
import configparser
|
|
79
|
+
from io import StringIO
|
|
80
|
+
|
|
81
|
+
aws_config = configparser.ConfigParser()
|
|
82
|
+
|
|
83
|
+
# Task role profile
|
|
84
|
+
aws_config["profile task"] = {
|
|
85
|
+
"role_arn": token_info["role_arn"],
|
|
86
|
+
"web_identity_token_file": token_file,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# CSPR role profile (default)
|
|
90
|
+
aws_config["profile cspr"] = {
|
|
91
|
+
"role_arn": cspr_role,
|
|
92
|
+
"source_profile": "task",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
aws_config_string = StringIO()
|
|
96
|
+
aws_config.write(aws_config_string)
|
|
97
|
+
aws_config_file = (
|
|
98
|
+
"/tmp/aws_config." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
99
|
+
)
|
|
100
|
+
with tempfile.NamedTemporaryFile(
|
|
101
|
+
"w", delete=False, dir=os.path.dirname(aws_config_file)
|
|
102
|
+
) as f:
|
|
103
|
+
f.write(aws_config_string.getvalue())
|
|
104
|
+
tmp_aws_config_file = f.name
|
|
105
|
+
os.rename(tmp_aws_config_file, aws_config_file)
|
|
106
|
+
os.environ["AWS_CONFIG_FILE"] = aws_config_file
|
|
107
|
+
os.environ["AWS_PROFILE"] = "cspr"
|
|
108
|
+
else:
|
|
109
|
+
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
|
|
110
|
+
os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
|
|
111
|
+
|
|
112
|
+
# Enable regional STS endpoints. This is the new recommended way
|
|
113
|
+
# by AWS [1] and is the more performant way.
|
|
114
|
+
# [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
|
|
115
|
+
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
|
|
116
|
+
if token_info.get("region"):
|
|
117
|
+
os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
|
|
118
|
+
|
|
119
|
+
with hide_access_keys():
|
|
120
|
+
if cspr_role:
|
|
121
|
+
# The generated AWS config will be used here since we set the
|
|
122
|
+
# AWS_CONFIG_FILE environment variable above.
|
|
123
|
+
if role_arn == USE_CSPR_ROLE_ARN_IF_SET:
|
|
124
|
+
# Otherwise start from the default profile, assuming CSPR role
|
|
125
|
+
session = boto3.session.Session(profile_name="cspr")
|
|
126
|
+
else:
|
|
127
|
+
session = boto3.session.Session(profile_name="task")
|
|
128
|
+
else:
|
|
129
|
+
# Not using AWS config, just AWS_WEB_IDENTITY_TOKEN_FILE + AWS_ROLE_ARN
|
|
130
|
+
session = boto3.session.Session()
|
|
131
|
+
|
|
132
|
+
if role_arn and role_arn != USE_CSPR_ROLE_ARN_IF_SET:
|
|
133
|
+
# If the user provided a role_arn, we assume that role
|
|
134
|
+
# using the task role credentials. CSPR role is not used.
|
|
135
|
+
fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
|
|
136
|
+
client_creator=session._session.create_client,
|
|
137
|
+
source_credentials=session._session.get_credentials(),
|
|
138
|
+
role_arn=role_arn,
|
|
139
|
+
extra_args={},
|
|
140
|
+
)
|
|
141
|
+
creds = botocore.credentials.DeferredRefreshableCredentials(
|
|
142
|
+
method="assume-role", refresh_using=fetcher.fetch_credentials
|
|
143
|
+
)
|
|
144
|
+
botocore_session = botocore.session.Session(session_vars=session_vars)
|
|
145
|
+
botocore_session._credentials = creds
|
|
146
|
+
return boto3.session.Session(botocore_session=botocore_session)
|
|
147
|
+
else:
|
|
148
|
+
# If the user didn't provide a role_arn, or if the role_arn
|
|
149
|
+
# is set to USE_CSPR_ROLE_ARN_IF_SET, we return the default
|
|
150
|
+
# session which would use the CSPR role if it is set on the
|
|
151
|
+
# server, and the task role otherwise.
|
|
152
|
+
return session
|
|
153
|
+
|
|
154
|
+
|
|
35
155
|
class ObpAuthProvider(object):
|
|
36
156
|
name = "obp"
|
|
37
157
|
|
|
@@ -42,67 +162,13 @@ class ObpAuthProvider(object):
|
|
|
42
162
|
if client_params is None:
|
|
43
163
|
client_params = {}
|
|
44
164
|
|
|
45
|
-
import boto3
|
|
46
|
-
import botocore
|
|
47
165
|
from botocore.exceptions import ClientError
|
|
48
|
-
from metaflow_extensions.outerbounds.plugins.auth_server import get_token
|
|
49
|
-
|
|
50
|
-
from hashlib import sha256
|
|
51
|
-
from metaflow.util import get_username
|
|
52
|
-
|
|
53
|
-
user = get_username()
|
|
54
|
-
|
|
55
|
-
token_info = get_token("/generate/aws")
|
|
56
166
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
# anyways, so it doesn't matter that its a little ugly.
|
|
63
|
-
token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
64
|
-
|
|
65
|
-
# Write to a temp file then rename to avoid a situation when someone
|
|
66
|
-
# tries to read the file after it was open for writing (and truncated)
|
|
67
|
-
# but before the token was written to it.
|
|
68
|
-
with tempfile.NamedTemporaryFile("w", delete=False) as f:
|
|
69
|
-
f.write(token_info["token"])
|
|
70
|
-
tmp_token_file = f.name
|
|
71
|
-
os.rename(tmp_token_file, token_file)
|
|
72
|
-
|
|
73
|
-
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
|
|
74
|
-
os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
|
|
75
|
-
|
|
76
|
-
# Enable regional STS endpoints. This is the new recommended way
|
|
77
|
-
# by AWS [1] and is the more performant way.
|
|
78
|
-
# [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
|
|
79
|
-
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
|
|
80
|
-
if token_info.get("region"):
|
|
81
|
-
os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
|
|
82
|
-
|
|
83
|
-
with hide_access_keys():
|
|
84
|
-
if role_arn:
|
|
85
|
-
session = boto3.session.Session()
|
|
86
|
-
fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
|
|
87
|
-
client_creator=session._session.create_client,
|
|
88
|
-
source_credentials=session._session.get_credentials(),
|
|
89
|
-
role_arn=role_arn,
|
|
90
|
-
extra_args={},
|
|
91
|
-
)
|
|
92
|
-
creds = botocore.credentials.DeferredRefreshableCredentials(
|
|
93
|
-
method="assume-role", refresh_using=fetcher.fetch_credentials
|
|
94
|
-
)
|
|
95
|
-
botocore_session = botocore.session.Session(session_vars=session_vars)
|
|
96
|
-
botocore_session._credentials = creds
|
|
97
|
-
session = boto3.session.Session(botocore_session=botocore_session)
|
|
98
|
-
if with_error:
|
|
99
|
-
return session.client(module, **client_params), ClientError
|
|
100
|
-
else:
|
|
101
|
-
return session.client(module, **client_params)
|
|
102
|
-
if with_error:
|
|
103
|
-
return boto3.client(module, **client_params), ClientError
|
|
104
|
-
else:
|
|
105
|
-
return boto3.client(module, **client_params)
|
|
167
|
+
session = get_boto3_session(role_arn, session_vars)
|
|
168
|
+
if with_error:
|
|
169
|
+
return session.client(module, **client_params), ClientError
|
|
170
|
+
else:
|
|
171
|
+
return session.client(module, **client_params)
|
|
106
172
|
|
|
107
173
|
|
|
108
174
|
AWS_CLIENT_PROVIDERS_DESC = [("obp", ".ObpAuthProvider")]
|
|
@@ -1,30 +1,30 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
5
7
|
from concurrent.futures import ThreadPoolExecutor
|
|
6
8
|
from typing import Dict
|
|
9
|
+
|
|
7
10
|
from metaflow.exception import MetaflowException
|
|
8
|
-
from metaflow.metaflow_config import
|
|
9
|
-
FAST_BAKERY_URL,
|
|
10
|
-
get_pinned_conda_libs,
|
|
11
|
-
)
|
|
11
|
+
from metaflow.metaflow_config import FAST_BAKERY_URL, get_pinned_conda_libs
|
|
12
12
|
from metaflow.metaflow_environment import MetaflowEnvironment
|
|
13
|
-
from metaflow.plugins.pypi.conda_environment import CondaEnvironment
|
|
14
|
-
from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
|
|
15
13
|
from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator
|
|
16
14
|
from metaflow.plugins.kubernetes.kubernetes_decorator import KubernetesDecorator
|
|
17
15
|
from metaflow.plugins.pypi.conda_decorator import CondaStepDecorator
|
|
16
|
+
from metaflow.plugins.pypi.conda_environment import CondaEnvironment
|
|
18
17
|
from metaflow.plugins.pypi.pypi_decorator import PyPIStepDecorator
|
|
19
18
|
|
|
19
|
+
from .fast_bakery import FastBakery, FastBakeryApiResponse, FastBakeryException
|
|
20
|
+
|
|
20
21
|
BAKERY_METAFILE = ".imagebakery-cache"
|
|
21
22
|
|
|
23
|
+
import fcntl
|
|
22
24
|
import json
|
|
23
25
|
import os
|
|
24
|
-
import fcntl
|
|
25
|
-
from functools import wraps
|
|
26
26
|
from concurrent.futures import ThreadPoolExecutor
|
|
27
|
-
|
|
27
|
+
from functools import wraps
|
|
28
28
|
|
|
29
29
|
# TODO - ensure that both @conda/@pypi are not assigned to the same step
|
|
30
30
|
|
|
@@ -36,6 +36,9 @@ def cache_request(cache_file):
|
|
|
36
36
|
call_args = kwargs.copy()
|
|
37
37
|
call_args.update(zip(func.__code__.co_varnames, args))
|
|
38
38
|
call_args.pop("self", None)
|
|
39
|
+
call_args.pop("ref", None)
|
|
40
|
+
# invalidate cache when moving from one deployment to another
|
|
41
|
+
call_args.update({"fast_bakery_url": FAST_BAKERY_URL})
|
|
39
42
|
cache_key = hashlib.md5(
|
|
40
43
|
json.dumps(call_args, sort_keys=True).encode("utf-8")
|
|
41
44
|
).hexdigest()
|
|
@@ -79,7 +82,7 @@ def cache_request(cache_file):
|
|
|
79
82
|
|
|
80
83
|
|
|
81
84
|
class DockerEnvironmentException(MetaflowException):
|
|
82
|
-
headline = "Ran into an error while
|
|
85
|
+
headline = "Ran into an error while baking image"
|
|
83
86
|
|
|
84
87
|
def __init__(self, msg):
|
|
85
88
|
super(DockerEnvironmentException, self).__init__(msg)
|
|
@@ -93,8 +96,8 @@ class DockerEnvironment(MetaflowEnvironment):
|
|
|
93
96
|
self.skipped_steps = set()
|
|
94
97
|
self.flow = flow
|
|
95
98
|
|
|
96
|
-
self.bakery = FastBakery(url=FAST_BAKERY_URL)
|
|
97
99
|
self.results = {}
|
|
100
|
+
self.images_baked = 0
|
|
98
101
|
|
|
99
102
|
def set_local_root(self, local_root):
|
|
100
103
|
self.local_root = local_root
|
|
@@ -102,15 +105,31 @@ class DockerEnvironment(MetaflowEnvironment):
|
|
|
102
105
|
def decospecs(self):
|
|
103
106
|
return ("conda", "fast_bakery_internal") + super().decospecs()
|
|
104
107
|
|
|
105
|
-
def validate_environment(self,
|
|
108
|
+
def validate_environment(self, logger, datastore_type):
|
|
106
109
|
self.datastore_type = datastore_type
|
|
107
|
-
self.
|
|
110
|
+
self.logger = logger
|
|
108
111
|
|
|
109
112
|
# Avoiding circular imports.
|
|
110
113
|
from metaflow.plugins import DATASTORES
|
|
111
114
|
|
|
112
115
|
self.datastore = [d for d in DATASTORES if d.TYPE == self.datastore_type][0]
|
|
113
116
|
|
|
117
|
+
# Mixing @pypi/@conda in a single step is not supported yet
|
|
118
|
+
for step in self.flow:
|
|
119
|
+
if (
|
|
120
|
+
sum(
|
|
121
|
+
1
|
|
122
|
+
for deco in step.decorators
|
|
123
|
+
if isinstance(deco, (PyPIStepDecorator, CondaStepDecorator))
|
|
124
|
+
)
|
|
125
|
+
> 1
|
|
126
|
+
):
|
|
127
|
+
raise MetaflowException(
|
|
128
|
+
"Mixing and matching PyPI packages and Conda packages within a\n"
|
|
129
|
+
"step is not yet supported. Use one of @pypi or @conda only for the *%s* step."
|
|
130
|
+
% step.name
|
|
131
|
+
)
|
|
132
|
+
|
|
114
133
|
def init_environment(self, echo):
|
|
115
134
|
self.skipped_steps = {
|
|
116
135
|
step.name
|
|
@@ -125,14 +144,21 @@ class DockerEnvironment(MetaflowEnvironment):
|
|
|
125
144
|
step for step in self.flow if step.name not in self.skipped_steps
|
|
126
145
|
]
|
|
127
146
|
if steps_to_bake:
|
|
128
|
-
|
|
129
|
-
|
|
147
|
+
self.logger("🚀 Baking container image(s) ...")
|
|
148
|
+
start_time = time.time()
|
|
149
|
+
self.results = self._bake(steps_to_bake)
|
|
130
150
|
for step in self.flow:
|
|
131
151
|
for d in step.decorators:
|
|
132
152
|
if isinstance(d, (BatchDecorator, KubernetesDecorator)):
|
|
133
153
|
d.attributes["image"] = self.results[step.name].container_image
|
|
134
154
|
d.attributes["executable"] = self.results[step.name].python_path
|
|
135
|
-
|
|
155
|
+
if self.images_baked > 0:
|
|
156
|
+
bake_time = time.time() - start_time
|
|
157
|
+
self.logger(
|
|
158
|
+
f"🎉 All container image(s) baked in {bake_time:.2f} seconds!"
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
self.logger("🎉 All container image(s) baked!")
|
|
136
162
|
|
|
137
163
|
if self.skipped_steps:
|
|
138
164
|
self.delegate = CondaEnvironment(self.flow)
|
|
@@ -140,29 +166,54 @@ class DockerEnvironment(MetaflowEnvironment):
|
|
|
140
166
|
self.delegate.validate_environment(echo, self.datastore_type)
|
|
141
167
|
self.delegate.init_environment(echo, self.skipped_steps)
|
|
142
168
|
|
|
143
|
-
def _bake(self, steps
|
|
169
|
+
def _bake(self, steps) -> Dict[str, FastBakeryApiResponse]:
|
|
144
170
|
metafile_path = get_fastbakery_metafile_path(self.local_root, self.flow.name)
|
|
171
|
+
logger_lock = threading.Lock()
|
|
145
172
|
|
|
146
173
|
@cache_request(metafile_path)
|
|
147
174
|
def _cached_bake(
|
|
148
|
-
|
|
175
|
+
ref=None,
|
|
176
|
+
python=None,
|
|
177
|
+
pypi_packages=None,
|
|
178
|
+
conda_packages=None,
|
|
179
|
+
base_image=None,
|
|
149
180
|
):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
181
|
+
bakery = FastBakery(url=FAST_BAKERY_URL)
|
|
182
|
+
bakery._reset_payload()
|
|
183
|
+
bakery.python_version(python)
|
|
184
|
+
bakery.pypi_packages(pypi_packages)
|
|
185
|
+
bakery.conda_packages(conda_packages)
|
|
186
|
+
bakery.base_image(base_image)
|
|
187
|
+
# bakery.ignore_cache()
|
|
188
|
+
|
|
189
|
+
with logger_lock:
|
|
190
|
+
self.logger(f"🍳 Baking [{ref}] ...")
|
|
191
|
+
self.logger(f" 🐍 Python: {python}")
|
|
192
|
+
|
|
193
|
+
if pypi_packages:
|
|
194
|
+
self.logger(f" 📦 PyPI packages:")
|
|
195
|
+
for package, version in pypi_packages.items():
|
|
196
|
+
self.logger(f" 🔧 {package}: {version}")
|
|
197
|
+
|
|
198
|
+
if conda_packages:
|
|
199
|
+
self.logger(f" 📦 Conda packages:")
|
|
200
|
+
for package, version in conda_packages.items():
|
|
201
|
+
self.logger(f" 🔧 {package}: {version}")
|
|
202
|
+
|
|
203
|
+
self.logger(f" 🏗️ Base image: {base_image}")
|
|
204
|
+
|
|
205
|
+
start_time = time.time()
|
|
156
206
|
try:
|
|
157
|
-
res =
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
207
|
+
res = bakery.bake()
|
|
208
|
+
# TODO: Get actual bake time from bakery
|
|
209
|
+
bake_time = time.time() - start_time
|
|
210
|
+
|
|
211
|
+
with logger_lock:
|
|
212
|
+
self.logger(f"🏁 Baked [{ref}] in {bake_time:.2f} seconds!")
|
|
213
|
+
self.images_baked += 1
|
|
163
214
|
return res
|
|
164
215
|
except FastBakeryException as ex:
|
|
165
|
-
raise DockerEnvironmentException(str(ex))
|
|
216
|
+
raise DockerEnvironmentException(f"Bake [{ref}] failed: {str(ex)}")
|
|
166
217
|
|
|
167
218
|
def prepare_step(step):
|
|
168
219
|
base_image = next(
|
|
@@ -216,10 +267,15 @@ class DockerEnvironment(MetaflowEnvironment):
|
|
|
216
267
|
}
|
|
217
268
|
|
|
218
269
|
with ThreadPoolExecutor() as executor:
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
270
|
+
prepared_args = list(executor.map(prepare_step, steps))
|
|
271
|
+
for i, args in enumerate(prepared_args, 1):
|
|
272
|
+
args["ref"] = f"#{i:02d}"
|
|
273
|
+
futures = [executor.submit(_cached_bake, **args) for args in prepared_args]
|
|
274
|
+
results = {}
|
|
275
|
+
for step, future in zip(steps, futures):
|
|
276
|
+
results[step.name] = future.result()
|
|
277
|
+
|
|
278
|
+
return results
|
|
223
279
|
|
|
224
280
|
def executable(self, step_name, default=None):
|
|
225
281
|
if step_name in self.skipped_steps:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Dict, Optional
|
|
2
2
|
import requests
|
|
3
|
+
import time
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class FastBakeryException(Exception):
|
|
@@ -140,6 +141,19 @@ class FastBakery:
|
|
|
140
141
|
headers = {**self.headers, **(SERVICE_HEADERS or {})}
|
|
141
142
|
except ImportError:
|
|
142
143
|
headers = self.headers
|
|
144
|
+
|
|
145
|
+
retryable_status_codes = [409]
|
|
146
|
+
|
|
147
|
+
for attempt in range(2): # 0 = initial attempt, 1-2 = retries
|
|
148
|
+
response = requests.post(self.url, json=payload, headers=headers)
|
|
149
|
+
|
|
150
|
+
if response.status_code not in retryable_status_codes:
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
if attempt < 2: # Don't sleep after the last attempt
|
|
154
|
+
sleep_time = 0.5 * (attempt + 1)
|
|
155
|
+
time.sleep(sleep_time)
|
|
156
|
+
|
|
143
157
|
response = requests.post(self.url, json=payload, headers=headers)
|
|
144
158
|
self._handle_error_response(response)
|
|
145
159
|
return FastBakeryApiResponse(response.json())
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import signal
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
|
|
8
|
+
from metaflow.exception import MetaflowException
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HeartbeatStore(object):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
heartbeat_prefix,
|
|
15
|
+
main_pid=None,
|
|
16
|
+
storage_backend=None,
|
|
17
|
+
emit_frequency=30,
|
|
18
|
+
missed_heartbeat_timeout=60,
|
|
19
|
+
monitor_frequency=15,
|
|
20
|
+
max_missed_heartbeats=3,
|
|
21
|
+
) -> None:
|
|
22
|
+
self.heartbeat_prefix = heartbeat_prefix
|
|
23
|
+
self.main_pid = main_pid
|
|
24
|
+
self.storage_backend = storage_backend
|
|
25
|
+
self.emit_frequency = emit_frequency
|
|
26
|
+
self.monitor_frequency = monitor_frequency
|
|
27
|
+
self.missed_heartbeat_timeout = missed_heartbeat_timeout
|
|
28
|
+
self.max_missed_heartbeats = max_missed_heartbeats
|
|
29
|
+
self.missed_heartbeats = 0
|
|
30
|
+
|
|
31
|
+
def emit_heartbeat(self, folder_name=None):
|
|
32
|
+
heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
|
|
33
|
+
if folder_name:
|
|
34
|
+
heartbeat_key = f"{folder_name}/{heartbeat_key}"
|
|
35
|
+
|
|
36
|
+
while True:
|
|
37
|
+
try:
|
|
38
|
+
epoch_string = str(datetime.now(timezone.utc).timestamp()).encode(
|
|
39
|
+
"utf-8"
|
|
40
|
+
)
|
|
41
|
+
self.storage_backend.save_bytes(
|
|
42
|
+
[(heartbeat_key, BytesIO(bytes(epoch_string)))], overwrite=True
|
|
43
|
+
)
|
|
44
|
+
except Exception as e:
|
|
45
|
+
print(f"Error writing heartbeat: {e}")
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
time.sleep(self.emit_frequency)
|
|
49
|
+
|
|
50
|
+
def emit_tombstone(self, folder_name=None):
|
|
51
|
+
tombstone_key = f"{self.heartbeat_prefix}/tombstone"
|
|
52
|
+
if folder_name:
|
|
53
|
+
tombstone_key = f"{folder_name}/{tombstone_key}"
|
|
54
|
+
|
|
55
|
+
tombstone_string = "tombstone".encode("utf-8")
|
|
56
|
+
try:
|
|
57
|
+
self.storage_backend.save_bytes(
|
|
58
|
+
[(tombstone_key, BytesIO(bytes(tombstone_string)))], overwrite=True
|
|
59
|
+
)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
print(f"Error writing tombstone: {e}")
|
|
62
|
+
sys.exit(1)
|
|
63
|
+
|
|
64
|
+
def __handle_tombstone(self, path):
|
|
65
|
+
if path is not None:
|
|
66
|
+
with open(path) as f:
|
|
67
|
+
contents = f.read()
|
|
68
|
+
if "tombstone" in contents:
|
|
69
|
+
print("[Outerbounds] Tombstone detected. Terminating the task..")
|
|
70
|
+
os.kill(self.main_pid, signal.SIGTERM)
|
|
71
|
+
sys.exit(1)
|
|
72
|
+
|
|
73
|
+
def __handle_heartbeat(self, path):
|
|
74
|
+
if path is not None:
|
|
75
|
+
with open(path) as f:
|
|
76
|
+
contents = f.read()
|
|
77
|
+
current_timestamp = datetime.now(timezone.utc).timestamp()
|
|
78
|
+
if current_timestamp - float(contents) > self.missed_heartbeat_timeout:
|
|
79
|
+
self.missed_heartbeats += 1
|
|
80
|
+
else:
|
|
81
|
+
self.missed_heartbeats = 0
|
|
82
|
+
else:
|
|
83
|
+
self.missed_heartbeats += 1
|
|
84
|
+
|
|
85
|
+
if self.missed_heartbeats > self.max_missed_heartbeats:
|
|
86
|
+
print(
|
|
87
|
+
f"[Outerbounds] Missed {self.max_missed_heartbeats} consecutive heartbeats. Terminating the task.."
|
|
88
|
+
)
|
|
89
|
+
os.kill(self.main_pid, signal.SIGTERM)
|
|
90
|
+
sys.exit(1)
|
|
91
|
+
|
|
92
|
+
def is_main_process_running(self):
|
|
93
|
+
try:
|
|
94
|
+
# Check if the process is running
|
|
95
|
+
os.kill(self.main_pid, 0)
|
|
96
|
+
except ProcessLookupError:
|
|
97
|
+
return False
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def monitor(self, folder_name=None):
|
|
101
|
+
heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
|
|
102
|
+
if folder_name:
|
|
103
|
+
heartbeat_key = f"{folder_name}/{heartbeat_key}"
|
|
104
|
+
|
|
105
|
+
tombstone_key = f"{self.heartbeat_prefix}/tombstone"
|
|
106
|
+
if folder_name:
|
|
107
|
+
tombstone_key = f"{folder_name}/{tombstone_key}"
|
|
108
|
+
|
|
109
|
+
while self.is_main_process_running():
|
|
110
|
+
with self.storage_backend.load_bytes(
|
|
111
|
+
[heartbeat_key, tombstone_key]
|
|
112
|
+
) as results:
|
|
113
|
+
for key, path, _ in results:
|
|
114
|
+
if key == tombstone_key:
|
|
115
|
+
self.__handle_tombstone(path)
|
|
116
|
+
elif key == heartbeat_key:
|
|
117
|
+
self.__handle_heartbeat(path)
|
|
118
|
+
|
|
119
|
+
time.sleep(self.monitor_frequency)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if __name__ == "__main__":
|
|
123
|
+
from metaflow.plugins import DATASTORES
|
|
124
|
+
from metaflow.metaflow_config import NVIDIA_HEARTBEAT_THRESHOLD
|
|
125
|
+
|
|
126
|
+
if len(sys.argv) != 4:
|
|
127
|
+
print("Usage: heartbeat_store.py <main_pid> <datastore_type> <folder_name>")
|
|
128
|
+
sys.exit(1)
|
|
129
|
+
_, main_pid, datastore_type, folder_name = sys.argv
|
|
130
|
+
|
|
131
|
+
if datastore_type not in ("azure", "gs", "s3"):
|
|
132
|
+
print(f"Datastore unsupported for type: {datastore_type}")
|
|
133
|
+
sys.exit(1)
|
|
134
|
+
|
|
135
|
+
datastores = [d for d in DATASTORES if d.TYPE == datastore_type]
|
|
136
|
+
datastore_sysroot = datastores[0].get_datastore_root_from_config(
|
|
137
|
+
lambda *args, **kwargs: None
|
|
138
|
+
)
|
|
139
|
+
if datastore_sysroot is None:
|
|
140
|
+
raise MetaflowException(
|
|
141
|
+
msg="METAFLOW_DATASTORE_SYSROOT_{datastore_type} must be set!".format(
|
|
142
|
+
datastore_type=datastore_type.upper()
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
storage = datastores[0](datastore_sysroot)
|
|
147
|
+
|
|
148
|
+
heartbeat_prefix = f"{os.getenv('MF_PATHSPEC')}/{os.getenv('MF_ATTEMPT')}"
|
|
149
|
+
|
|
150
|
+
store = HeartbeatStore(
|
|
151
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
152
|
+
main_pid=int(main_pid),
|
|
153
|
+
storage_backend=storage,
|
|
154
|
+
max_missed_heartbeats=int(NVIDIA_HEARTBEAT_THRESHOLD),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
store.monitor(folder_name=folder_name)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
import
|
|
3
|
+
import threading
|
|
4
|
+
from urllib.parse import urlparse
|
|
4
5
|
from urllib.request import HTTPError, Request, URLError, urlopen
|
|
5
6
|
|
|
6
7
|
from metaflow import util
|
|
@@ -14,6 +15,7 @@ from metaflow.mflog import (
|
|
|
14
15
|
)
|
|
15
16
|
import requests
|
|
16
17
|
from metaflow.metaflow_config_funcs import init_config
|
|
18
|
+
from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import HeartbeatStore
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class NvcfException(MetaflowException):
|
|
@@ -58,10 +60,11 @@ class Nvcf(object):
|
|
|
58
60
|
code_package_url, code_package_ds
|
|
59
61
|
)
|
|
60
62
|
init_expr = " && ".join(init_cmds)
|
|
63
|
+
heartbeat_expr = f'python -m metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store "$MAIN_PID" {code_package_ds} nvcf_heartbeats 1>> $MFLOG_STDOUT 2>> $MFLOG_STDERR'
|
|
61
64
|
step_expr = bash_capture_logs(
|
|
62
65
|
" && ".join(
|
|
63
66
|
self.environment.bootstrap_commands(step_name, code_package_ds)
|
|
64
|
-
+ [step_cli]
|
|
67
|
+
+ [step_cli + " & MAIN_PID=$!; " + heartbeat_expr]
|
|
65
68
|
)
|
|
66
69
|
)
|
|
67
70
|
|
|
@@ -87,7 +90,9 @@ class Nvcf(object):
|
|
|
87
90
|
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
|
|
88
91
|
% cmd_str
|
|
89
92
|
)
|
|
90
|
-
self.job = Job(
|
|
93
|
+
self.job = Job(
|
|
94
|
+
'bash -c "%s"' % cmd_str, env, task_spec, self.datastore._storage_impl
|
|
95
|
+
)
|
|
91
96
|
self.job.submit()
|
|
92
97
|
|
|
93
98
|
def wait(self, stdout_location, stderr_location, echo=None):
|
|
@@ -137,8 +142,7 @@ result_endpoint = f"{nvcf_url}/v2/nvcf/pexec/status"
|
|
|
137
142
|
|
|
138
143
|
|
|
139
144
|
class Job(object):
|
|
140
|
-
def __init__(self, command, env):
|
|
141
|
-
|
|
145
|
+
def __init__(self, command, env, task_spec, backend):
|
|
142
146
|
self._payload = {
|
|
143
147
|
"command": command,
|
|
144
148
|
"env": {k: v for k, v in env.items() if v is not None},
|
|
@@ -170,6 +174,27 @@ class Job(object):
|
|
|
170
174
|
if f["model_key"] == "metaflow_task_executor":
|
|
171
175
|
self._function_id = f["id"]
|
|
172
176
|
|
|
177
|
+
flow_name = task_spec.get("flow_name")
|
|
178
|
+
run_id = task_spec.get("run_id")
|
|
179
|
+
step_name = task_spec.get("step_name")
|
|
180
|
+
task_id = task_spec.get("task_id")
|
|
181
|
+
retry_count = task_spec.get("retry_count")
|
|
182
|
+
|
|
183
|
+
heartbeat_prefix = "/".join(
|
|
184
|
+
(flow_name, str(run_id), step_name, str(task_id), str(retry_count))
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
store = HeartbeatStore(
|
|
188
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
189
|
+
main_pid=None,
|
|
190
|
+
storage_backend=backend,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.heartbeat_thread = threading.Thread(
|
|
194
|
+
target=store.emit_heartbeat, args=("nvcf_heartbeats",), daemon=True
|
|
195
|
+
)
|
|
196
|
+
self.heartbeat_thread.start()
|
|
197
|
+
|
|
173
198
|
def submit(self):
|
|
174
199
|
try:
|
|
175
200
|
headers = {
|
|
@@ -224,7 +249,6 @@ class Job(object):
|
|
|
224
249
|
|
|
225
250
|
def _poll(self):
|
|
226
251
|
try:
|
|
227
|
-
invocation_id = self._invocation_id
|
|
228
252
|
headers = {
|
|
229
253
|
"Authorization": f"Bearer {self._ngc_api_key}",
|
|
230
254
|
"Content-Type": "application/json",
|
|
@@ -4,7 +4,7 @@ import sys
|
|
|
4
4
|
import time
|
|
5
5
|
import traceback
|
|
6
6
|
|
|
7
|
-
from metaflow import util
|
|
7
|
+
from metaflow import util, Run
|
|
8
8
|
from metaflow._vendor import click
|
|
9
9
|
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
|
|
10
10
|
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
|
@@ -27,6 +27,7 @@ from metaflow.metaflow_config import (
|
|
|
27
27
|
CARD_GSROOT,
|
|
28
28
|
KUBERNETES_SANDBOX_INIT_SCRIPT,
|
|
29
29
|
OTEL_ENDPOINT,
|
|
30
|
+
NVIDIA_HEARTBEAT_THRESHOLD,
|
|
30
31
|
)
|
|
31
32
|
from metaflow.mflog import TASK_LOG_SOURCE
|
|
32
33
|
|
|
@@ -43,6 +44,65 @@ def nvcf():
|
|
|
43
44
|
pass
|
|
44
45
|
|
|
45
46
|
|
|
47
|
+
@nvcf.command(help="List steps / tasks running as an NVCF job.")
|
|
48
|
+
@click.argument("run-id")
|
|
49
|
+
@click.pass_context
|
|
50
|
+
def list(ctx, run_id):
|
|
51
|
+
flow_name = ctx.obj.flow.name
|
|
52
|
+
run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
|
|
53
|
+
running_invocations = []
|
|
54
|
+
for each_step in run_obj:
|
|
55
|
+
if (
|
|
56
|
+
not each_step.task.finished
|
|
57
|
+
and "nvcf-function-id" in each_step.task.metadata_dict
|
|
58
|
+
):
|
|
59
|
+
task_pathspec = each_step.task.pathspec
|
|
60
|
+
attempt = each_step.task.metadata_dict.get("attempt")
|
|
61
|
+
flow_name, run_id, step_name, task_id = task_pathspec.split("/")
|
|
62
|
+
running_invocations.append(
|
|
63
|
+
f"Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if running_invocations:
|
|
67
|
+
for each_invocation in running_invocations:
|
|
68
|
+
print(each_invocation)
|
|
69
|
+
else:
|
|
70
|
+
print("No running NVCF invocations for Run ID: %s" % run_id)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@nvcf.command(help="Kill steps / tasks running as an NVCF job.")
|
|
74
|
+
@click.argument("run-id")
|
|
75
|
+
@click.pass_context
|
|
76
|
+
def kill(ctx, run_id):
|
|
77
|
+
from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import (
|
|
78
|
+
HeartbeatStore,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
flow_name = ctx.obj.flow.name
|
|
82
|
+
run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
|
|
83
|
+
|
|
84
|
+
for each_step in run_obj:
|
|
85
|
+
if (
|
|
86
|
+
not each_step.task.finished
|
|
87
|
+
and "nvcf-function-id" in each_step.task.metadata_dict
|
|
88
|
+
):
|
|
89
|
+
task_pathspec = each_step.task.pathspec
|
|
90
|
+
attempt = each_step.task.metadata_dict.get("attempt")
|
|
91
|
+
heartbeat_prefix = "{task_pathspec}/{attempt}".format(
|
|
92
|
+
task_pathspec=task_pathspec, attempt=attempt
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
datastore_root = ctx.obj.datastore_impl.datastore_root
|
|
96
|
+
store = HeartbeatStore(
|
|
97
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
98
|
+
main_pid=None,
|
|
99
|
+
storage_backend=ctx.obj.datastore_impl(datastore_root),
|
|
100
|
+
)
|
|
101
|
+
store.emit_tombstone(folder_name="nvcf_heartbeats")
|
|
102
|
+
else:
|
|
103
|
+
print("No running NVCF invocations for Run ID: %s" % run_id)
|
|
104
|
+
|
|
105
|
+
|
|
46
106
|
@nvcf.command(
|
|
47
107
|
help="Execute a single task using NVCF. This command calls the "
|
|
48
108
|
"top-level step command inside a NVCF job with the given options. "
|
|
@@ -139,6 +199,7 @@ def step(ctx, step_name, code_package_sha, code_package_url, **kwargs):
|
|
|
139
199
|
"METAFLOW_CARD_GSROOT": CARD_GSROOT,
|
|
140
200
|
"METAFLOW_INIT_SCRIPT": KUBERNETES_SANDBOX_INIT_SCRIPT,
|
|
141
201
|
"METAFLOW_OTEL_ENDPOINT": OTEL_ENDPOINT,
|
|
202
|
+
"METAFLOW_NVIDIA_HEARTBEAT_THRESHOLD": str(NVIDIA_HEARTBEAT_THRESHOLD),
|
|
142
203
|
}
|
|
143
204
|
|
|
144
205
|
env_deco = [deco for deco in node.decorators if deco.name == "environment"]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# These two fields will show up within `metaflow_version` task metadata.
|
|
2
|
+
# Setting to major version of ob-metaflow-extensions only, so we don't keep trying
|
|
3
|
+
# (and failing) to keep this in sync with setup.py
|
|
4
|
+
# E.g. "2.7.22.1+ob(v1)"
|
|
5
|
+
__version__ = "v1"
|
|
6
|
+
__mf_extensions__ = "ob"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Must match the signature of metaflow.plugins.aws.aws_client.get_aws_client
|
|
10
|
+
# This function is called by the "userland" code inside tasks. Metaflow internals
|
|
11
|
+
# will call the function in metaflow.plugins.aws.aws_client.get_aws_client directly.
|
|
12
|
+
#
|
|
13
|
+
# Unlike the original function, this wrapper will use the CSPR role if both of the following
|
|
14
|
+
# conditions are met:
|
|
15
|
+
#
|
|
16
|
+
# 1. CSPR is set
|
|
17
|
+
# 2. user didn't provide a role to assume explicitly.
|
|
18
|
+
#
|
|
19
|
+
def get_aws_client(
|
|
20
|
+
module, with_error=False, role_arn=None, session_vars=None, client_params=None
|
|
21
|
+
):
|
|
22
|
+
import metaflow.plugins.aws.aws_client
|
|
23
|
+
|
|
24
|
+
from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
|
|
25
|
+
|
|
26
|
+
return metaflow.plugins.aws.aws_client.get_aws_client(
|
|
27
|
+
module,
|
|
28
|
+
with_error=with_error,
|
|
29
|
+
role_arn=role_arn or USE_CSPR_ROLE_ARN_IF_SET,
|
|
30
|
+
session_vars=session_vars,
|
|
31
|
+
client_params=client_params,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# This should match the signature of metaflow.plugins.datatools.s3.S3.
|
|
36
|
+
#
|
|
37
|
+
# This assumes that "userland" code inside tasks will call this, while Metaflow
|
|
38
|
+
# internals will call metaflow.plugins.datatools.s3.S3 directly.
|
|
39
|
+
#
|
|
40
|
+
# This wrapper will make S3() use the CSPR role if its set, and user didn't provide
|
|
41
|
+
# a role to assume explicitly.
|
|
42
|
+
def S3(*args, **kwargs):
|
|
43
|
+
import sys
|
|
44
|
+
import metaflow.plugins.datatools.s3
|
|
45
|
+
from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
|
|
46
|
+
|
|
47
|
+
if "role" not in kwargs or kwargs["role"] is None:
|
|
48
|
+
kwargs["role"] = USE_CSPR_ROLE_ARN_IF_SET
|
|
49
|
+
|
|
50
|
+
return metaflow.plugins.datatools.s3.S3(*args, **kwargs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
from .. import profilers
|
|
@@ -16,6 +16,7 @@ metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py
|
|
|
16
16
|
metaflow_extensions/outerbounds/plugins/nim/__init__.py
|
|
17
17
|
metaflow_extensions/outerbounds/plugins/nim/nim_manager.py
|
|
18
18
|
metaflow_extensions/outerbounds/plugins/nvcf/__init__.py
|
|
19
|
+
metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py
|
|
19
20
|
metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py
|
|
20
21
|
metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py
|
|
21
22
|
metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
# These two fields will show up within `metaflow_version` task metadata.
|
|
2
|
-
# Setting to major version of ob-metaflow-extensions only, so we don't keep trying
|
|
3
|
-
# (and failing) to keep this in sync with setup.py
|
|
4
|
-
# E.g. "2.7.22.1+ob(v1)"
|
|
5
|
-
__version__ = "v1"
|
|
6
|
-
__mf_extensions__ = "ob"
|
|
7
|
-
|
|
8
|
-
# To support "from metaflow import get_aws_client"
|
|
9
|
-
from metaflow.plugins.aws.aws_client import get_aws_client
|
|
10
|
-
from .. import profilers
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|