ob-metaflow-extensions 1.1.83__py2.py3-none-any.whl → 1.1.86__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -14,6 +14,13 @@ DEFAULT_GCP_CLIENT_PROVIDER = "obp"
14
14
  FAST_BAKERY_URL = from_conf("FAST_BAKERY_URL", None)
15
15
 
16
16
 
17
+ ###
18
+ # NVCF configuration
19
+ ###
20
+ # Maximum number of consecutive heartbeats that can be missed.
21
+ NVIDIA_HEARTBEAT_THRESHOLD = from_conf("NVIDIA_HEARTBEAT_THRESHOLD", "3")
22
+
23
+
17
24
  ###
18
25
  # Snowpark configuration
19
26
  ###
@@ -32,6 +32,126 @@ def hide_access_keys(*args, **kwds):
32
32
  os.environ["AWS_SESSION_TOKEN"] = AWS_SESSION_TOKEN
33
33
 
34
34
 
35
+ # This is a special placeholder value that can be passed as role_arn to
36
+ # get_boto3_session() which makes it use the CSPR role, if its set.
37
+ USE_CSPR_ROLE_ARN_IF_SET = "__cspr__"
38
+
39
+
40
+ def get_boto3_session(role_arn=None, session_vars=None):
41
+ import boto3
42
+ import botocore
43
+ from metaflow_extensions.outerbounds.plugins.auth_server import get_token
44
+
45
+ from hashlib import sha256
46
+ from metaflow.util import get_username
47
+
48
+ user = get_username()
49
+
50
+ token_info = get_token("/generate/aws")
51
+
52
+ # Write token to a file. The file name is derived from the user name
53
+ # so it works with multiple users on the same machine.
54
+ #
55
+ # We hash the user name so we don't have to deal with special characters
56
+ # in the file name and the file name is not exposed to the user
57
+ # anyways, so it doesn't matter that its a little ugly.
58
+ token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
59
+
60
+ # Write to a temp file then rename to avoid a situation when someone
61
+ # tries to read the file after it was open for writing (and truncated)
62
+ # but before the token was written to it.
63
+ with tempfile.NamedTemporaryFile("w", delete=False) as f:
64
+ f.write(token_info["token"])
65
+ tmp_token_file = f.name
66
+ os.rename(tmp_token_file, token_file)
67
+
68
+ cspr_role = None
69
+ if token_info.get("cspr_role_arn"):
70
+ cspr_role = token_info["cspr_role_arn"]
71
+
72
+ if cspr_role:
73
+ # If CSPR role is set, we set it as the default role to assume
74
+ # for the AWS SDK. We do this by writing an AWS config file
75
+ # with two profiles. One to get credentials for the task role
76
+ # in exchange for the OIDC token, and second to assume the
77
+ # CSPR role using the task role credentials.
78
+ import configparser
79
+ from io import StringIO
80
+
81
+ aws_config = configparser.ConfigParser()
82
+
83
+ # Task role profile
84
+ aws_config["profile task"] = {
85
+ "role_arn": token_info["role_arn"],
86
+ "web_identity_token_file": token_file,
87
+ }
88
+
89
+ # CSPR role profile (default)
90
+ aws_config["profile cspr"] = {
91
+ "role_arn": cspr_role,
92
+ "source_profile": "task",
93
+ }
94
+
95
+ aws_config_string = StringIO()
96
+ aws_config.write(aws_config_string)
97
+ aws_config_file = (
98
+ "/tmp/aws_config." + sha256(user.encode("utf-8")).hexdigest()[:16]
99
+ )
100
+ with tempfile.NamedTemporaryFile(
101
+ "w", delete=False, dir=os.path.dirname(aws_config_file)
102
+ ) as f:
103
+ f.write(aws_config_string.getvalue())
104
+ tmp_aws_config_file = f.name
105
+ os.rename(tmp_aws_config_file, aws_config_file)
106
+ os.environ["AWS_CONFIG_FILE"] = aws_config_file
107
+ os.environ["AWS_DEFAULT_PROFILE"] = "cspr"
108
+ else:
109
+ os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
110
+ os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
111
+
112
+ # Enable regional STS endpoints. This is the new recommended way
113
+ # by AWS [1] and is the more performant way.
114
+ # [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
115
+ os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
116
+ if token_info.get("region"):
117
+ os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
118
+
119
+ with hide_access_keys():
120
+ if cspr_role:
121
+ # The generated AWS config will be used here since we set the
122
+ # AWS_CONFIG_FILE environment variable above.
123
+ if role_arn == USE_CSPR_ROLE_ARN_IF_SET:
124
+ # Otherwise start from the default profile, assuming CSPR role
125
+ session = boto3.session.Session(profile_name="default")
126
+ else:
127
+ session = boto3.session.Session(profile_name="task")
128
+ else:
129
+ # Not using AWS config, just AWS_WEB_IDENTITY_TOKEN_FILE + AWS_ROLE_ARN
130
+ session = boto3.session.Session()
131
+
132
+ if role_arn and role_arn != USE_CSPR_ROLE_ARN_IF_SET:
133
+ # If the user provided a role_arn, we assume that role
134
+ # using the task role credentials. CSPR role is not used.
135
+ fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
136
+ client_creator=session._session.create_client,
137
+ source_credentials=session._session.get_credentials(),
138
+ role_arn=role_arn,
139
+ extra_args={},
140
+ )
141
+ creds = botocore.credentials.DeferredRefreshableCredentials(
142
+ method="assume-role", refresh_using=fetcher.fetch_credentials
143
+ )
144
+ botocore_session = botocore.session.Session(session_vars=session_vars)
145
+ botocore_session._credentials = creds
146
+ return boto3.session.Session(botocore_session=botocore_session)
147
+ else:
148
+ # If the user didn't provide a role_arn, or if the role_arn
149
+ # is set to USE_CSPR_ROLE_ARN_IF_SET, we return the default
150
+ # session which would use the CSPR role if it is set on the
151
+ # server, and the task role otherwise.
152
+ return session
153
+
154
+
35
155
  class ObpAuthProvider(object):
36
156
  name = "obp"
37
157
 
@@ -42,67 +162,13 @@ class ObpAuthProvider(object):
42
162
  if client_params is None:
43
163
  client_params = {}
44
164
 
45
- import boto3
46
- import botocore
47
165
  from botocore.exceptions import ClientError
48
- from metaflow_extensions.outerbounds.plugins.auth_server import get_token
49
-
50
- from hashlib import sha256
51
- from metaflow.util import get_username
52
-
53
- user = get_username()
54
-
55
- token_info = get_token("/generate/aws")
56
166
 
57
- # Write token to a file. The file name is derived from the user name
58
- # so it works with multiple users on the same machine.
59
- #
60
- # We hash the user name so we don't have to deal with special characters
61
- # in the file name and the file name is not exposed to the user
62
- # anyways, so it doesn't matter that its a little ugly.
63
- token_file = "/tmp/obp_token." + sha256(user.encode("utf-8")).hexdigest()[:16]
64
-
65
- # Write to a temp file then rename to avoid a situation when someone
66
- # tries to read the file after it was open for writing (and truncated)
67
- # but before the token was written to it.
68
- with tempfile.NamedTemporaryFile("w", delete=False) as f:
69
- f.write(token_info["token"])
70
- tmp_token_file = f.name
71
- os.rename(tmp_token_file, token_file)
72
-
73
- os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file
74
- os.environ["AWS_ROLE_ARN"] = token_info["role_arn"]
75
-
76
- # Enable regional STS endpoints. This is the new recommended way
77
- # by AWS [1] and is the more performant way.
78
- # [1] https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
79
- os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
80
- if token_info.get("region"):
81
- os.environ["AWS_DEFAULT_REGION"] = token_info["region"]
82
-
83
- with hide_access_keys():
84
- if role_arn:
85
- session = boto3.session.Session()
86
- fetcher = botocore.credentials.AssumeRoleCredentialFetcher(
87
- client_creator=session._session.create_client,
88
- source_credentials=session._session.get_credentials(),
89
- role_arn=role_arn,
90
- extra_args={},
91
- )
92
- creds = botocore.credentials.DeferredRefreshableCredentials(
93
- method="assume-role", refresh_using=fetcher.fetch_credentials
94
- )
95
- botocore_session = botocore.session.Session(session_vars=session_vars)
96
- botocore_session._credentials = creds
97
- session = boto3.session.Session(botocore_session=botocore_session)
98
- if with_error:
99
- return session.client(module, **client_params), ClientError
100
- else:
101
- return session.client(module, **client_params)
102
- if with_error:
103
- return boto3.client(module, **client_params), ClientError
104
- else:
105
- return boto3.client(module, **client_params)
167
+ session = get_boto3_session(role_arn, session_vars)
168
+ if with_error:
169
+ return session.client(module, **client_params), ClientError
170
+ else:
171
+ return session.client(module, **client_params)
106
172
 
107
173
 
108
174
  AWS_CLIENT_PROVIDERS_DESC = [("obp", ".ObpAuthProvider")]
@@ -0,0 +1,157 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ from io import BytesIO
6
+ from datetime import datetime, timezone
7
+
8
+ from metaflow.exception import MetaflowException
9
+
10
+
11
+ class HeartbeatStore(object):
12
+ def __init__(
13
+ self,
14
+ heartbeat_prefix,
15
+ main_pid=None,
16
+ storage_backend=None,
17
+ emit_frequency=30,
18
+ missed_heartbeat_timeout=60,
19
+ monitor_frequency=15,
20
+ max_missed_heartbeats=3,
21
+ ) -> None:
22
+ self.heartbeat_prefix = heartbeat_prefix
23
+ self.main_pid = main_pid
24
+ self.storage_backend = storage_backend
25
+ self.emit_frequency = emit_frequency
26
+ self.monitor_frequency = monitor_frequency
27
+ self.missed_heartbeat_timeout = missed_heartbeat_timeout
28
+ self.max_missed_heartbeats = max_missed_heartbeats
29
+ self.missed_heartbeats = 0
30
+
31
+ def emit_heartbeat(self, folder_name=None):
32
+ heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
33
+ if folder_name:
34
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
35
+
36
+ while True:
37
+ try:
38
+ epoch_string = str(datetime.now(timezone.utc).timestamp()).encode(
39
+ "utf-8"
40
+ )
41
+ self.storage_backend.save_bytes(
42
+ [(heartbeat_key, BytesIO(bytes(epoch_string)))], overwrite=True
43
+ )
44
+ except Exception as e:
45
+ print(f"Error writing heartbeat: {e}")
46
+ sys.exit(1)
47
+
48
+ time.sleep(self.emit_frequency)
49
+
50
+ def emit_tombstone(self, folder_name=None):
51
+ tombstone_key = f"{self.heartbeat_prefix}/tombstone"
52
+ if folder_name:
53
+ tombstone_key = f"{folder_name}/{tombstone_key}"
54
+
55
+ tombstone_string = "tombstone".encode("utf-8")
56
+ try:
57
+ self.storage_backend.save_bytes(
58
+ [(tombstone_key, BytesIO(bytes(tombstone_string)))], overwrite=True
59
+ )
60
+ except Exception as e:
61
+ print(f"Error writing tombstone: {e}")
62
+ sys.exit(1)
63
+
64
+ def __handle_tombstone(self, path):
65
+ if path is not None:
66
+ with open(path) as f:
67
+ contents = f.read()
68
+ if "tombstone" in contents:
69
+ print("[Outerbounds] Tombstone detected. Terminating the task..")
70
+ os.kill(self.main_pid, signal.SIGTERM)
71
+ sys.exit(1)
72
+
73
+ def __handle_heartbeat(self, path):
74
+ if path is not None:
75
+ with open(path) as f:
76
+ contents = f.read()
77
+ current_timestamp = datetime.now(timezone.utc).timestamp()
78
+ if current_timestamp - float(contents) > self.missed_heartbeat_timeout:
79
+ self.missed_heartbeats += 1
80
+ else:
81
+ self.missed_heartbeats = 0
82
+ else:
83
+ self.missed_heartbeats += 1
84
+
85
+ if self.missed_heartbeats > self.max_missed_heartbeats:
86
+ print(
87
+ f"[Outerbounds] Missed {self.max_missed_heartbeats} consecutive heartbeats. Terminating the task.."
88
+ )
89
+ os.kill(self.main_pid, signal.SIGTERM)
90
+ sys.exit(1)
91
+
92
+ def is_main_process_running(self):
93
+ try:
94
+ # Check if the process is running
95
+ os.kill(self.main_pid, 0)
96
+ except ProcessLookupError:
97
+ return False
98
+ return True
99
+
100
+ def monitor(self, folder_name=None):
101
+ heartbeat_key = f"{self.heartbeat_prefix}/heartbeat"
102
+ if folder_name:
103
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
104
+
105
+ tombstone_key = f"{self.heartbeat_prefix}/tombstone"
106
+ if folder_name:
107
+ tombstone_key = f"{folder_name}/{tombstone_key}"
108
+
109
+ while self.is_main_process_running():
110
+ with self.storage_backend.load_bytes(
111
+ [heartbeat_key, tombstone_key]
112
+ ) as results:
113
+ for key, path, _ in results:
114
+ if key == tombstone_key:
115
+ self.__handle_tombstone(path)
116
+ elif key == heartbeat_key:
117
+ self.__handle_heartbeat(path)
118
+
119
+ time.sleep(self.monitor_frequency)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from metaflow.plugins import DATASTORES
124
+ from metaflow.metaflow_config import NVIDIA_HEARTBEAT_THRESHOLD
125
+
126
+ if len(sys.argv) != 4:
127
+ print("Usage: heartbeat_store.py <main_pid> <datastore_type> <folder_name>")
128
+ sys.exit(1)
129
+ _, main_pid, datastore_type, folder_name = sys.argv
130
+
131
+ if datastore_type not in ("azure", "gs", "s3"):
132
+ print(f"Datastore unsupported for type: {datastore_type}")
133
+ sys.exit(1)
134
+
135
+ datastores = [d for d in DATASTORES if d.TYPE == datastore_type]
136
+ datastore_sysroot = datastores[0].get_datastore_root_from_config(
137
+ lambda *args, **kwargs: None
138
+ )
139
+ if datastore_sysroot is None:
140
+ raise MetaflowException(
141
+ msg="METAFLOW_DATASTORE_SYSROOT_{datastore_type} must be set!".format(
142
+ datastore_type=datastore_type.upper()
143
+ )
144
+ )
145
+
146
+ storage = datastores[0](datastore_sysroot)
147
+
148
+ heartbeat_prefix = f"{os.getenv('MF_PATHSPEC')}/{os.getenv('MF_ATTEMPT')}"
149
+
150
+ store = HeartbeatStore(
151
+ heartbeat_prefix=heartbeat_prefix,
152
+ main_pid=int(main_pid),
153
+ storage_backend=storage,
154
+ max_missed_heartbeats=int(NVIDIA_HEARTBEAT_THRESHOLD),
155
+ )
156
+
157
+ store.monitor(folder_name=folder_name)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import os
3
- import time
3
+ import threading
4
+ from urllib.parse import urlparse
4
5
  from urllib.request import HTTPError, Request, URLError, urlopen
5
6
 
6
7
  from metaflow import util
@@ -14,6 +15,7 @@ from metaflow.mflog import (
14
15
  )
15
16
  import requests
16
17
  from metaflow.metaflow_config_funcs import init_config
18
+ from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import HeartbeatStore
17
19
 
18
20
 
19
21
  class NvcfException(MetaflowException):
@@ -58,10 +60,11 @@ class Nvcf(object):
58
60
  code_package_url, code_package_ds
59
61
  )
60
62
  init_expr = " && ".join(init_cmds)
63
+ heartbeat_expr = f'python -m metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store "$MAIN_PID" {code_package_ds} nvcf_heartbeats 1>> $MFLOG_STDOUT 2>> $MFLOG_STDERR'
61
64
  step_expr = bash_capture_logs(
62
65
  " && ".join(
63
66
  self.environment.bootstrap_commands(step_name, code_package_ds)
64
- + [step_cli]
67
+ + [step_cli + " & MAIN_PID=$!; " + heartbeat_expr]
65
68
  )
66
69
  )
67
70
 
@@ -87,7 +90,9 @@ class Nvcf(object):
87
90
  '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
88
91
  % cmd_str
89
92
  )
90
- self.job = Job('bash -c "%s"' % cmd_str, env)
93
+ self.job = Job(
94
+ 'bash -c "%s"' % cmd_str, env, task_spec, self.datastore._storage_impl
95
+ )
91
96
  self.job.submit()
92
97
 
93
98
  def wait(self, stdout_location, stderr_location, echo=None):
@@ -137,8 +142,7 @@ result_endpoint = f"{nvcf_url}/v2/nvcf/pexec/status"
137
142
 
138
143
 
139
144
  class Job(object):
140
- def __init__(self, command, env):
141
-
145
+ def __init__(self, command, env, task_spec, backend):
142
146
  self._payload = {
143
147
  "command": command,
144
148
  "env": {k: v for k, v in env.items() if v is not None},
@@ -170,6 +174,27 @@ class Job(object):
170
174
  if f["model_key"] == "metaflow_task_executor":
171
175
  self._function_id = f["id"]
172
176
 
177
+ flow_name = task_spec.get("flow_name")
178
+ run_id = task_spec.get("run_id")
179
+ step_name = task_spec.get("step_name")
180
+ task_id = task_spec.get("task_id")
181
+ retry_count = task_spec.get("retry_count")
182
+
183
+ heartbeat_prefix = "/".join(
184
+ (flow_name, str(run_id), step_name, str(task_id), str(retry_count))
185
+ )
186
+
187
+ store = HeartbeatStore(
188
+ heartbeat_prefix=heartbeat_prefix,
189
+ main_pid=None,
190
+ storage_backend=backend,
191
+ )
192
+
193
+ self.heartbeat_thread = threading.Thread(
194
+ target=store.emit_heartbeat, args=("nvcf_heartbeats",), daemon=True
195
+ )
196
+ self.heartbeat_thread.start()
197
+
173
198
  def submit(self):
174
199
  try:
175
200
  headers = {
@@ -224,7 +249,6 @@ class Job(object):
224
249
 
225
250
  def _poll(self):
226
251
  try:
227
- invocation_id = self._invocation_id
228
252
  headers = {
229
253
  "Authorization": f"Bearer {self._ngc_api_key}",
230
254
  "Content-Type": "application/json",
@@ -4,7 +4,7 @@ import sys
4
4
  import time
5
5
  import traceback
6
6
 
7
- from metaflow import util
7
+ from metaflow import util, Run
8
8
  from metaflow._vendor import click
9
9
  from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
10
10
  from metaflow.metadata.util import sync_local_metadata_from_datastore
@@ -27,6 +27,7 @@ from metaflow.metaflow_config import (
27
27
  CARD_GSROOT,
28
28
  KUBERNETES_SANDBOX_INIT_SCRIPT,
29
29
  OTEL_ENDPOINT,
30
+ NVIDIA_HEARTBEAT_THRESHOLD,
30
31
  )
31
32
  from metaflow.mflog import TASK_LOG_SOURCE
32
33
 
@@ -43,6 +44,65 @@ def nvcf():
43
44
  pass
44
45
 
45
46
 
47
+ @nvcf.command(help="List steps / tasks running as an NVCF job.")
48
+ @click.argument("run-id")
49
+ @click.pass_context
50
+ def list(ctx, run_id):
51
+ flow_name = ctx.obj.flow.name
52
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
53
+ running_invocations = []
54
+ for each_step in run_obj:
55
+ if (
56
+ not each_step.task.finished
57
+ and "nvcf-function-id" in each_step.task.metadata_dict
58
+ ):
59
+ task_pathspec = each_step.task.pathspec
60
+ attempt = each_step.task.metadata_dict.get("attempt")
61
+ flow_name, run_id, step_name, task_id = task_pathspec.split("/")
62
+ running_invocations.append(
63
+ f"Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt}"
64
+ )
65
+
66
+ if running_invocations:
67
+ for each_invocation in running_invocations:
68
+ print(each_invocation)
69
+ else:
70
+ print("No running NVCF invocations for Run ID: %s" % run_id)
71
+
72
+
73
+ @nvcf.command(help="Kill steps / tasks running as an NVCF job.")
74
+ @click.argument("run-id")
75
+ @click.pass_context
76
+ def kill(ctx, run_id):
77
+ from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import (
78
+ HeartbeatStore,
79
+ )
80
+
81
+ flow_name = ctx.obj.flow.name
82
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
83
+
84
+ for each_step in run_obj:
85
+ if (
86
+ not each_step.task.finished
87
+ and "nvcf-function-id" in each_step.task.metadata_dict
88
+ ):
89
+ task_pathspec = each_step.task.pathspec
90
+ attempt = each_step.task.metadata_dict.get("attempt")
91
+ heartbeat_prefix = "{task_pathspec}/{attempt}".format(
92
+ task_pathspec=task_pathspec, attempt=attempt
93
+ )
94
+
95
+ datastore_root = ctx.obj.datastore_impl.datastore_root
96
+ store = HeartbeatStore(
97
+ heartbeat_prefix=heartbeat_prefix,
98
+ main_pid=None,
99
+ storage_backend=ctx.obj.datastore_impl(datastore_root),
100
+ )
101
+ store.emit_tombstone(folder_name="nvcf_heartbeats")
102
+ else:
103
+ print("No running NVCF invocations for Run ID: %s" % run_id)
104
+
105
+
46
106
  @nvcf.command(
47
107
  help="Execute a single task using NVCF. This command calls the "
48
108
  "top-level step command inside a NVCF job with the given options. "
@@ -139,6 +199,7 @@ def step(ctx, step_name, code_package_sha, code_package_url, **kwargs):
139
199
  "METAFLOW_CARD_GSROOT": CARD_GSROOT,
140
200
  "METAFLOW_INIT_SCRIPT": KUBERNETES_SANDBOX_INIT_SCRIPT,
141
201
  "METAFLOW_OTEL_ENDPOINT": OTEL_ENDPOINT,
202
+ "METAFLOW_NVIDIA_HEARTBEAT_THRESHOLD": str(NVIDIA_HEARTBEAT_THRESHOLD),
142
203
  }
143
204
 
144
205
  env_deco = [deco for deco in node.decorators if deco.name == "environment"]
@@ -5,6 +5,49 @@
5
5
  __version__ = "v1"
6
6
  __mf_extensions__ = "ob"
7
7
 
8
- # To support "from metaflow import get_aws_client"
9
- from metaflow.plugins.aws.aws_client import get_aws_client
8
+
9
+ # Must match the signature of metaflow.plugins.aws.aws_client.get_aws_client
10
+ # This function is called by the "userland" code inside tasks. Metaflow internals
11
+ # will call the function in metaflow.plugins.aws.aws_client.get_aws_client directly.
12
+ #
13
+ # Unlike the original function, this wrapper will use the CSPR role if both of the following
14
+ # conditions are met:
15
+ #
16
+ # 1. CSPR is set
17
+ # 2. user didn't provide a role to assume explicitly.
18
+ #
19
+ def get_aws_client(
20
+ module, with_error=False, role_arn=None, session_vars=None, client_params=None
21
+ ):
22
+ import metaflow.plugins.aws.aws_client
23
+
24
+ from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
25
+
26
+ return metaflow.plugins.aws.aws_client.get_aws_client(
27
+ module,
28
+ with_error=with_error,
29
+ role_arn=role_arn or USE_CSPR_ROLE_ARN_IF_SET,
30
+ session_vars=session_vars,
31
+ client_params=client_params,
32
+ )
33
+
34
+
35
+ # This should match the signature of metaflow.plugins.datatools.s3.S3.
36
+ #
37
+ # This assumes that "userland" code inside tasks will call this, while Metaflow
38
+ # internals will call metaflow.plugins.datatools.s3.S3 directly.
39
+ #
40
+ # This wrapper will make S3() use the CSPR role if its set, and user didn't provide
41
+ # a role to assume explicitly.
42
+ def S3(*args, **kwargs):
43
+ import sys
44
+ import metaflow.plugins.datatools.s3
45
+ from metaflow_extensions.outerbounds.plugins import USE_CSPR_ROLE_ARN_IF_SET
46
+
47
+ if "role" not in kwargs or kwargs["role"] is None:
48
+ kwargs["role"] = USE_CSPR_ROLE_ARN_IF_SET
49
+
50
+ return metaflow.plugins.datatools.s3.S3(*args, **kwargs)
51
+
52
+
10
53
  from .. import profilers
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.83
3
+ Version: 1.1.86
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: boto3
9
9
  Requires-Dist: kubernetes
10
- Requires-Dist: ob-metaflow (==2.12.18.1)
10
+ Requires-Dist: ob-metaflow (==2.12.18.2)
11
11
 
12
12
  # Outerbounds platform package
13
13
 
@@ -1,7 +1,7 @@
1
1
  metaflow_extensions/outerbounds/__init__.py,sha256=TRGvIUMjkfneWtYUFSWoubu_Kf2ekAL4WLbV3IxOj9k,499
2
2
  metaflow_extensions/outerbounds/remote_config.py,sha256=Zpfpjgz68_ZgxlXezjzlsDLo4840rkWuZgwDB_5H57U,4059
3
- metaflow_extensions/outerbounds/config/__init__.py,sha256=MwC9dK3A5waKt-DOdHJMw-7sA5Zrl89uLmYJiM3mucc,1057
4
- metaflow_extensions/outerbounds/plugins/__init__.py,sha256=Y6Y2RlZFW5RwZjXa5QrKptht-u1p8faJxFFaA2n9Jy8,10074
3
+ metaflow_extensions/outerbounds/config/__init__.py,sha256=JsQGRuGFz28fQWjUvxUgR8EKBLGRdLUIk_buPLJplJY,1225
4
+ metaflow_extensions/outerbounds/plugins/__init__.py,sha256=g07Xj7YifwLfTTa94cyf3kLebHezeLWNG9HoGnpDyMo,12519
5
5
  metaflow_extensions/outerbounds/plugins/auth_server.py,sha256=1v2GBqoMBxp5E7Lejz139w-jxJtPnLDvvHXP0HhEIHI,2361
6
6
  metaflow_extensions/outerbounds/plugins/perimeters.py,sha256=QXh3SFP7GQbS-RAIxUOPbhPzQ7KDFVxZkTdKqFKgXjI,2697
7
7
  metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -14,8 +14,9 @@ metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py,sha256=g
14
14
  metaflow_extensions/outerbounds/plugins/nim/__init__.py,sha256=GVnvSTjqYVj5oG2yh8KJFt7iZ33cEadDD5HbdmC9hJ0,1457
15
15
  metaflow_extensions/outerbounds/plugins/nim/nim_manager.py,sha256=SWieODDxtIaeZwdMYtObDi57Kjyfw2DUuE6pJtU750w,9206
16
16
  metaflow_extensions/outerbounds/plugins/nvcf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py,sha256=ftxC5SCo64P5Ycpv5vudluTnQi3-VCZW0umdsPP326A,7926
18
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py,sha256=ow3lonclEDoZEUQCDV_L8lEr6HopXqjNXzubRrfdIm4,7219
17
+ metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py,sha256=wIlPBzsTszkHpftK1x7zBgaQ_7d3tNURqh4ez71Ra7A,5416
18
+ metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py,sha256=NIt1kJHuYpnCF7n73A90ZITWsk5QWtsbiHfzvdVjgqk,8997
19
+ metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py,sha256=AhlFa8UVzVKQRUilQwtwp7ZVWJ0WNitPU0vp29i7WuY,9545
19
20
  metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py,sha256=0xNA4aRTPJ4SKpRIFKZzlL9a7lf367KGTrVWVXd-uGE,6052
20
21
  metaflow_extensions/outerbounds/plugins/snowpark/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
22
  metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py,sha256=vzgpVLCKvHjzHNfJvmH0jcxefYNsVggw_vof_y_U_a8,10643
@@ -28,11 +29,11 @@ metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py,sha256
28
29
  metaflow_extensions/outerbounds/profilers/__init__.py,sha256=wa_jhnCBr82TBxoS0e8b6_6sLyZX0fdHicuGJZNTqKw,29
29
30
  metaflow_extensions/outerbounds/profilers/gpu.py,sha256=a5YZAepujuP0uDqG9UpXBlZS3wjUt4Yv8CjybXqeT2c,24342
30
31
  metaflow_extensions/outerbounds/toplevel/__init__.py,sha256=qWUJSv_r5hXJ7jV_On4nEasKIfUCm6_UjkjXWA_A1Ts,90
31
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=efl01b4O7mRXerjOtLUD-CQ2l7ZGG78iyEXRGMAJYsU,412
32
+ metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py,sha256=Zq3OuL1bOod8KJra-Zk8B3gNhSHoWEGteM9T7g0pp6E,1881
32
33
  metaflow_extensions/outerbounds/toplevel/plugins/azure/__init__.py,sha256=WUuhz2YQfI4fz7nIcipwwWq781eaoHEk7n4GAn1npDg,63
33
34
  metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py,sha256=BbZiaH3uILlEZ6ntBLKeNyqn3If8nIXZFq_Apd7Dhco,70
34
35
  metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
35
- ob_metaflow_extensions-1.1.83.dist-info/METADATA,sha256=jjlyH-CO-VHAH8Q3T8THUvoPPfRf6OIreJNk8Otjj1c,520
36
- ob_metaflow_extensions-1.1.83.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
37
- ob_metaflow_extensions-1.1.83.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
38
- ob_metaflow_extensions-1.1.83.dist-info/RECORD,,
36
+ ob_metaflow_extensions-1.1.86.dist-info/METADATA,sha256=LKi6v-1FMvAWE2u9eW2PKoVBi1xl6Wor1TjhwYVzPNk,520
37
+ ob_metaflow_extensions-1.1.86.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
38
+ ob_metaflow_extensions-1.1.86.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
39
+ ob_metaflow_extensions-1.1.86.dist-info/RECORD,,