metaflow 2.11.15__py2.py3-none-any.whl → 2.11.16__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow/__init__.py +3 -0
- metaflow/clone_util.py +6 -0
- metaflow/extension_support/plugins.py +2 -0
- metaflow/metaflow_config.py +24 -0
- metaflow/metaflow_environment.py +2 -2
- metaflow/plugins/__init__.py +19 -0
- metaflow/plugins/airflow/airflow.py +7 -0
- metaflow/plugins/argo/argo_workflows.py +17 -0
- metaflow/plugins/azure/__init__.py +3 -0
- metaflow/plugins/azure/azure_credential.py +53 -0
- metaflow/plugins/azure/azure_exceptions.py +1 -1
- metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
- metaflow/plugins/azure/azure_utils.py +2 -35
- metaflow/plugins/azure/blob_service_client_factory.py +4 -2
- metaflow/plugins/datastores/azure_storage.py +6 -6
- metaflow/plugins/datatools/s3/s3.py +1 -1
- metaflow/plugins/gcp/__init__.py +1 -0
- metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +169 -0
- metaflow/plugins/gcp/gs_storage_client_factory.py +52 -1
- metaflow/plugins/kubernetes/kubernetes.py +85 -8
- metaflow/plugins/kubernetes/kubernetes_cli.py +24 -1
- metaflow/plugins/kubernetes/kubernetes_client.py +4 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +49 -4
- metaflow/plugins/kubernetes/kubernetes_job.py +208 -201
- metaflow/plugins/kubernetes/kubernetes_jobsets.py +784 -0
- metaflow/plugins/timeout_decorator.py +2 -1
- metaflow/task.py +1 -12
- metaflow/tuple_util.py +27 -0
- metaflow/util.py +0 -15
- metaflow/version.py +1 -1
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/METADATA +2 -2
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/RECORD +36 -31
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/LICENSE +0 -0
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/WHEEL +0 -0
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/entry_points.txt +0 -0
- {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
3
|
+
from json import JSONDecodeError
|
4
|
+
|
5
|
+
|
6
|
+
from metaflow.exception import MetaflowException
|
7
|
+
from metaflow.plugins.secrets import SecretsProvider
|
8
|
+
import re
|
9
|
+
from metaflow.plugins.gcp.gs_storage_client_factory import get_credentials
|
10
|
+
from metaflow.metaflow_config import GCP_SECRET_MANAGER_PREFIX
|
11
|
+
|
12
|
+
|
13
|
+
class MetaflowGcpSecretsManagerBadResponse(MetaflowException):
|
14
|
+
"""Raised when the response from GCP Secrets Manager is not valid in some way"""
|
15
|
+
|
16
|
+
|
17
|
+
class MetaflowGcpSecretsManagerDuplicateKey(MetaflowException):
|
18
|
+
"""Raised when the response from GCP Secrets Manager contains duplicate keys"""
|
19
|
+
|
20
|
+
|
21
|
+
class MetaflowGcpSecretsManagerJSONParseError(MetaflowException):
|
22
|
+
"""Raised when the SecretString response from GCP Secrets Manager is not valid JSON"""
|
23
|
+
|
24
|
+
|
25
|
+
class MetaflowGcpSecretsManagerNotJSONObject(MetaflowException):
|
26
|
+
"""Raised when the SecretString response from GCP Secrets Manager is not valid JSON dictionary"""
|
27
|
+
|
28
|
+
|
29
|
+
def _sanitize_key_as_env_var(key):
|
30
|
+
"""
|
31
|
+
Sanitize a key as an environment variable name.
|
32
|
+
This is purely a convenience trade-off to cover common cases well, vs. introducing
|
33
|
+
ambiguities (e.g. did the final '_' come from '.', or '-' or is original?).
|
34
|
+
|
35
|
+
1/27/2023(jackie):
|
36
|
+
|
37
|
+
We start with few rules and should *sparingly* add more over time.
|
38
|
+
Also, it's TBD whether all possible providers will share the same sanitization logic.
|
39
|
+
Therefore we will keep this function private for now
|
40
|
+
"""
|
41
|
+
return key.replace("-", "_").replace(".", "_").replace("/", "_")
|
42
|
+
|
43
|
+
|
44
|
+
class GcpSecretManagerSecretsProvider(SecretsProvider):
|
45
|
+
TYPE = "gcp-secret-manager"
|
46
|
+
|
47
|
+
def get_secret_as_dict(self, secret_id, options={}, role=None):
|
48
|
+
"""
|
49
|
+
Reads a secret from GCP Secrets Manager and returns it as a dictionary of environment variables.
|
50
|
+
|
51
|
+
If the secret contains a string payload ("SecretString"):
|
52
|
+
- if the `json` option is True:
|
53
|
+
Secret will be parsed as a JSON. If successfully parsed, AND the JSON contains a
|
54
|
+
top-level object, each entry K/V in the object will also be converted to an entry in the result. V will
|
55
|
+
always be casted to a string (if not already a string).
|
56
|
+
- If `json` option is False (default):
|
57
|
+
Will be returned as a single entry in the result, with the key being the last part after / in secret_id.
|
58
|
+
|
59
|
+
On GCP Secrets Manager, the secret payload is a binary blob. However, by default we interpret it as UTF8 encoded
|
60
|
+
string. To disable this, set the `binary` option to True, the binary will be base64 encoded in the result.
|
61
|
+
|
62
|
+
All keys in the result are sanitized to be more valid environment variable names. This is done on a best effort
|
63
|
+
basis. Further validation is expected to be done by the invoking @secrets decorator itself.
|
64
|
+
|
65
|
+
:param secret_id: GCP Secrets Manager secret ID
|
66
|
+
:param options: unused
|
67
|
+
:return: dict of environment variables. All keys and values are strings.
|
68
|
+
"""
|
69
|
+
from google.cloud.secretmanager_v1.services.secret_manager_service import (
|
70
|
+
SecretManagerServiceClient,
|
71
|
+
)
|
72
|
+
from google.cloud.secretmanager_v1.services.secret_manager_service.transports import (
|
73
|
+
SecretManagerServiceTransport,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Full secret id looks like projects/1234567890/secrets/mysecret/versions/latest
|
77
|
+
#
|
78
|
+
# We allow these forms of secret_id:
|
79
|
+
#
|
80
|
+
# 1. Full path like projects/1234567890/secrets/mysecret/versions/latest
|
81
|
+
# This is what you'd specify if you used to GCP SDK.
|
82
|
+
#
|
83
|
+
# 2. Full path but without the version like projects/1234567890/secrets/mysecret.
|
84
|
+
# This is what you see in the GCP console, makes it easier to copy & paste.
|
85
|
+
#
|
86
|
+
# 3. Simple string like mysecret
|
87
|
+
#
|
88
|
+
# 4. Simple string with /versions/<version> suffix like mysecret/versions/1
|
89
|
+
|
90
|
+
# The latter two forms require METAFLOW_GCP_SECRET_MANAGER_PREFIX to be set.
|
91
|
+
|
92
|
+
match_full = re.match(
|
93
|
+
r"^projects/\d+/secrets/([\w\-]+)(/versions/([\w\-]+))?$", secret_id
|
94
|
+
)
|
95
|
+
match_partial = re.match(r"^([\w\-]+)(/versions/[\w\-]+)?$", secret_id)
|
96
|
+
if match_full:
|
97
|
+
# Full path
|
98
|
+
env_var_name = match_full.group(1)
|
99
|
+
if match_full.group(3):
|
100
|
+
# With version specified
|
101
|
+
full_secret_name = secret_id
|
102
|
+
else:
|
103
|
+
# No version specified, use latest
|
104
|
+
full_secret_name = secret_id + "/versions/latest"
|
105
|
+
elif match_partial:
|
106
|
+
# Partial path, possibly with /versions/<version> suffix
|
107
|
+
env_var_name = secret_id
|
108
|
+
if not GCP_SECRET_MANAGER_PREFIX:
|
109
|
+
raise ValueError(
|
110
|
+
"Cannot use simple secret_id without setting METAFLOW_GCP_SECRET_MANAGER_PREFIX. %s"
|
111
|
+
% GCP_SECRET_MANAGER_PREFIX
|
112
|
+
)
|
113
|
+
if match_partial.group(2):
|
114
|
+
# With version specified
|
115
|
+
full_secret_name = "%s%s" % (GCP_SECRET_MANAGER_PREFIX, secret_id)
|
116
|
+
env_var_name = match_partial.group(1)
|
117
|
+
else:
|
118
|
+
# No version specified, use latest
|
119
|
+
full_secret_name = "%s%s/versions/latest" % (
|
120
|
+
GCP_SECRET_MANAGER_PREFIX,
|
121
|
+
secret_id,
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
raise ValueError(
|
125
|
+
"Invalid secret_id: %s. Must be either a full path or a simple string."
|
126
|
+
% secret_id
|
127
|
+
)
|
128
|
+
|
129
|
+
result = {}
|
130
|
+
|
131
|
+
def _sanitize_and_add_entry_to_result(k, v):
|
132
|
+
# Two jobs - sanitize, and check for dupes
|
133
|
+
sanitized_k = _sanitize_key_as_env_var(k)
|
134
|
+
if sanitized_k in result:
|
135
|
+
raise MetaflowGcpSecretsManagerDuplicateKey(
|
136
|
+
"Duplicate key in secret: '%s' (sanitizes to '%s')"
|
137
|
+
% (k, sanitized_k)
|
138
|
+
)
|
139
|
+
result[sanitized_k] = v
|
140
|
+
|
141
|
+
credentials, _ = get_credentials(
|
142
|
+
scopes=SecretManagerServiceTransport.AUTH_SCOPES
|
143
|
+
)
|
144
|
+
client = SecretManagerServiceClient(credentials=credentials)
|
145
|
+
response = client.access_secret_version(request={"name": full_secret_name})
|
146
|
+
payload_str = response.payload.data.decode("UTF-8")
|
147
|
+
if options.get("json", False):
|
148
|
+
obj = json.loads(payload_str)
|
149
|
+
if type(obj) == dict:
|
150
|
+
for k, v in obj.items():
|
151
|
+
# We try to make it work here - cast to string always
|
152
|
+
_sanitize_and_add_entry_to_result(k, str(v))
|
153
|
+
else:
|
154
|
+
raise MetaflowGcpSecretsManagerNotJSONObject(
|
155
|
+
"Secret string is a JSON, but not an object (dict-like) - actual type %s."
|
156
|
+
% type(obj)
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
if options.get("env_var_name"):
|
160
|
+
env_var_name = options["env_var_name"]
|
161
|
+
|
162
|
+
if options.get("binary", False):
|
163
|
+
_sanitize_and_add_entry_to_result(
|
164
|
+
env_var_name, base64.b64encode(response.payload.data)
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
_sanitize_and_add_entry_to_result(env_var_name, payload_str)
|
168
|
+
|
169
|
+
return result
|
@@ -8,7 +8,7 @@ def _get_cache_key():
|
|
8
8
|
return os.getpid(), threading.get_ident()
|
9
9
|
|
10
10
|
|
11
|
-
def
|
11
|
+
def _get_gs_storage_client_default():
|
12
12
|
cache_key = _get_cache_key()
|
13
13
|
if cache_key not in _client_cache:
|
14
14
|
from google.cloud import storage
|
@@ -19,3 +19,54 @@ def get_gs_storage_client():
|
|
19
19
|
credentials=credentials, project=project_id
|
20
20
|
)
|
21
21
|
return _client_cache[cache_key]
|
22
|
+
|
23
|
+
|
24
|
+
class GcpDefaultClientProvider(object):
|
25
|
+
name = "gcp-default"
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def get_gs_storage_client(*args, **kwargs):
|
29
|
+
return _get_gs_storage_client_default()
|
30
|
+
|
31
|
+
@staticmethod
|
32
|
+
def get_credentials(scopes, *args, **kwargs):
|
33
|
+
import google.auth
|
34
|
+
|
35
|
+
return google.auth.default(scopes=scopes)
|
36
|
+
|
37
|
+
|
38
|
+
cached_provider_class = None
|
39
|
+
|
40
|
+
|
41
|
+
def get_gs_storage_client():
|
42
|
+
global cached_provider_class
|
43
|
+
if cached_provider_class is None:
|
44
|
+
from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
|
45
|
+
from metaflow.plugins import GCP_CLIENT_PROVIDERS
|
46
|
+
|
47
|
+
for p in GCP_CLIENT_PROVIDERS:
|
48
|
+
if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
|
49
|
+
cached_provider_class = p
|
50
|
+
break
|
51
|
+
else:
|
52
|
+
raise ValueError(
|
53
|
+
"Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
|
54
|
+
)
|
55
|
+
return cached_provider_class.get_gs_storage_client()
|
56
|
+
|
57
|
+
|
58
|
+
def get_credentials(scopes, *args, **kwargs):
|
59
|
+
global cached_provider_class
|
60
|
+
if cached_provider_class is None:
|
61
|
+
from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
|
62
|
+
from metaflow.plugins import GCP_CLIENT_PROVIDERS
|
63
|
+
|
64
|
+
for p in GCP_CLIENT_PROVIDERS:
|
65
|
+
if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
|
66
|
+
cached_provider_class = p
|
67
|
+
break
|
68
|
+
else:
|
69
|
+
raise ValueError(
|
70
|
+
"Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
|
71
|
+
)
|
72
|
+
return cached_provider_class.get_credentials(scopes, *args, **kwargs)
|
@@ -3,9 +3,9 @@ import math
|
|
3
3
|
import os
|
4
4
|
import re
|
5
5
|
import shlex
|
6
|
+
import copy
|
6
7
|
import time
|
7
8
|
from typing import Dict, List, Optional
|
8
|
-
import uuid
|
9
9
|
from uuid import uuid4
|
10
10
|
|
11
11
|
from metaflow import current, util
|
@@ -27,8 +27,11 @@ from metaflow.metaflow_config import (
|
|
27
27
|
DATASTORE_SYSROOT_S3,
|
28
28
|
DATATOOLS_S3ROOT,
|
29
29
|
DEFAULT_AWS_CLIENT_PROVIDER,
|
30
|
+
DEFAULT_GCP_CLIENT_PROVIDER,
|
30
31
|
DEFAULT_METADATA,
|
31
32
|
DEFAULT_SECRETS_BACKEND_TYPE,
|
33
|
+
GCP_SECRET_MANAGER_PREFIX,
|
34
|
+
AZURE_KEY_VAULT_PREFIX,
|
32
35
|
KUBERNETES_FETCH_EC2_METADATA,
|
33
36
|
KUBERNETES_LABELS,
|
34
37
|
KUBERNETES_SANDBOX_INIT_SCRIPT,
|
@@ -66,6 +69,12 @@ class KubernetesKilledException(MetaflowException):
|
|
66
69
|
headline = "Kubernetes Batch job killed"
|
67
70
|
|
68
71
|
|
72
|
+
def _extract_labels_and_annotations_from_job_spec(job_spec):
|
73
|
+
annotations = job_spec.template.metadata.annotations
|
74
|
+
labels = job_spec.template.metadata.labels
|
75
|
+
return copy.copy(annotations), copy.copy(labels)
|
76
|
+
|
77
|
+
|
69
78
|
class Kubernetes(object):
|
70
79
|
def __init__(
|
71
80
|
self,
|
@@ -140,9 +149,64 @@ class Kubernetes(object):
|
|
140
149
|
return shlex.split('bash -c "%s"' % cmd_str)
|
141
150
|
|
142
151
|
def launch_job(self, **kwargs):
|
143
|
-
|
152
|
+
if (
|
153
|
+
"num_parallel" in kwargs
|
154
|
+
and kwargs["num_parallel"]
|
155
|
+
and int(kwargs["num_parallel"]) > 0
|
156
|
+
):
|
157
|
+
job = self.create_job_object(**kwargs)
|
158
|
+
spec = job.create_job_spec()
|
159
|
+
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
|
160
|
+
# This will be modified by the KubernetesJobSet object
|
161
|
+
annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
|
162
|
+
self._job = self.create_jobset(
|
163
|
+
job_spec=spec,
|
164
|
+
run_id=kwargs["run_id"],
|
165
|
+
step_name=kwargs["step_name"],
|
166
|
+
task_id=kwargs["task_id"],
|
167
|
+
namespace=kwargs["namespace"],
|
168
|
+
env=kwargs["env"],
|
169
|
+
num_parallel=kwargs["num_parallel"],
|
170
|
+
port=kwargs["port"],
|
171
|
+
annotations=annotations,
|
172
|
+
labels=labels,
|
173
|
+
).execute()
|
174
|
+
else:
|
175
|
+
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
|
176
|
+
self._job = self.create_job_object(**kwargs).create().execute()
|
177
|
+
|
178
|
+
def create_jobset(
|
179
|
+
self,
|
180
|
+
job_spec=None,
|
181
|
+
run_id=None,
|
182
|
+
step_name=None,
|
183
|
+
task_id=None,
|
184
|
+
namespace=None,
|
185
|
+
env=None,
|
186
|
+
num_parallel=None,
|
187
|
+
port=None,
|
188
|
+
annotations=None,
|
189
|
+
labels=None,
|
190
|
+
):
|
191
|
+
if env is None:
|
192
|
+
env = {}
|
193
|
+
|
194
|
+
_prefix = str(uuid4())[:6]
|
195
|
+
js = KubernetesClient().jobset(
|
196
|
+
name="js-%s" % _prefix,
|
197
|
+
run_id=run_id,
|
198
|
+
task_id=task_id,
|
199
|
+
step_name=step_name,
|
200
|
+
namespace=namespace,
|
201
|
+
labels=self._get_labels(labels),
|
202
|
+
annotations=annotations,
|
203
|
+
num_parallel=num_parallel,
|
204
|
+
job_spec=job_spec,
|
205
|
+
port=port,
|
206
|
+
)
|
207
|
+
return js
|
144
208
|
|
145
|
-
def
|
209
|
+
def create_job_object(
|
146
210
|
self,
|
147
211
|
flow_name,
|
148
212
|
run_id,
|
@@ -176,14 +240,15 @@ class Kubernetes(object):
|
|
176
240
|
labels=None,
|
177
241
|
shared_memory=None,
|
178
242
|
port=None,
|
243
|
+
name_pattern=None,
|
244
|
+
num_parallel=None,
|
179
245
|
):
|
180
246
|
if env is None:
|
181
247
|
env = {}
|
182
|
-
|
183
248
|
job = (
|
184
249
|
KubernetesClient()
|
185
250
|
.job(
|
186
|
-
generate_name=
|
251
|
+
generate_name=name_pattern,
|
187
252
|
namespace=namespace,
|
188
253
|
service_account=service_account,
|
189
254
|
secrets=secrets,
|
@@ -217,6 +282,7 @@ class Kubernetes(object):
|
|
217
282
|
persistent_volume_claims=persistent_volume_claims,
|
218
283
|
shared_memory=shared_memory,
|
219
284
|
port=port,
|
285
|
+
num_parallel=num_parallel,
|
220
286
|
)
|
221
287
|
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
|
222
288
|
.environment_variable("METAFLOW_CODE_URL", code_package_url)
|
@@ -243,10 +309,19 @@ class Kubernetes(object):
|
|
243
309
|
.environment_variable(
|
244
310
|
"METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER
|
245
311
|
)
|
312
|
+
.environment_variable(
|
313
|
+
"METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER
|
314
|
+
)
|
246
315
|
.environment_variable(
|
247
316
|
"METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
|
248
317
|
AWS_SECRETS_MANAGER_DEFAULT_REGION,
|
249
318
|
)
|
319
|
+
.environment_variable(
|
320
|
+
"METAFLOW_GCP_SECRET_MANAGER_PREFIX", GCP_SECRET_MANAGER_PREFIX
|
321
|
+
)
|
322
|
+
.environment_variable(
|
323
|
+
"METAFLOW_AZURE_KEY_VAULT_PREFIX", AZURE_KEY_VAULT_PREFIX
|
324
|
+
)
|
250
325
|
.environment_variable("METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL)
|
251
326
|
.environment_variable(
|
252
327
|
"METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT",
|
@@ -332,6 +407,9 @@ class Kubernetes(object):
|
|
332
407
|
.label("app.kubernetes.io/part-of", "metaflow")
|
333
408
|
)
|
334
409
|
|
410
|
+
return job
|
411
|
+
|
412
|
+
def create_k8sjob(self, job):
|
335
413
|
return job.create()
|
336
414
|
|
337
415
|
def wait(self, stdout_location, stderr_location, echo=None):
|
@@ -366,7 +444,7 @@ class Kubernetes(object):
|
|
366
444
|
t = time.time()
|
367
445
|
time.sleep(update_delay(time.time() - start_time))
|
368
446
|
|
369
|
-
|
447
|
+
_make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id)
|
370
448
|
|
371
449
|
stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
|
372
450
|
stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
|
@@ -376,7 +454,7 @@ class Kubernetes(object):
|
|
376
454
|
|
377
455
|
# 2) Tail logs until the job has finished
|
378
456
|
tail_logs(
|
379
|
-
prefix=
|
457
|
+
prefix=_make_prefix(),
|
380
458
|
stdout_tail=stdout_tail,
|
381
459
|
stderr_tail=stderr_tail,
|
382
460
|
echo=echo,
|
@@ -392,7 +470,6 @@ class Kubernetes(object):
|
|
392
470
|
# exists prior to calling S3Tail and note the user about
|
393
471
|
# truncated logs if it doesn't.
|
394
472
|
# TODO : For hard crashes, we can fetch logs from the pod.
|
395
|
-
|
396
473
|
if self._job.has_failed:
|
397
474
|
exit_code, reason = self._job.reason
|
398
475
|
msg = next(
|
@@ -7,11 +7,17 @@ from metaflow import JSONTypeClass, util
|
|
7
7
|
from metaflow._vendor import click
|
8
8
|
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
|
9
9
|
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
10
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
10
11
|
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
|
11
12
|
from metaflow.mflog import TASK_LOG_SOURCE
|
12
13
|
import metaflow.tracing as tracing
|
13
14
|
|
14
|
-
from .kubernetes import
|
15
|
+
from .kubernetes import (
|
16
|
+
Kubernetes,
|
17
|
+
KubernetesKilledException,
|
18
|
+
parse_kube_keyvalue_list,
|
19
|
+
KubernetesException,
|
20
|
+
)
|
15
21
|
from .kubernetes_decorator import KubernetesDecorator
|
16
22
|
|
17
23
|
|
@@ -109,6 +115,15 @@ def kubernetes():
|
|
109
115
|
)
|
110
116
|
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
|
111
117
|
@click.option("--port", default=None, help="Port number to expose from the container")
|
118
|
+
@click.option(
|
119
|
+
"--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
|
120
|
+
)
|
121
|
+
@click.option(
|
122
|
+
"--num-parallel",
|
123
|
+
default=None,
|
124
|
+
type=int,
|
125
|
+
help="Number of parallel nodes to run as a multi-node job.",
|
126
|
+
)
|
112
127
|
@click.pass_context
|
113
128
|
def step(
|
114
129
|
ctx,
|
@@ -136,6 +151,7 @@ def step(
|
|
136
151
|
tolerations=None,
|
137
152
|
shared_memory=None,
|
138
153
|
port=None,
|
154
|
+
num_parallel=None,
|
139
155
|
**kwargs
|
140
156
|
):
|
141
157
|
def echo(msg, stream="stderr", job_id=None, **kwargs):
|
@@ -167,6 +183,12 @@ def step(
|
|
167
183
|
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
|
168
184
|
env.update(split_vars)
|
169
185
|
|
186
|
+
if num_parallel is not None and num_parallel <= 1:
|
187
|
+
raise KubernetesException(
|
188
|
+
"Using @parallel with `num_parallel` <= 1 is not supported with Kubernetes. "
|
189
|
+
"Please set the value of `num_parallel` to be greater than 1."
|
190
|
+
)
|
191
|
+
|
170
192
|
# Set retry policy.
|
171
193
|
retry_count = int(kwargs.get("retry_count", 0))
|
172
194
|
retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
|
@@ -251,6 +273,7 @@ def step(
|
|
251
273
|
tolerations=tolerations,
|
252
274
|
shared_memory=shared_memory,
|
253
275
|
port=port,
|
276
|
+
num_parallel=num_parallel,
|
254
277
|
)
|
255
278
|
except Exception as e:
|
256
279
|
traceback.print_exc(chain=False)
|
@@ -4,7 +4,7 @@ import time
|
|
4
4
|
|
5
5
|
from metaflow.exception import MetaflowException
|
6
6
|
|
7
|
-
from .kubernetes_job import KubernetesJob
|
7
|
+
from .kubernetes_job import KubernetesJob, KubernetesJobSet
|
8
8
|
|
9
9
|
|
10
10
|
CLIENT_REFRESH_INTERVAL_SECONDS = 300
|
@@ -61,5 +61,8 @@ class KubernetesClient(object):
|
|
61
61
|
|
62
62
|
return self._client
|
63
63
|
|
64
|
+
def jobset(self, **kwargs):
|
65
|
+
return KubernetesJobSet(self, **kwargs)
|
66
|
+
|
64
67
|
def job(self, **kwargs):
|
65
68
|
return KubernetesJob(self, **kwargs)
|
@@ -32,6 +32,8 @@ from metaflow.sidecar import Sidecar
|
|
32
32
|
|
33
33
|
from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
|
34
34
|
from .kubernetes import KubernetesException, parse_kube_keyvalue_list
|
35
|
+
from metaflow.unbounded_foreach import UBF_CONTROL
|
36
|
+
from .kubernetes_jobsets import TaskIdConstructor
|
35
37
|
|
36
38
|
try:
|
37
39
|
unicode
|
@@ -239,11 +241,15 @@ class KubernetesDecorator(StepDecorator):
|
|
239
241
|
"Kubernetes. Please use one or the other.".format(step=step)
|
240
242
|
)
|
241
243
|
|
242
|
-
for deco in decos
|
243
|
-
|
244
|
-
|
245
|
-
|
244
|
+
if any([deco.name == "parallel" for deco in decos]) and any(
|
245
|
+
[deco.name == "catch" for deco in decos]
|
246
|
+
):
|
247
|
+
raise MetaflowException(
|
248
|
+
"Step *{step}* contains a @parallel decorator "
|
249
|
+
"with the @catch decorator. @catch is not supported with @parallel on Kubernetes.".format(
|
250
|
+
step=step
|
246
251
|
)
|
252
|
+
)
|
247
253
|
|
248
254
|
# Set run time limit for the Kubernetes job.
|
249
255
|
self.run_time_limit = get_run_time_limit_for_task(decos)
|
@@ -421,6 +427,10 @@ class KubernetesDecorator(StepDecorator):
|
|
421
427
|
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME"
|
422
428
|
]
|
423
429
|
meta["kubernetes-node-ip"] = os.environ["METAFLOW_KUBERNETES_NODE_IP"]
|
430
|
+
if os.environ.get("METAFLOW_KUBERNETES_JOBSET_NAME"):
|
431
|
+
meta["kubernetes-jobset-name"] = os.environ[
|
432
|
+
"METAFLOW_KUBERNETES_JOBSET_NAME"
|
433
|
+
]
|
424
434
|
|
425
435
|
# TODO (savin): Introduce equivalent support for Microsoft Azure and
|
426
436
|
# Google Cloud Platform
|
@@ -453,6 +463,24 @@ class KubernetesDecorator(StepDecorator):
|
|
453
463
|
self._save_logs_sidecar = Sidecar("save_logs_periodically")
|
454
464
|
self._save_logs_sidecar.start()
|
455
465
|
|
466
|
+
num_parallel = None
|
467
|
+
if hasattr(flow, "_parallel_ubf_iter"):
|
468
|
+
num_parallel = flow._parallel_ubf_iter.num_parallel
|
469
|
+
|
470
|
+
if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
|
471
|
+
control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
|
472
|
+
num_parallel
|
473
|
+
)
|
474
|
+
mapper_task_ids = [control_task_id] + worker_task_ids
|
475
|
+
flow._control_mapper_tasks = [
|
476
|
+
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
|
477
|
+
for mapper_task_id in mapper_task_ids
|
478
|
+
]
|
479
|
+
flow._control_task_is_mapper_zero = True
|
480
|
+
|
481
|
+
if num_parallel and num_parallel > 1:
|
482
|
+
_setup_multinode_environment()
|
483
|
+
|
456
484
|
def task_finished(
|
457
485
|
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
|
458
486
|
):
|
@@ -486,3 +514,20 @@ class KubernetesDecorator(StepDecorator):
|
|
486
514
|
cls.package_url, cls.package_sha = flow_datastore.save_data(
|
487
515
|
[package.blob], len_hint=1
|
488
516
|
)[0]
|
517
|
+
|
518
|
+
|
519
|
+
def _setup_multinode_environment():
|
520
|
+
import socket
|
521
|
+
|
522
|
+
os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
|
523
|
+
os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
|
524
|
+
if os.environ.get("CONTROL_INDEX") is not None:
|
525
|
+
os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
|
526
|
+
elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
|
527
|
+
os.environ["MF_PARALLEL_NODE_INDEX"] = str(
|
528
|
+
int(os.environ["WORKER_REPLICA_INDEX"]) + 1
|
529
|
+
)
|
530
|
+
else:
|
531
|
+
raise MetaflowException(
|
532
|
+
"Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found"
|
533
|
+
)
|