ob-metaflow-extensions 1.1.83__py2.py3-none-any.whl → 1.1.86__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- metaflow_extensions/outerbounds/config/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/__init__.py +125 -59
- metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +157 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +30 -6
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +62 -1
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +45 -2
- {ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/METADATA +2 -2
- {ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/RECORD +10 -9
- {ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,13 @@ DEFAULT_GCP_CLIENT_PROVIDER = "obp"
|
|
|
14
14
|
FAST_BAKERY_URL = from_conf("FAST_BAKERY_URL", None)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
###
|
|
18
|
+
# NVCF configuration
|
|
19
|
+
###
|
|
20
|
+
# Maximum number of consecutive heartbeats that can be missed.
|
|
21
|
+
NVIDIA_HEARTBEAT_THRESHOLD = from_conf("NVIDIA_HEARTBEAT_THRESHOLD", "3")
|
|
22
|
+
|
|
23
|
+
|
|
17
24
|
###
|
|
18
25
|
# Snowpark configuration
|
|
19
26
|
###
|
|
@@ -32,6 +32,126 @@ def hide_access_keys(*args, **kwds):
|
|
|
32
32
|
os.environ["AWS_SESSION_TOKEN"] = AWS_SESSION_TOKEN
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
# This is a special placeholder value that can be passed as role_arn to
|
|
36
|
+
# get_boto3_session() which makes it use the CSPR role, if its set.
|
|
37
|
+
USE_CSPR_ROLE_ARN_IF_SET = "__cspr__"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_boto3_session(role_arn=None, session_vars=None):
|
|
41
|
+
import boto3
|
|
42
|
+
import botocore
|
|
43
|
+
from metaflow_extensions.outerbounds.plugins.auth_server import get_token
|
|
44
|
+
|
|
45
|
+
from hashlib import sha256
|
|
46
|
+
from metaflow.util import get_username
|
|
47
|
+
|
|
48
|
+
user = get_username()
|
|
49
|
+
|
|
50
|
+
token_info = get_token("/generate/aws")
|
|
51
|
+
|
|
52
|
+
# Write token to a file. The file name is derived from the user name
|
|
53
|
+
# so it works with multiple users on the same machine.
|
|
54
|
+
#
|
|
55
|
+
# We hash the user name so we don't have to deal with special characters
|
|
56
|
+
# in the file name and the file name is not exposed to the user
|
|
57
|
+
# anyways, so it doesn't matter that its a little ugly.
|
|
58
|
+
token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
59
|
+
|
|
60
|
+
# Write to a temp file then rename to avoid a situation when someone
|
|
61
|
+
# tries to read the file after it was open for writing (and truncated)
|
|
62
|
+
# but before the token was written to it.
|
|
63
|
+
with tempfile.NamedTemporaryFile("w", delete=False) as f:
|
|
64
|
+
f.write(token_info["token"])
|
|
65
|
+
tmp_token_file = f.name
|
|
66
|
+
os.rename(tmp_token_file, token_file)
|
|
67
|
+
|
|
68
|
+
cspr_role = None
|
|
69
|
+
if token_info.get("cspr_role_arn"):
|
|
70
|
+
cspr_role = token_info["cspr_role_arn"]
|
|
71
|
+
|
|
72
|
+
if cspr_role:
|
|
73
|
+
# If CSPR role is set, we set it as the default role to assume
|
|
74
|
+
# for the AWS SDK. We do this by writing an AWS config file
|
|
75
|
+
# with two profiles. One to get credentials for the task role
|
|
76
|
+
# in exchange for the OIDC token, and second to assume the
|
|
77
|
+
# CSPR role using the task role credentials.
|
|
78
|
+
import configparser
|
|
79
|
+
from io import StringIO
|
|
80
|
+
|
|
81
|
+
aws_config = configparser.ConfigParser()
|
|
82
|
+
|
|
83
|
+
# Task role profile
|
|
84
|
+
aws_config["profile task"] = {
|
|
85
|
+
"role_arn": token_info["role_arn"],
|
|
86
|
+
"web_identity_token_file": token_file,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# CSPR role profile (default)
|
|
90
|
+
aws_config["profile cspr"] = {
|
|
91
|
+
"role_arn": cspr_role,
|
|
92
|
+
"source_profile": "task",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
aws_config_string = StringIO()
|
|
96
|
+
aws_config.write(aws_config_string)
|
|
97
|
+
aws_config_file = (
|
|
98
|
+
"/tmp/aws_config." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
99
|
+
)
|
|
100
|
+
with tempfile.NamedTemporaryFile(
|
|
101
|
+
"w", delete=False, dir=os.path.dirname(aws_config_file)
|
|
102
|
+
) as f:
|
|
103
|
+
f.write(aws_config_string.getvalue())
|
|
104
|
+
tmp_aws_config_file = f.name
|
|
105
|
+
os.rename(tmp_aws_config_file, aws_config_file)
|
|
106
|
+
os.environ["AWS_CONFIG_FILE"] = aws_config_file
|
|
107
|
+
os.environ["AWS_DEFAULT_PROFILE"] = "cspr"
|
|
108
|
+
else:
|
|
109
|
+
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
|
|
110
|
+
os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
|
|
111
|
+
|
|
112
|
+
# Enable regional STS endpoints. This is the new recommended way
|
|
113
|
+
# by AWS [1] and is the more performant way.
|
|
114
|
+
# [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
|
|
115
|
+
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
|
|
116
|
+
if token_info.get("region"):
|
|
117
|
+
os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
|
|
118
|
+
|
|
119
|
+
with hide_access_keys():
|
|
120
|
+
if cspr_role:
|
|
121
|
+
# The generated AWS config will be used here since we set the
|
|
122
|
+
# AWS_CONFIG_FILE environment variable above.
|
|
123
|
+
if role_arn == USE_CSPR_ROLE_ARN_IF_SET:
|
|
124
|
+
# Otherwise start from the default profile, assuming CSPR role
|
|
125
|
+
session = boto3.session.Session(profile_name="default")
|
|
126
|
+
else:
|
|
127
|
+
session = boto3.session.Session(profile_name="task")
|
|
128
|
+
else:
|
|
129
|
+
# Not using AWS config, just AWS_WEB_IDENTITY_TOKEN_FILE + AWS_ROLE_ARN
|
|
130
|
+
session = boto3.session.Session()
|
|
131
|
+
|
|
132
|
+
if role_arn and role_arn != USE_CSPR_ROLE_ARN_IF_SET:
|
|
133
|
+
# If the user provided a role_arn, we assume that role
|
|
134
|
+
# using the task role credentials. CSPR role is not used.
|
|
135
|
+
fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
|
|
136
|
+
client_creator=session._session.create_client,
|
|
137
|
+
source_credentials=session._session.get_credentials(),
|
|
138
|
+
role_arn=role_arn,
|
|
139
|
+
extra_args={},
|
|
140
|
+
)
|
|
141
|
+
creds = botocore.credentials.DeferredRefreshableCredentials(
|
|
142
|
+
method="assume-role", refresh_using=fetcher.fetch_credentials
|
|
143
|
+
)
|
|
144
|
+
botocore_session = botocore.session.Session(session_vars=session_vars)
|
|
145
|
+
botocore_session._credentials = creds
|
|
146
|
+
return boto3.session.Session(botocore_session=botocore_session)
|
|
147
|
+
else:
|
|
148
|
+
# If the user didn't provide a role_arn, or if the role_arn
|
|
149
|
+
# is set to USE_CSPR_ROLE_ARN_IF_SET, we return the default
|
|
150
|
+
# session which would use the CSPR role if it is set on the
|
|
151
|
+
# server, and the task role otherwise.
|
|
152
|
+
return session
|
|
153
|
+
|
|
154
|
+
|
|
35
155
|
class ObpAuthProvider(object):
|
|
36
156
|
name = "obp"
|
|
37
157
|
|
|
@@ -42,67 +162,13 @@ class ObpAuthProvider(object):
|
|
|
42
162
|
if client_params is None:
|
|
43
163
|
client_params = {}
|
|
44
164
|
|
|
45
|
-
import boto3
|
|
46
|
-
import botocore
|
|
47
165
|
from botocore.exceptions import ClientError
|
|
48
|
-
from metaflow_extensions.outerbounds.plugins.auth_server import get_token
|
|
49
|
-
|
|
50
|
-
from hashlib import sha256
|
|
51
|
-
from metaflow.util import get_username
|
|
52
|
-
|
|
53
|
-
user = get_username()
|
|
54
|
-
|
|
55
|
-
token_info = get_token("/generate/aws")
|
|
56
166
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
# anyways, so it doesn't matter that its a little ugly.
|
|
63
|
-
token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
|
|
64
|
-
|
|
65
|
-
# Write to a temp file then rename to avoid a situation when someone
|
|
66
|
-
# tries to read the file after it was open for writing (and truncated)
|
|
67
|
-
# but before the token was written to it.
|
|
68
|
-
with tempfile.NamedTemporaryFile("w", delete=False) as f:
|
|
69
|
-
f.write(token_info["token"])
|
|
70
|
-
tmp_token_file = f.name
|
|
71
|
-
os.rename(tmp_token_file, token_file)
|
|
72
|
-
|
|
73
|
-
os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
|
|
74
|
-
os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
|
|
75
|
-
|
|
76
|
-
# Enable regional STS endpoints. This is the new recommended way
|
|
77
|
-
# by AWS [1] and is the more performant way.
|
|
78
|
-
# [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
|
|
79
|
-
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
|
|
80
|
-
if token_info.get("region"):
|
|
81
|
-
os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
|
|
82
|
-
|
|
83
|
-
with hide_access_keys():
|
|
84
|
-
if role_arn:
|
|
85
|
-
session = boto3.session.Session()
|
|
86
|
-
fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
|
|
87
|
-
client_creator=session._session.create_client,
|
|
88
|
-
source_credentials=session._session.get_credentials(),
|
|
89
|
-
role_arn=role_arn,
|
|
90
|
-
extra_args={},
|
|
91
|
-
)
|
|
92
|
-
creds = botocore.credentials.DeferredRefreshableCredentials(
|
|
93
|
-
method="assume-role", refresh_using=fetcher.fetch_credentials
|
|
94
|
-
)
|
|
95
|
-
botocore_session = botocore.session.Session(session_vars=session_vars)
|
|
96
|
-
botocore_session._credentials = creds
|
|
97
|
-
session = boto3.session.Session(botocore_session=botocore_session)
|
|
98
|
-
if with_error:
|
|
99
|
-
return session.client(module, **client_params), ClientError
|
|
100
|
-
else:
|
|
101
|
-
return session.client(module, **client_params)
|
|
102
|
-
if with_error:
|
|
103
|
-
return boto3.client(module, **client_params), ClientError
|
|
104
|
-
else:
|
|
105
|
-
return boto3.client(module, **client_params)
|
|
167
|
+
session = get_boto3_session(role_arn, session_vars)
|
|
168
|
+
if with_error:
|
|
169
|
+
return session.client(module, **client_params), ClientError
|
|
170
|
+
else:
|
|
171
|
+
return session.client(module, **client_params)
|
|
106
172
|
|
|
107
173
|
|
|
108
174
|
AWS_CLIENT_PROVIDERS_DESC = [("obp", ".ObpAuthProvider")]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import signal
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
|
|
8
|
+
from metaflow.exception import MetaflowException
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HeartbeatStore(object):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
heartbeat_prefix,
|
|
15
|
+
main_pid=None,
|
|
16
|
+
storage_backend=None,
|
|
17
|
+
emit_frequency=30,
|
|
18
|
+
missed_heartbeat_timeout=60,
|
|
19
|
+
monitor_frequency=15,
|
|
20
|
+
max_missed_heartbeats=3,
|
|
21
|
+
) -> None:
|
|
22
|
+
self.heartbeat_prefix = heartbeat_prefix
|
|
23
|
+
self.main_pid = main_pid
|
|
24
|
+
self.storage_backend = storage_backend
|
|
25
|
+
self.emit_frequency = emit_frequency
|
|
26
|
+
self.monitor_frequency = monitor_frequency
|
|
27
|
+
self.missed_heartbeat_timeout = missed_heartbeat_timeout
|
|
28
|
+
self.max_missed_heartbeats = max_missed_heartbeats
|
|
29
|
+
self.missed_heartbeats = 0
|
|
30
|
+
|
|
31
|
+
def emit_heartbeat(self, folder_name=None):
|
|
32
|
+
heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
|
|
33
|
+
if folder_name:
|
|
34
|
+
heartbeat_key = f"{folder_name}/{heartbeat_key}"
|
|
35
|
+
|
|
36
|
+
while True:
|
|
37
|
+
try:
|
|
38
|
+
epoch_string = str(datetime.now(timezone.utc).timestamp()).encode(
|
|
39
|
+
"utf-8"
|
|
40
|
+
)
|
|
41
|
+
self.storage_backend.save_bytes(
|
|
42
|
+
[(heartbeat_key, BytesIO(bytes(epoch_string)))], overwrite=True
|
|
43
|
+
)
|
|
44
|
+
except Exception as e:
|
|
45
|
+
print(f"Error writing heartbeat: {e}")
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
time.sleep(self.emit_frequency)
|
|
49
|
+
|
|
50
|
+
def emit_tombstone(self, folder_name=None):
|
|
51
|
+
tombstone_key = f"{self.heartbeat_prefix}/tombstone"
|
|
52
|
+
if folder_name:
|
|
53
|
+
tombstone_key = f"{folder_name}/{tombstone_key}"
|
|
54
|
+
|
|
55
|
+
tombstone_string = "tombstone".encode("utf-8")
|
|
56
|
+
try:
|
|
57
|
+
self.storage_backend.save_bytes(
|
|
58
|
+
[(tombstone_key, BytesIO(bytes(tombstone_string)))], overwrite=True
|
|
59
|
+
)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
print(f"Error writing tombstone: {e}")
|
|
62
|
+
sys.exit(1)
|
|
63
|
+
|
|
64
|
+
def __handle_tombstone(self, path):
|
|
65
|
+
if path is not None:
|
|
66
|
+
with open(path) as f:
|
|
67
|
+
contents = f.read()
|
|
68
|
+
if "tombstone" in contents:
|
|
69
|
+
print("[Outerbounds] Tombstone detected. Terminating the task..")
|
|
70
|
+
os.kill(self.main_pid, signal.SIGTERM)
|
|
71
|
+
sys.exit(1)
|
|
72
|
+
|
|
73
|
+
def __handle_heartbeat(self, path):
|
|
74
|
+
if path is not None:
|
|
75
|
+
with open(path) as f:
|
|
76
|
+
contents = f.read()
|
|
77
|
+
current_timestamp = datetime.now(timezone.utc).timestamp()
|
|
78
|
+
if current_timestamp - float(contents) > self.missed_heartbeat_timeout:
|
|
79
|
+
self.missed_heartbeats += 1
|
|
80
|
+
else:
|
|
81
|
+
self.missed_heartbeats = 0
|
|
82
|
+
else:
|
|
83
|
+
self.missed_heartbeats += 1
|
|
84
|
+
|
|
85
|
+
if self.missed_heartbeats > self.max_missed_heartbeats:
|
|
86
|
+
print(
|
|
87
|
+
f"[Outerbounds] Missed {self.max_missed_heartbeats} consecutive heartbeats. Terminating the task.."
|
|
88
|
+
)
|
|
89
|
+
os.kill(self.main_pid, signal.SIGTERM)
|
|
90
|
+
sys.exit(1)
|
|
91
|
+
|
|
92
|
+
def is_main_process_running(self):
|
|
93
|
+
try:
|
|
94
|
+
# Check if the process is running
|
|
95
|
+
os.kill(self.main_pid, 0)
|
|
96
|
+
except ProcessLookupError:
|
|
97
|
+
return False
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def monitor(self, folder_name=None):
|
|
101
|
+
heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
|
|
102
|
+
if folder_name:
|
|
103
|
+
heartbeat_key = f"{folder_name}/{heartbeat_key}"
|
|
104
|
+
|
|
105
|
+
tombstone_key = f"{self.heartbeat_prefix}/tombstone"
|
|
106
|
+
if folder_name:
|
|
107
|
+
tombstone_key = f"{folder_name}/{tombstone_key}"
|
|
108
|
+
|
|
109
|
+
while self.is_main_process_running():
|
|
110
|
+
with self.storage_backend.load_bytes(
|
|
111
|
+
[heartbeat_key, tombstone_key]
|
|
112
|
+
) as results:
|
|
113
|
+
for key, path, _ in results:
|
|
114
|
+
if key == tombstone_key:
|
|
115
|
+
self.__handle_tombstone(path)
|
|
116
|
+
elif key == heartbeat_key:
|
|
117
|
+
self.__handle_heartbeat(path)
|
|
118
|
+
|
|
119
|
+
time.sleep(self.monitor_frequency)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if __name__ == "__main__":
|
|
123
|
+
from metaflow.plugins import DATASTORES
|
|
124
|
+
from metaflow.metaflow_config import NVIDIA_HEARTBEAT_THRESHOLD
|
|
125
|
+
|
|
126
|
+
if len(sys.argv) != 4:
|
|
127
|
+
print("Usage: heartbeat_store.py <main_pid> <datastore_type> <folder_name>")
|
|
128
|
+
sys.exit(1)
|
|
129
|
+
_, main_pid, datastore_type, folder_name = sys.argv
|
|
130
|
+
|
|
131
|
+
if datastore_type not in ("azure", "gs", "s3"):
|
|
132
|
+
print(f"Datastore unsupported for type: {datastore_type}")
|
|
133
|
+
sys.exit(1)
|
|
134
|
+
|
|
135
|
+
datastores = [d for d in DATASTORES if d.TYPE == datastore_type]
|
|
136
|
+
datastore_sysroot = datastores[0].get_datastore_root_from_config(
|
|
137
|
+
lambda *args, **kwargs: None
|
|
138
|
+
)
|
|
139
|
+
if datastore_sysroot is None:
|
|
140
|
+
raise MetaflowException(
|
|
141
|
+
msg="METAFLOW_DATASTORE_SYSROOT_{datastore_type} must be set!".format(
|
|
142
|
+
datastore_type=datastore_type.upper()
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
storage = datastores[0](datastore_sysroot)
|
|
147
|
+
|
|
148
|
+
heartbeat_prefix = f"{os.getenv('MF_PATHSPEC')}/{os.getenv('MF_ATTEMPT')}"
|
|
149
|
+
|
|
150
|
+
store = HeartbeatStore(
|
|
151
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
152
|
+
main_pid=int(main_pid),
|
|
153
|
+
storage_backend=storage,
|
|
154
|
+
max_missed_heartbeats=int(NVIDIA_HEARTBEAT_THRESHOLD),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
store.monitor(folder_name=folder_name)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
import
|
|
3
|
+
import threading
|
|
4
|
+
from urllib.parse import urlparse
|
|
4
5
|
from urllib.request import HTTPError, Request, URLError, urlopen
|
|
5
6
|
|
|
6
7
|
from metaflow import util
|
|
@@ -14,6 +15,7 @@ from metaflow.mflog import (
|
|
|
14
15
|
)
|
|
15
16
|
import requests
|
|
16
17
|
from metaflow.metaflow_config_funcs import init_config
|
|
18
|
+
from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import HeartbeatStore
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class NvcfException(MetaflowException):
|
|
@@ -58,10 +60,11 @@ class Nvcf(object):
|
|
|
58
60
|
code_package_url, code_package_ds
|
|
59
61
|
)
|
|
60
62
|
init_expr = " && ".join(init_cmds)
|
|
63
|
+
heartbeat_expr = f'python -m metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store "$MAIN_PID" {code_package_ds} nvcf_heartbeats 1>> $MFLOG_STDOUT 2>> $MFLOG_STDERR'
|
|
61
64
|
step_expr = bash_capture_logs(
|
|
62
65
|
" && ".join(
|
|
63
66
|
self.environment.bootstrap_commands(step_name, code_package_ds)
|
|
64
|
-
+ [step_cli]
|
|
67
|
+
+ [step_cli + " & MAIN_PID=$!; " + heartbeat_expr]
|
|
65
68
|
)
|
|
66
69
|
)
|
|
67
70
|
|
|
@@ -87,7 +90,9 @@ class Nvcf(object):
|
|
|
87
90
|
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
|
|
88
91
|
% cmd_str
|
|
89
92
|
)
|
|
90
|
-
self.job = Job(
|
|
93
|
+
self.job = Job(
|
|
94
|
+
'bash -c "%s"' % cmd_str, env, task_spec, self.datastore._storage_impl
|
|
95
|
+
)
|
|
91
96
|
self.job.submit()
|
|
92
97
|
|
|
93
98
|
def wait(self, stdout_location, stderr_location, echo=None):
|
|
@@ -137,8 +142,7 @@ result_endpoint = f"{nvcf_url}/v2/nvcf/pexec/status"
|
|
|
137
142
|
|
|
138
143
|
|
|
139
144
|
class Job(object):
|
|
140
|
-
def __init__(self, command, env):
|
|
141
|
-
|
|
145
|
+
def __init__(self, command, env, task_spec, backend):
|
|
142
146
|
self._payload = {
|
|
143
147
|
"command": command,
|
|
144
148
|
"env": {k: v for k, v in env.items() if v is not None},
|
|
@@ -170,6 +174,27 @@ class Job(object):
|
|
|
170
174
|
if f["model_key"] == "metaflow_task_executor":
|
|
171
175
|
self._function_id = f["id"]
|
|
172
176
|
|
|
177
|
+
flow_name = task_spec.get("flow_name")
|
|
178
|
+
run_id = task_spec.get("run_id")
|
|
179
|
+
step_name = task_spec.get("step_name")
|
|
180
|
+
task_id = task_spec.get("task_id")
|
|
181
|
+
retry_count = task_spec.get("retry_count")
|
|
182
|
+
|
|
183
|
+
heartbeat_prefix = "/".join(
|
|
184
|
+
(flow_name, str(run_id), step_name, str(task_id), str(retry_count))
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
store = HeartbeatStore(
|
|
188
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
189
|
+
main_pid=None,
|
|
190
|
+
storage_backend=backend,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.heartbeat_thread = threading.Thread(
|
|
194
|
+
target=store.emit_heartbeat, args=("nvcf_heartbeats",), daemon=True
|
|
195
|
+
)
|
|
196
|
+
self.heartbeat_thread.start()
|
|
197
|
+
|
|
173
198
|
def submit(self):
|
|
174
199
|
try:
|
|
175
200
|
headers = {
|
|
@@ -224,7 +249,6 @@ class Job(object):
|
|
|
224
249
|
|
|
225
250
|
def _poll(self):
|
|
226
251
|
try:
|
|
227
|
-
invocation_id = self._invocation_id
|
|
228
252
|
headers = {
|
|
229
253
|
"Authorization": f"Bearer {self._ngc_api_key}",
|
|
230
254
|
"Content-Type": "application/json",
|
|
@@ -4,7 +4,7 @@ import sys
|
|
|
4
4
|
import time
|
|
5
5
|
import traceback
|
|
6
6
|
|
|
7
|
-
from metaflow import util
|
|
7
|
+
from metaflow import util, Run
|
|
8
8
|
from metaflow._vendor import click
|
|
9
9
|
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
|
|
10
10
|
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
|
@@ -27,6 +27,7 @@ from metaflow.metaflow_config import (
|
|
|
27
27
|
CARD_GSROOT,
|
|
28
28
|
KUBERNETES_SANDBOX_INIT_SCRIPT,
|
|
29
29
|
OTEL_ENDPOINT,
|
|
30
|
+
NVIDIA_HEARTBEAT_THRESHOLD,
|
|
30
31
|
)
|
|
31
32
|
from metaflow.mflog import TASK_LOG_SOURCE
|
|
32
33
|
|
|
@@ -43,6 +44,65 @@ def nvcf():
|
|
|
43
44
|
pass
|
|
44
45
|
|
|
45
46
|
|
|
47
|
+
@nvcf.command(help="List steps / tasks running as an NVCF job.")
|
|
48
|
+
@click.argument("run-id")
|
|
49
|
+
@click.pass_context
|
|
50
|
+
def list(ctx, run_id):
|
|
51
|
+
flow_name = ctx.obj.flow.name
|
|
52
|
+
run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
|
|
53
|
+
running_invocations = []
|
|
54
|
+
for each_step in run_obj:
|
|
55
|
+
if (
|
|
56
|
+
not each_step.task.finished
|
|
57
|
+
and "nvcf-function-id" in each_step.task.metadata_dict
|
|
58
|
+
):
|
|
59
|
+
task_pathspec = each_step.task.pathspec
|
|
60
|
+
attempt = each_step.task.metadata_dict.get("attempt")
|
|
61
|
+
flow_name, run_id, step_name, task_id = task_pathspec.split("/")
|
|
62
|
+
running_invocations.append(
|
|
63
|
+
f"Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if running_invocations:
|
|
67
|
+
for each_invocation in running_invocations:
|
|
68
|
+
print(each_invocation)
|
|
69
|
+
else:
|
|
70
|
+
print("No running NVCF invocations for Run ID: %s" % run_id)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@nvcf.command(help="Kill steps / tasks running as an NVCF job.")
|
|
74
|
+
@click.argument("run-id")
|
|
75
|
+
@click.pass_context
|
|
76
|
+
def kill(ctx, run_id):
|
|
77
|
+
from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import (
|
|
78
|
+
HeartbeatStore,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
flow_name = ctx.obj.flow.name
|
|
82
|
+
run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
|
|
83
|
+
|
|
84
|
+
for each_step in run_obj:
|
|
85
|
+
if (
|
|
86
|
+
not each_step.task.finished
|
|
87
|
+
and "nvcf-function-id" in each_step.task.metadata_dict
|
|
88
|
+
):
|
|
89
|
+
task_pathspec = each_step.task.pathspec
|
|
90
|
+
attempt = each_step.task.metadata_dict.get("attempt")
|
|
91
|
+
heartbeat_prefix = "{task_pathspec}/{attempt}".format(
|
|
92
|
+
task_pathspec=task_pathspec, attempt=attempt
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
datastore_root = ctx.obj.datastore_impl.datastore_root
|
|
96
|
+
store = HeartbeatStore(
|
|
97
|
+
heartbeat_prefix=heartbeat_prefix,
|
|
98
|
+
main_pid=None,
|
|
99
|
+
storage_backend=ctx.obj.datastore_impl(datastore_root),
|
|
100
|
+
)
|
|
101
|
+
store.emit_tombstone(folder_name="nvcf_heartbeats")
|
|
102
|
+
else:
|
|
103
|
+
print("No running NVCF invocations for Run ID: %s" % run_id)
|
|
104
|
+
|
|
105
|
+
|
|
46
106
|
@nvcf.command(
|
|
47
107
|
help="Execute a single task using NVCF. This command calls the "
|
|
48
108
|
"top-level step command inside a NVCF job with the given options. "
|
|
@@ -139,6 +199,7 @@ def step(ctx, step_name, code_package_sha, code_package_url, **kwargs):
|
|
|
139
199
|
"METAFLOW_CARD_GSROOT": CARD_GSROOT,
|
|
140
200
|
"METAFLOW_INIT_SCRIPT": KUBERNETES_SANDBOX_INIT_SCRIPT,
|
|
141
201
|
"METAFLOW_OTEL_ENDPOINT": OTEL_ENDPOINT,
|
|
202
|
+
"METAFLOW_NVIDIA_HEARTBEAT_THRESHOLD": str(NVIDIA_HEARTBEAT_THRESHOLD),
|
|
142
203
|
}
|
|
143
204
|
|
|
144
205
|
env_deco = [deco for deco in node.decorators if deco.name == "environment"]
|
|
@@ -5,6 +5,49 @@
|
|
|
5
5
|
__version__ = "v1"
|
|
6
6
|
__mf_extensions__ = "ob"
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
|
|
9
|
+
# Must match the signature of metaflow.plugins.aws.aws_client.get_aws_client
|
|
10
|
+
# This function is called by the "userland" code inside tasks. Metaflow internals
|
|
11
|
+
# will call the function in metaflow.plugins.aws.aws_client.get_aws_client directly.
|
|
12
|
+
#
|
|
13
|
+
# Unlike the original function, this wrapper will use the CSPR role if both of the following
|
|
14
|
+
# conditions are met:
|
|
15
|
+
#
|
|
16
|
+
# 1. CSPR is set
|
|
17
|
+
# 2. user didn't provide a role to assume explicitly.
|
|
18
|
+
#
|
|
19
|
+
def get_aws_client(
|
|
20
|
+
module, with_error=False, role_arn=None, session_vars=None, client_params=None
|
|
21
|
+
):
|
|
22
|
+
import metaflow.plugins.aws.aws_client
|
|
23
|
+
|
|
24
|
+
from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
|
|
25
|
+
|
|
26
|
+
return metaflow.plugins.aws.aws_client.get_aws_client(
|
|
27
|
+
module,
|
|
28
|
+
with_error=with_error,
|
|
29
|
+
role_arn=role_arn or USE_CSPR_ROLE_ARN_IF_SET,
|
|
30
|
+
session_vars=session_vars,
|
|
31
|
+
client_params=client_params,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# This should match the signature of metaflow.plugins.datatools.s3.S3.
|
|
36
|
+
#
|
|
37
|
+
# This assumes that "userland" code inside tasks will call this, while Metaflow
|
|
38
|
+
# internals will call metaflow.plugins.datatools.s3.S3 directly.
|
|
39
|
+
#
|
|
40
|
+
# This wrapper will make S3() use the CSPR role if its set, and user didn't provide
|
|
41
|
+
# a role to assume explicitly.
|
|
42
|
+
def S3(*args, **kwargs):
|
|
43
|
+
import sys
|
|
44
|
+
import metaflow.plugins.datatools.s3
|
|
45
|
+
from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
|
|
46
|
+
|
|
47
|
+
if "role" not in kwargs or kwargs["role"] is None:
|
|
48
|
+
kwargs["role"] = USE_CSPR_ROLE_ARN_IF_SET
|
|
49
|
+
|
|
50
|
+
return metaflow.plugins.datatools.s3.S3(*args, **kwargs)
|
|
51
|
+
|
|
52
|
+
|
|
10
53
|
from .. import profilers
|
{ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/METADATA
RENAMED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ob-metaflow-extensions
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.86
|
|
4
4
|
Summary: Outerbounds Platform Extensions for Metaflow
|
|
5
5
|
Author: Outerbounds, Inc.
|
|
6
6
|
License: Commercial
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: boto3
|
|
9
9
|
Requires-Dist: kubernetes
|
|
10
|
-
Requires-Dist: ob-metaflow (==2.12.18.
|
|
10
|
+
Requires-Dist: ob-metaflow (==2.12.18.2)
|
|
11
11
|
|
|
12
12
|
# Outerbounds platform package
|
|
13
13
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
metaflow_extensions/outerbounds/__init__.py,sha256=TRGvIUMjkfneWtYUFSWoubu_Kf2ekAL4WLbV3IxOj9k,499
|
|
2
2
|
metaflow_extensions/outerbounds/remote_config.py,sha256=Zpfpjgz68_ZgxlXezjzlsDLo4840rkWuZgwDB_5H57U,4059
|
|
3
|
-
metaflow_extensions/outerbounds/config/__init__.py,sha256=
|
|
4
|
-
metaflow_extensions/outerbounds/plugins/__init__.py,sha256=
|
|
3
|
+
metaflow_extensions/outerbounds/config/__init__.py,sha256=JsQGRuGFz28fQWjUvxUgR8EKBLGRdLUIk_buPLJplJY,1225
|
|
4
|
+
metaflow_extensions/outerbounds/plugins/__init__.py,sha256=g07Xj7YifwLfTTa94cyf3kLebHezeLWNG9HoGnpDyMo,12519
|
|
5
5
|
metaflow_extensions/outerbounds/plugins/auth_server.py,sha256=1v2GBqoMBxp5E7Lejz139w-jxJtPnLDvvHXP0HhEIHI,2361
|
|
6
6
|
metaflow_extensions/outerbounds/plugins/perimeters.py,sha256=QXh3SFP7GQbS-RAIxUOPbhPzQ7KDFVxZkTdKqFKgXjI,2697
|
|
7
7
|
metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -14,8 +14,9 @@ metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py,sha256=g
|
|
|
14
14
|
metaflow_extensions/outerbounds/plugins/nim/__init__.py,sha256=GVnvSTjqYVj5oG2yh8KJFt7iZ33cEadDD5HbdmC9hJ0,1457
|
|
15
15
|
metaflow_extensions/outerbounds/plugins/nim/nim_manager.py,sha256=SWieODDxtIaeZwdMYtObDi57Kjyfw2DUuE6pJtU750w,9206
|
|
16
16
|
metaflow_extensions/outerbounds/plugins/nvcf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
-
metaflow_extensions/outerbounds/plugins/nvcf/
|
|
18
|
-
metaflow_extensions/outerbounds/plugins/nvcf/
|
|
17
|
+
metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py,sha256=wIlPBzsTszkHpftK1x7zBgaQ_7d3tNURqh4ez71Ra7A,5416
|
|
18
|
+
metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py,sha256=NIt1kJHuYpnCF7n73A90ZITWsk5QWtsbiHfzvdVjgqk,8997
|
|
19
|
+
metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py,sha256=AhlFa8UVzVKQRUilQwtwp7ZVWJ0WNitPU0vp29i7WuY,9545
|
|
19
20
|
metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py,sha256=0xNA4aRTPJ4SKpRIFKZzlL9a7lf367KGTrVWVXd-uGE,6052
|
|
20
21
|
metaflow_extensions/outerbounds/plugins/snowpark/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
22
|
metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py,sha256=vzgpVLCKvHjzHNfJvmH0jcxefYNsVggw_vof_y_U_a8,10643
|
|
@@ -28,11 +29,11 @@ metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py,sha256
|
|
|
28
29
|
metaflow_extensions/outerbounds/profilers/__init__.py,sha256=wa_jhnCBr82TBxoS0e8b6_6sLyZX0fdHicuGJZNTqKw,29
|
|
29
30
|
metaflow_extensions/outerbounds/profilers/gpu.py,sha256=a5YZAepujuP0uDqG9UpXBlZS3wjUt4Yv8CjybXqeT2c,24342
|
|
30
31
|
metaflow_extensions/outerbounds/toplevel/__init__.py,sha256=qWUJSv_r5hXJ7jV_On4nEasKIfUCm6_UjkjXWA_A1Ts,90
|
|
31
|
-
metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=
|
|
32
|
+
metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=Zq3OuL1bOod8KJra-Zk8B3gNhSHoWEGteM9T7g0pp6E,1881
|
|
32
33
|
metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py,sha256=WUuhz2YQfI4fz7nIcipwwWq781eaoHEk7n4GAn1npDg,63
|
|
33
34
|
metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py,sha256=BbZiaH3uILlEZ6ntBLKeNyqn3If8nIXZFq_Apd7Dhco,70
|
|
34
35
|
metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
|
|
35
|
-
ob_metaflow_extensions-1.1.
|
|
36
|
-
ob_metaflow_extensions-1.1.
|
|
37
|
-
ob_metaflow_extensions-1.1.
|
|
38
|
-
ob_metaflow_extensions-1.1.
|
|
36
|
+
ob_metaflow_extensions-1.1.86.dist-info/METADATA,sha256=LKi6v-1FMvAWE2u9eW2PKoVBi1xl6Wor1TjhwYVzPNk,520
|
|
37
|
+
ob_metaflow_extensions-1.1.86.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
|
38
|
+
ob_metaflow_extensions-1.1.86.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
|
|
39
|
+
ob_metaflow_extensions-1.1.86.dist-info/RECORD,,
|
|
File without changes
|
{ob_metaflow_extensions-1.1.83.dist-info → ob_metaflow_extensions-1.1.86.dist-info}/top_level.txt
RENAMED
|
File without changes
|