ob-metaflow 2.11.15.3__py2.py3-none-any.whl → 2.11.16.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow might be problematic. Click here for more details.
- metaflow/__init__.py +3 -0
- metaflow/clone_util.py +6 -0
- metaflow/extension_support/plugins.py +1 -1
- metaflow/metaflow_config.py +5 -3
- metaflow/metaflow_environment.py +3 -3
- metaflow/plugins/__init__.py +4 -4
- metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +18 -14
- metaflow/plugins/datatools/s3/s3.py +1 -1
- metaflow/plugins/gcp/__init__.py +1 -1
- metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +11 -6
- metaflow/plugins/kubernetes/kubernetes.py +79 -49
- metaflow/plugins/kubernetes/kubernetes_cli.py +20 -33
- metaflow/plugins/kubernetes/kubernetes_client.py +4 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +44 -61
- metaflow/plugins/kubernetes/kubernetes_job.py +217 -584
- metaflow/plugins/kubernetes/kubernetes_jobsets.py +784 -0
- metaflow/plugins/timeout_decorator.py +2 -1
- metaflow/task.py +1 -12
- metaflow/tuple_util.py +27 -0
- metaflow/util.py +0 -15
- metaflow/version.py +1 -1
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/METADATA +2 -2
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/RECORD +27 -25
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/LICENSE +0 -0
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/WHEEL +0 -0
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/entry_points.txt +0 -0
- {ob_metaflow-2.11.15.3.dist-info → ob_metaflow-2.11.16.1.dist-info}/top_level.txt +0 -0
metaflow/__init__.py
CHANGED
|
@@ -143,6 +143,9 @@ from .client import (
|
|
|
143
143
|
DataArtifact,
|
|
144
144
|
)
|
|
145
145
|
|
|
146
|
+
# Import data class within tuple_util but not introduce new symbols.
|
|
147
|
+
from . import tuple_util
|
|
148
|
+
|
|
146
149
|
__version_addl__ = []
|
|
147
150
|
_ext_debug("Loading top-level modules")
|
|
148
151
|
for m in _tl_modules:
|
metaflow/clone_util.py
CHANGED
|
@@ -66,6 +66,12 @@ def clone_task_helper(
|
|
|
66
66
|
type="attempt",
|
|
67
67
|
tags=metadata_tags,
|
|
68
68
|
),
|
|
69
|
+
MetaDatum(
|
|
70
|
+
field="attempt_ok",
|
|
71
|
+
value="True", # During clone, the task is always considered successful.
|
|
72
|
+
type="internal_attempt_status",
|
|
73
|
+
tags=metadata_tags,
|
|
74
|
+
),
|
|
69
75
|
],
|
|
70
76
|
)
|
|
71
77
|
output.done()
|
|
@@ -179,8 +179,8 @@ _plugin_categories = {
|
|
|
179
179
|
"metadata_provider": lambda x: x.TYPE,
|
|
180
180
|
"datastore": lambda x: x.TYPE,
|
|
181
181
|
"secrets_provider": lambda x: x.TYPE,
|
|
182
|
-
"azure_client_provider": lambda x: x.name,
|
|
183
182
|
"gcp_client_provider": lambda x: x.name,
|
|
183
|
+
"azure_client_provider": lambda x: x.name,
|
|
184
184
|
"sidecar": None,
|
|
185
185
|
"logging_sidecar": None,
|
|
186
186
|
"monitor_sidecar": None,
|
metaflow/metaflow_config.py
CHANGED
|
@@ -227,6 +227,8 @@ DEFAULT_CONTAINER_REGISTRY = from_conf("DEFAULT_CONTAINER_REGISTRY")
|
|
|
227
227
|
INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", False)
|
|
228
228
|
# Maximum length of the foreach value string to be stored in each ForeachFrame.
|
|
229
229
|
MAXIMUM_FOREACH_VALUE_CHARS = from_conf("MAXIMUM_FOREACH_VALUE_CHARS", 30)
|
|
230
|
+
# The default runtime limit (In seconds) of jobs launched by any compute provider. Default of 5 days.
|
|
231
|
+
DEFAULT_RUNTIME_LIMIT = from_conf("DEFAULT_RUNTIME_LIMIT", 5 * 24 * 60 * 60)
|
|
230
232
|
|
|
231
233
|
###
|
|
232
234
|
# Organization customizations
|
|
@@ -327,8 +329,6 @@ KUBERNETES_CONTAINER_REGISTRY = from_conf(
|
|
|
327
329
|
)
|
|
328
330
|
# Toggle for trying to fetch EC2 instance metadata
|
|
329
331
|
KUBERNETES_FETCH_EC2_METADATA = from_conf("KUBERNETES_FETCH_EC2_METADATA", False)
|
|
330
|
-
# Default port number to open on the pods
|
|
331
|
-
KUBERNETES_PORT = from_conf("KUBERNETES_PORT", None)
|
|
332
332
|
# Shared memory in MB to use for this step
|
|
333
333
|
KUBERNETES_SHARED_MEMORY = from_conf("KUBERNETES_SHARED_MEMORY", None)
|
|
334
334
|
# Default port number to open on the pods
|
|
@@ -338,10 +338,12 @@ KUBERNETES_CPU = from_conf("KUBERNETES_CPU", None)
|
|
|
338
338
|
KUBERNETES_MEMORY = from_conf("KUBERNETES_MEMORY", None)
|
|
339
339
|
KUBERNETES_DISK = from_conf("KUBERNETES_DISK", None)
|
|
340
340
|
|
|
341
|
-
|
|
342
341
|
ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
|
|
343
342
|
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")
|
|
344
343
|
|
|
344
|
+
KUBERNETES_JOBSET_GROUP = from_conf("KUBERNETES_JOBSET_GROUP", "jobset.x-k8s.io")
|
|
345
|
+
KUBERNETES_JOBSET_VERSION = from_conf("KUBERNETES_JOBSET_VERSION", "v1alpha2")
|
|
346
|
+
|
|
345
347
|
##
|
|
346
348
|
# Argo Events Configuration
|
|
347
349
|
##
|
metaflow/metaflow_environment.py
CHANGED
|
@@ -91,7 +91,7 @@ class MetaflowEnvironment(object):
|
|
|
91
91
|
if datastore_type == "s3":
|
|
92
92
|
return (
|
|
93
93
|
'%s -m awscli ${METAFLOW_S3_ENDPOINT_URL:+--endpoint-url=\\"${METAFLOW_S3_ENDPOINT_URL}\\"} '
|
|
94
|
-
+ "s3 cp %s job.tar"
|
|
94
|
+
+ "s3 cp %s job.tar >/dev/null"
|
|
95
95
|
) % (self._python(), code_package_url)
|
|
96
96
|
elif datastore_type == "azure":
|
|
97
97
|
from .plugins.azure.azure_utils import parse_azure_full_path
|
|
@@ -119,9 +119,9 @@ class MetaflowEnvironment(object):
|
|
|
119
119
|
)
|
|
120
120
|
|
|
121
121
|
def _get_install_dependencies_cmd(self, datastore_type):
|
|
122
|
-
cmds = ["%s -m pip install requests" % self._python()]
|
|
122
|
+
cmds = ["%s -m pip install requests -qqq" % self._python()]
|
|
123
123
|
if datastore_type == "s3":
|
|
124
|
-
cmds.append("%s -m pip install awscli boto3" % self._python())
|
|
124
|
+
cmds.append("%s -m pip install awscli boto3 -qqq" % self._python())
|
|
125
125
|
elif datastore_type == "azure":
|
|
126
126
|
cmds.append(
|
|
127
127
|
"%s -m pip install azure-identity azure-storage-blob azure-keyvault-secrets simple-azure-blob-downloader -qqq"
|
metaflow/plugins/__init__.py
CHANGED
|
@@ -131,14 +131,14 @@ SECRETS_PROVIDERS_DESC = [
|
|
|
131
131
|
),
|
|
132
132
|
]
|
|
133
133
|
|
|
134
|
-
AZURE_CLIENT_PROVIDERS_DESC = [
|
|
135
|
-
("azure-default", ".azure.azure_credential.AzureDefaultClientProvider")
|
|
136
|
-
]
|
|
137
|
-
|
|
138
134
|
GCP_CLIENT_PROVIDERS_DESC = [
|
|
139
135
|
("gcp-default", ".gcp.gs_storage_client_factory.GcpDefaultClientProvider")
|
|
140
136
|
]
|
|
141
137
|
|
|
138
|
+
AZURE_CLIENT_PROVIDERS_DESC = [
|
|
139
|
+
("azure-default", ".azure.azure_credential.AzureDefaultClientProvider")
|
|
140
|
+
]
|
|
141
|
+
|
|
142
142
|
|
|
143
143
|
process_plugins(globals())
|
|
144
144
|
|
|
@@ -80,7 +80,7 @@ class AzureKeyVaultSecretsProvider(SecretsProvider):
|
|
|
80
80
|
try:
|
|
81
81
|
parsed_vault_url = urlparse(secret_id)
|
|
82
82
|
except ValueError:
|
|
83
|
-
print(
|
|
83
|
+
print("invalid vault url", file=sys.stderr)
|
|
84
84
|
return False
|
|
85
85
|
hostname = parsed_vault_url.netloc
|
|
86
86
|
|
|
@@ -94,34 +94,34 @@ class AzureKeyVaultSecretsProvider(SecretsProvider):
|
|
|
94
94
|
if not k_v_domain_found:
|
|
95
95
|
# the secret_id started with https:// however the key_vault_domains
|
|
96
96
|
# were not present in the secret_id which means
|
|
97
|
-
raise MetaflowAzureKeyVaultBadVault(
|
|
97
|
+
raise MetaflowAzureKeyVaultBadVault("bad key vault domain %s" % secret_id)
|
|
98
98
|
|
|
99
99
|
# given the secret_id seems to have a valid key vault domain
|
|
100
100
|
# lets verify that the vault name corresponds to its regex.
|
|
101
101
|
vault_name = hostname[: -len(actual_k_v_domain)]
|
|
102
102
|
# verify the vault name pattern
|
|
103
103
|
if not self._is_valid_vault_name(vault_name):
|
|
104
|
-
raise MetaflowAzureKeyVaultBadVault(
|
|
104
|
+
raise MetaflowAzureKeyVaultBadVault("bad key vault name %s" % vault_name)
|
|
105
105
|
|
|
106
106
|
path_parts = parsed_vault_url.path.strip("/").split("/")
|
|
107
107
|
total_path_parts = len(path_parts)
|
|
108
108
|
if total_path_parts < 2 or total_path_parts > 3:
|
|
109
109
|
raise MetaflowAzureKeyVaultBadSecretPath(
|
|
110
|
-
|
|
110
|
+
"bad secret uri path %s" % path_parts
|
|
111
111
|
)
|
|
112
112
|
|
|
113
113
|
object_type = path_parts[0]
|
|
114
114
|
if not self._is_valid_object_type(object_type):
|
|
115
|
-
raise MetaflowAzureKeyVaultBadSecretType(
|
|
115
|
+
raise MetaflowAzureKeyVaultBadSecretType("bad secret type %s" % object_type)
|
|
116
116
|
|
|
117
117
|
secret_name = path_parts[1]
|
|
118
118
|
if not self._is_valid_secret_name(secret_name=secret_name):
|
|
119
|
-
raise MetaflowAzureKeyVaultBadSecretName(
|
|
119
|
+
raise MetaflowAzureKeyVaultBadSecretName("bad secret name %s" % secret_name)
|
|
120
120
|
|
|
121
121
|
if total_path_parts == 3:
|
|
122
122
|
if not self._is_valid_object_version(path_parts[2]):
|
|
123
123
|
raise MetaflowAzureKeyVaultBadSecretVersion(
|
|
124
|
-
|
|
124
|
+
"bad secret version %s" % path_parts[2]
|
|
125
125
|
)
|
|
126
126
|
|
|
127
127
|
return True
|
|
@@ -139,10 +139,11 @@ class AzureKeyVaultSecretsProvider(SecretsProvider):
|
|
|
139
139
|
# must be set.
|
|
140
140
|
if not AZURE_KEY_VAULT_PREFIX:
|
|
141
141
|
raise ValueError(
|
|
142
|
-
|
|
142
|
+
"cannot use simple secret id without setting METAFLOW_AZURE_KEY_VAULT_PREFIX. %s"
|
|
143
|
+
% AZURE_KEY_VAULT_PREFIX
|
|
143
144
|
)
|
|
144
145
|
domain = AZURE_KEY_VAULT_PREFIX.rstrip("/")
|
|
145
|
-
full_secret =
|
|
146
|
+
full_secret = "%s/secrets/%s" % (domain, secret_id)
|
|
146
147
|
if not self._is_secret_id_fully_qualified_url(full_secret):
|
|
147
148
|
return False
|
|
148
149
|
|
|
@@ -186,29 +187,32 @@ class AzureKeyVaultSecretsProvider(SecretsProvider):
|
|
|
186
187
|
|
|
187
188
|
# if the secret_id is None/empty/does not start with https then return false
|
|
188
189
|
if secret_id is None or secret_id == "":
|
|
189
|
-
raise MetaflowAzureKeyVaultBadSecret(
|
|
190
|
+
raise MetaflowAzureKeyVaultBadSecret("empty secret id is not supported")
|
|
190
191
|
|
|
191
192
|
# check if the passed in secret is a short-form ( #3/#4 in the above comment)
|
|
192
193
|
if not secret_id.startswith("https://"):
|
|
193
194
|
# check if the secret_id is of form `secret_name` OR `secret_name/secret_version`
|
|
194
195
|
if not self._is_partial_secret_valid(secret_id=secret_id):
|
|
195
196
|
raise MetaflowAzureKeyVaultBadSecret(
|
|
196
|
-
|
|
197
|
+
"unsupported partial secret %s" % secret_id
|
|
197
198
|
)
|
|
198
199
|
|
|
199
200
|
domain = AZURE_KEY_VAULT_PREFIX.rstrip("/")
|
|
200
|
-
full_secret =
|
|
201
|
+
full_secret = "%s/secrets/%s" % (domain, secret_id)
|
|
201
202
|
|
|
202
203
|
# if the secret id is passed as a URL - then check if the url is fully qualified
|
|
203
204
|
if secret_id.startswith("https://"):
|
|
204
205
|
if not self._is_secret_id_fully_qualified_url(secret_id=secret_id):
|
|
205
|
-
raise MetaflowException(
|
|
206
|
+
raise MetaflowException("unsupported secret %s" % secret_id)
|
|
206
207
|
full_secret = secret_id
|
|
207
208
|
|
|
208
209
|
# at this point I know that the secret URL is good so we can start creating the Secret Client
|
|
209
210
|
az_credentials = create_cacheable_azure_credential()
|
|
210
211
|
res = urlparse(full_secret)
|
|
211
|
-
az_vault_url =
|
|
212
|
+
az_vault_url = "%s://%s" % (
|
|
213
|
+
res.scheme,
|
|
214
|
+
res.netloc,
|
|
215
|
+
) # https://myvault.vault.azure.net
|
|
212
216
|
secret_data = res.path.strip("/").split("/")[1:]
|
|
213
217
|
secret_name = secret_data[0]
|
|
214
218
|
secret_version = None
|
|
@@ -21,7 +21,6 @@ from metaflow.metaflow_config import (
|
|
|
21
21
|
TEMPDIR,
|
|
22
22
|
)
|
|
23
23
|
from metaflow.util import (
|
|
24
|
-
namedtuple_with_defaults,
|
|
25
24
|
is_stringish,
|
|
26
25
|
to_bytes,
|
|
27
26
|
to_unicode,
|
|
@@ -29,6 +28,7 @@ from metaflow.util import (
|
|
|
29
28
|
url_quote,
|
|
30
29
|
url_unquote,
|
|
31
30
|
)
|
|
31
|
+
from metaflow.tuple_util import namedtuple_with_defaults
|
|
32
32
|
from metaflow.exception import MetaflowException
|
|
33
33
|
from metaflow.debug import debug
|
|
34
34
|
import metaflow.tracing as tracing
|
metaflow/plugins/gcp/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .gs_storage_client_factory import get_credentials
|
|
1
|
+
from .gs_storage_client_factory import get_credentials
|
|
@@ -89,7 +89,9 @@ class GcpSecretManagerSecretsProvider(SecretsProvider):
|
|
|
89
89
|
|
|
90
90
|
# The latter two forms require METAFLOW_GCP_SECRET_MANAGER_PREFIX to be set.
|
|
91
91
|
|
|
92
|
-
match_full = re.match(
|
|
92
|
+
match_full = re.match(
|
|
93
|
+
r"^projects/\d+/secrets/([\w\-]+)(/versions/([\w\-]+))?$", secret_id
|
|
94
|
+
)
|
|
93
95
|
match_partial = re.match(r"^([\w\-]+)(/versions/[\w\-]+)?$", secret_id)
|
|
94
96
|
if match_full:
|
|
95
97
|
# Full path
|
|
@@ -105,20 +107,23 @@ class GcpSecretManagerSecretsProvider(SecretsProvider):
|
|
|
105
107
|
env_var_name = secret_id
|
|
106
108
|
if not GCP_SECRET_MANAGER_PREFIX:
|
|
107
109
|
raise ValueError(
|
|
108
|
-
|
|
110
|
+
"Cannot use simple secret_id without setting METAFLOW_GCP_SECRET_MANAGER_PREFIX. %s"
|
|
111
|
+
% GCP_SECRET_MANAGER_PREFIX
|
|
109
112
|
)
|
|
110
113
|
if match_partial.group(2):
|
|
111
114
|
# With version specified
|
|
112
|
-
full_secret_name =
|
|
115
|
+
full_secret_name = "%s%s" % (GCP_SECRET_MANAGER_PREFIX, secret_id)
|
|
113
116
|
env_var_name = match_partial.group(1)
|
|
114
117
|
else:
|
|
115
118
|
# No version specified, use latest
|
|
116
|
-
full_secret_name = (
|
|
117
|
-
|
|
119
|
+
full_secret_name = "%s%s/versions/latest" % (
|
|
120
|
+
GCP_SECRET_MANAGER_PREFIX,
|
|
121
|
+
secret_id,
|
|
118
122
|
)
|
|
119
123
|
else:
|
|
120
124
|
raise ValueError(
|
|
121
|
-
|
|
125
|
+
"Invalid secret_id: %s. Must be either a full path or a simple string."
|
|
126
|
+
% secret_id
|
|
122
127
|
)
|
|
123
128
|
|
|
124
129
|
result = {}
|
|
@@ -3,10 +3,9 @@ import math
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
5
|
import shlex
|
|
6
|
-
import time
|
|
7
6
|
import copy
|
|
7
|
+
import time
|
|
8
8
|
from typing import Dict, List, Optional
|
|
9
|
-
import uuid
|
|
10
9
|
from uuid import uuid4
|
|
11
10
|
|
|
12
11
|
from metaflow import current, util
|
|
@@ -70,6 +69,12 @@ class KubernetesKilledException(MetaflowException):
|
|
|
70
69
|
headline = "Kubernetes Batch job killed"
|
|
71
70
|
|
|
72
71
|
|
|
72
|
+
def _extract_labels_and_annotations_from_job_spec(job_spec):
|
|
73
|
+
annotations = job_spec.template.metadata.annotations
|
|
74
|
+
labels = job_spec.template.metadata.labels
|
|
75
|
+
return copy.copy(annotations), copy.copy(labels)
|
|
76
|
+
|
|
77
|
+
|
|
73
78
|
class Kubernetes(object):
|
|
74
79
|
def __init__(
|
|
75
80
|
self,
|
|
@@ -144,9 +149,64 @@ class Kubernetes(object):
|
|
|
144
149
|
return shlex.split('bash -c "%s"' % cmd_str)
|
|
145
150
|
|
|
146
151
|
def launch_job(self, **kwargs):
|
|
147
|
-
|
|
152
|
+
if (
|
|
153
|
+
"num_parallel" in kwargs
|
|
154
|
+
and kwargs["num_parallel"]
|
|
155
|
+
and int(kwargs["num_parallel"]) > 0
|
|
156
|
+
):
|
|
157
|
+
job = self.create_job_object(**kwargs)
|
|
158
|
+
spec = job.create_job_spec()
|
|
159
|
+
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
|
|
160
|
+
# This will be modified by the KubernetesJobSet object
|
|
161
|
+
annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
|
|
162
|
+
self._job = self.create_jobset(
|
|
163
|
+
job_spec=spec,
|
|
164
|
+
run_id=kwargs["run_id"],
|
|
165
|
+
step_name=kwargs["step_name"],
|
|
166
|
+
task_id=kwargs["task_id"],
|
|
167
|
+
namespace=kwargs["namespace"],
|
|
168
|
+
env=kwargs["env"],
|
|
169
|
+
num_parallel=kwargs["num_parallel"],
|
|
170
|
+
port=kwargs["port"],
|
|
171
|
+
annotations=annotations,
|
|
172
|
+
labels=labels,
|
|
173
|
+
).execute()
|
|
174
|
+
else:
|
|
175
|
+
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
|
|
176
|
+
self._job = self.create_job_object(**kwargs).create().execute()
|
|
177
|
+
|
|
178
|
+
def create_jobset(
|
|
179
|
+
self,
|
|
180
|
+
job_spec=None,
|
|
181
|
+
run_id=None,
|
|
182
|
+
step_name=None,
|
|
183
|
+
task_id=None,
|
|
184
|
+
namespace=None,
|
|
185
|
+
env=None,
|
|
186
|
+
num_parallel=None,
|
|
187
|
+
port=None,
|
|
188
|
+
annotations=None,
|
|
189
|
+
labels=None,
|
|
190
|
+
):
|
|
191
|
+
if env is None:
|
|
192
|
+
env = {}
|
|
193
|
+
|
|
194
|
+
_prefix = str(uuid4())[:6]
|
|
195
|
+
js = KubernetesClient().jobset(
|
|
196
|
+
name="js-%s" % _prefix,
|
|
197
|
+
run_id=run_id,
|
|
198
|
+
task_id=task_id,
|
|
199
|
+
step_name=step_name,
|
|
200
|
+
namespace=namespace,
|
|
201
|
+
labels=self._get_labels(labels),
|
|
202
|
+
annotations=annotations,
|
|
203
|
+
num_parallel=num_parallel,
|
|
204
|
+
job_spec=job_spec,
|
|
205
|
+
port=port,
|
|
206
|
+
)
|
|
207
|
+
return js
|
|
148
208
|
|
|
149
|
-
def
|
|
209
|
+
def create_job_object(
|
|
150
210
|
self,
|
|
151
211
|
flow_name,
|
|
152
212
|
run_id,
|
|
@@ -178,19 +238,17 @@ class Kubernetes(object):
|
|
|
178
238
|
persistent_volume_claims=None,
|
|
179
239
|
tolerations=None,
|
|
180
240
|
labels=None,
|
|
181
|
-
annotations=None,
|
|
182
|
-
num_parallel=0,
|
|
183
|
-
attrs={},
|
|
184
241
|
shared_memory=None,
|
|
185
242
|
port=None,
|
|
243
|
+
name_pattern=None,
|
|
244
|
+
num_parallel=None,
|
|
186
245
|
):
|
|
187
246
|
if env is None:
|
|
188
247
|
env = {}
|
|
189
|
-
|
|
190
248
|
job = (
|
|
191
249
|
KubernetesClient()
|
|
192
250
|
.job(
|
|
193
|
-
generate_name=
|
|
251
|
+
generate_name=name_pattern,
|
|
194
252
|
namespace=namespace,
|
|
195
253
|
service_account=service_account,
|
|
196
254
|
secrets=secrets,
|
|
@@ -222,10 +280,9 @@ class Kubernetes(object):
|
|
|
222
280
|
tmpfs_size=tmpfs_size,
|
|
223
281
|
tmpfs_path=tmpfs_path,
|
|
224
282
|
persistent_volume_claims=persistent_volume_claims,
|
|
225
|
-
num_parallel=num_parallel,
|
|
226
|
-
attrs=attrs,
|
|
227
283
|
shared_memory=shared_memory,
|
|
228
284
|
port=port,
|
|
285
|
+
num_parallel=num_parallel,
|
|
229
286
|
)
|
|
230
287
|
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
|
|
231
288
|
.environment_variable("METAFLOW_CODE_URL", code_package_url)
|
|
@@ -288,7 +345,6 @@ class Kubernetes(object):
|
|
|
288
345
|
# see get_datastore_root_from_config in datastore/local.py).
|
|
289
346
|
)
|
|
290
347
|
|
|
291
|
-
self.num_parallel = num_parallel
|
|
292
348
|
# Temporary passing of *some* environment variables. Do not rely on this
|
|
293
349
|
# mechanism as it will be removed in the near future
|
|
294
350
|
for k, v in config_values():
|
|
@@ -351,6 +407,9 @@ class Kubernetes(object):
|
|
|
351
407
|
.label("app.kubernetes.io/part-of", "metaflow")
|
|
352
408
|
)
|
|
353
409
|
|
|
410
|
+
return job
|
|
411
|
+
|
|
412
|
+
def create_k8sjob(self, job):
|
|
354
413
|
return job.create()
|
|
355
414
|
|
|
356
415
|
def wait(self, stdout_location, stderr_location, echo=None):
|
|
@@ -364,7 +423,7 @@ class Kubernetes(object):
|
|
|
364
423
|
sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
|
|
365
424
|
return 0.5 + sigmoid * 30.0
|
|
366
425
|
|
|
367
|
-
def wait_for_launch(job
|
|
426
|
+
def wait_for_launch(job):
|
|
368
427
|
status = job.status
|
|
369
428
|
echo(
|
|
370
429
|
"Task is starting (%s)..." % status,
|
|
@@ -374,56 +433,28 @@ class Kubernetes(object):
|
|
|
374
433
|
t = time.time()
|
|
375
434
|
start_time = time.time()
|
|
376
435
|
while job.is_waiting:
|
|
377
|
-
|
|
378
|
-
if status !=
|
|
379
|
-
|
|
380
|
-
child_statuses = ""
|
|
381
|
-
else:
|
|
382
|
-
status_keys = set(
|
|
383
|
-
[child_job.status for child_job in child_jobs]
|
|
384
|
-
)
|
|
385
|
-
status_counts = [
|
|
386
|
-
(
|
|
387
|
-
status,
|
|
388
|
-
len(
|
|
389
|
-
[
|
|
390
|
-
child_job.status == status
|
|
391
|
-
for child_job in child_jobs
|
|
392
|
-
]
|
|
393
|
-
),
|
|
394
|
-
)
|
|
395
|
-
for status in status_keys
|
|
396
|
-
]
|
|
397
|
-
child_statuses = " (parallel node status: [{}])".format(
|
|
398
|
-
", ".join(
|
|
399
|
-
[
|
|
400
|
-
"{}:{}".format(status, num)
|
|
401
|
-
for (status, num) in sorted(status_counts)
|
|
402
|
-
]
|
|
403
|
-
)
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
status = job.status
|
|
436
|
+
new_status = job.status
|
|
437
|
+
if status != new_status or (time.time() - t) > 30:
|
|
438
|
+
status = new_status
|
|
407
439
|
echo(
|
|
408
|
-
"Task is starting (
|
|
440
|
+
"Task is starting (%s)..." % status,
|
|
409
441
|
"stderr",
|
|
410
442
|
job_id=job.id,
|
|
411
443
|
)
|
|
412
444
|
t = time.time()
|
|
413
445
|
time.sleep(update_delay(time.time() - start_time))
|
|
414
446
|
|
|
415
|
-
|
|
447
|
+
_make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id)
|
|
416
448
|
|
|
417
449
|
stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
|
|
418
450
|
stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
|
|
419
451
|
|
|
420
|
-
child_jobs = []
|
|
421
452
|
# 1) Loop until the job has started
|
|
422
|
-
wait_for_launch(self._job
|
|
453
|
+
wait_for_launch(self._job)
|
|
423
454
|
|
|
424
455
|
# 2) Tail logs until the job has finished
|
|
425
456
|
tail_logs(
|
|
426
|
-
prefix=
|
|
457
|
+
prefix=_make_prefix(),
|
|
427
458
|
stdout_tail=stdout_tail,
|
|
428
459
|
stderr_tail=stderr_tail,
|
|
429
460
|
echo=echo,
|
|
@@ -439,7 +470,6 @@ class Kubernetes(object):
|
|
|
439
470
|
# exists prior to calling S3Tail and note the user about
|
|
440
471
|
# truncated logs if it doesn't.
|
|
441
472
|
# TODO : For hard crashes, we can fetch logs from the pod.
|
|
442
|
-
|
|
443
473
|
if self._job.has_failed:
|
|
444
474
|
exit_code, reason = self._job.reason
|
|
445
475
|
msg = next(
|
|
@@ -7,11 +7,17 @@ from metaflow import JSONTypeClass, util
|
|
|
7
7
|
from metaflow._vendor import click
|
|
8
8
|
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
|
|
9
9
|
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
|
10
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
|
10
11
|
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
|
|
11
12
|
from metaflow.mflog import TASK_LOG_SOURCE
|
|
12
13
|
import metaflow.tracing as tracing
|
|
13
14
|
|
|
14
|
-
from .kubernetes import
|
|
15
|
+
from .kubernetes import (
|
|
16
|
+
Kubernetes,
|
|
17
|
+
KubernetesKilledException,
|
|
18
|
+
parse_kube_keyvalue_list,
|
|
19
|
+
KubernetesException,
|
|
20
|
+
)
|
|
15
21
|
from .kubernetes_decorator import KubernetesDecorator
|
|
16
22
|
|
|
17
23
|
|
|
@@ -107,27 +113,17 @@ def kubernetes():
|
|
|
107
113
|
type=JSONTypeClass(),
|
|
108
114
|
multiple=False,
|
|
109
115
|
)
|
|
116
|
+
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
|
|
117
|
+
@click.option("--port", default=None, help="Port number to expose from the container")
|
|
110
118
|
@click.option(
|
|
111
|
-
"--
|
|
112
|
-
default=None,
|
|
113
|
-
type=JSONTypeClass(),
|
|
114
|
-
multiple=False,
|
|
115
|
-
)
|
|
116
|
-
@click.option(
|
|
117
|
-
"--annotations",
|
|
118
|
-
default=None,
|
|
119
|
-
type=JSONTypeClass(),
|
|
120
|
-
multiple=False,
|
|
119
|
+
"--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
|
|
121
120
|
)
|
|
122
|
-
@click.option("--ubf-context", default=None, type=click.Choice([None, "ubf_control"]))
|
|
123
121
|
@click.option(
|
|
124
122
|
"--num-parallel",
|
|
125
|
-
default=
|
|
123
|
+
default=None,
|
|
126
124
|
type=int,
|
|
127
125
|
help="Number of parallel nodes to run as a multi-node job.",
|
|
128
126
|
)
|
|
129
|
-
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
|
|
130
|
-
@click.option("--port", default=None, help="Port number to expose from the container")
|
|
131
127
|
@click.pass_context
|
|
132
128
|
def step(
|
|
133
129
|
ctx,
|
|
@@ -153,11 +149,9 @@ def step(
|
|
|
153
149
|
run_time_limit=None,
|
|
154
150
|
persistent_volume_claims=None,
|
|
155
151
|
tolerations=None,
|
|
156
|
-
labels=None,
|
|
157
|
-
annotations=None,
|
|
158
|
-
num_parallel=None,
|
|
159
152
|
shared_memory=None,
|
|
160
153
|
port=None,
|
|
154
|
+
num_parallel=None,
|
|
161
155
|
**kwargs
|
|
162
156
|
):
|
|
163
157
|
def echo(msg, stream="stderr", job_id=None, **kwargs):
|
|
@@ -189,6 +183,12 @@ def step(
|
|
|
189
183
|
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
|
|
190
184
|
env.update(split_vars)
|
|
191
185
|
|
|
186
|
+
if num_parallel is not None and num_parallel <= 1:
|
|
187
|
+
raise KubernetesException(
|
|
188
|
+
"Using @parallel with `num_parallel` <= 1 is not supported with Kubernetes. "
|
|
189
|
+
"Please set the value of `num_parallel` to be greater than 1."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
192
|
# Set retry policy.
|
|
193
193
|
retry_count = int(kwargs.get("retry_count", 0))
|
|
194
194
|
retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
|
|
@@ -203,17 +203,11 @@ def step(
|
|
|
203
203
|
)
|
|
204
204
|
time.sleep(minutes_between_retries * 60)
|
|
205
205
|
|
|
206
|
-
step_args = " ".join(util.dict_to_cli_options(kwargs))
|
|
207
|
-
num_parallel = num_parallel or 0
|
|
208
|
-
if num_parallel and num_parallel > 1:
|
|
209
|
-
# For multinode, we need to add a placeholder that can be mutated by the caller
|
|
210
|
-
step_args += " [multinode-args]"
|
|
211
|
-
|
|
212
206
|
step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
|
|
213
207
|
entrypoint="%s -u %s" % (executable, os.path.basename(sys.argv[0])),
|
|
214
208
|
top_args=" ".join(util.dict_to_cli_options(ctx.parent.parent.params)),
|
|
215
209
|
step=step_name,
|
|
216
|
-
step_args=
|
|
210
|
+
step_args=" ".join(util.dict_to_cli_options(kwargs)),
|
|
217
211
|
)
|
|
218
212
|
|
|
219
213
|
# Set log tailing.
|
|
@@ -239,10 +233,6 @@ def step(
|
|
|
239
233
|
),
|
|
240
234
|
)
|
|
241
235
|
|
|
242
|
-
attrs = {
|
|
243
|
-
"metaflow.task_id": kwargs["task_id"],
|
|
244
|
-
"requires_passwordless_ssh": any([getattr(deco, "requires_passwordless_ssh", False) for deco in node.decorators]),
|
|
245
|
-
}
|
|
246
236
|
try:
|
|
247
237
|
kubernetes = Kubernetes(
|
|
248
238
|
datastore=ctx.obj.flow_datastore,
|
|
@@ -281,12 +271,9 @@ def step(
|
|
|
281
271
|
env=env,
|
|
282
272
|
persistent_volume_claims=persistent_volume_claims,
|
|
283
273
|
tolerations=tolerations,
|
|
284
|
-
labels=labels,
|
|
285
|
-
annotations=annotations,
|
|
286
|
-
num_parallel=num_parallel,
|
|
287
274
|
shared_memory=shared_memory,
|
|
288
275
|
port=port,
|
|
289
|
-
|
|
276
|
+
num_parallel=num_parallel,
|
|
290
277
|
)
|
|
291
278
|
except Exception as e:
|
|
292
279
|
traceback.print_exc(chain=False)
|
|
@@ -4,7 +4,7 @@ import time
|
|
|
4
4
|
|
|
5
5
|
from metaflow.exception import MetaflowException
|
|
6
6
|
|
|
7
|
-
from .kubernetes_job import KubernetesJob
|
|
7
|
+
from .kubernetes_job import KubernetesJob, KubernetesJobSet
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
CLIENT_REFRESH_INTERVAL_SECONDS = 300
|
|
@@ -61,5 +61,8 @@ class KubernetesClient(object):
|
|
|
61
61
|
|
|
62
62
|
return self._client
|
|
63
63
|
|
|
64
|
+
def jobset(self, **kwargs):
|
|
65
|
+
return KubernetesJobSet(self, **kwargs)
|
|
66
|
+
|
|
64
67
|
def job(self, **kwargs):
|
|
65
68
|
return KubernetesJob(self, **kwargs)
|