metaflow 2.12.8__py2.py3-none-any.whl → 2.12.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow/__init__.py +2 -0
- metaflow/cli.py +12 -4
- metaflow/extension_support/plugins.py +1 -0
- metaflow/flowspec.py +8 -1
- metaflow/lint.py +13 -0
- metaflow/metaflow_current.py +0 -8
- metaflow/plugins/__init__.py +12 -0
- metaflow/plugins/argo/argo_workflows.py +616 -46
- metaflow/plugins/argo/argo_workflows_cli.py +70 -3
- metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
- metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
- metaflow/plugins/argo/daemon.py +59 -0
- metaflow/plugins/argo/jobset_input_paths.py +16 -0
- metaflow/plugins/aws/batch/batch_decorator.py +16 -13
- metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
- metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
- metaflow/plugins/cards/card_cli.py +1 -1
- metaflow/plugins/kubernetes/kubernetes.py +279 -52
- metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
- metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
- metaflow/plugins/kubernetes/kubernetes_job.py +7 -6
- metaflow/plugins/kubernetes/kubernetes_jobsets.py +511 -272
- metaflow/plugins/parallel_decorator.py +108 -8
- metaflow/plugins/secrets/secrets_decorator.py +12 -3
- metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
- metaflow/runner/deployer.py +386 -0
- metaflow/runner/metaflow_runner.py +1 -20
- metaflow/runner/nbdeploy.py +130 -0
- metaflow/runner/nbrun.py +4 -28
- metaflow/runner/utils.py +49 -0
- metaflow/runtime.py +246 -134
- metaflow/version.py +1 -1
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/METADATA +2 -2
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/RECORD +39 -32
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/WHEEL +1 -1
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/LICENSE +0 -0
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/entry_points.txt +0 -0
- {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
|
+
import copy
|
1
2
|
import json
|
2
3
|
import math
|
3
4
|
import os
|
4
5
|
import re
|
5
6
|
import shlex
|
6
|
-
import copy
|
7
7
|
import time
|
8
8
|
from typing import Dict, List, Optional
|
9
9
|
from uuid import uuid4
|
@@ -14,10 +14,11 @@ from metaflow.metaflow_config import (
|
|
14
14
|
ARGO_EVENTS_EVENT,
|
15
15
|
ARGO_EVENTS_EVENT_BUS,
|
16
16
|
ARGO_EVENTS_EVENT_SOURCE,
|
17
|
-
ARGO_EVENTS_SERVICE_ACCOUNT,
|
18
17
|
ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
|
19
|
-
|
18
|
+
ARGO_EVENTS_SERVICE_ACCOUNT,
|
20
19
|
ARGO_EVENTS_WEBHOOK_AUTH,
|
20
|
+
AWS_SECRETS_MANAGER_DEFAULT_REGION,
|
21
|
+
AZURE_KEY_VAULT_PREFIX,
|
21
22
|
AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
|
22
23
|
CARD_AZUREROOT,
|
23
24
|
CARD_GSROOT,
|
@@ -31,18 +32,18 @@ from metaflow.metaflow_config import (
|
|
31
32
|
DEFAULT_METADATA,
|
32
33
|
DEFAULT_SECRETS_BACKEND_TYPE,
|
33
34
|
GCP_SECRET_MANAGER_PREFIX,
|
34
|
-
AZURE_KEY_VAULT_PREFIX,
|
35
35
|
KUBERNETES_FETCH_EC2_METADATA,
|
36
36
|
KUBERNETES_LABELS,
|
37
37
|
KUBERNETES_SANDBOX_INIT_SCRIPT,
|
38
|
+
OTEL_ENDPOINT,
|
38
39
|
S3_ENDPOINT_URL,
|
40
|
+
S3_SERVER_SIDE_ENCRYPTION,
|
39
41
|
SERVICE_HEADERS,
|
42
|
+
KUBERNETES_SECRETS,
|
40
43
|
SERVICE_INTERNAL_URL,
|
41
|
-
S3_SERVER_SIDE_ENCRYPTION,
|
42
|
-
OTEL_ENDPOINT,
|
43
44
|
)
|
45
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
44
46
|
from metaflow.metaflow_config_funcs import config_values
|
45
|
-
|
46
47
|
from metaflow.mflog import (
|
47
48
|
BASH_SAVE_LOGS,
|
48
49
|
bash_capture_logs,
|
@@ -60,6 +61,10 @@ STDERR_FILE = "mflog_stderr"
|
|
60
61
|
STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
|
61
62
|
STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
|
62
63
|
|
64
|
+
METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE = (
|
65
|
+
"{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}"
|
66
|
+
)
|
67
|
+
|
63
68
|
|
64
69
|
class KubernetesException(MetaflowException):
|
65
70
|
headline = "Kubernetes error"
|
@@ -69,12 +74,6 @@ class KubernetesKilledException(MetaflowException):
|
|
69
74
|
headline = "Kubernetes Batch job killed"
|
70
75
|
|
71
76
|
|
72
|
-
def _extract_labels_and_annotations_from_job_spec(job_spec):
|
73
|
-
annotations = job_spec.template.metadata.annotations
|
74
|
-
labels = job_spec.template.metadata.labels
|
75
|
-
return copy.copy(annotations), copy.copy(labels)
|
76
|
-
|
77
|
-
|
78
77
|
class Kubernetes(object):
|
79
78
|
def __init__(
|
80
79
|
self,
|
@@ -154,57 +153,287 @@ class Kubernetes(object):
|
|
154
153
|
and kwargs["num_parallel"]
|
155
154
|
and int(kwargs["num_parallel"]) > 0
|
156
155
|
):
|
157
|
-
|
158
|
-
spec = job.create_job_spec()
|
159
|
-
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
|
160
|
-
# This will be modified by the KubernetesJobSet object
|
161
|
-
annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
|
162
|
-
self._job = self.create_jobset(
|
163
|
-
job_spec=spec,
|
164
|
-
run_id=kwargs["run_id"],
|
165
|
-
step_name=kwargs["step_name"],
|
166
|
-
task_id=kwargs["task_id"],
|
167
|
-
namespace=kwargs["namespace"],
|
168
|
-
env=kwargs["env"],
|
169
|
-
num_parallel=kwargs["num_parallel"],
|
170
|
-
port=kwargs["port"],
|
171
|
-
annotations=annotations,
|
172
|
-
labels=labels,
|
173
|
-
).execute()
|
156
|
+
self._job = self.create_jobset(**kwargs).execute()
|
174
157
|
else:
|
158
|
+
kwargs.pop("num_parallel", None)
|
175
159
|
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
|
176
160
|
self._job = self.create_job_object(**kwargs).create().execute()
|
177
161
|
|
178
162
|
def create_jobset(
|
179
163
|
self,
|
180
|
-
|
181
|
-
run_id
|
182
|
-
step_name
|
183
|
-
task_id
|
164
|
+
flow_name,
|
165
|
+
run_id,
|
166
|
+
step_name,
|
167
|
+
task_id,
|
168
|
+
attempt,
|
169
|
+
user,
|
170
|
+
code_package_sha,
|
171
|
+
code_package_url,
|
172
|
+
code_package_ds,
|
173
|
+
docker_image,
|
174
|
+
docker_image_pull_policy,
|
175
|
+
step_cli=None,
|
176
|
+
service_account=None,
|
177
|
+
secrets=None,
|
178
|
+
node_selector=None,
|
184
179
|
namespace=None,
|
180
|
+
cpu=None,
|
181
|
+
gpu=None,
|
182
|
+
gpu_vendor=None,
|
183
|
+
disk=None,
|
184
|
+
memory=None,
|
185
|
+
use_tmpfs=None,
|
186
|
+
tmpfs_tempdir=None,
|
187
|
+
tmpfs_size=None,
|
188
|
+
tmpfs_path=None,
|
189
|
+
run_time_limit=None,
|
185
190
|
env=None,
|
186
|
-
|
187
|
-
|
188
|
-
annotations=None,
|
191
|
+
persistent_volume_claims=None,
|
192
|
+
tolerations=None,
|
189
193
|
labels=None,
|
194
|
+
shared_memory=None,
|
195
|
+
port=None,
|
196
|
+
num_parallel=None,
|
190
197
|
):
|
191
|
-
|
192
|
-
|
198
|
+
name = "js-%s" % str(uuid4())[:6]
|
199
|
+
jobset = (
|
200
|
+
KubernetesClient()
|
201
|
+
.jobset(
|
202
|
+
name=name,
|
203
|
+
namespace=namespace,
|
204
|
+
service_account=service_account,
|
205
|
+
node_selector=node_selector,
|
206
|
+
image=docker_image,
|
207
|
+
image_pull_policy=docker_image_pull_policy,
|
208
|
+
cpu=cpu,
|
209
|
+
memory=memory,
|
210
|
+
disk=disk,
|
211
|
+
gpu=gpu,
|
212
|
+
gpu_vendor=gpu_vendor,
|
213
|
+
timeout_in_seconds=run_time_limit,
|
214
|
+
# Retries are handled by Metaflow runtime
|
215
|
+
retries=0,
|
216
|
+
step_name=step_name,
|
217
|
+
# We set the jobset name as the subdomain.
|
218
|
+
# todo: [final-refactor] ask @shri what was the motive when we did initial implementation
|
219
|
+
subdomain=name,
|
220
|
+
tolerations=tolerations,
|
221
|
+
use_tmpfs=use_tmpfs,
|
222
|
+
tmpfs_tempdir=tmpfs_tempdir,
|
223
|
+
tmpfs_size=tmpfs_size,
|
224
|
+
tmpfs_path=tmpfs_path,
|
225
|
+
persistent_volume_claims=persistent_volume_claims,
|
226
|
+
shared_memory=shared_memory,
|
227
|
+
port=port,
|
228
|
+
num_parallel=num_parallel,
|
229
|
+
)
|
230
|
+
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
|
231
|
+
.environment_variable("METAFLOW_CODE_URL", code_package_url)
|
232
|
+
.environment_variable("METAFLOW_CODE_DS", code_package_ds)
|
233
|
+
.environment_variable("METAFLOW_USER", user)
|
234
|
+
.environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
|
235
|
+
.environment_variable(
|
236
|
+
"METAFLOW_SERVICE_HEADERS",
|
237
|
+
json.dumps(SERVICE_HEADERS),
|
238
|
+
)
|
239
|
+
.environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
|
240
|
+
.environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
|
241
|
+
.environment_variable("METAFLOW_DEFAULT_DATASTORE", self._datastore.TYPE)
|
242
|
+
.environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
|
243
|
+
.environment_variable("METAFLOW_KUBERNETES_WORKLOAD", 1)
|
244
|
+
.environment_variable(
|
245
|
+
"METAFLOW_KUBERNETES_FETCH_EC2_METADATA", KUBERNETES_FETCH_EC2_METADATA
|
246
|
+
)
|
247
|
+
.environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "kubernetes")
|
248
|
+
.environment_variable(
|
249
|
+
"METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
|
250
|
+
)
|
251
|
+
.environment_variable("METAFLOW_CARD_S3ROOT", CARD_S3ROOT)
|
252
|
+
.environment_variable(
|
253
|
+
"METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER
|
254
|
+
)
|
255
|
+
.environment_variable(
|
256
|
+
"METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER
|
257
|
+
)
|
258
|
+
.environment_variable(
|
259
|
+
"METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
|
260
|
+
AWS_SECRETS_MANAGER_DEFAULT_REGION,
|
261
|
+
)
|
262
|
+
.environment_variable(
|
263
|
+
"METAFLOW_GCP_SECRET_MANAGER_PREFIX", GCP_SECRET_MANAGER_PREFIX
|
264
|
+
)
|
265
|
+
.environment_variable(
|
266
|
+
"METAFLOW_AZURE_KEY_VAULT_PREFIX", AZURE_KEY_VAULT_PREFIX
|
267
|
+
)
|
268
|
+
.environment_variable("METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL)
|
269
|
+
.environment_variable(
|
270
|
+
"METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT",
|
271
|
+
AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
|
272
|
+
)
|
273
|
+
.environment_variable(
|
274
|
+
"METAFLOW_DATASTORE_SYSROOT_AZURE", DATASTORE_SYSROOT_AZURE
|
275
|
+
)
|
276
|
+
.environment_variable("METAFLOW_CARD_AZUREROOT", CARD_AZUREROOT)
|
277
|
+
.environment_variable("METAFLOW_DATASTORE_SYSROOT_GS", DATASTORE_SYSROOT_GS)
|
278
|
+
.environment_variable("METAFLOW_CARD_GSROOT", CARD_GSROOT)
|
279
|
+
# support Metaflow sandboxes
|
280
|
+
.environment_variable(
|
281
|
+
"METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
|
282
|
+
)
|
283
|
+
.environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
|
284
|
+
# Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync
|
285
|
+
# between the local user instance and the remote Kubernetes pod
|
286
|
+
# assumes metadata is stored in DATASTORE_LOCAL_DIR on the Kubernetes
|
287
|
+
# pod; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL is NOT set (
|
288
|
+
# see get_datastore_root_from_config in datastore/local.py).
|
289
|
+
)
|
193
290
|
|
194
|
-
|
195
|
-
|
196
|
-
|
291
|
+
_labels = self._get_labels(labels)
|
292
|
+
for k, v in _labels.items():
|
293
|
+
jobset.label(k, v)
|
294
|
+
|
295
|
+
for k in list(
|
296
|
+
[] if not secrets else [secrets] if isinstance(secrets, str) else secrets
|
297
|
+
) + KUBERNETES_SECRETS.split(","):
|
298
|
+
jobset.secret(k)
|
299
|
+
|
300
|
+
jobset.environment_variables_from_selectors(
|
301
|
+
{
|
302
|
+
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
|
303
|
+
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
|
304
|
+
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
|
305
|
+
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
|
306
|
+
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
|
307
|
+
}
|
308
|
+
)
|
309
|
+
|
310
|
+
# Temporary passing of *some* environment variables. Do not rely on this
|
311
|
+
# mechanism as it will be removed in the near future
|
312
|
+
for k, v in config_values():
|
313
|
+
if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
|
314
|
+
jobset.environment_variable(k, v)
|
315
|
+
|
316
|
+
if S3_SERVER_SIDE_ENCRYPTION is not None:
|
317
|
+
jobset.environment_variable(
|
318
|
+
"METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
|
319
|
+
)
|
320
|
+
|
321
|
+
# Set environment variables to support metaflow.integrations.ArgoEvent
|
322
|
+
jobset.environment_variable(
|
323
|
+
"METAFLOW_ARGO_EVENTS_WEBHOOK_URL", ARGO_EVENTS_INTERNAL_WEBHOOK_URL
|
324
|
+
)
|
325
|
+
jobset.environment_variable("METAFLOW_ARGO_EVENTS_EVENT", ARGO_EVENTS_EVENT)
|
326
|
+
jobset.environment_variable(
|
327
|
+
"METAFLOW_ARGO_EVENTS_EVENT_BUS", ARGO_EVENTS_EVENT_BUS
|
328
|
+
)
|
329
|
+
jobset.environment_variable(
|
330
|
+
"METAFLOW_ARGO_EVENTS_EVENT_SOURCE", ARGO_EVENTS_EVENT_SOURCE
|
331
|
+
)
|
332
|
+
jobset.environment_variable(
|
333
|
+
"METAFLOW_ARGO_EVENTS_SERVICE_ACCOUNT", ARGO_EVENTS_SERVICE_ACCOUNT
|
334
|
+
)
|
335
|
+
jobset.environment_variable(
|
336
|
+
"METAFLOW_ARGO_EVENTS_WEBHOOK_AUTH",
|
337
|
+
ARGO_EVENTS_WEBHOOK_AUTH,
|
338
|
+
)
|
339
|
+
|
340
|
+
## -----Jobset specific env vars START here-----
|
341
|
+
jobset.environment_variable("MF_MASTER_ADDR", jobset.jobset_control_addr)
|
342
|
+
jobset.environment_variable("MF_MASTER_PORT", str(port))
|
343
|
+
jobset.environment_variable("MF_WORLD_SIZE", str(num_parallel))
|
344
|
+
jobset.environment_variable_from_selector(
|
345
|
+
"JOBSET_RESTART_ATTEMPT",
|
346
|
+
"metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
|
347
|
+
)
|
348
|
+
jobset.environment_variable_from_selector(
|
349
|
+
"METAFLOW_KUBERNETES_JOBSET_NAME",
|
350
|
+
"metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
|
351
|
+
)
|
352
|
+
jobset.environment_variable_from_selector(
|
353
|
+
"MF_WORKER_REPLICA_INDEX",
|
354
|
+
"metadata.annotations['jobset.sigs.k8s.io/job-index']",
|
355
|
+
)
|
356
|
+
## -----Jobset specific env vars END here-----
|
357
|
+
|
358
|
+
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
|
359
|
+
if tmpfs_enabled and tmpfs_tempdir:
|
360
|
+
jobset.environment_variable("METAFLOW_TEMPDIR", tmpfs_path)
|
361
|
+
|
362
|
+
for name, value in env.items():
|
363
|
+
jobset.environment_variable(name, value)
|
364
|
+
|
365
|
+
annotations = {
|
366
|
+
"metaflow/user": user,
|
367
|
+
"metaflow/flow_name": flow_name,
|
368
|
+
"metaflow/control-task-id": task_id,
|
369
|
+
}
|
370
|
+
if current.get("project_name"):
|
371
|
+
annotations.update(
|
372
|
+
{
|
373
|
+
"metaflow/project_name": current.project_name,
|
374
|
+
"metaflow/branch_name": current.branch_name,
|
375
|
+
"metaflow/project_flow_name": current.project_flow_name,
|
376
|
+
}
|
377
|
+
)
|
378
|
+
|
379
|
+
for name, value in annotations.items():
|
380
|
+
jobset.annotation(name, value)
|
381
|
+
|
382
|
+
(
|
383
|
+
jobset.annotation("metaflow/run_id", run_id)
|
384
|
+
.annotation("metaflow/step_name", step_name)
|
385
|
+
.annotation("metaflow/attempt", attempt)
|
386
|
+
.label("app.kubernetes.io/name", "metaflow-task")
|
387
|
+
.label("app.kubernetes.io/part-of", "metaflow")
|
388
|
+
)
|
389
|
+
|
390
|
+
## ----------- control/worker specific values START here -----------
|
391
|
+
# We will now set the appropriate command for the control/worker job
|
392
|
+
_get_command = lambda index, _tskid: self._command(
|
393
|
+
flow_name=flow_name,
|
197
394
|
run_id=run_id,
|
198
|
-
task_id=task_id,
|
199
395
|
step_name=step_name,
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
396
|
+
task_id=_tskid,
|
397
|
+
attempt=attempt,
|
398
|
+
code_package_url=code_package_url,
|
399
|
+
step_cmds=[
|
400
|
+
step_cli.replace(
|
401
|
+
METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE,
|
402
|
+
"--ubf-context $UBF_CONTEXT --split-index %s --task-id %s"
|
403
|
+
% (index, _tskid),
|
404
|
+
)
|
405
|
+
],
|
406
|
+
)
|
407
|
+
jobset.control.replicas(1)
|
408
|
+
jobset.worker.replicas(num_parallel - 1)
|
409
|
+
|
410
|
+
# We set the appropriate command for the control/worker job
|
411
|
+
# and also set the task-id/spit-index for the control/worker job
|
412
|
+
# appropirately.
|
413
|
+
jobset.control.command(_get_command("0", str(task_id)))
|
414
|
+
jobset.worker.command(
|
415
|
+
_get_command(
|
416
|
+
"`expr $[MF_WORKER_REPLICA_INDEX] + 1`",
|
417
|
+
"-".join(
|
418
|
+
[
|
419
|
+
str(task_id),
|
420
|
+
"worker",
|
421
|
+
"$MF_WORKER_REPLICA_INDEX",
|
422
|
+
]
|
423
|
+
),
|
424
|
+
)
|
206
425
|
)
|
207
|
-
|
426
|
+
|
427
|
+
jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
|
428
|
+
jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
|
429
|
+
# Every control job requires an environment variable of MF_CONTROL_INDEX
|
430
|
+
# set to 0 so that we can derive the MF_PARALLEL_NODE_INDEX correctly.
|
431
|
+
# Since only the control job has MF_CONTROL_INDE set to 0, all worker nodes
|
432
|
+
# will use MF_WORKER_REPLICA_INDEX
|
433
|
+
jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
|
434
|
+
## ----------- control/worker specific values END here -----------
|
435
|
+
|
436
|
+
return jobset
|
208
437
|
|
209
438
|
def create_job_object(
|
210
439
|
self,
|
@@ -241,7 +470,6 @@ class Kubernetes(object):
|
|
241
470
|
shared_memory=None,
|
242
471
|
port=None,
|
243
472
|
name_pattern=None,
|
244
|
-
num_parallel=None,
|
245
473
|
):
|
246
474
|
if env is None:
|
247
475
|
env = {}
|
@@ -282,7 +510,6 @@ class Kubernetes(object):
|
|
282
510
|
persistent_volume_claims=persistent_volume_claims,
|
283
511
|
shared_memory=shared_memory,
|
284
512
|
port=port,
|
285
|
-
num_parallel=num_parallel,
|
286
513
|
)
|
287
514
|
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
|
288
515
|
.environment_variable("METAFLOW_CODE_URL", code_package_url)
|
@@ -3,20 +3,20 @@ import sys
|
|
3
3
|
import time
|
4
4
|
import traceback
|
5
5
|
|
6
|
+
import metaflow.tracing as tracing
|
6
7
|
from metaflow import JSONTypeClass, util
|
7
8
|
from metaflow._vendor import click
|
8
9
|
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
|
9
10
|
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
10
|
-
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
11
11
|
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
|
12
12
|
from metaflow.mflog import TASK_LOG_SOURCE
|
13
|
-
|
13
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
14
14
|
|
15
15
|
from .kubernetes import (
|
16
16
|
Kubernetes,
|
17
|
+
KubernetesException,
|
17
18
|
KubernetesKilledException,
|
18
19
|
parse_kube_keyvalue_list,
|
19
|
-
KubernetesException,
|
20
20
|
)
|
21
21
|
from .kubernetes_decorator import KubernetesDecorator
|
22
22
|
|
@@ -185,8 +185,8 @@ def step(
|
|
185
185
|
|
186
186
|
if num_parallel is not None and num_parallel <= 1:
|
187
187
|
raise KubernetesException(
|
188
|
-
"Using @parallel with `num_parallel` <= 1 is not supported with
|
189
|
-
"Please set the value of `num_parallel` to be greater than 1."
|
188
|
+
"Using @parallel with `num_parallel` <= 1 is not supported with "
|
189
|
+
"@kubernetes. Please set the value of `num_parallel` to be greater than 1."
|
190
190
|
)
|
191
191
|
|
192
192
|
# Set retry policy.
|
@@ -203,19 +203,37 @@ def step(
|
|
203
203
|
)
|
204
204
|
time.sleep(minutes_between_retries * 60)
|
205
205
|
|
206
|
+
# Explicitly Remove `ubf_context` from `kwargs` so that it's not passed as a commandline option
|
207
|
+
# If an underlying step command is executing a vanilla Kubernetes job, then it should never need
|
208
|
+
# to know about the UBF context.
|
209
|
+
# If it is a jobset which is executing a multi-node job, then the UBF context is set based on the
|
210
|
+
# `ubf_context` parameter passed to the jobset.
|
211
|
+
kwargs.pop("ubf_context", None)
|
212
|
+
# `task_id` is also need to be removed from `kwargs` as it needs to be dynamically
|
213
|
+
# set in the downstream code IF num_parallel is > 1
|
214
|
+
task_id = kwargs["task_id"]
|
215
|
+
if num_parallel:
|
216
|
+
kwargs.pop("task_id")
|
217
|
+
|
206
218
|
step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
|
207
219
|
entrypoint="%s -u %s" % (executable, os.path.basename(sys.argv[0])),
|
208
220
|
top_args=" ".join(util.dict_to_cli_options(ctx.parent.parent.params)),
|
209
221
|
step=step_name,
|
210
222
|
step_args=" ".join(util.dict_to_cli_options(kwargs)),
|
211
223
|
)
|
224
|
+
# Since it is a parallel step there are some parts of the step_cli that need to be modified
|
225
|
+
# based on the type of worker in the JobSet. This is why we will create a placeholder string
|
226
|
+
# in the template which will be replaced based on the type of worker.
|
227
|
+
|
228
|
+
if num_parallel:
|
229
|
+
step_cli = "%s {METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}" % step_cli
|
212
230
|
|
213
231
|
# Set log tailing.
|
214
232
|
ds = ctx.obj.flow_datastore.get_task_datastore(
|
215
233
|
mode="w",
|
216
234
|
run_id=kwargs["run_id"],
|
217
235
|
step_name=step_name,
|
218
|
-
task_id=
|
236
|
+
task_id=task_id,
|
219
237
|
attempt=int(retry_count),
|
220
238
|
)
|
221
239
|
stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
|
@@ -229,7 +247,7 @@ def step(
|
|
229
247
|
sync_local_metadata_from_datastore(
|
230
248
|
DATASTORE_LOCAL_DIR,
|
231
249
|
ctx.obj.flow_datastore.get_task_datastore(
|
232
|
-
kwargs["run_id"], step_name,
|
250
|
+
kwargs["run_id"], step_name, task_id
|
233
251
|
),
|
234
252
|
)
|
235
253
|
|
@@ -245,7 +263,7 @@ def step(
|
|
245
263
|
flow_name=ctx.obj.flow.name,
|
246
264
|
run_id=kwargs["run_id"],
|
247
265
|
step_name=step_name,
|
248
|
-
task_id=
|
266
|
+
task_id=task_id,
|
249
267
|
attempt=str(retry_count),
|
250
268
|
user=util.get_username(),
|
251
269
|
code_package_sha=code_package_sha,
|
@@ -12,28 +12,27 @@ from metaflow.metaflow_config import (
|
|
12
12
|
DATASTORE_LOCAL_DIR,
|
13
13
|
KUBERNETES_CONTAINER_IMAGE,
|
14
14
|
KUBERNETES_CONTAINER_REGISTRY,
|
15
|
+
KUBERNETES_CPU,
|
16
|
+
KUBERNETES_DISK,
|
15
17
|
KUBERNETES_FETCH_EC2_METADATA,
|
16
|
-
KUBERNETES_IMAGE_PULL_POLICY,
|
17
18
|
KUBERNETES_GPU_VENDOR,
|
19
|
+
KUBERNETES_IMAGE_PULL_POLICY,
|
20
|
+
KUBERNETES_MEMORY,
|
18
21
|
KUBERNETES_NAMESPACE,
|
19
22
|
KUBERNETES_NODE_SELECTOR,
|
20
23
|
KUBERNETES_PERSISTENT_VOLUME_CLAIMS,
|
21
|
-
|
24
|
+
KUBERNETES_PORT,
|
22
25
|
KUBERNETES_SERVICE_ACCOUNT,
|
23
26
|
KUBERNETES_SHARED_MEMORY,
|
24
|
-
|
25
|
-
KUBERNETES_CPU,
|
26
|
-
KUBERNETES_MEMORY,
|
27
|
-
KUBERNETES_DISK,
|
27
|
+
KUBERNETES_TOLERATIONS,
|
28
28
|
)
|
29
29
|
from metaflow.plugins.resources_decorator import ResourcesDecorator
|
30
30
|
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
|
31
31
|
from metaflow.sidecar import Sidecar
|
32
|
+
from metaflow.unbounded_foreach import UBF_CONTROL
|
32
33
|
|
33
34
|
from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
|
34
35
|
from .kubernetes import KubernetesException, parse_kube_keyvalue_list
|
35
|
-
from metaflow.unbounded_foreach import UBF_CONTROL
|
36
|
-
from .kubernetes_jobsets import TaskIdConstructor
|
37
36
|
|
38
37
|
try:
|
39
38
|
unicode
|
@@ -416,8 +415,8 @@ class KubernetesDecorator(StepDecorator):
|
|
416
415
|
# check for the existence of METAFLOW_KUBERNETES_WORKLOAD environment
|
417
416
|
# variable.
|
418
417
|
|
418
|
+
meta = {}
|
419
419
|
if "METAFLOW_KUBERNETES_WORKLOAD" in os.environ:
|
420
|
-
meta = {}
|
421
420
|
meta["kubernetes-pod-name"] = os.environ["METAFLOW_KUBERNETES_POD_NAME"]
|
422
421
|
meta["kubernetes-pod-namespace"] = os.environ[
|
423
422
|
"METAFLOW_KUBERNETES_POD_NAMESPACE"
|
@@ -427,15 +426,15 @@ class KubernetesDecorator(StepDecorator):
|
|
427
426
|
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME"
|
428
427
|
]
|
429
428
|
meta["kubernetes-node-ip"] = os.environ["METAFLOW_KUBERNETES_NODE_IP"]
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
429
|
+
|
430
|
+
meta["kubernetes-jobset-name"] = os.environ.get(
|
431
|
+
"METAFLOW_KUBERNETES_JOBSET_NAME"
|
432
|
+
)
|
434
433
|
|
435
434
|
# TODO (savin): Introduce equivalent support for Microsoft Azure and
|
436
435
|
# Google Cloud Platform
|
437
|
-
# TODO: Introduce a way to detect Cloud Provider, so unnecessary requests
|
438
|
-
# can be avoided by not having to try out all providers.
|
436
|
+
# TODO: Introduce a way to detect Cloud Provider, so unnecessary requests
|
437
|
+
# (and delays) can be avoided by not having to try out all providers.
|
439
438
|
if KUBERNETES_FETCH_EC2_METADATA:
|
440
439
|
instance_meta = get_ec2_instance_metadata()
|
441
440
|
meta.update(instance_meta)
|
@@ -451,14 +450,6 @@ class KubernetesDecorator(StepDecorator):
|
|
451
450
|
# "METAFLOW_KUBERNETES_POD_NAME"
|
452
451
|
# ].rpartition("-")[0]
|
453
452
|
|
454
|
-
entries = [
|
455
|
-
MetaDatum(field=k, value=v, type=k, tags=[])
|
456
|
-
for k, v in meta.items()
|
457
|
-
if v is not None
|
458
|
-
]
|
459
|
-
# Register book-keeping metadata for debugging.
|
460
|
-
metadata.register_metadata(run_id, step_name, task_id, entries)
|
461
|
-
|
462
453
|
# Start MFLog sidecar to collect task logs.
|
463
454
|
self._save_logs_sidecar = Sidecar("save_logs_periodically")
|
464
455
|
self._save_logs_sidecar.start()
|
@@ -467,19 +458,34 @@ class KubernetesDecorator(StepDecorator):
|
|
467
458
|
if hasattr(flow, "_parallel_ubf_iter"):
|
468
459
|
num_parallel = flow._parallel_ubf_iter.num_parallel
|
469
460
|
|
470
|
-
if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
|
471
|
-
control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
|
472
|
-
num_parallel
|
473
|
-
)
|
474
|
-
mapper_task_ids = [control_task_id] + worker_task_ids
|
475
|
-
flow._control_mapper_tasks = [
|
476
|
-
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
|
477
|
-
for mapper_task_id in mapper_task_ids
|
478
|
-
]
|
479
|
-
flow._control_task_is_mapper_zero = True
|
480
|
-
|
481
461
|
if num_parallel and num_parallel > 1:
|
482
462
|
_setup_multinode_environment()
|
463
|
+
# current.parallel.node_index will be correctly available over here.
|
464
|
+
meta.update({"parallel-node-index": current.parallel.node_index})
|
465
|
+
if ubf_context == UBF_CONTROL:
|
466
|
+
flow._control_mapper_tasks = [
|
467
|
+
"{}/{}/{}".format(run_id, step_name, task_id)
|
468
|
+
for task_id in [task_id]
|
469
|
+
+ [
|
470
|
+
"%s-worker-%d" % (task_id, idx)
|
471
|
+
for idx in range(num_parallel - 1)
|
472
|
+
]
|
473
|
+
]
|
474
|
+
flow._control_task_is_mapper_zero = True
|
475
|
+
|
476
|
+
if len(meta) > 0:
|
477
|
+
entries = [
|
478
|
+
MetaDatum(
|
479
|
+
field=k,
|
480
|
+
value=v,
|
481
|
+
type=k,
|
482
|
+
tags=["attempt_id:{0}".format(retry_count)],
|
483
|
+
)
|
484
|
+
for k, v in meta.items()
|
485
|
+
if v is not None
|
486
|
+
]
|
487
|
+
# Register book-keeping metadata for debugging.
|
488
|
+
metadata.register_metadata(run_id, step_name, task_id, entries)
|
483
489
|
|
484
490
|
def task_finished(
|
485
491
|
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
|
@@ -516,18 +522,24 @@ class KubernetesDecorator(StepDecorator):
|
|
516
522
|
)[0]
|
517
523
|
|
518
524
|
|
525
|
+
# TODO: Unify this method with the multi-node setup in @batch
|
519
526
|
def _setup_multinode_environment():
|
527
|
+
# FIXME: what about MF_MASTER_PORT
|
520
528
|
import socket
|
521
529
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
|
526
|
-
elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
|
527
|
-
os.environ["MF_PARALLEL_NODE_INDEX"] = str(
|
528
|
-
int(os.environ["WORKER_REPLICA_INDEX"]) + 1
|
530
|
+
try:
|
531
|
+
os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(
|
532
|
+
os.environ["MF_MASTER_ADDR"]
|
529
533
|
)
|
530
|
-
|
531
|
-
|
532
|
-
|
534
|
+
os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["MF_WORLD_SIZE"]
|
535
|
+
os.environ["MF_PARALLEL_NODE_INDEX"] = (
|
536
|
+
str(0)
|
537
|
+
if "MF_CONTROL_INDEX" in os.environ
|
538
|
+
else str(int(os.environ["MF_WORKER_REPLICA_INDEX"]) + 1)
|
533
539
|
)
|
540
|
+
except KeyError as e:
|
541
|
+
raise MetaflowException("Environment variable {} is missing.".format(e))
|
542
|
+
except socket.gaierror:
|
543
|
+
raise MetaflowException("Failed to get host by name for MF_MASTER_ADDR.")
|
544
|
+
except ValueError:
|
545
|
+
raise MetaflowException("Invalid value for MF_WORKER_REPLICA_INDEX.")
|