metaflow 2.12.8__py2.py3-none-any.whl → 2.12.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +616 -46
  9. metaflow/plugins/argo/argo_workflows_cli.py +70 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/daemon.py +59 -0
  13. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  14. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  15. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  16. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  17. metaflow/plugins/cards/card_cli.py +1 -1
  18. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  19. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  20. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  21. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  22. metaflow/plugins/kubernetes/kubernetes_job.py +7 -6
  23. metaflow/plugins/kubernetes/kubernetes_jobsets.py +511 -272
  24. metaflow/plugins/parallel_decorator.py +108 -8
  25. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  26. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  27. metaflow/runner/deployer.py +386 -0
  28. metaflow/runner/metaflow_runner.py +1 -20
  29. metaflow/runner/nbdeploy.py +130 -0
  30. metaflow/runner/nbrun.py +4 -28
  31. metaflow/runner/utils.py +49 -0
  32. metaflow/runtime.py +246 -134
  33. metaflow/version.py +1 -1
  34. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/METADATA +2 -2
  35. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/RECORD +39 -32
  36. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/WHEEL +1 -1
  37. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/LICENSE +0 -0
  38. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/entry_points.txt +0 -0
  39. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
+ import copy
1
2
  import json
2
3
  import math
3
4
  import os
4
5
  import re
5
6
  import shlex
6
- import copy
7
7
  import time
8
8
  from typing import Dict, List, Optional
9
9
  from uuid import uuid4
@@ -14,10 +14,11 @@ from metaflow.metaflow_config import (
14
14
  ARGO_EVENTS_EVENT,
15
15
  ARGO_EVENTS_EVENT_BUS,
16
16
  ARGO_EVENTS_EVENT_SOURCE,
17
- ARGO_EVENTS_SERVICE_ACCOUNT,
18
17
  ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
19
- AWS_SECRETS_MANAGER_DEFAULT_REGION,
18
+ ARGO_EVENTS_SERVICE_ACCOUNT,
20
19
  ARGO_EVENTS_WEBHOOK_AUTH,
20
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
21
+ AZURE_KEY_VAULT_PREFIX,
21
22
  AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
22
23
  CARD_AZUREROOT,
23
24
  CARD_GSROOT,
@@ -31,18 +32,18 @@ from metaflow.metaflow_config import (
31
32
  DEFAULT_METADATA,
32
33
  DEFAULT_SECRETS_BACKEND_TYPE,
33
34
  GCP_SECRET_MANAGER_PREFIX,
34
- AZURE_KEY_VAULT_PREFIX,
35
35
  KUBERNETES_FETCH_EC2_METADATA,
36
36
  KUBERNETES_LABELS,
37
37
  KUBERNETES_SANDBOX_INIT_SCRIPT,
38
+ OTEL_ENDPOINT,
38
39
  S3_ENDPOINT_URL,
40
+ S3_SERVER_SIDE_ENCRYPTION,
39
41
  SERVICE_HEADERS,
42
+ KUBERNETES_SECRETS,
40
43
  SERVICE_INTERNAL_URL,
41
- S3_SERVER_SIDE_ENCRYPTION,
42
- OTEL_ENDPOINT,
43
44
  )
45
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
44
46
  from metaflow.metaflow_config_funcs import config_values
45
-
46
47
  from metaflow.mflog import (
47
48
  BASH_SAVE_LOGS,
48
49
  bash_capture_logs,
@@ -60,6 +61,10 @@ STDERR_FILE = "mflog_stderr"
60
61
  STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
61
62
  STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
62
63
 
64
+ METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE = (
65
+ "{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}"
66
+ )
67
+
63
68
 
64
69
  class KubernetesException(MetaflowException):
65
70
  headline = "Kubernetes error"
@@ -69,12 +74,6 @@ class KubernetesKilledException(MetaflowException):
69
74
  headline = "Kubernetes Batch job killed"
70
75
 
71
76
 
72
- def _extract_labels_and_annotations_from_job_spec(job_spec):
73
- annotations = job_spec.template.metadata.annotations
74
- labels = job_spec.template.metadata.labels
75
- return copy.copy(annotations), copy.copy(labels)
76
-
77
-
78
77
  class Kubernetes(object):
79
78
  def __init__(
80
79
  self,
@@ -154,57 +153,287 @@ class Kubernetes(object):
154
153
  and kwargs["num_parallel"]
155
154
  and int(kwargs["num_parallel"]) > 0
156
155
  ):
157
- job = self.create_job_object(**kwargs)
158
- spec = job.create_job_spec()
159
- # `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
160
- # This will be modified by the KubernetesJobSet object
161
- annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
162
- self._job = self.create_jobset(
163
- job_spec=spec,
164
- run_id=kwargs["run_id"],
165
- step_name=kwargs["step_name"],
166
- task_id=kwargs["task_id"],
167
- namespace=kwargs["namespace"],
168
- env=kwargs["env"],
169
- num_parallel=kwargs["num_parallel"],
170
- port=kwargs["port"],
171
- annotations=annotations,
172
- labels=labels,
173
- ).execute()
156
+ self._job = self.create_jobset(**kwargs).execute()
174
157
  else:
158
+ kwargs.pop("num_parallel", None)
175
159
  kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
176
160
  self._job = self.create_job_object(**kwargs).create().execute()
177
161
 
178
162
  def create_jobset(
179
163
  self,
180
- job_spec=None,
181
- run_id=None,
182
- step_name=None,
183
- task_id=None,
164
+ flow_name,
165
+ run_id,
166
+ step_name,
167
+ task_id,
168
+ attempt,
169
+ user,
170
+ code_package_sha,
171
+ code_package_url,
172
+ code_package_ds,
173
+ docker_image,
174
+ docker_image_pull_policy,
175
+ step_cli=None,
176
+ service_account=None,
177
+ secrets=None,
178
+ node_selector=None,
184
179
  namespace=None,
180
+ cpu=None,
181
+ gpu=None,
182
+ gpu_vendor=None,
183
+ disk=None,
184
+ memory=None,
185
+ use_tmpfs=None,
186
+ tmpfs_tempdir=None,
187
+ tmpfs_size=None,
188
+ tmpfs_path=None,
189
+ run_time_limit=None,
185
190
  env=None,
186
- num_parallel=None,
187
- port=None,
188
- annotations=None,
191
+ persistent_volume_claims=None,
192
+ tolerations=None,
189
193
  labels=None,
194
+ shared_memory=None,
195
+ port=None,
196
+ num_parallel=None,
190
197
  ):
191
- if env is None:
192
- env = {}
198
+ name = "js-%s" % str(uuid4())[:6]
199
+ jobset = (
200
+ KubernetesClient()
201
+ .jobset(
202
+ name=name,
203
+ namespace=namespace,
204
+ service_account=service_account,
205
+ node_selector=node_selector,
206
+ image=docker_image,
207
+ image_pull_policy=docker_image_pull_policy,
208
+ cpu=cpu,
209
+ memory=memory,
210
+ disk=disk,
211
+ gpu=gpu,
212
+ gpu_vendor=gpu_vendor,
213
+ timeout_in_seconds=run_time_limit,
214
+ # Retries are handled by Metaflow runtime
215
+ retries=0,
216
+ step_name=step_name,
217
+ # We set the jobset name as the subdomain.
218
+ # todo: [final-refactor] ask @shri what was the motive when we did initial implementation
219
+ subdomain=name,
220
+ tolerations=tolerations,
221
+ use_tmpfs=use_tmpfs,
222
+ tmpfs_tempdir=tmpfs_tempdir,
223
+ tmpfs_size=tmpfs_size,
224
+ tmpfs_path=tmpfs_path,
225
+ persistent_volume_claims=persistent_volume_claims,
226
+ shared_memory=shared_memory,
227
+ port=port,
228
+ num_parallel=num_parallel,
229
+ )
230
+ .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
231
+ .environment_variable("METAFLOW_CODE_URL", code_package_url)
232
+ .environment_variable("METAFLOW_CODE_DS", code_package_ds)
233
+ .environment_variable("METAFLOW_USER", user)
234
+ .environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
235
+ .environment_variable(
236
+ "METAFLOW_SERVICE_HEADERS",
237
+ json.dumps(SERVICE_HEADERS),
238
+ )
239
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
240
+ .environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
241
+ .environment_variable("METAFLOW_DEFAULT_DATASTORE", self._datastore.TYPE)
242
+ .environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
243
+ .environment_variable("METAFLOW_KUBERNETES_WORKLOAD", 1)
244
+ .environment_variable(
245
+ "METAFLOW_KUBERNETES_FETCH_EC2_METADATA", KUBERNETES_FETCH_EC2_METADATA
246
+ )
247
+ .environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "kubernetes")
248
+ .environment_variable(
249
+ "METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
250
+ )
251
+ .environment_variable("METAFLOW_CARD_S3ROOT", CARD_S3ROOT)
252
+ .environment_variable(
253
+ "METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER
254
+ )
255
+ .environment_variable(
256
+ "METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER
257
+ )
258
+ .environment_variable(
259
+ "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
260
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
261
+ )
262
+ .environment_variable(
263
+ "METAFLOW_GCP_SECRET_MANAGER_PREFIX", GCP_SECRET_MANAGER_PREFIX
264
+ )
265
+ .environment_variable(
266
+ "METAFLOW_AZURE_KEY_VAULT_PREFIX", AZURE_KEY_VAULT_PREFIX
267
+ )
268
+ .environment_variable("METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL)
269
+ .environment_variable(
270
+ "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT",
271
+ AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
272
+ )
273
+ .environment_variable(
274
+ "METAFLOW_DATASTORE_SYSROOT_AZURE", DATASTORE_SYSROOT_AZURE
275
+ )
276
+ .environment_variable("METAFLOW_CARD_AZUREROOT", CARD_AZUREROOT)
277
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_GS", DATASTORE_SYSROOT_GS)
278
+ .environment_variable("METAFLOW_CARD_GSROOT", CARD_GSROOT)
279
+ # support Metaflow sandboxes
280
+ .environment_variable(
281
+ "METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
282
+ )
283
+ .environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
284
+ # Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync
285
+ # between the local user instance and the remote Kubernetes pod
286
+ # assumes metadata is stored in DATASTORE_LOCAL_DIR on the Kubernetes
287
+ # pod; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL is NOT set (
288
+ # see get_datastore_root_from_config in datastore/local.py).
289
+ )
193
290
 
194
- _prefix = str(uuid4())[:6]
195
- js = KubernetesClient().jobset(
196
- name="js-%s" % _prefix,
291
+ _labels = self._get_labels(labels)
292
+ for k, v in _labels.items():
293
+ jobset.label(k, v)
294
+
295
+ for k in list(
296
+ [] if not secrets else [secrets] if isinstance(secrets, str) else secrets
297
+ ) + KUBERNETES_SECRETS.split(","):
298
+ jobset.secret(k)
299
+
300
+ jobset.environment_variables_from_selectors(
301
+ {
302
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
303
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
304
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
305
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
306
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
307
+ }
308
+ )
309
+
310
+ # Temporary passing of *some* environment variables. Do not rely on this
311
+ # mechanism as it will be removed in the near future
312
+ for k, v in config_values():
313
+ if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
314
+ jobset.environment_variable(k, v)
315
+
316
+ if S3_SERVER_SIDE_ENCRYPTION is not None:
317
+ jobset.environment_variable(
318
+ "METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
319
+ )
320
+
321
+ # Set environment variables to support metaflow.integrations.ArgoEvent
322
+ jobset.environment_variable(
323
+ "METAFLOW_ARGO_EVENTS_WEBHOOK_URL", ARGO_EVENTS_INTERNAL_WEBHOOK_URL
324
+ )
325
+ jobset.environment_variable("METAFLOW_ARGO_EVENTS_EVENT", ARGO_EVENTS_EVENT)
326
+ jobset.environment_variable(
327
+ "METAFLOW_ARGO_EVENTS_EVENT_BUS", ARGO_EVENTS_EVENT_BUS
328
+ )
329
+ jobset.environment_variable(
330
+ "METAFLOW_ARGO_EVENTS_EVENT_SOURCE", ARGO_EVENTS_EVENT_SOURCE
331
+ )
332
+ jobset.environment_variable(
333
+ "METAFLOW_ARGO_EVENTS_SERVICE_ACCOUNT", ARGO_EVENTS_SERVICE_ACCOUNT
334
+ )
335
+ jobset.environment_variable(
336
+ "METAFLOW_ARGO_EVENTS_WEBHOOK_AUTH",
337
+ ARGO_EVENTS_WEBHOOK_AUTH,
338
+ )
339
+
340
+ ## -----Jobset specific env vars START here-----
341
+ jobset.environment_variable("MF_MASTER_ADDR", jobset.jobset_control_addr)
342
+ jobset.environment_variable("MF_MASTER_PORT", str(port))
343
+ jobset.environment_variable("MF_WORLD_SIZE", str(num_parallel))
344
+ jobset.environment_variable_from_selector(
345
+ "JOBSET_RESTART_ATTEMPT",
346
+ "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
347
+ )
348
+ jobset.environment_variable_from_selector(
349
+ "METAFLOW_KUBERNETES_JOBSET_NAME",
350
+ "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
351
+ )
352
+ jobset.environment_variable_from_selector(
353
+ "MF_WORKER_REPLICA_INDEX",
354
+ "metadata.annotations['jobset.sigs.k8s.io/job-index']",
355
+ )
356
+ ## -----Jobset specific env vars END here-----
357
+
358
+ tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
359
+ if tmpfs_enabled and tmpfs_tempdir:
360
+ jobset.environment_variable("METAFLOW_TEMPDIR", tmpfs_path)
361
+
362
+ for name, value in env.items():
363
+ jobset.environment_variable(name, value)
364
+
365
+ annotations = {
366
+ "metaflow/user": user,
367
+ "metaflow/flow_name": flow_name,
368
+ "metaflow/control-task-id": task_id,
369
+ }
370
+ if current.get("project_name"):
371
+ annotations.update(
372
+ {
373
+ "metaflow/project_name": current.project_name,
374
+ "metaflow/branch_name": current.branch_name,
375
+ "metaflow/project_flow_name": current.project_flow_name,
376
+ }
377
+ )
378
+
379
+ for name, value in annotations.items():
380
+ jobset.annotation(name, value)
381
+
382
+ (
383
+ jobset.annotation("metaflow/run_id", run_id)
384
+ .annotation("metaflow/step_name", step_name)
385
+ .annotation("metaflow/attempt", attempt)
386
+ .label("app.kubernetes.io/name", "metaflow-task")
387
+ .label("app.kubernetes.io/part-of", "metaflow")
388
+ )
389
+
390
+ ## ----------- control/worker specific values START here -----------
391
+ # We will now set the appropriate command for the control/worker job
392
+ _get_command = lambda index, _tskid: self._command(
393
+ flow_name=flow_name,
197
394
  run_id=run_id,
198
- task_id=task_id,
199
395
  step_name=step_name,
200
- namespace=namespace,
201
- labels=self._get_labels(labels),
202
- annotations=annotations,
203
- num_parallel=num_parallel,
204
- job_spec=job_spec,
205
- port=port,
396
+ task_id=_tskid,
397
+ attempt=attempt,
398
+ code_package_url=code_package_url,
399
+ step_cmds=[
400
+ step_cli.replace(
401
+ METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE,
402
+ "--ubf-context $UBF_CONTEXT --split-index %s --task-id %s"
403
+ % (index, _tskid),
404
+ )
405
+ ],
406
+ )
407
+ jobset.control.replicas(1)
408
+ jobset.worker.replicas(num_parallel - 1)
409
+
410
+ # We set the appropriate command for the control/worker job
411
+ # and also set the task-id/spit-index for the control/worker job
412
+ # appropirately.
413
+ jobset.control.command(_get_command("0", str(task_id)))
414
+ jobset.worker.command(
415
+ _get_command(
416
+ "`expr $[MF_WORKER_REPLICA_INDEX] + 1`",
417
+ "-".join(
418
+ [
419
+ str(task_id),
420
+ "worker",
421
+ "$MF_WORKER_REPLICA_INDEX",
422
+ ]
423
+ ),
424
+ )
206
425
  )
207
- return js
426
+
427
+ jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
428
+ jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
429
+ # Every control job requires an environment variable of MF_CONTROL_INDEX
430
+ # set to 0 so that we can derive the MF_PARALLEL_NODE_INDEX correctly.
431
+ # Since only the control job has MF_CONTROL_INDE set to 0, all worker nodes
432
+ # will use MF_WORKER_REPLICA_INDEX
433
+ jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
434
+ ## ----------- control/worker specific values END here -----------
435
+
436
+ return jobset
208
437
 
209
438
  def create_job_object(
210
439
  self,
@@ -241,7 +470,6 @@ class Kubernetes(object):
241
470
  shared_memory=None,
242
471
  port=None,
243
472
  name_pattern=None,
244
- num_parallel=None,
245
473
  ):
246
474
  if env is None:
247
475
  env = {}
@@ -282,7 +510,6 @@ class Kubernetes(object):
282
510
  persistent_volume_claims=persistent_volume_claims,
283
511
  shared_memory=shared_memory,
284
512
  port=port,
285
- num_parallel=num_parallel,
286
513
  )
287
514
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
288
515
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -3,20 +3,20 @@ import sys
3
3
  import time
4
4
  import traceback
5
5
 
6
+ import metaflow.tracing as tracing
6
7
  from metaflow import JSONTypeClass, util
7
8
  from metaflow._vendor import click
8
9
  from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
9
10
  from metaflow.metadata.util import sync_local_metadata_from_datastore
10
- from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
11
11
  from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
12
12
  from metaflow.mflog import TASK_LOG_SOURCE
13
- import metaflow.tracing as tracing
13
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
14
14
 
15
15
  from .kubernetes import (
16
16
  Kubernetes,
17
+ KubernetesException,
17
18
  KubernetesKilledException,
18
19
  parse_kube_keyvalue_list,
19
- KubernetesException,
20
20
  )
21
21
  from .kubernetes_decorator import KubernetesDecorator
22
22
 
@@ -185,8 +185,8 @@ def step(
185
185
 
186
186
  if num_parallel is not None and num_parallel <= 1:
187
187
  raise KubernetesException(
188
- "Using @parallel with `num_parallel` <= 1 is not supported with Kubernetes. "
189
- "Please set the value of `num_parallel` to be greater than 1."
188
+ "Using @parallel with `num_parallel` <= 1 is not supported with "
189
+ "@kubernetes. Please set the value of `num_parallel` to be greater than 1."
190
190
  )
191
191
 
192
192
  # Set retry policy.
@@ -203,19 +203,37 @@ def step(
203
203
  )
204
204
  time.sleep(minutes_between_retries * 60)
205
205
 
206
+ # Explicitly Remove `ubf_context` from `kwargs` so that it's not passed as a commandline option
207
+ # If an underlying step command is executing a vanilla Kubernetes job, then it should never need
208
+ # to know about the UBF context.
209
+ # If it is a jobset which is executing a multi-node job, then the UBF context is set based on the
210
+ # `ubf_context` parameter passed to the jobset.
211
+ kwargs.pop("ubf_context", None)
212
+ # `task_id` is also need to be removed from `kwargs` as it needs to be dynamically
213
+ # set in the downstream code IF num_parallel is > 1
214
+ task_id = kwargs["task_id"]
215
+ if num_parallel:
216
+ kwargs.pop("task_id")
217
+
206
218
  step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
207
219
  entrypoint="%s -u %s" % (executable, os.path.basename(sys.argv[0])),
208
220
  top_args=" ".join(util.dict_to_cli_options(ctx.parent.parent.params)),
209
221
  step=step_name,
210
222
  step_args=" ".join(util.dict_to_cli_options(kwargs)),
211
223
  )
224
+ # Since it is a parallel step there are some parts of the step_cli that need to be modified
225
+ # based on the type of worker in the JobSet. This is why we will create a placeholder string
226
+ # in the template which will be replaced based on the type of worker.
227
+
228
+ if num_parallel:
229
+ step_cli = "%s {METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}" % step_cli
212
230
 
213
231
  # Set log tailing.
214
232
  ds = ctx.obj.flow_datastore.get_task_datastore(
215
233
  mode="w",
216
234
  run_id=kwargs["run_id"],
217
235
  step_name=step_name,
218
- task_id=kwargs["task_id"],
236
+ task_id=task_id,
219
237
  attempt=int(retry_count),
220
238
  )
221
239
  stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
@@ -229,7 +247,7 @@ def step(
229
247
  sync_local_metadata_from_datastore(
230
248
  DATASTORE_LOCAL_DIR,
231
249
  ctx.obj.flow_datastore.get_task_datastore(
232
- kwargs["run_id"], step_name, kwargs["task_id"]
250
+ kwargs["run_id"], step_name, task_id
233
251
  ),
234
252
  )
235
253
 
@@ -245,7 +263,7 @@ def step(
245
263
  flow_name=ctx.obj.flow.name,
246
264
  run_id=kwargs["run_id"],
247
265
  step_name=step_name,
248
- task_id=kwargs["task_id"],
266
+ task_id=task_id,
249
267
  attempt=str(retry_count),
250
268
  user=util.get_username(),
251
269
  code_package_sha=code_package_sha,
@@ -6,7 +6,6 @@ from metaflow.exception import MetaflowException
6
6
 
7
7
  from .kubernetes_job import KubernetesJob, KubernetesJobSet
8
8
 
9
-
10
9
  CLIENT_REFRESH_INTERVAL_SECONDS = 300
11
10
 
12
11
 
@@ -12,28 +12,27 @@ from metaflow.metaflow_config import (
12
12
  DATASTORE_LOCAL_DIR,
13
13
  KUBERNETES_CONTAINER_IMAGE,
14
14
  KUBERNETES_CONTAINER_REGISTRY,
15
+ KUBERNETES_CPU,
16
+ KUBERNETES_DISK,
15
17
  KUBERNETES_FETCH_EC2_METADATA,
16
- KUBERNETES_IMAGE_PULL_POLICY,
17
18
  KUBERNETES_GPU_VENDOR,
19
+ KUBERNETES_IMAGE_PULL_POLICY,
20
+ KUBERNETES_MEMORY,
18
21
  KUBERNETES_NAMESPACE,
19
22
  KUBERNETES_NODE_SELECTOR,
20
23
  KUBERNETES_PERSISTENT_VOLUME_CLAIMS,
21
- KUBERNETES_TOLERATIONS,
24
+ KUBERNETES_PORT,
22
25
  KUBERNETES_SERVICE_ACCOUNT,
23
26
  KUBERNETES_SHARED_MEMORY,
24
- KUBERNETES_PORT,
25
- KUBERNETES_CPU,
26
- KUBERNETES_MEMORY,
27
- KUBERNETES_DISK,
27
+ KUBERNETES_TOLERATIONS,
28
28
  )
29
29
  from metaflow.plugins.resources_decorator import ResourcesDecorator
30
30
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
31
31
  from metaflow.sidecar import Sidecar
32
+ from metaflow.unbounded_foreach import UBF_CONTROL
32
33
 
33
34
  from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
34
35
  from .kubernetes import KubernetesException, parse_kube_keyvalue_list
35
- from metaflow.unbounded_foreach import UBF_CONTROL
36
- from .kubernetes_jobsets import TaskIdConstructor
37
36
 
38
37
  try:
39
38
  unicode
@@ -416,8 +415,8 @@ class KubernetesDecorator(StepDecorator):
416
415
  # check for the existence of METAFLOW_KUBERNETES_WORKLOAD environment
417
416
  # variable.
418
417
 
418
+ meta = {}
419
419
  if "METAFLOW_KUBERNETES_WORKLOAD" in os.environ:
420
- meta = {}
421
420
  meta["kubernetes-pod-name"] = os.environ["METAFLOW_KUBERNETES_POD_NAME"]
422
421
  meta["kubernetes-pod-namespace"] = os.environ[
423
422
  "METAFLOW_KUBERNETES_POD_NAMESPACE"
@@ -427,15 +426,15 @@ class KubernetesDecorator(StepDecorator):
427
426
  "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME"
428
427
  ]
429
428
  meta["kubernetes-node-ip"] = os.environ["METAFLOW_KUBERNETES_NODE_IP"]
430
- if os.environ.get("METAFLOW_KUBERNETES_JOBSET_NAME"):
431
- meta["kubernetes-jobset-name"] = os.environ[
432
- "METAFLOW_KUBERNETES_JOBSET_NAME"
433
- ]
429
+
430
+ meta["kubernetes-jobset-name"] = os.environ.get(
431
+ "METAFLOW_KUBERNETES_JOBSET_NAME"
432
+ )
434
433
 
435
434
  # TODO (savin): Introduce equivalent support for Microsoft Azure and
436
435
  # Google Cloud Platform
437
- # TODO: Introduce a way to detect Cloud Provider, so unnecessary requests (and delays)
438
- # can be avoided by not having to try out all providers.
436
+ # TODO: Introduce a way to detect Cloud Provider, so unnecessary requests
437
+ # (and delays) can be avoided by not having to try out all providers.
439
438
  if KUBERNETES_FETCH_EC2_METADATA:
440
439
  instance_meta = get_ec2_instance_metadata()
441
440
  meta.update(instance_meta)
@@ -451,14 +450,6 @@ class KubernetesDecorator(StepDecorator):
451
450
  # "METAFLOW_KUBERNETES_POD_NAME"
452
451
  # ].rpartition("-")[0]
453
452
 
454
- entries = [
455
- MetaDatum(field=k, value=v, type=k, tags=[])
456
- for k, v in meta.items()
457
- if v is not None
458
- ]
459
- # Register book-keeping metadata for debugging.
460
- metadata.register_metadata(run_id, step_name, task_id, entries)
461
-
462
453
  # Start MFLog sidecar to collect task logs.
463
454
  self._save_logs_sidecar = Sidecar("save_logs_periodically")
464
455
  self._save_logs_sidecar.start()
@@ -467,19 +458,34 @@ class KubernetesDecorator(StepDecorator):
467
458
  if hasattr(flow, "_parallel_ubf_iter"):
468
459
  num_parallel = flow._parallel_ubf_iter.num_parallel
469
460
 
470
- if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
471
- control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
472
- num_parallel
473
- )
474
- mapper_task_ids = [control_task_id] + worker_task_ids
475
- flow._control_mapper_tasks = [
476
- "%s/%s/%s" % (run_id, step_name, mapper_task_id)
477
- for mapper_task_id in mapper_task_ids
478
- ]
479
- flow._control_task_is_mapper_zero = True
480
-
481
461
  if num_parallel and num_parallel > 1:
482
462
  _setup_multinode_environment()
463
+ # current.parallel.node_index will be correctly available over here.
464
+ meta.update({"parallel-node-index": current.parallel.node_index})
465
+ if ubf_context == UBF_CONTROL:
466
+ flow._control_mapper_tasks = [
467
+ "{}/{}/{}".format(run_id, step_name, task_id)
468
+ for task_id in [task_id]
469
+ + [
470
+ "%s-worker-%d" % (task_id, idx)
471
+ for idx in range(num_parallel - 1)
472
+ ]
473
+ ]
474
+ flow._control_task_is_mapper_zero = True
475
+
476
+ if len(meta) > 0:
477
+ entries = [
478
+ MetaDatum(
479
+ field=k,
480
+ value=v,
481
+ type=k,
482
+ tags=["attempt_id:{0}".format(retry_count)],
483
+ )
484
+ for k, v in meta.items()
485
+ if v is not None
486
+ ]
487
+ # Register book-keeping metadata for debugging.
488
+ metadata.register_metadata(run_id, step_name, task_id, entries)
483
489
 
484
490
  def task_finished(
485
491
  self, step_name, flow, graph, is_task_ok, retry_count, max_retries
@@ -516,18 +522,24 @@ class KubernetesDecorator(StepDecorator):
516
522
  )[0]
517
523
 
518
524
 
525
+ # TODO: Unify this method with the multi-node setup in @batch
519
526
  def _setup_multinode_environment():
527
+ # FIXME: what about MF_MASTER_PORT
520
528
  import socket
521
529
 
522
- os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
523
- os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
524
- if os.environ.get("CONTROL_INDEX") is not None:
525
- os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
526
- elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
527
- os.environ["MF_PARALLEL_NODE_INDEX"] = str(
528
- int(os.environ["WORKER_REPLICA_INDEX"]) + 1
530
+ try:
531
+ os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(
532
+ os.environ["MF_MASTER_ADDR"]
529
533
  )
530
- else:
531
- raise MetaflowException(
532
- "Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found"
534
+ os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["MF_WORLD_SIZE"]
535
+ os.environ["MF_PARALLEL_NODE_INDEX"] = (
536
+ str(0)
537
+ if "MF_CONTROL_INDEX" in os.environ
538
+ else str(int(os.environ["MF_WORKER_REPLICA_INDEX"]) + 1)
533
539
  )
540
+ except KeyError as e:
541
+ raise MetaflowException("Environment variable {} is missing.".format(e))
542
+ except socket.gaierror:
543
+ raise MetaflowException("Failed to get host by name for MF_MASTER_ADDR.")
544
+ except ValueError:
545
+ raise MetaflowException("Invalid value for MF_WORKER_REPLICA_INDEX.")